# Gradient transfer between two models

In [1]:
import torch
from asyncfl.network import MNIST_CNN
from asyncfl.network import flatten, flatten_g, unflatten, unflatten_g
from asyncfl.dataloader import afl_dataset

### Create two models

In [2]:
# Load model
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
dataset_name = 'mnist'
train_set, test_set = afl_dataset(dataset_name)
model = MNIST_CNN().to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


model2 = MNIST_CNN().to(device)
loss_function2 = torch.nn.CrossEntropyLoss()
optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.01, momentum=0.5)
model

MNIST_CNN(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
  (criterion): CrossEntropyLoss()
)

### Train
Train the model for a batch and create a flattend vector

In [3]:
w_flat = flatten(model)
g_flat = torch.zeros_like(w_flat)


try:
    inputs, labels = next(train_set)
except StopIteration as _si:
    # Reload data
    train_set, test_set = afl_dataset(dataset_name)
    inputs, labels = next(train_set)

# Make sure the data and labels are set to either GPU or CPU
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()

# Convert model gradients into flattened vector
flatten_g(model,g_flat)
# Show vector
g_flat

tensor([-0.0041, -0.0091, -0.0163,  ..., -0.0523, -0.0091,  0.0472],
       device='cuda:0')

### Transfer Vector
Transfer the gradient vector to the second model

In [4]:
# Unflatten the vector into the model
unflatten_g(model2, g_flat)

In [8]:
# Get the gradient vector from the second model to check if the vectors are correct
w_flat2 = flatten(model2)
g_flat2 = torch.zeros_like(w_flat2)

flatten_g(model2, g_flat2)
print(f'gradient tensors are equal ? {torch.equal(g_flat, g_flat2)}')
g_flat2

gradient tensors are equal ? True


tensor([-0.0041, -0.0091, -0.0163,  ..., -0.0523, -0.0091,  0.0472],
       device='cuda:0')