# Federated learning with YOLOv8

## Before you start

Let's make sure that you have access to a GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Runtime` -> `Change Runtime Type` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

In [1]:
!nvidia-smi

Wed Jul 17 16:46:08 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Install the ultralytics library

In [None]:
!pip install ultralytics

## Download the datasets

Download the datasets from GitHub

In [None]:
%cd /content
!git clone https://github.com/losor2002/FederatedLearningYOLOv8.git
!mv FederatedLearningYOLOv8/datasets .
!rm -r FederatedLearningYOLOv8

## Train the model

Configure the run and train the model

In [4]:
%cd /content

from ultralytics import YOLO
import copy
import torch
from numpy import random
import os

# Configure the run
PROJECT = 'test'
DATASET = 'augmented1000'
GLOBAL_MODEL_EPOCHS = 5
FEDERATED = False
GLOBAL_FEDERATED_EPOCHS = 5
CLIENT_EPOCHS = 20
TOTAL_FEDERATED_CLIENTS = 5
ACTIVE_FEDERATED_CLIENTS = 3
BEST_MAPS = 3
GLOBAL_AVG_EPOCHS = 1

# Define the averaging function
def average_weights(w):
  w_avg = copy.deepcopy(w[0])
  for key in w_avg.keys():
      for i in range(1, len(w)):
          w_avg[key] += w[i][key]
      w_avg[key] = torch.div(w_avg[key], len(w))
  return w_avg

def main():
  # Check if the project already exists
  if os.path.exists('/content/runs/' + PROJECT):
    raise Exception(f'Project {PROJECT} already exists, choose a new project name')

  # Print the run configuration
  print('PROJECT : ' + PROJECT)
  print('DATASET : ' + DATASET)
  print('GLOBAL_MODEL_EPOCHS : ' + str(GLOBAL_MODEL_EPOCHS))
  print('FEDERATED : ' + str(FEDERATED))
  print('GLOBAL_FEDERATED_EPOCHS : ' + str(GLOBAL_FEDERATED_EPOCHS))
  print('CLIENT_EPOCHS : ' + str(CLIENT_EPOCHS))
  print('TOTAL_FEDERATED_CLIENTS : ' + str(TOTAL_FEDERATED_CLIENTS))
  print('ACTIVE_FEDERATED_CLIENTS : ' + str(ACTIVE_FEDERATED_CLIENTS))
  print('BEST_MAPS : ' + str(BEST_MAPS))
  print('GLOBAL_AVG_EPOCHS : ' + str(GLOBAL_AVG_EPOCHS))

  # Initialize the global model YOLOv8 for object detection
  global_model = YOLO('yolov8n.pt')

  # Train the global model
  results = global_model.train(data=f'datasets/{DATASET}/global/data.yaml',
                               epochs=GLOBAL_MODEL_EPOCHS, name='global0',
                               project=f'runs/{PROJECT}')

  # Print the mAP
  print('global0 mAP = ' + str(results.box.map))

  if not FEDERATED:
    return

  # Initialize the random number generator
  rng = random.default_rng()

  #Start the federated learning
  for epoch in range(1, GLOBAL_FEDERATED_EPOCHS + 1):
    print(f'\n | Global Training Round : {epoch} |\n')

    # Array of tuples (client, client mAP)
    maps = []

    # Select the clients
    clients = rng.choice(TOTAL_FEDERATED_CLIENTS, ACTIVE_FEDERATED_CLIENTS,
                         replace=False)
    clients.sort()
    print('Clients : ' + str(clients))

    # Train the clients
    for client in clients:
      # Set up the client model
      client_model = YOLO(f'runs/{PROJECT}/global{epoch - 1}/weights/best.pt')

      # Train the client model
      client_res = client_model.train(data=f'datasets/{DATASET}/client{client}/data.yaml',
                                      epochs=CLIENT_EPOCHS, name=f'client{epoch}{client}',
                                      project=f'runs/{PROJECT}')

      # Print the client mAP
      print(f'client{epoch}{client} mAP = ' + str(client_res.box.map))

      # Save the mAP
      maps.append((client, client_res.box.map))

    clients_weights = []

    # Sort by mAP and take the n best
    maps.sort(key=lambda tup: tup[1], reverse=True)
    for i in range(BEST_MAPS):
      client_model = YOLO(f'runs/{PROJECT}/client{epoch}{maps[i][0]}/weights/best.pt')
      clients_weights.append(client_model.state_dict())
      print(f'chosen client{epoch}{maps[i][0]} mAP = ' + str(maps[i][1]))

    # Average clients best weights
    avg_weights = average_weights(clients_weights)

    #Load the global model and train to test and save the average weights
    global_model = YOLO(f'runs/{PROJECT}/global{epoch - 1}/weights/best.pt')
    global_model.load_state_dict(avg_weights)
    results = global_model.train(data=f'datasets/{DATASET}/global/data.yaml',
                                 epochs=GLOBAL_AVG_EPOCHS, name=f'global{epoch}',
                                 project=f'runs/{PROJECT}')
    print('global' + str(epoch) + ' mAP = ' + str(results.box.map))

try:
  main()
finally:
  # Restart the runtime to clean the memory
  exit()

/content
PROJECT : test
DATASET : augmented1000
GLOBAL_MODEL_EPOCHS : 5
FEDERATED : False
GLOBAL_FEDERATED_EPOCHS : 5
CLIENT_EPOCHS : 20
TOTAL_FEDERATED_CLIENTS : 5
ACTIVE_FEDERATED_CLIENTS : 3
BEST_MAPS : 3
GLOBAL_AVG_EPOCHS : 1
Downloading https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt to 'yolov8n.pt'...


100%|██████████| 6.25M/6.25M [00:00<00:00, 168MB/s]


Ultralytics YOLOv8.2.58 🚀 Python-3.10.12 torch-2.3.1+cu121 CUDA:0 (Tesla T4, 15102MiB)
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=yolov8n.pt, data=datasets/augmented1000/global/data.yaml, epochs=5, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=runs/test, name=global0, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_la

100%|██████████| 755k/755k [00:00<00:00, 42.0MB/s]


Overriding model.yaml nc=80 with nc=1

                   from  n    params  module                                       arguments                     
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]                 
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]                
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]             
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]                
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]             
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]           
  7                  -1  1    295424  ultralytics

[34m[1mtrain: [0mScanning /content/datasets/augmented1000/global/train/labels... 800 images, 0 backgrounds, 0 corrupt: 100%|██████████| 800/800 [00:00<00:00, 1835.13it/s]

[34m[1mtrain: [0mNew cache created: /content/datasets/augmented1000/global/train/labels.cache





[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01), CLAHE(p=0.01, clip_limit=(1, 4.0), tile_grid_size=(8, 8))


  self.pid = os.fork()
[34m[1mval: [0mScanning /content/datasets/augmented1000/global/valid/labels... 200 images, 0 backgrounds, 0 corrupt: 100%|██████████| 200/200 [00:00<00:00, 665.36it/s]


[34m[1mval: [0mNew cache created: /content/datasets/augmented1000/global/valid/labels.cache
Plotting labels to runs/test/global0/labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.002, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 640 train, 640 val
Using 2 dataloader workers
Logging results to [1mruns/test/global0[0m
Starting training for 5 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


        1/5      2.59G      1.635      2.588      1.563         49        640: 100%|██████████| 50/50 [00:25<00:00,  1.97it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 7/7 [00:06<00:00,  1.06it/s]

                   all        200        503      0.886      0.093      0.333      0.168






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


        2/5      2.17G      1.525       1.93      1.463         53        640: 100%|██████████| 50/50 [00:19<00:00,  2.51it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 7/7 [00:02<00:00,  2.64it/s]


                   all        200        503      0.461      0.352      0.357      0.182

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


        3/5      2.29G      1.516      1.786      1.449        100        640: 100%|██████████| 50/50 [00:19<00:00,  2.60it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 7/7 [00:02<00:00,  2.54it/s]


                   all        200        503      0.477      0.499      0.461      0.236

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


        4/5      2.29G      1.465      1.675      1.407         69        640: 100%|██████████| 50/50 [00:21<00:00,  2.28it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 7/7 [00:02<00:00,  2.36it/s]

                   all        200        503      0.652      0.497      0.573      0.329






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


        5/5      2.17G      1.378      1.507      1.352         57        640: 100%|██████████| 50/50 [00:19<00:00,  2.56it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 7/7 [00:04<00:00,  1.72it/s]

                   all        200        503      0.703      0.674      0.726      0.442






5 epochs completed in 0.037 hours.
Optimizer stripped from runs/test/global0/weights/last.pt, 6.2MB
Optimizer stripped from runs/test/global0/weights/best.pt, 6.2MB

Validating runs/test/global0/weights/best.pt...
Ultralytics YOLOv8.2.58 🚀 Python-3.10.12 torch-2.3.1+cu121 CUDA:0 (Tesla T4, 15102MiB)
Model summary (fused): 168 layers, 3,005,843 parameters, 0 gradients, 8.1 GFLOPs


                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 7/7 [00:05<00:00,  1.28it/s]


                   all        200        503      0.703      0.674      0.726      0.442
Speed: 0.2ms preprocess, 3.7ms inference, 0.0ms loss, 6.2ms postprocess per image
Results saved to [1mruns/test/global0[0m
global0 mAP = 0.44167053836123393


## Clean the runtime

Use this to restart the runtime and clean the memory if needed

In [None]:
exit()

## Zip the runs folder and download it

In [None]:
!zip -r /content/runs.zip /content/runs

from google.colab import files
files.download('/content/runs.zip')