First we split the dataset. This only needs to be done once. We use a seed to ensure consistent results.

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from pathlib import Path
import time

from utils.pytorch_models import ResNet18
from models.poolformer import create_poolformer_s12
from models.ConvMixer import create_convmixer_1024_20
from models.MLPMixer import create_mlp_mixer
from utils.clients import GlobalClient
from utils.pytorch_utils import start_cuda

In [2]:
c1_path = Path("data/c1_train.csv")
c2_path = Path("data/c2_train.csv")
c3_path = Path("data/c3_train.csv")
if not (c1_path.exists() and c2_path.exists() and c3_path.exists()):
    df = pd.read_csv("data/train.csv", header=None)

    seed = 42

    df_c1, temp = train_test_split(df, test_size=(2/3), random_state=seed)
    df_c2, df_c3 = train_test_split(temp, test_size=0.5, random_state=seed)

    df_c1.to_csv("data/c1_train.csv", index=False, header=None)
    df_c2.to_csv("data/c2_train.csv", index=False, header=None)
    df_c3.to_csv("data/c3_train.csv", index=False, header=None)

In [3]:
cuda_no = 1
batch_size = 128
num_workers = 0
epochs = 10
communication_rounds = 20

channels = 10
num_classes = 19
dataset_filter = "serbia"

## Train ConvMixer

In [None]:
convmixer = create_convmixer_1024_20(channels, num_classes, pretrained=False)
global_client_convmixer = GlobalClient(
    model=convmixer,
    lmdb_path="data/BigEarth_Serbia_Summer_S2.lmdb",
    val_path="data/test.csv",
    csv_paths=["data/c1_train.csv", "data/c2_train.csv", "data/c3_train.csv"],
)

In [None]:
global_convmixer_results, global_convmixer_client_results = global_client_convmixer.train(communication_rounds=communication_rounds, epochs=epochs)


## Train PoolFormer

In [None]:
poolformer_s12 = create_poolformer_s12(in_chans=channels, num_classes=num_classes)
global_client_poolformer = GlobalClient(
    model=poolformer_s12,
    lmdb_path="data/BigEarth_Serbia_Summer_S2.lmdb",
    val_path="data/test.csv",
    csv_paths=["data/c1_train.csv", "data/c2_train.csv", "data/c3_train.csv"],
)

In [None]:
global_poolformer_results, global_poolformer_client_results = global_client_poolformer.train(communication_rounds=communication_rounds, epochs=epochs)

## Train ResNet18

In [4]:
resnet18 = ResNet18(num_cls=num_classes, channels=channels, pretrained=False)
global_client_resnet18 = GlobalClient(
    model=resnet18,
    lmdb_path="data/BigEarth_Serbia_Summer_S2.lmdb",
    val_path="data/test.csv",
    csv_paths=["data/c1_train.csv", "data/c2_train.csv", "data/c3_train.csv"],
)



In [5]:
global_resnet18_results, global_resnet18_client_results = global_client_resnet18.train(communication_rounds=communication_rounds, epochs=epochs)

Round 1/20
----------
Epoch 1/10
----------


training:   0%|          | 0/21 [00:00<?, ?it/s]

training: 100%|██████████| 21/21 [00:15<00:00,  1.35it/s]


Epoch 2/10
----------


training:  14%|█▍        | 3/21 [00:01<00:09,  1.80it/s]

## Train MLP-Mixer

In [None]:
mlp_mixer = create_mlp_mixer(channels, num_classes)
global_client_mlp_mixer = GlobalClient(
    model=mlp_mixer,
    lmdb_path="data/BigEarth_Serbia_Summer_S2.lmdb",
    val_path="data/test.csv",
    csv_paths=["data/c1_train.csv", "data/c2_train.csv", "data/c3_train.csv"],
)

In [None]:
global_mlp_mixer_results, global_mlp_mixer_client_results = global_client_mlp_mixer.train(communication_rounds=communication_rounds, epochs=epochs)