## Load data

In [1]:
import os
os.chdir('..')

In [2]:
from src.dataset import *
from src.model import *
from src.train_test import *
import matplotlib.pyplot as plt
import pandas as pd
from torch import optim
from src.vis import *
from src.gradient import *

In [3]:
batch_size = 32
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [4]:
# load data
train_data_loader = data_loader(dataset_name = 'CIFAR10', 
                                batch_size = batch_size, 
                                train=True)
test_data_loader = data_loader(dataset_name = 'CIFAR10', 
                                batch_size = batch_size, 
                                train=False)    


Files already downloaded and verified
Files already downloaded and verified


## Original model

In [6]:
model_name = 'CIFAR17_add000'

# initialize model
model = KNOWN_MODELS[model_name]
model = model.to(device)

plot_training_acc(model, train_data_loader,
                  model_name=model_name, data_name='CIFAR10', total_trails=30)

### Get pos_grad_dict and neg_grad_dict for anchor model

In [73]:
model_name = 'CIFAR17_add210'
# initialize model
model = KNOWN_MODELS[model_name]
model = model.to(device)

pos_grad_dict, neg_grad_dict = get_graddict(model=model, 
                                            model_name=model_name, data_name='CIFAR10',
                                            train_data_loader=train_data_loader, num_trail=5, 
                                            pos_thre=5, neg_thre=0,
#                                             compute_index=True, 
                                            compute_index=False, 
                                            vis=True, save=True)

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 16/26775 [00:00<02:54, 153.76it/s]

Number of postive samples:  26775
Number of negative samples:  4902
Use trail 1 to compute conflicting gradients


100%|██████████| 26775/26775 [03:08<00:00, 142.33it/s]
  0%|          | 16/4902 [00:00<00:32, 150.16it/s]

Length of data 26775


100%|██████████| 4902/4902 [00:32<00:00, 149.78it/s]

Length of data 4902





### Get pos_grad_dict and neg_grad_dict for neighbor model

In [None]:
model_name = 'CIFAR17_add210'
neighbor_name = 'CIFAR17_add220'
# initialize model
neighbor_model = KNOWN_MODELS[neighbor_name]
neighbor_model = neighbor_model.to(device)

pos_grad_dict, neg_grad_dict = get_neighbor_graddict(model_name=model_name,
                          neighbor_model=neighbor_model,
                          neighbor_model_name=neighbor_name,
                          data_name='CIFAR10',
                          train_data_loader=train_data_loader)

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 33/26775 [00:00<01:21, 328.19it/s]

Number of postive samples:  26775
Number of negative samples:  4902
Use trail 1 to compute conflicting gradients


 47%|████▋     | 12508/26775 [00:37<01:35, 149.20it/s]

### Weights contradiction visualization

In [76]:
# weight_contradict(pos_grad_dict, neg_grad_dict, method='sign')
# weight_contradict(pos_grad_dict, neg_grad_dict, method='level')

### Get layer contradiction level

In [None]:
# avg of weight contradiction level for each layer
for name, grad_pos in pos_grad_dict.items():
    if 'weight' in name:
        grad_neg = neg_grad_dict[name]

        conflict_level = (torch.sign(grad_pos) != torch.sign(grad_neg)) * (torch.abs(grad_pos - grad_neg))
        
        conflict_level = conflict_level.mean(dim=tuple(range(1, len(conflict_level.shape)))) # sum over kernel size
        
        layer_conflict_level = conflict_level.mean().item()
        print(name, layer_conflict_level)

In [71]:
# max of weight contradiction level for each layer
for name, grad_pos in pos_grad_dict.items():
    if 'weight' in name:
        grad_neg = neg_grad_dict[name]

        conflict_level = (torch.sign(grad_pos) != torch.sign(grad_neg)) * (torch.abs(grad_pos - grad_neg))
        
        conflict_level = conflict_level.mean(dim=tuple(range(1, len(conflict_level.shape)))) # sum over kernel size
        
        layer_conflict_level = conflict_level.max().item()
        print(name, layer_conflict_level)

body.cnn1.conv.weight 0.6552368998527527
body.cnn2.conv.weight 0.15954236686229706
body.cnn3.conv.weight 0.18937832117080688
head.dense.fc1.weight 0.08281862735748291
head.dense.fc2.weight 0.04943676292896271


In [72]:
# variance of weight contradiction level for each layer
for name, grad_pos in pos_grad_dict.items():
    if 'weight' in name:
        grad_neg = neg_grad_dict[name]

        conflict_level = (torch.sign(grad_pos) != torch.sign(grad_neg)) * (torch.abs(grad_pos - grad_neg))
        
        conflict_level = conflict_level.mean(dim=tuple(range(1, len(conflict_level.shape)))) # sum over kernel size
        
        layer_conflict_level = conflict_level.var().item()
        print(name, layer_conflict_level)

body.cnn1.conv.weight 0.043868374079465866
body.cnn2.conv.weight 0.0041794911958277225
body.cnn3.conv.weight 0.0036977888084948063
head.dense.fc1.weight 0.000469876074930653
head.dense.fc2.weight 0.000319293380016461


## Evolution Trace

In [22]:
for folder in os.listdir('checkpoints/'):
    if folder.startswith('CIFAR17-CIFAR10-model'):
        os.rename(os.path.join('checkpoints/', folder),
                 os.path.join('checkpoints/', 
                              'CIFAR17_add000-CIFAR10-model'+folder.split('CIFAR17-CIFAR10-model')[1]))

In [23]:
initial_trace = ['CIFAR17_add000', 'CIFAR17_add010', 'CIFAR17_add110', 'CIFAR17_add210']

logfile = open('log/train_trace.log').readlines()
for j in logfile:
    if 'INFO:trace:Update model name' in j:
        initial_trace.append(j.strip().split(' ')[-1])

In [24]:
initial_trace

['CIFAR17_add000',
 'CIFAR17_add010',
 'CIFAR17_add110',
 'CIFAR17_add210',
 'CIFAR17_add211',
 'CIFAR17_add212',
 'CIFAR17_add222',
 'CIFAR17_add232',
 'CIFAR17_add332',
 'CIFAR17_add333',
 'CIFAR17_add334',
 'CIFAR17_add434']

In [15]:
for trace in initial_trace:
    i,j,k = int(trace.split('add')[1][0]), \
            int(trace.split('add')[1][1]), \
            int(trace.split('add')[1][2]) # current index of model
    
    if 'lasti' in locals():
        incre_i, incre_j, incre_k = i-lasti, j-lastj, k-lastk
        ct = 1
        for layer in [incre_i, incre_j, incre_k]:
            if layer == 1:
                print('Increase {}th layer'.format(ct))
                break
            ct += 1
    
    lasti, lastj, lastk = i,j,k

Increase 2th layer
Increase 1th layer
Increase 1th layer
Increase 3th layer
Increase 3th layer
Increase 2th layer
Increase 2th layer
Increase 1th layer
Increase 3th layer
Increase 3th layer
Increase 1th layer
