In [52]:
import argparse
import os
import random
import shutil
import time
import warnings
import json

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms

from model_analysis_nets import *
from CKA import linear_CKA, kernel_CKA

In [53]:
trans_mnist = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
dataset_test = datasets.MNIST('data/mnist/', train=False, download=True, transform=trans_mnist)
test_loader = torch.utils.data.DataLoader(
                dataset_test, batch_size=32,
                num_workers=2, pin_memory=True)

In [54]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [55]:
attack_pth = "/mnt/sda3/docker_space/Code/TDA-NN/3090/LG-FedAvg/save_attack_ub/cifar10/vgg_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-48-57/local_attack_save"
normal_pth = "/mnt/sda3/docker_space/Code/TDA-NN/3090/LG-FedAvg/save_attack_ub/cifar10/vgg_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-48-57/local_normal_save"


In [103]:
# md1_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/cifar10/resnet20_iidTrue_num100_C0.8_le2/shard2/pattern12-01--00-32-01/local_attack_save/iter_18_attack_0.pt"
# md2_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/cifar10/resnet20_iidTrue_num100_C0.8_le2/shard2/pattern12-01--00-32-01/local_normal_save/iter_18_normal_21.pt"
md1_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/mnist/lenet_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-49-20/local_attack_save/iter_8_attack_1.pt"
md2_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/mnist/lenet_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-49-20/local_normal_save/iter_8_normal_24.pt"
md3_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/mnist/lenet_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-49-20/local_attack_save/iter_8_attack_2.pt"
md4_pth = "../TDA-NN/3090/LG-FedAvg/save_attack_ub/mnist/lenet_iidTrue_num100_C0.8_le2/shard2/pattern11-30--17-49-20/local_normal_save/iter_8_normal_27.pt"

# md1 = torch.load(md1_pth)
# md2 = torch.load(md2_pth)
# md3 = torch.load(md3_pth)
# md4 = torch.load(md4_pth)
model1 = LeNet().to(device)
model1.load_state_dict(torch.load(md1_pth))
model2 = LeNet().to(device)
model2.load_state_dict(torch.load(md2_pth))
model3 = LeNet().to(device)
model3.load_state_dict(torch.load(md3_pth))
model4 = LeNet().to(device)
model4.load_state_dict(torch.load(md4_pth))

<All keys matched successfully>

In [104]:
print(model1,model3)


LeNet(
  (layer1): Sequential(
    (0): Conv2d(1, 25, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer2): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (0): Conv2d(25, 50, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer4): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=1250, out_features=2048, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=2048, out_features=1024, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=1024, out_features=128, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
) 

In [105]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model1.parameters(), lr=0.01, momentum=0.5)

for X, Y in test_loader:
    X_test = X.to(device)
    Y_test = Y.to(device)
    break

# Forward pass
outputs = model1(X_test)

# Compute the loss
loss = criterion(outputs, Y_test)

# Zero the gradients
optimizer.zero_grad()

# Backward pass
loss.backward()

# Access the gradients
gradients1 = {}
for name, param in model1.named_parameters():
    gradients1[name] = param.grad

In [106]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model2.parameters(), lr=0.01, momentum=0.5)

for X, Y in test_loader:
    X_test = X.to(device)
    Y_test = Y.to(device)
    break

# Forward pass
outputs = model2(X_test)

# Compute the loss
loss = criterion(outputs, Y_test)

# Zero the gradients
optimizer.zero_grad()

# Backward pass
loss.backward()

# Access the gradients
gradients2 = {}
for name, param in model2.named_parameters():
    gradients2[name] = param.grad

In [107]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model3.parameters(), lr=0.01, momentum=0.5)

for X, Y in test_loader:
    X_test = X.to(device)
    Y_test = Y.to(device)
    break

# Forward pass
outputs = model3(X_test)

# Compute the loss
loss = criterion(outputs, Y_test)

# Zero the gradients
optimizer.zero_grad()

# Backward pass
loss.backward()

# Access the gradients
gradients3 = {}
for name, param in model3.named_parameters():
    gradients3[name] = param.grad

In [108]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model4.parameters(), lr=0.01, momentum=0.5)

for X, Y in test_loader:
    X_test = X.to(device)
    Y_test = Y.to(device)
    break

# Forward pass
outputs = model4(X_test)

# Compute the loss
loss = criterion(outputs, Y_test)

# Zero the gradients
optimizer.zero_grad()

# Backward pass
loss.backward()

# Access the gradients
gradients4 = {}
for name, param in model4.named_parameters():
    gradients4[name] = param.grad

In [109]:
print(gradients1.keys())

dict_keys(['layer1.0.weight', 'layer1.0.bias', 'layer1.1.weight', 'layer1.1.bias', 'layer3.0.weight', 'layer3.0.bias', 'layer3.1.weight', 'layer3.1.bias', 'fc.0.weight', 'fc.0.bias', 'fc.2.weight', 'fc.2.bias', 'fc.4.weight', 'fc.4.bias', 'fc.6.weight', 'fc.6.bias'])


In [110]:
acts1 = gradients1["fc.2.weight"]
acts2 = gradients2["fc.2.weight"]
acts3 = gradients3["fc.2.weight"]
acts4 = gradients4["fc.2.weight"]

In [111]:
print("activation shapes", acts1.shape, acts2.shape)

activation shapes torch.Size([1024, 2048]) torch.Size([1024, 2048])


In [112]:
acts1 = acts1.cpu()
acts2 = acts2.cpu()
acts3 = acts3.cpu()
acts4 = acts4.cpu()


In [113]:
print('Linear CKA: {}'.format(linear_CKA(acts1.T, acts2.T)))
print('RBF Kernel: {}'.format(kernel_CKA(acts1.T, acts2.T)))

Linear CKA: 1.0
RBF Kernel: 1.0


In [114]:
print('Linear CKA: {}'.format(linear_CKA(acts1.T, acts3.T)))
print('RBF Kernel: {}'.format(kernel_CKA(acts1.T, acts3.T)))

Linear CKA: 0.7925365230940459
RBF Kernel: 0.7706235972869823


In [115]:
print('Linear CKA: {}'.format(linear_CKA(acts2.T, acts4.T)))
print('Linear CKA: {}'.format(kernel_CKA(acts2.T, acts4.T)))

Linear CKA: 0.7449127272374326
Linear CKA: 0.7440736646438636
