In [None]:
import sys
import os

def running_in_colab():
    return 'google.colab' in sys.modules or os.path.exists('/content')

branch = "main"
username = "giovanna-brod-zamojska"
repo = "federated-learning-project"

is_private = True


def clone_repo_if_needed(exists_ok: bool, username: str, repository: str, is_private: bool, branch: str = None):

  colab_repo_path = f'/content/{repository}/'
  
  if running_in_colab():

    if exists_ok and os.path.exists(colab_repo_path):
        print(f"Repository already exists at {colab_repo_path}")
        return

    if not os.path.exists(colab_repo_path) or not exists_ok:

        # Remove any existing repo
        print(f"Removing content of {colab_repo_path}")
        os.system(f"rm -rf {colab_repo_path}")
        print("Current directory files and folders:", os.system("ls"))

        print("Cloning GitHub repo...")

        if is_private:
            # Clone private repository
            # Clone the GitHub repo (only needed once, if not already cloned)
            from getpass import getpass


            # Prompt for GitHub token (ensure token has access to the repo)
            token = getpass('Enter GitHub token: ')

            if branch:
              !git clone --branch {branch} https://{username}:{token}@github.com/{username}/{repo}.git
            else: 
              !git clone https://{username}:{token}@github.com/{username}/{repo}.git

        else:
            # Clone public repository
            if branch:
              !git clone --branch {branch} https://github.com/{username}/{repo}.git
            else:
              !git clone https://github.com/{username}/{repo}.git


    requirements_path = f"{colab_repo_path}/colab-requirements.txt"
    !pip install -r "$requirements_path"

  else:
    print("Not running in Google Colab. Skipping repository cloning.")#



def setup_notebook(repo_root_name: str = "federated-learning-project"):
    import sys
    from pathlib import Path

    if running_in_colab():
        print("Sys.path: ", sys.path)

        colab_repo_path = f'/content/{repo_root_name}/'
         # Add the repository root to sys.path so modules can be imported
        if str(colab_repo_path) not in sys.path:
            sys.path.insert(0, colab_repo_path)
            print(f"Added {colab_repo_path} to sys.path")
    else:
      
        notebook_dir = Path().absolute()
        project_root = notebook_dir.parent.parent

        # Add project root to Python path if not already present
        if str(project_root) not in sys.path:
            sys.path.insert(0, str(project_root))
            print(f"Added {project_root} to Python path")

        
clone_repo_if_needed(branch=branch, exists_ok=True, username=username, repository=repo, is_private=is_private)

setup_notebook()

    

In [None]:
from collections import OrderedDict
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms

import flwr
from flwr.client import ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner
from flwr.common import Context

DEVICE = torch.device("cuda")
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")


# --- Config ---

# FedAvg with I.I.D. sharding

# - K=100 (number of clients - FIXED)
# - C=0.1 (fraction of clients used per round - FIXED)
# - J=4 (number of local steps - FIXED)
# - Nc={100} (number of labels for each client - FIXED) iid = each client receives an equal number of labels
# - number of rounds -> "proper": up to you to define, based on convergence and time/compute budget
#   comment: if C = 10%  and K=100 then "number of rounds" at least 10 where at each round we pass 10 different clients with respect to prev round (?)

SEED = 42

NUM_CLIENTS = 100
NC = 100
NUM_ROUNDS = 2
CLIENT_FRACTION_PER_ROUND = 0.1

BATCH_SIZE = 32
NUM_WORKERS = 2
PARTITION_TYPE = "dirichlet"  # "iid" or "dirichlet"

LR = 0.05
MOM = 0.9
NUM_WORKERS = 4
WD = 0.0005
LOCAL_STEPS = 5

# --- Dataset partitioning ---
if PARTITION_TYPE == "iid":
    partitioner = IidPartitioner(num_partitions=NUM_CLIENTS)
else:
    partitioner = DirichletPartitioner(
        num_partitions=NUM_CLIENTS, alpha=0.5, partition_by="fine_label"
    )


def load_datasets(partition_id: int):
    fds = FederatedDataset(
        dataset="uoft-cs/cifar100",
        partitioners={"train": partitioner},
        seed=SEED,
    )
    partition = fds.load_partition(partition_id)
    testset = fds.load_split("test")

    # Divide data on each node: 80% train, 20% test
    partition_train_val = partition.train_test_split(test_size=0.2, seed=SEED)
    num_classes = partition_train_val["train"].features["fine_label"].num_classes

    train_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    def apply_train_transform(batch):
        batch["img"] = [train_transform(img) for img in batch["img"]]
        return batch

    def apply_test_transform(batch):
        batch["img"] = [test_transform(img) for img in batch["img"]]
        return batch

    trainloader = DataLoader(
        partition_train_val["train"].with_transform(apply_train_transform),
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    valloader = DataLoader(
        partition_train_val["test"].with_transform(apply_test_transform),
        batch_size=BATCH_SIZE,
        shuffle=False,
    )
    testloader = DataLoader(
        testset.with_transform(apply_test_transform),
        batch_size=BATCH_SIZE,
        shuffle=False,
    )

    return trainloader, valloader, testloader, num_classes


def load_dino_model(num_classes: int = 100):
    # Load DINO ViT-S/16 model from torch hub

    model = torch.hub.load("facebookresearch/dino:main", "dino_vits16")
    model.head = nn.Linear(384, num_classes)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"DINO ViT-S/16 model instantiated. Using device: {device}")

    for param in model.parameters():
        param.requires_grad = False

    for param in model.head.parameters():
        param.requires_grad = True

    return model, device


def get_parameters(model) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def set_parameters(model, parameters: List[np.ndarray]):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


# Flower client implementation
class FlowerClient(NumPyClient):
    def __init__(
        self,
        cid: str,
        train_loader: DataLoader,
        val_loader: DataLoader,
        model,
        device,
    ):
        self.cid = cid
        self.model = model
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = SGD(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=LR,
            momentum=MOM,
            weight_decay=WD,
        )
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=LOCAL_STEPS)
        print(f"Client {cid} instantiated.")

    def get_parameters(self, config: Dict = None) -> List:
        print(f"[Client {self.cid}] get_parameters")
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters: List, config: Dict = None) -> None:
        print(f"[Client {self.cid}] set_parameters, config: {config}")
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

        # state_dict = self.model.state_dict()
        # for k, v in zip(state_dict.keys(), parameters):
        #     state_dict[k] = torch.tensor(v)
        # self.model.load_state_dict(state_dict)

    def fit(self, parameters: List, config: Dict = None) -> Tuple[List, int, Dict]:
        print(f"[Client {self.cid}] fit, config: {config}")

        self.set_parameters(parameters)
        self.model.train()

        local_steps = LOCAL_STEPS  # Number of local SGD steps
        steps_done = 0
        running_loss, correct, total = 0.0, 0, 0

        # pbar = tqdm(self.train_loader, desc="Training")
        while steps_done < local_steps:
            for batch in self.train_loader:
                x = batch["img"]
                y = batch["fine_label"]
                x, y = x.to(self.device), y.to(self.device)

                features = self.model(x)
                outputs = self.model.head(features)
                loss = self.loss_fn(outputs, y)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.scheduler.step()

                # Metrics and logging
                running_loss += loss.item()
                _, predicted = torch.max(outputs, dim=1)

                total += y.size(0)
                correct += predicted.eq(y).sum().item()

                accuracy = correct / total

                steps_done += 1
                # pbar.set_postfix({"loss": running_loss / total, "accuracy": accuracy})

        print("Accuracy:", accuracy)

        return self.get_parameters(), len(self.train_loader.dataset), {}

    def evaluate(
        self, parameters: List, config: Dict = None
    ) -> Tuple[float, int, Dict]:
        print(f"[Client {self.cid}] evaluate, config: {config}")
        self.set_parameters(parameters)
        self.model.eval()

        correct = 0
        total = 0

        with torch.no_grad():
            for batch in self.val_loader:
                x = batch["img"]
                y = batch["fine_label"]

                x, y = x.to(self.device), y.to(self.device)
                outputs = self.model(x)
                preds = outputs.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)

        accuracy = correct / total
        return float(accuracy), total, {"accuracy": accuracy}


# --- Client factory ---


def client_fn(context):

    cid = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    train_loader, val_loader, _, num_classes = load_datasets(partition_id=cid)

    model, _ = load_dino_model(num_classes=num_classes)

    return FlowerClient(
        str(cid),
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        device=DEVICE,
    ).to_client()


# --- Server factory ---


def server_fn(context: Context):

    print("Calling server_fn")

    strategy = FedAvg(
        fraction_fit=CLIENT_FRACTION_PER_ROUND,
        fraction_evaluate=CLIENT_FRACTION_PER_ROUND,
        min_fit_clients=2,
        min_evaluate_clients=2,
        min_available_clients=2,
        evaluate_fn=None,
    )

    config = ServerConfig(num_rounds=NUM_ROUNDS)

    return ServerAppComponents(strategy=strategy, config=config)


# --- Server metrics ---
# def weighted_average(metrics: List[Tuple[int, Dict[str, float]]]):
#     total = sum([num for num, _ in metrics])
#     acc = sum([num * m["accuracy"] for num, m in metrics]) / total
#     return {"accuracy": acc}


# --- Run ---

client_app = ClientApp(client_fn=client_fn)
server_app = ServerApp(server_fn=server_fn)

# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1, "num_cpus": 1}}

run_simulation(
    client_app=client_app,
    server_app=server_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
    verbose_logging=True,
)