# Vision Transformer in Federated Learning - Setup Guide

This notebook will guide you through the process of setting up and running the ViT-FL model based on the paper ["Rethinking Architecture Design for Tackling Data Heterogeneity in Federated Learning"](https://arxiv.org/abs/2106.06047).

We'll go through the following steps:
1. Installing Required Dependencies
2. Downloading and Preparing the Dataset
3. Setting up Pre-trained Models
4. Running the Model Training

Let's get started!


In [1]:
import torch
print(torch.cuda.is_available())

True


In [1]:
# import os

# # Check if requirements.txt exists
# if os.path.exists('requirements.txt'):
#     print("requirements.txt found. Installing dependencies...")
#     %pip install -r requirements.txt
# else:
#     print("requirements.txt not found. Please make sure you're in the correct directory.")


requirements.txt found. Installing dependencies...


## 2. Downloading and Preparing the Dataset

For this guide, we'll use the CIFAR-10 dataset. You need to:

1. Download the data partitions from [Google Drive](https://drive.google.com/drive/folders/1ZErR7RMSVImkzYzz0hLl25f9agJwp0Zx?usp=sharing)
2. Place the downloaded `cifar10.npy` file in the `data` subdirectory

Let's create the data directory if it doesn't exist and check for the dataset:


In [1]:
# import os

# # Create data directory if it doesn't exist
# if not os.path.exists('data'):
#     os.makedirs('data')
#     print("Created 'data' directory")

# # Check if the dataset file exists
# if os.path.exists('data/cifar10.npy'):
#     print("CIFAR-10 dataset found!")
# else:
#     print("Please download cifar10.npy from the Google Drive link and place it in the 'data' directory")


CIFAR-10 dataset found!


## 3. Setting up Pre-trained Models

For ViTs, we need to modify the pre-trained model loading links in the timm library. The models we'll be using are pre-trained on ImageNet1K. Here are the steps:

1. Locate the `vision_transformer.py` file in your timm installation
2. Modify the `default_cfgs` dictionary with the correct URLs for the pre-trained models

For this example, we'll use ViT-small. The URL should be:
```python
'vit_small_patch16_224': _cfg(
    url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz')
```

Let's check the timm installation and locate the file:


In [1]:
import timm
import os

# Get timm installation directory
timm_dir = os.path.dirname(timm.__file__)
vit_file = os.path.join(timm_dir, 'models', 'vision_transformer.py')

if os.path.exists(vit_file):
    print(f"Found vision_transformer.py at: {vit_file}")
    print("\nPlease modify the default_cfgs dictionary in this file with the correct URL for ViT-small")
else:
    print("Could not find vision_transformer.py. Please check your timm installation")


Found vision_transformer.py at: d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\venv\lib\site-packages\timm\models\vision_transformer.py

Please modify the default_cfgs dictionary in this file with the correct URL for ViT-small


## 4. Running the Model Training

Now that we have everything set up, we can run the model training. We'll use the ViT-CWT implementation with the following configuration:
- Dataset: CIFAR-10
- Split type: split_2
- Network: ViT-small
- Local epochs: 1
- Communication rounds: 100

Here's the command to run the training:


In [2]:
# Uninstall current timm version and install the compatible version
print("Installing compatible version of timm...")
# %pip uninstall -y timm
# %pip install timm==0.5.4  # This version should have _pil_interp

# Verify timm version
import timm
print(f"\nInstalled timm version: {timm.__version__}")

# Check if _pil_interp is available
try:
    from timm.data.transforms import _pil_interp
    print("Successfully imported _pil_interp from timm.data.transforms!")
except ImportError as e:
    print(f"Error importing _pil_interp: {e}")


Installing compatible version of timm...

Installed timm version: 0.3.2
Successfully imported _pil_interp from timm.data.transforms!


In [None]:
# Try running the training command again with the compatible timm version
import subprocess
import sys

python_executable = sys.executable
cmd = f'"{python_executable}" train_CWT.py --FL_platform ViT-CWT --net_name ViT-tiny --dataset cifar10 --E_epoch 1 --max_communication_rounds 100 --split_type split_2 --save_model_flag'

print("Running training command with explicit Python path:")
print(cmd)
print("\nExecuting the command...")

try:
    # Run the command
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
    
    # Print output in real-time
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
            
    # Print any errors
    stderr = process.stderr.read()
    if stderr:
        print("\nErrors:")
        print(stderr)
        
except Exception as e:
    print(f"An error occurred: {str(e)}")


Running training command with explicit Python path:
"d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\venv\Scripts\python.exe" train_CWT.py --FL_platform ViT-CWT --net_name ViT-tiny --dataset cifar10 --E_epoch 1 --max_communication_rounds 100 --split_type split_2 --save_model_flag

Executing the command...
We use ViT tiny
sgd
++++++++++++++++ Other Train related parameters ++++++++++++++++
E_epoch: 1
FL_platform: ViT-CWT
Pretrained: True
batch_size: 32
cfg: configs/swin_tiny_patch4_window7_224.yaml
data_path: ./data/
dataset: cifar10
decay_type: cosine
device: cuda:0
gpu_ids: 0
grad_clip: True
img_size: 224
learning_rate: 0.003
max_communication_rounds: 100
max_grad_norm: 1.0
name: ViT-tiny_split_2_lr_0.003_Pretrained_True_optimizer_sgd_WUP_500_Round_100_Eepochs_1_Seed_42
net_name: ViT-tiny
num_classes: 10
num_workers: 4
optimizer_type: sgd
output_dir: output\ViT-CWT\cifar10\ViT-tiny_split_2_lr_0.003_Pretrained_True_optimizer_sgd_WUP_500_Round_100_Eepochs_1

In [12]:
# Let's verify our Python environment and numpy installation
import sys
print("Python executable:", sys.executable)
print("\nChecking if numpy is installed correctly...")
try:
    import numpy
    print(f"Numpy is installed! Version: {numpy.__version__}")
except ImportError:
    print("Numpy is not installed in the current environment. Let's install it...")
    %pip install numpy --upgrade


Python executable: d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\.venv\Scripts\python.exe

Checking if numpy is installed correctly...
Numpy is installed! Version: 2.2.6


In [15]:
# Let's modify the training command to use the same Python interpreter as our notebook
python_executable = sys.executable
cmd = f'"{python_executable}" train_CWT.py --FL_platform ViT-CWT --net_name ViT-small --dataset cifar10 --E_epoch 1 --max_communication_rounds 100 --split_type split_2 --save_model_flag'

print("Running training command with explicit Python path:")
print(cmd)
print("\nExecuting the command...")

try:
    # Run the command
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
    
    # Print output in real-time
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
            
    # Print any errors
    stderr = process.stderr.read()
    if stderr:
        print("\nErrors:")
        print(stderr)
        
except Exception as e:
    print(f"An error occurred: {str(e)}")


Running training command with explicit Python path:
"d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\.venv\Scripts\python.exe" train_CWT.py --FL_platform ViT-CWT --net_name ViT-small --dataset cifar10 --E_epoch 1 --max_communication_rounds 100 --split_type split_2 --save_model_flag

Executing the command...

Errors:
Traceback (most recent call last):
  File [35m"d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\train_CWT.py"[0m, line [35m16[0m, in [35m<module>[0m
    from utils.data_utils import DatasetFLViT, create_dataset_and_evalmetrix
  File [35m"d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\utils\data_utils.py"[0m, line [35m10[0m, in [35m<module>[0m
    from timm.data.transforms import _pil_interp
[1;35mImportError[0m: [35mcannot import name '_pil_interp' from 'timm.data.transforms' (d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\.venv\Lib\site-pa

In [9]:
import subprocess

# Command to run the training
cmd = "python train_CWT.py --FL_platform ViT-CWT --net_name ViT-small --dataset cifar10 --E_epoch 1 --max_communication_rounds 5 --split_type split_2 --save_model_flag"

print("Running training command:")
print(cmd)
print("\nExecuting the command...")

try:
    # Run the command
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    
    # Print output in real-time
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
            
    # Print any errors
    stderr = process.stderr.read()
    if stderr:
        print("\nErrors:")
        print(stderr)
        
except Exception as e:
    print(f"An error occurred: {str(e)}")


Running training command:
python3 train_CWT.py --FL_platform ViT-CWT --net_name ViT-small --dataset cifar10 --E_epoch 1 --max_communication_rounds 100 --split_type split_2 --save_model_flag

Executing the command...

Errors:
Traceback (most recent call last):
  File [35m"d:\lucas\5 - cadeiras\1 periodo - Mestrado\Visao computacional\ViT-FL-FedBABU\train_CWT.py"[0m, line [35m7[0m, in [35m<module>[0m
    import numpy as np
[1;35mModuleNotFoundError[0m: [35mNo module named 'numpy'[0m



## Notes and Troubleshooting

1. All checkpoints, results, and log files will be saved to the `output_dir` folder
2. The final performance will be saved in `log_file.txt`
3. If you encounter any errors:
   - Make sure all dependencies are properly installed
   - Verify that the CIFAR-10 dataset is in the correct location
   - Check that the pre-trained model URLs are properly configured in the timm library
   - Ensure you have sufficient disk space for model checkpoints and results

You can also try the FedAVG implementation using a similar command:
```python
python train_FedAVG.py --FL_platform ViT-FedAVG --net_name ViT-small --dataset cifar10 --E_epoch 1 --max_communication_rounds 100 --num_local_clients -1 --split_type split_2 --save_model_flag
```
