In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

BATCH_SIZE = 512

In [4]:
from collections import defaultdict
import random 

def get_client_loaders(num_clients=3):
    cifar_train = datasets.CIFAR10(root='./data_cifar', train=True,
                                   download=True, transform=transform)
    client_data = []
    samples_per_client = len(cifar_train) // num_clients
    for i in range(num_clients):
        indices = range(i * samples_per_client, (i+1) * samples_per_client)
        client_data.append(Subset(cifar_train, indices))
    
    client_loaders = [DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False) for data in client_data]
    return client_loaders

def get_imbalanced_client_loaders(num_clients=3):
    client_data = []
    cifar_train = datasets.CIFAR10(root='./data_cifar', train=True,
                                   download=True, transform=transform)
    cifar_common = Subset(cifar_train, range(5000))
    cifar_train = Subset(cifar_train, range(5000, len(cifar_train)))
    
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(cifar_train):
        class_indices[label].append(idx)
    for i in range(num_clients):
        client_indices = []
        client_classes = random.sample(range(10), 7)
        for obj_class in client_classes:
            client_indices += random.sample(class_indices[obj_class], 50)
        client_data.append(Subset(cifar_train, client_indices))
    client_loaders = [DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False) for data in client_data]
    common_loader = DataLoader(cifar_common, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
    return client_loaders, common_loader


In [5]:
def client_train(model, loader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

In [6]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            _, predicted = torch.max(logits, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

In [7]:
num_clients = 10

In [8]:
validation_loder = DataLoader(datasets.CIFAR10(root='./data_cifar', train=False, download=True, transform=transform), batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

In [9]:
client_models = [SimpleCNN().to(device) for _ in range(num_clients)]
client_loaders, common_loader = get_imbalanced_client_loaders(num_clients)

In [10]:
class ClientState:
    def __init__(self, model, lmbd = 0.1, gamma = 0.1):
        self.weights = []
        self.state = [torch.zeros_like(p) for p in model.parameters()]
        self.prev_state = [torch.zeros_like(p) for p in model.parameters()]
        self.lmbd = lmbd
        self.gamma = gamma
    
    def local_step(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] += param.grad - self.prev_state[idx]
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)
    
    def global_step(self, model):
        for idx, param in enumerate(model.parameters()):
            self.state[idx] = param.grad
            self.prev_state[idx] = param.grad
            param.grad = torch.zeros_like(param.grad)

    def set_weights(self, model):
        self.weights = [p for p in model.parameters()]

    def get_reg_term(self, model):
        l2_regularization = torch.tensor(0., requires_grad=True)
        for idx, (param, x, g) in enumerate(zip(model.parameters(), self.weights, self.state)):
            l2_regularization = l2_regularization + torch.norm(param - (x - self.gamma * self.lmbd * g), p=2)
        return l2_regularization * 0.5


In [11]:
from torch.utils.tensorboard import SummaryWriter
import torch.utils.tensorboard as tensorboard

writer = SummaryWriter()

In [12]:
def comp_top_k(logits, k = 3):
    size = logits.size()
    _, indices = torch.topk(logits, k, dim= -1)
    sparse_logits = torch.gather(logits, dim=-1, index=indices).to(logits.device)
    return sparse_logits, indices, size

def decomp_top_k(sparse_logits, indices, size):
    decomp_logits  = torch.zeros(size, dtype= sparse_logits.dtype, device= sparse_logits.device)
    decomp_logits.scatter_(-1, indices, sparse_logits)
    return decomp_logits

In [None]:
def dithering(logits, b = 2, s = 8, p = 2):
    norm_logits = torch.norm(logits.float(), p, dim=-1, keepdim=True)
    normalized_logits = logits.abs() / norm_logits
    levels = b ** -torch.arange(1, s + 1, dtype=logits.dtype, device=logits.device)  
    levels = torch.cat([torch.tensor([1.0], device=logits.device, dtype=logits.dtype), levels, torch.tensor([0.0], device=logits.device, dtype=logits.dtype)])
    u = torch.sum(normalized_logits.unsqueeze(-1) <= levels.unsqueeze(0), dim=-1, dtype= float) 
    lower = b ** -u
    upper = b ** -(u - 1)
    lower = torch.where(u == 10, 0.0, lower)
    probs_upper = (normalized_logits - lower) / (upper - lower)
    bern = torch.bernoulli(probs_upper)
    quantized = torch.where(bern == 1, upper, lower)
    result = norm_logits * torch.sign(logits) * quantized
    return result

In [None]:
def comp_rand_k(logits, k = 3):
    seed = random.randint(0, 1000)
    torch.manual_seed(seed)
    size = logits.size()
    new_size = size[:-1] + (k,)
    sparse_logits = torch.zeros(new_size, dtype= logits.dtype)
    for i in range(size[0]):
        for j in range(size[1]):
            indices = torch.randperm(size[-1], device=logits.device)[:k]
            indices, _ = torch.sort(indices)
            sparse_logits[i, j] = logits[i, j, indices]
    sparse_logits *= size[-1] / k
    return sparse_logits, seed, size

def decomp_rand_k(sparse_logits, seed, size, k = 3):
    torch.manual_seed(seed)
    decomp_logits  = torch.full(size, float('-inf') ,dtype= sparse_logits.dtype)
    for i in range(size[0]):
        for j in range(size[1]):
            indices = torch.randperm(size[-1], device=sparse_logits.device)[:k]
            indices, _ = torch.sort(indices)
            decomp_logits[i, j, indices] = sparse_logits[i, j]
    return decomp_logits

In [13]:
import random

clients_states = [ClientState(model) for model in client_models]
client_optimizers = [
    optim.Adam(model.parameters())
    for model in client_models
]
criterion = nn.CrossEntropyLoss()

class DifferentiableCompress(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x#decomp_top_k(*comp_top_k(x))
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def sync_clients(batch):
    inputs, target = batch
    inputs = inputs.to(device)
    target = target.to(device)
    logits = []
    for model in client_models:
        raw_out = model(inputs)
        compressed = DifferentiableCompress.apply(raw_out)
        logits.append(compressed)
    
    logits = torch.stack(logits)
    logits = logits.mean(0)
    loss = criterion(logits, target)
    loss.backward()

def run_step(p, iter_idx):
    if random.random() > p:
        batch = next(iter(common_loader))
        sync_clients(batch)
        for model, state in zip(client_models, clients_states):
            state.global_step(model)
    else:
        for model, state, loader in zip(client_models, clients_states, client_loaders):
            inputs, target = next(iter(loader))
            inputs = inputs.to(device)
            target = target.to(device)
            logits = model(inputs)
            loss = criterion(logits, target)
            loss.backward()
            state.local_step(model)
    
    for model, state, loader, optimizer in zip(client_models, clients_states, client_loaders, client_optimizers):
        state.set_weights(model)
        for _ in range(10):
            inputs, target = next(iter(loader))
            inputs = inputs.to(device)
            target = target.to(device)
            logits = model(inputs)
            loss = criterion(logits, target) * state.gamma
            reg = state.get_reg_term(model)
            loss += reg
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [14]:
def logging(iter_idx):
    local_accuracy = torch.tensor([evaluate(model, loader) for model, loader in zip(client_models, client_loaders)])
    single_accuracy = torch.tensor([evaluate(model, validation_loder) for model in client_models])
    writer.add_scalars(
        main_tag='Accuracy',
        tag_scalar_dict={'local': local_accuracy.mean(), 'single': single_accuracy.mean()},
        global_step=iter_idx
    )
    print(f"Single accuracy = {single_accuracy.mean():.2f} +- {single_accuracy.std():.2f} || Local accuracy = {local_accuracy.mean():.2f} +- {local_accuracy.std():.2f}")

In [15]:
for i in range(100):
    if i % 10 == 0:
        logging(i)
    run_step(0.2, i)

Single accuracy = 0.10 +- 0.01 || Local accuracy = 0.08 +- 0.07
Single accuracy = 0.29 +- 0.02 || Local accuracy = 0.79 +- 0.07
Single accuracy = 0.28 +- 0.02 || Local accuracy = 1.00 +- 0.00
Single accuracy = 0.27 +- 0.02 || Local accuracy = 1.00 +- 0.00


KeyboardInterrupt: 

In [16]:
writer.close()