In [1]:
from serverbase import *
from userbase import *
from serverDoVeRA import *
from userDoVeRA import *
import numpy as np
import torch

In [2]:
num_gpus = torch.cuda.device_count()

config = {
        "num_user": 3,
        "batch_size": 64,
        "global_L": True,
        "dim_L": 2,
        "num_gpus": num_gpus,
        "global_epochs": 20,
        "user_ratio": 1}

num_gpus

4

In [3]:
from torchvision import  datasets
from torchvision import transforms
from torch.utils.data import DataLoader

BATCH_SIZE = 640

train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE, shuffle=True)

for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([640, 1, 28, 28])
Image label dimensions: torch.Size([640])


In [4]:
nb_client = 3
clients = []

server = 0

In [5]:
train_loader_list = []

labels = [0, 1, 2]
indices = [idx for idx, target in enumerate(train_dataset.targets) if target in labels]
dataloader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, indices),
                                         batch_size=BATCH_SIZE)
train_loader_list.append(dataloader)
labels = [3, 4, 5]
indices = [idx for idx, target in enumerate(train_dataset.targets) if target in labels]
dataloader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, indices),
                                         batch_size=BATCH_SIZE)
train_loader_list.append(dataloader)
labels = [6, 7, 8, 9]
indices = [idx for idx, target in enumerate(train_dataset.targets) if target in labels]
dataloader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, indices),
                                         batch_size=BATCH_SIZE)
train_loader_list.append(dataloader)

train_loader_list

[<torch.utils.data.dataloader.DataLoader at 0x7f5d141b37c0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f5d141b3eb0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f5bf1497760>]

In [6]:
from torch import nn
class MultiLayerPerceptron(nn.Module):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )

    def forward(self, x):
        x = self.layers(x)
        return x

In [7]:
num_features = 28*28
num_hidden_1 = 128
num_hidden_2 = 256
num_classes = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 0.005
num_epoches = 5

In [8]:
def compute_accuracy(model, data_loader, device):
    # model.eval()
    correct_pred, num_examples = 0, 0

    with torch.no_grad():
        for features, targets in data_loader:
            features = features.view(-1, 28*28).to(device)
            targets = targets.to(device)

            logits = model(features)
            _, predicted_labels = torch.max(logits, 1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels==targets).sum()

        return correct_pred.float()/num_examples*100

# MLP Model

# Pre-trained Model

In [9]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

model = MultiLayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)

model.to(device)
optimizer_pretrained = torch.optim.Adam(model.parameters(), lr=learning_rate)
print(device)
print(model)
print(optimizer_pretrained)

cuda
MultiLayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.005
    maximize: False
    weight_decay: 0
)


In [10]:
import time
import torch.nn.functional as F
def train(num_epoches, model, optimizer, train_loader, device):
    start_time = time.time()
    for epoch in range(num_epoches):
        for batch_idx, (features, targets) in enumerate(train_loader):
            features = features.view(-1, 28*28).to(device)
            targets = targets.to(device)

            logits = model(features)
            loss = F.cross_entropy(logits, targets)

            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            if not batch_idx%400:
                print('Epoch: %03d/%03d|Batch %03d/%03d| Loss: %.4f' % (epoch+1, num_epoches, batch_idx, len(train_loader), loss))

        with torch.set_grad_enabled(False):
            print('Epoch: %03d/%03d training accuracy: %.2f%%' % (epoch+1, num_epoches, compute_accuracy(model, train_loader, device)))

        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

In [11]:
train(num_epoches, model, optimizer_pretrained, train_loader, device)
print(f'Test accuracy: {compute_accuracy(model, test_loader, device):.2f}%')

Epoch: 001/005|Batch 000/094| Loss: 2.3053
Epoch: 001/005 training accuracy: 95.21%
Time elapsed: 0.11 min
Epoch: 002/005|Batch 000/094| Loss: 0.1342
Epoch: 002/005 training accuracy: 97.14%
Time elapsed: 0.22 min
Epoch: 003/005|Batch 000/094| Loss: 0.1267
Epoch: 003/005 training accuracy: 97.42%
Time elapsed: 0.33 min
Epoch: 004/005|Batch 000/094| Loss: 0.0813
Epoch: 004/005 training accuracy: 98.35%
Time elapsed: 0.43 min
Epoch: 005/005|Batch 000/094| Loss: 0.0412
Epoch: 005/005 training accuracy: 98.80%
Time elapsed: 0.54 min
Total Training Time: 0.54 min
Test accuracy: 97.49%


# Fed MLP Model

In [12]:
import copy

server_model = copy.deepcopy(model)

In [13]:
import random
from tqdm import tqdm

torch.manual_seed(42)
torch.cuda.manual_seed(42)

server = ServerMLP(model=server_model, test_loader=test_loader)

user_list = []

# Create users
for i in range(3):
    user_i = UserMLP(train_loader=train_loader_list[i], model=server_model, user_id=i, local_epochs=5)
    user_list.append(user_i)



for _ in tqdm(range(20), desc=f"Progress"):
    # Distribute initial model to users
    server.distribute_model(user_list)
    
    # Sub-sample users
    sub_user_list = random.sample(user_list, int(1 * 3))

    # Check the sub-sampled user and train model
    users_loss = 0.0
    for user in sub_user_list:
        user_loss = user.user_train()
        users_loss += user_loss
    # Aggregate weights on server
    server.aggregate_weights(sub_user_list)

    # Calulate avg loss on selected users
    train_loss =  users_loss / len(sub_user_list)    
    val_loss = server.model_eval()

    # wandb.log({"train_loss": train_loss, "val_loss": val_loss})
    print(f'Test accuracy of server: {server.compute_accuracy()}%')

Progress:   0%|          | 0/20 [00:00<?, ?it/s]

layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.07399697927758098


Progress:   5%|▌         | 1/20 [00:21<06:49, 21.55s/it]

Test accuracy of server: 97.8499984741211%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.07547288434579968


Progress:  10%|█         | 2/20 [00:41<06:10, 20.58s/it]

Test accuracy of server: 97.79999542236328%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.08223380986601114


Progress:  15%|█▌        | 3/20 [01:01<05:44, 20.29s/it]

Test accuracy of server: 97.61000061035156%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.08928364282473922


Progress:  20%|██        | 4/20 [01:21<05:22, 20.14s/it]

Test accuracy of server: 97.40999603271484%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.09891076711937785


Progress:  25%|██▌       | 5/20 [01:41<05:01, 20.07s/it]

Test accuracy of server: 97.0999984741211%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.0993972725700587


Progress:  30%|███       | 6/20 [02:01<04:42, 20.16s/it]

Test accuracy of server: 97.19000244140625%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.1102590449154377


Progress:  35%|███▌      | 7/20 [02:21<04:21, 20.12s/it]

Test accuracy of server: 96.97000122070312%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.1121402932330966


Progress:  40%|████      | 8/20 [02:41<04:00, 20.08s/it]

Test accuracy of server: 96.87000274658203%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.12939141085371375


Progress:  45%|████▌     | 9/20 [03:01<03:40, 20.01s/it]

Test accuracy of server: 96.59000396728516%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.1322789746336639


Progress:  50%|█████     | 10/20 [03:21<03:19, 19.92s/it]

Test accuracy of server: 96.43000030517578%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.13430627342313528


Progress:  55%|█████▌    | 11/20 [03:41<02:59, 19.95s/it]

Test accuracy of server: 96.44000244140625%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.13616208778694272


Progress:  60%|██████    | 12/20 [04:01<02:39, 19.97s/it]

Test accuracy of server: 96.62999725341797%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.14684918895363808


Progress:  65%|██████▌   | 13/20 [04:21<02:19, 19.98s/it]

Test accuracy of server: 96.4000015258789%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.13582701445557177


Progress:  70%|███████   | 14/20 [04:41<01:59, 19.93s/it]

Test accuracy of server: 96.70999908447266%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.16833082539960742


Progress:  75%|███████▌  | 15/20 [05:00<01:39, 19.92s/it]

Test accuracy of server: 96.15999603271484%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.14374617487192154


Progress:  80%|████████  | 16/20 [05:20<01:19, 19.90s/it]

Test accuracy of server: 96.63999938964844%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.151554049924016


Progress:  85%|████████▌ | 17/20 [05:41<01:00, 20.08s/it]

Test accuracy of server: 96.52999877929688%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.15025727823376656


Progress:  90%|█████████ | 18/20 [06:01<00:40, 20.14s/it]

Test accuracy of server: 96.51000213623047%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.18111255392432213


Progress:  95%|█████████▌| 19/20 [06:22<00:20, 20.23s/it]

Test accuracy of server: 96.01000213623047%
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
layers.0.weight
layers.0.bias
layers.2.weight
layers.2.bias
layers.4.weight
layers.4.bias
val_loss: 0.16727013047784567


Progress: 100%|██████████| 20/20 [06:41<00:00, 20.09s/it]

Test accuracy of server: 96.48999786376953%





In [14]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad=False
        else:
            freeze_linear_layers(child)

# LoRA

In [15]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1/torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank)*std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha*(x@self.A@self.B)
        return x

In [16]:
from torch.nn import functional as F

class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features,
            linear.out_features,
            rank,
            alpha
        )

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight+self.lora.alpha*lora.T

        return F.linear(x, combined_weight, self.linear.bias)

In [17]:
import copy

model_lora = copy.deepcopy(model)

model_lora.layers[0]=LinearWithLoRAMerged(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2]=LinearWithLoRAMerged(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4]=LinearWithLoRAMerged(model_lora.layers[4], rank=4, alpha=8)
model_lora.to(device)
optimizer_lora=torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
print(model_lora)

MultiLayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRAMerged(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRAMerged(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRAMerged(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)


In [18]:
freeze_linear_layers(model_lora)
for name, param in model_lora.named_parameters():
    print(f'{name}: {param.requires_grad}')

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [19]:
import random
from tqdm import tqdm

torch.manual_seed(42)
torch.cuda.manual_seed(42)

server = ServerMLP(model=model_lora, test_loader=test_loader)

user_list = []

# Create users
for i in range(3):
    user_i = UserMLP(train_loader=train_loader_list[i], model=model_lora, user_id=i, local_epochs=5)
    user_list.append(user_i)



for _ in tqdm(range(20), desc=f"Progress"):
    # Distribute initial model to users
    server.distribute_model(user_list)
    
    # Sub-sample users
    sub_user_list = random.sample(user_list, int(1 * 3))

    # Check the sub-sampled user and train model
    users_loss = 0.0
    for user in sub_user_list:
        user_loss = user.user_train()
        users_loss += user_loss
    # Aggregate weights on server
    server.aggregate_weights(sub_user_list)

    # Calulate avg loss on selected users
    train_loss =  users_loss / len(sub_user_list)    
    val_loss = server.model_eval()

    # wandb.log({"train_loss": train_loss, "val_loss": val_loss})
    print(f'Test accuracy of server: {server.compute_accuracy()}%')

Progress:   0%|          | 0/20 [00:00<?, ?it/s]

layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
val_loss: 0.08553802664391696


Progress:   5%|▌         | 1/20 [00:21<06:43, 21.22s/it]

Test accuracy of server: 97.31999969482422%
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
val_loss: 0.08775392640382051


Progress:  10%|█         | 2/20 [00:42<06:18, 21.05s/it]

Test accuracy of server: 97.39999389648438%
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
val_loss: 0.09048473229631782


Progress:  15%|█▌        | 3/20 [01:03<05:58, 21.07s/it]

Test accuracy of server: 97.16999816894531%
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
val_loss: 0.09238581359386444


Progress:  20%|██        | 4/20 [01:24<05:37, 21.12s/it]

Test accuracy of server: 97.16999816894531%
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
layers.0.linear.weight
layers.0.linear.bias
layers.0.lora.A
layers.0.lora.B
layers.2.linear.weight
layers.2.linear.bias
layers.2.lora.A
layers.2.lora.B
layers.4.linear.weight
layers.4.linear.bias
layers.4.lora.A
layers.4.lora.B
val_loss: 0.09461246151477098


Progress:  25%|██▌       | 5/20 [01:45<05:16, 21.13s/it]

Test accuracy of server: 97.04000091552734%


Progress:  25%|██▌       | 5/20 [01:46<05:18, 21.23s/it]


KeyboardInterrupt: 

# DoVeRa

In [None]:
class VeRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1/torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.normal(size=(in_dim, rank), mean=0., std=std_dev), requires_grad=False)
        self.d = nn.Parameter(torch.ones(rank))
        self.B = nn.Parameter(torch.normal(size=(rank, out_dim), mean=0., std=std_dev), requires_grad=False)
        self.b = nn.Parameter(torch.zeros(out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ torch.diag(self.d) @ self.B @ torch.diag(self.b))
        return x

In [None]:
class LinearWithDoVeRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.vera = VeRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        self.m = nn.Parameter(self.linear.weight.norm(p=2, dim=0, keepdim=True))

    def forward(self, x):
        vera=self.vera.A @ torch.diag(self.vera.d) @ self.vera.B @ torch.diag(self.vera.b)
        numerator=self.linear.weight+self.vera.alpha*vera.T
        denominator=numerator.norm(p=2, dim=0, keepdim=True)
        directional_component=numerator/denominator
        new_weight=self.m*directional_component
        return F.linear(x, new_weight, self.linear.bias)

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

model_dovera = copy.deepcopy(model)

model_dovera.layers[0]=LinearWithDoVeRAMerged(model_dovera.layers[0], rank=4, alpha=8)
model_dovera.layers[2]=LinearWithDoVeRAMerged(model_dovera.layers[2], rank=4, alpha=8)
model_dovera.layers[4]=LinearWithDoVeRAMerged(model_dovera.layers[4], rank=4, alpha=8)

model_dovera.to(device)
optimizer_dovera=torch.optim.Adam(model_dovera.parameters(), lr=learning_rate)

print(model_dovera)

MultiLayerPerceptron(
  (layers): Sequential(
    (0): LinearWithDoVeRAMerged(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (vera): VeRALayer()
    )
    (1): ReLU()
    (2): LinearWithDoVeRAMerged(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (vera): VeRALayer()
    )
    (3): ReLU()
    (4): LinearWithDoVeRAMerged(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (vera): VeRALayer()
    )
  )
)


In [None]:
freeze_linear_layers(model_dovera)
for name, param in model_dovera.named_parameters():
    print(f'{name}: {param.requires_grad}')

layers.0.m: True
layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.vera.A: False
layers.0.vera.d: True
layers.0.vera.B: False
layers.0.vera.b: True
layers.2.m: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.vera.A: False
layers.2.vera.d: True
layers.2.vera.B: False
layers.2.vera.b: True
layers.4.m: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.vera.A: False
layers.4.vera.d: True
layers.4.vera.B: False
layers.4.vera.b: True


In [None]:
import random
from tqdm import tqdm

torch.manual_seed(42)
torch.cuda.manual_seed(42)

server = ServerMLP(model=model_dovera, test_loader=test_loader)

user_list = []

# Create users
for i in range(3):
    user_i = UserMLP(train_loader=train_loader_list[i], model=model_dovera, user_id=i, local_epochs=5)
    user_list.append(user_i)



for _ in tqdm(range(20), desc=f"Progress"):
    # Distribute initial model to users
    server.distribute_model(user_list)
    
    # Sub-sample users
    sub_user_list = random.sample(user_list, int(1 * 3))

    # Check the sub-sampled user and train model
    users_loss = 0.0
    for user in sub_user_list:
        user_loss = user.user_train()
        users_loss += user_loss
    # Aggregate weights on server
    server.aggregate_weights(sub_user_list)

    # Calulate avg loss on selected users
    train_loss =  users_loss / len(sub_user_list)    
    val_loss = server.model_eval()

    # wandb.log({"train_loss": train_loss, "val_loss": val_loss})
    print(f'Test accuracy of server: {server.compute_accuracy()}%')

Progress:   0%|          | 0/20 [00:00<?, ?it/s]

val_loss: 0.07653172709979117


Progress:   5%|▌         | 1/20 [00:23<07:27, 23.57s/it]

Test accuracy of server: 97.80999755859375%
val_loss: 0.076780412113294


Progress:  10%|█         | 2/20 [00:46<07:01, 23.40s/it]

Test accuracy of server: 97.7699966430664%
val_loss: 0.07780138798989356


Progress:  15%|█▌        | 3/20 [01:10<06:41, 23.62s/it]

Test accuracy of server: 97.69999694824219%
val_loss: 0.07961395150050521


Progress:  20%|██        | 4/20 [01:34<06:18, 23.65s/it]

Test accuracy of server: 97.7199935913086%
val_loss: 0.08090309845283628


Progress:  25%|██▌       | 5/20 [01:59<06:02, 24.17s/it]

Test accuracy of server: 97.66999816894531%
val_loss: 0.0816525318659842


Progress:  30%|███       | 6/20 [02:23<05:35, 23.98s/it]

Test accuracy of server: 97.67999267578125%
val_loss: 0.08376806182786822


Progress:  35%|███▌      | 7/20 [02:46<05:09, 23.80s/it]

Test accuracy of server: 97.69999694824219%
val_loss: 0.08381795464083552


Progress:  40%|████      | 8/20 [03:10<04:44, 23.71s/it]

Test accuracy of server: 97.7199935913086%
val_loss: 0.08524775435216725


Progress:  45%|████▌     | 9/20 [03:34<04:23, 23.98s/it]

Test accuracy of server: 97.7199935913086%
val_loss: 0.08545513963326812


Progress:  50%|█████     | 10/20 [03:59<04:03, 24.39s/it]

Test accuracy of server: 97.73999786376953%
val_loss: 0.08651515864767134


Progress:  55%|█████▌    | 11/20 [04:23<03:38, 24.26s/it]

Test accuracy of server: 97.7199935913086%
val_loss: 0.08824604540131986


Progress:  60%|██████    | 12/20 [04:49<03:17, 24.68s/it]

Test accuracy of server: 97.70999908447266%
val_loss: 0.0878830412402749


Progress:  65%|██████▌   | 13/20 [05:13<02:51, 24.56s/it]

Test accuracy of server: 97.68999481201172%
val_loss: 0.08727805956732482


Progress:  70%|███████   | 14/20 [05:40<02:30, 25.04s/it]

Test accuracy of server: 97.66999816894531%
val_loss: 0.09041420766152442


Progress:  75%|███████▌  | 15/20 [06:06<02:07, 25.54s/it]

Test accuracy of server: 97.63999938964844%
val_loss: 0.09022549306973815


Progress:  80%|████████  | 16/20 [06:34<01:44, 26.21s/it]

Test accuracy of server: 97.5999984741211%
val_loss: 0.09096389613114297


Progress:  85%|████████▌ | 17/20 [06:58<01:16, 25.42s/it]

Test accuracy of server: 97.55999755859375%
val_loss: 0.09108555503189564


Progress:  90%|█████████ | 18/20 [07:21<00:49, 24.84s/it]

Test accuracy of server: 97.53999328613281%
val_loss: 0.09169882838614285


Progress:  95%|█████████▌| 19/20 [07:46<00:24, 24.98s/it]

Test accuracy of server: 97.53999328613281%
val_loss: 0.093390446389094


Progress: 100%|██████████| 20/20 [08:10<00:00, 24.52s/it]

Test accuracy of server: 97.55999755859375%





In [None]:
import torchvision
resnet = torchvision.models.resnet50()
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
for child in resnet.children():
        if isinstance(child, nn.Linear):
            print(child)

Linear(in_features=2048, out_features=1000, bias=True)


In [None]:
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')

In [None]:
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
for child in model.children():
        if isinstance(child, nn.Linear):
            print(child)