In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
%cd /content/gdrive/MyDrive/research/SWIL-Comparisons
! pip3 install -r requirements.txt

/content/gdrive/MyDrive/research/SWIL-Comparisons


In [None]:
# Update path for custom module support in Google Colab
import sys
sys.path.append('/content/gdrive/MyDrive/research/SWIL-Comparisons/src')

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

#from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import *
from utils.dataset_tools import split_training_data, reorder_classes
from utils.feature_extractor import *
from utils.cosine_similarity import *
from utils.gen_dataset import *

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
model_dir = '/content/gdrive/MyDrive/research/SWIL-Comparisons/src/models/'
log_dir = './logs/'
data_dir = './data/'
datasets_dir = './datasets/'

model_selection = 'cnn' # linear | cnn | vgg
dataset_selection = 'cifar10' # cifar10 | fashionmnist

ckpt_file = model_dir + model_selection + '_' + dataset_selection + '_' + 'holdout_[8, 9].pt'
gen_dataset_path = datasets_dir + "g_" + dataset_selection + '/annotations'  + '.csv'
print(gen_dataset_path)

holdout_classes = [8, 9]
new_class = 8

batch_size = 10
num_classes = 9

./datasets/g_cifar10/annotations.csv


#### Hyperparameters

In [None]:
num_epochs = 10

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

# Data Preparation

In [None]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
total_classes = len(np.unique(train_data.targets))

## FOL

In [None]:
if model_selection == 'linear':
    fil_model = add_output_nodes(ckpt_file, device, arch='linear')
    fil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    fil_model = add_output_nodes(ckpt_file, device, arch='cnn-demo')
    fil_model.conv1.requires_grad_(False)
    fil_model.conv2.requires_grad_(False)
    fil_model.fc1.requires_grad_(False)
elif model_selection == 'cnn':
    fil_model = add_output_nodes(ckpt_file, device, arch='cnn')
    fil_model.conv_block1.requires_grad_(False)
    fil_model.conv_block2.requires_grad_(False)
    #fil_model.conv_block3.Conv5.requires_grad_(False)
    #fil_model.conv_block3.Relu5.requires_grad_(False)
    #fil_model.conv_block3.BN5.requires_grad_(False)
    # TODO: what exactly should be frozen in this new model?
    fil_model.conv_block3.requires_grad_(False)
    
fil_model = fil_model.to(device)

cuda


In [None]:
# CIFAR10 match torchvision with paper
# cat and car classes left out, cat added back later
ordering = {
    0:(5, False),
    1:(9, False),
    2: (0, False),
    3: (8, False),
    4: (1, False),
    5: (2, False),
    6: (3, False),
    7: (4, False),
    8: (6, False),
    9: (7, False),
}

In [None]:
targets, classes = reorder_classes(train_data, ordering)
train_data.targets = targets
train_data.classes = classes

targets, classes = reorder_classes(test_data, ordering)
test_data.targets = targets
test_data.classes = classes

In [None]:
fil_optimizer = torch.optim.Adam(fil_model.parameters(), lr=initial_learning_rate)
fil_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=fil_optimizer, gamma=decay_rate)

In [None]:
# training on fully interleaved data
# leaving out 9 for cifar10, right?
included_data, excluded_data = split_training_data(train_data, [9])
train_fil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

# test on the full 9 classes (old classes + new one, still excluding one)
included_data, excluded_data = split_training_data(test_data, [9])
test_fil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
linear_frozen = "_none_frozen"

model_file_fil = model_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[9]' + '_' + 'fil' + linear_frozen + '.pt'

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

# no label swapping needed here for the cnn right?

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_fil_loader, fil_model, loss_fn, fil_optimizer, device)
    test_loss = test(test_fil_loader, fil_model, loss_fn, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    fil_lr_scheduler.step()
    
torch.save(fil_model.state_dict(), model_file_fil)

print("Done!")

Epoch 1
-------------------------------
loss: 31.672812  [    0/45000]
loss: 0.415219  [32000/45000]
Test Error: 
 Accuracy: 81.1%, Avg loss: 0.723884 

Epoch 2
-------------------------------
loss: 0.247373  [    0/45000]
loss: 0.034443  [32000/45000]
Test Error: 
 Accuracy: 81.8%, Avg loss: 0.848789 

Epoch 3
-------------------------------
loss: 0.095087  [    0/45000]
loss: 0.020888  [32000/45000]
Test Error: 
 Accuracy: 82.0%, Avg loss: 0.891962 

Epoch 4
-------------------------------
loss: 0.143121  [    0/45000]
loss: 0.080008  [32000/45000]
Test Error: 
 Accuracy: 82.0%, Avg loss: 0.975927 

Epoch 5
-------------------------------
loss: 0.021674  [    0/45000]
loss: 0.001383  [32000/45000]
Test Error: 
 Accuracy: 82.1%, Avg loss: 1.132829 

Epoch 6
-------------------------------
loss: 0.029534  [    0/45000]
loss: 0.004255  [32000/45000]
Test Error: 
 Accuracy: 82.2%, Avg loss: 1.228700 

Epoch 7
-------------------------------
loss: 0.003566  [    0/45000]
loss: 0.025231  [

## SWIL

In [None]:
if model_selection == 'linear':
    swil_model = add_output_nodes(ckpt_file, device, arch='linear')
    swil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn':
    swil_model = add_output_nodes(ckpt_file, device, arch='cnn')
    swil_model.conv_block1.requires_grad_(False)
    swil_model.conv_block2.requires_grad_(False)
    #fil_model.conv_block3.Conv5.requires_grad_(False)
    #fil_model.conv_block3.Relu5.requires_grad_(False)
    #fil_model.conv_block3.BN5.requires_grad_(False)
    # TODO: what exactly should be frozen in this new model?
    swil_model.conv_block3.requires_grad_(False)

    
swil_model = swil_model.to(device)

cuda


In [None]:
swil_optimizer = torch.optim.Adam(swil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=swil_optimizer, gamma=decay_rate)

### Generate average activations

In [None]:
ccidx = [5,0,8] + list(range(1,5)) + [6,7]

V = get_avg_activations(swil_model, train_data, ccidx, ['fc_block.fc3'], device)

In [None]:
cifar10_sim_vec = get_similarity_vec(V)

with open(r'/content/gdrive/MyDrive/research/SWIL-Comparisons/data/cifar10_cnn_sim_scores_fc_block_fc3.txt', 'w') as f:
  for x in cifar10_sim_vec:
    f.write(str(x))
    f.write('\n')

## Generate SWIL DLs and Train

In [None]:
cifar10_classes = list(range(9))

included_data, excluded_data = split_training_data(train_data, [9])
train_swil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

included_data, excluded_data = split_training_data(test_data, [9])
test_swil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

class_subsets, class_idxs, subset_size = generate_dls(train_data, cifar10_classes)

In [None]:
if dataset_selection == 'cifar10':
  scores_file = 'cifar10_cnn_sim_scores.txt'
  newclass_sample_size = 720
  multiplier = 1
else:
  scores_file = 'fmnist_sim_scores_boot.txt'
  newclass_sample_size = 75
  multiplier = 7

with open(r'./data/' + scores_file, 'r') as fp:
    sim_scores = [float(i) for i in fp.readlines()]

sim_sum = sum(sim_scores)

sim_norms = [x/sim_sum for x in sim_scores]

if dataset_selection == 'cifar10':
  # appears that the cap is diffrent for the cnn
  sim_sample_sizes = [int(x * newclass_sample_size * 3.8) for x in sim_norms] + [newclass_sample_size]
else:
  sim_sample_sizes = [27 if x < 0.2 else int(x * newclass_sample_size * 3.52) for x in sim_norms] + [newclass_sample_size]

In [None]:
sim_sample_sizes = [i*multiplier for i in sim_sample_sizes]
print(sim_sample_sizes) 
# [365, 469, 308, 260, 317, 316, 378, 319, 720]
# this is not replicative of the cnn in the paper
# all too similar
 # these may be wrong if 
# just fc2: [390, 401, 326, 314, 321, 312, 356, 313, 720]

# just fc3: [395, 426, 284, 289, 338, 323, 379, 299, 720]

[365, 474, 307, 258, 319, 313, 379, 317, 720]


In [None]:
from random import sample

sampled_idxs = []

for i in range(len(cifar10_classes)):
    idx_sample = sample(class_idxs[i].tolist(), sim_sample_sizes[i])
    sampled_idxs += idx_sample

swil_train_subset = torch.utils.data.Subset(train_data, sampled_idxs)

swil_train_dl = torch.utils.data.DataLoader(swil_train_subset, batch_size=1, shuffle=True, num_workers=2)

included_data, excluded_data = split_training_data(test_data, [9])
test_swil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
swil_model.cuda(device=device)

CNN_3B(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (conv_block1): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU()
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (mpool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv_block2): Sequential(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU()
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (mpool): MaxPool2d

### Training Loop

In [None]:
model_file = model_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil.pt'
t = range(num_epochs)

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(swil_train_dl, swil_model, loss_fn, swil_optimizer, device)
    test_loss = test(test_swil_loader, swil_model, loss_fn, device)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(swil_model.state_dict(), model_file)

print("Done!")

Epoch 1
-------------------------------
loss: 0.000000  [    0/ 3452]
loss: 0.000149  [ 1000/ 3452]
loss: 0.000000  [ 2000/ 3452]
loss: 0.025637  [ 3000/ 3452]
Test Error: 
 Accuracy: 74.0%, Avg loss: 1.599184 

Epoch 2
-------------------------------
loss: 0.544327  [    0/ 3452]
loss: 0.000000  [ 1000/ 3452]
loss: 0.657801  [ 2000/ 3452]
loss: 2.027863  [ 3000/ 3452]
Test Error: 
 Accuracy: 74.8%, Avg loss: 1.762350 

Epoch 3
-------------------------------
loss: 0.996023  [    0/ 3452]
loss: 0.565895  [ 1000/ 3452]
loss: 0.000000  [ 2000/ 3452]
loss: 0.000000  [ 3000/ 3452]
Test Error: 
 Accuracy: 75.6%, Avg loss: 2.415757 

Epoch 4
-------------------------------
loss: 0.000000  [    0/ 3452]
loss: 0.000000  [ 1000/ 3452]
loss: 0.000000  [ 2000/ 3452]
loss: 0.017456  [ 3000/ 3452]
Test Error: 
 Accuracy: 76.0%, Avg loss: 2.238202 

Epoch 5
-------------------------------
loss: 0.000740  [    0/ 3452]
loss: 0.000000  [ 1000/ 3452]
loss: 0.002001  [ 2000/ 3452]
loss: 0.000013  [ 3000