In [None]:
import torch
import torch.nn as nn
import numpy as np

from splitters import get_imbalanced_client_loaders
from states import ClientState, AdamClientState
from LogitsTransmition.src.src.algorithms.compresors import CompressedModel, IdenticalCompressor, ExpDitheringCompressor, RandKCompressor, TopKCompressor
from validation_utils import total_evaluation, evaluate
from algorithms.feen import feen_training_step
from algorithms.fedavg import fed_avg_training_step
from models import ResNet18

In [2]:
num_clients = 10
gamma = 1e-4
lmbd = 0.1
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

models = [CompressedModel(ResNet18(num_clients), TopKCompressor()).to(device) for _ in range(num_clients)]
client_optimizers = [torch.optim.Adam(model.parameters(), lr=1e-5) for model in models]
clients_states = [AdamClientState(model, lmbd, gamma) for model in models]
criterion = nn.CrossEntropyLoss()


In [3]:
local_train_loader, local_test_loader, common_loader, validation_loader = get_imbalanced_client_loaders(num_clients, classes_per_client=3)

In [4]:
model = ResNet18(num_clients).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [5]:
for _ in range(10):
    for inputs, target in common_loader:
        inputs = inputs.to(device)
        target = target.to(device)

        logits = model(inputs)
        loss = criterion(logits, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Accuracy{evaluate(model, validation_loader, device)}")

Accuracy0.1297
Accuracy0.0938
Accuracy0.1002
Accuracy0.1291
Accuracy0.1458
Accuracy0.1266
Accuracy0.1361
Accuracy0.1842
Accuracy0.2114
Accuracy0.2336


In [6]:
for _ in range(10):
    for inputs, target in common_loader:
        inputs = inputs.to(device)
        target = target.to(device)

        logits = model(inputs)
        loss = criterion(logits, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Accuracy{evaluate(model, validation_loader, device)}")

Accuracy0.2329
Accuracy0.2538
Accuracy0.2819
Accuracy0.2499
Accuracy0.2922
Accuracy0.2922
Accuracy0.286
Accuracy0.3057
Accuracy0.3059
Accuracy0.2995


In [7]:
for _ in range(10):
    for inputs, target in common_loader:
        inputs = inputs.to(device)
        target = target.to(device)

        logits = model(inputs)
        loss = criterion(logits, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Accuracy{evaluate(model, validation_loader, device)}")

Accuracy0.3304
Accuracy0.3246
Accuracy0.3201
Accuracy0.3324
Accuracy0.3336
Accuracy0.3389
Accuracy0.3436
Accuracy0.3281
Accuracy0.3335
Accuracy0.3186
