In [1]:
%cd ../..

/Users/davideleo/Desktop/Projects/research/papers/fl_wavelet_v0


### Get gradients

In [2]:
import random
import torch 
import numpy as np 
from src.data.cifar100 import get_federation 
from src.federated_learning.standard.fedavg import Client as Client
from src.models.neural_networks import LeNet5
from copy import deepcopy
from tqdm import tqdm

random.seed(42)
np.random.seed(42)
torch.random.manual_seed(42)

# Federation
model = LeNet5(in_channels = 3, in_padding = 0, num_classes = 100)

test_dataset = get_federation(
    num_shards = 1,
    alpha = 1000,
    attacks = [],
    attacks_proba = 0
)[0]["test"]

In [3]:
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

# Train VAE 
num_features = 200
num_epochs = 100 
batch_size = 64 
device = "cpu"

gradients = []
test_dataset.load()

dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
model = model.to(device)
optim = Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

for epoch in tqdm(range(num_epochs)): 
    for X, y in dataloader: 
        theta0_sd = model.cpu().state_dict()["_mlp.5.weight"].clone()
        y_hat = model(X.to(device))
        loss = criterion(y_hat, y.to(device))
        loss.backward()
        optim.step()
        optim.zero_grad()
        dtheta = model.cpu().state_dict()["_mlp.5.weight"] - theta0_sd
        flat = dtheta.flatten()
        indices = torch.randperm(flat.size(0))[:num_features]
        sampled = flat[indices].clone()
        gradients.append(sampled)

100%|██████████| 10/10 [00:08<00:00,  1.12it/s]


In [4]:
import torch.nn.functional as F
from torch.utils.data import Dataset
from src.models.neural_networks import VAE

class TensorListDataset(Dataset):
    def __init__(self, tensor_list):
        self.data = tensor_list 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Loss function
def vae_loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

# Setup
num_epochs = 100 
batch_size = 128 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = TensorListDataset(gradients)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
model = VAE(num_features).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    for batch in dataloader:
        batch = batch.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch)
        loss = vae_loss_function(recon_batch, batch, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    avg_loss = train_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

Epoch 1/100 - Loss: 6.0630
Epoch 2/100 - Loss: 2.6886
Epoch 3/100 - Loss: 1.5033
Epoch 4/100 - Loss: 1.0817
Epoch 5/100 - Loss: 0.8609
Epoch 6/100 - Loss: 0.7351
Epoch 7/100 - Loss: 0.6233
Epoch 8/100 - Loss: 0.5426
Epoch 9/100 - Loss: 0.4697
Epoch 10/100 - Loss: 0.4195
Epoch 11/100 - Loss: 0.3640
Epoch 12/100 - Loss: 0.3193
Epoch 13/100 - Loss: 0.2839
Epoch 14/100 - Loss: 0.2485
Epoch 15/100 - Loss: 0.2232
Epoch 16/100 - Loss: 0.1995
Epoch 17/100 - Loss: 0.1767
Epoch 18/100 - Loss: 0.1587
Epoch 19/100 - Loss: 0.1452
Epoch 20/100 - Loss: 0.1309
Epoch 21/100 - Loss: 0.1183
Epoch 22/100 - Loss: 0.1079
Epoch 23/100 - Loss: 0.0982
Epoch 24/100 - Loss: 0.0895
Epoch 25/100 - Loss: 0.0824
Epoch 26/100 - Loss: 0.0753
Epoch 27/100 - Loss: 0.0698
Epoch 28/100 - Loss: 0.0645
Epoch 29/100 - Loss: 0.0590
Epoch 30/100 - Loss: 0.0552
Epoch 31/100 - Loss: 0.0507
Epoch 32/100 - Loss: 0.0473
Epoch 33/100 - Loss: 0.0437
Epoch 34/100 - Loss: 0.0411
Epoch 35/100 - Loss: 0.0380
Epoch 36/100 - Loss: 0.0358
E

In [5]:
torch.save(model, "notebooks/cifar100/results/vae.pth")