In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [7]:
df = pd.read_csv("data/train.csv", header=None)

seed = 42

df_c1, temp = train_test_split(df, test_size=0.7, random_state=seed)
df_c2, df_c3 = train_test_split(df, test_size=0.5, random_state=seed)

df_c1.to_csv("data/c1_train.csv", index=False, header=None)
df_c2.to_csv("data/c2_train.csv", index=False, header=None)
df_c3.to_csv("data/c3_train.csv", index=False, header=None)


In [11]:
cuda_no = 1
batch_size = 128
num_workers = 0
epochs = 1

channels = 10
num_classes = 19
dataset_filter = "serbia"

In [29]:
import copy
from utils.pytorch_datasets import Ben19Dataset
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

class FLCLient:
    def __init__(self, model, lmdb_path, csv_path) -> None:
        self.model = model
        self.dataset = Ben19Dataset(lmdb_path, csv_path)
        self.train_loader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=True,
            pin_memory=True,
        )
        
    def set_model(self, model):
        self.model = copy.deepcopy(model)
        
    def train_one_round(self):
        state_before = self.model.state_dict()
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0)
        criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
        
        for epoch in range(1, epochs + 1):
            print("Epoch {}/{}".format(epoch, epochs))
            print("-" * 10)

            self.train_epoch(criterion, optimizer)
        
        state_after = self.model.state_dict()
        
        model_update = {}
        for key, value_before in state_before.items():
            value_after = state_after[key]
            diff = value_after - value_before
            model_update[key] = diff
        
        return model_update
    
    def train_epoch(self, criterion, optimizer):
        self.model.train()
        for idx, batch in enumerate(tqdm(self.train_loader, desc="training")):
            data, labels, index = batch["data"], batch["label"], batch["index"]
            data = data
            labels = labels
            optimizer.zero_grad()

            logits = self.model(data)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            break

In [40]:
import numpy as np
import copy
from utils.pytorch_utils import init_results, get_classification_report, update_results, print_micro_macro

class GlobalClient:
    def __init__(self, model, lmdb_path, csv_paths, val_path) -> None:
        self.model = model
        self.clients = [
            FLCLient(copy.deepcopy(self.model), lmdb_path, csv_path) for csv_path in csv_paths
        ]
        self.validation_set = Ben19Dataset(
            lmdb_path=lmdb_path, csv_path=val_path, img_transform="default"
        )
        self.val_loader = DataLoader(
            self.validation_set,
            batch_size=batch_size,
            num_workers=num_workers,
            shuffle=False,
            pin_memory=True,
        )
    
    def train(self, communication_rounds):
        results = init_results(num_classes)
        
        for com_round in range(1, communication_rounds + 1):
            print("Round {}/{}".format(com_round, communication_rounds))
            print("-" * 10)
            
            self.communication_round()
            report = self.validation_round()
            
            results = update_results(results, report, num_classes)
            print_micro_macro(report)
            
            for client in self.clients:
                client.set_model(self.model)
        
        return results
    
    def validation_round(self):
        self.model.eval()
        y_true = []
        predicted_probs = []

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(self.val_loader, desc="test")):
                data = batch["data"]
                labels = batch["label"].numpy()

                logits = self.model(data)
                probs = torch.sigmoid(logits).cpu().numpy()

                predicted_probs += list(probs)

                y_true += list(labels)

        predicted_probs = np.asarray(predicted_probs)
        y_predicted = (predicted_probs >= 0.5).astype(np.float32)

        y_true = np.asarray(y_true)
        report = get_classification_report(
            y_true, y_predicted, predicted_probs, dataset_filter
        )
        return report
    
    def communication_round(self):
        # here the clients train
        # TODO: could be parallelized
        model_updates = [client.train_one_round() for client in self.clients]
        
        # parameter aggregation
        update_aggregation = {}
        for key in model_updates[0].keys():
            params = torch.stack([update[key] for update in model_updates], dim=0)
            avg = torch.mean(params, dim=0)
            update_aggregation[key] = avg
        
        # update the global model
        self.model.load_state_dict(update_aggregation)
        # self.model.load_state_dict(model_updates[0])
        

In [41]:
from utils.pytorch_models import ResNet18

model = ResNet18(num_cls=num_classes, channels=channels, pretrained=True)

global_client = GlobalClient(model, "data/BigEarth_Serbia_Summer_S2.lmdb", ["data/c1_train.csv", "data/c2_train.csv", "data/c3_train.csv"], "data/test.csv")
results = global_client.train(1)



Round 1/1
----------
Epoch 1/1
----------


training:   0%|          | 0/19 [00:06<?, ?it/s]


Epoch 1/1
----------


training:   0%|          | 0/31 [00:06<?, ?it/s]


Epoch 1/1
----------


training:   0%|          | 0/31 [00:06<?, ?it/s]


RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long