In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd / content/gdrive/MyDrive/cs394n_project/CS394N
! pip3 install -r requirements.txt

In [None]:
# Update path for custom module support in Google Colab
import sys
sys.path.append('/content/gdrive/MyDrive/cs394n_project/CS394N/src')

In [4]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

#from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import *
from utils.dataset_tools import split_training_data
from utils.feature_extractor import *
from utils.cosine_similarity import *
from utils.gen_dataset import *

In [5]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [6]:
model_dir = './models/'
log_dir = './logs/'
data_dir = './data/'
datasets_dir = './datasets/'

model_selection = 'cnn' # linear | cnn | vgg
dataset_selection = 'cifar10' # cifar10 | fashionmnist

ckpt_file = model_dir + model_selection + '_' + dataset_selection + '_' + 'holdout_[8, 9].pt'
gen_dataset_path = datasets_dir + "g_" + dataset_selection + '/annotations'  + '.csv'
print(gen_dataset_path)

holdout_classes = [8, 9]
new_class = 8

batch_size = 32
num_classes = 9

./datasets/g_cifar10/annotations.csv


#### Hyperparameters

In [7]:
num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

# Data Preparation

In [9]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [10]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
total_classes = len(np.unique(train_data.targets))

## FOL

In [8]:
if model_selection == 'linear':
    fol_model = add_output_nodes(ckpt_file, device, arch='linear')
    fol_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    fol_model = add_output_nodes(ckpt_file, device, arch='cnn-demo')
    fol_model.conv1.requires_grad_(False)
    fol_model.conv2.requires_grad_(False)
    fol_model.fc1.requires_grad_(False)
elif model_selection == 'cnn':
    fol_model = add_output_nodes(ckpt_file, device, arch='cnn')
    fol_model.conv_block1.requires_grad_(False)
    fol_model.conv_block2.requires_grad_(False)
    fol_model.conv_block3.Conv5.requires_grad_(False)
    fol_model.conv_block3.Relu5.requires_grad_(False)
    fol_model.conv_block3.BN5.requires_grad_(False)
    
fol_model = fol_model.to(device)

In [12]:
fol_optimizer = torch.optim.Adam(fol_model.parameters(), lr=initial_learning_rate)
fol_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=fol_optimizer, gamma=decay_rate)

In [None]:
# just train on the new class
included_data, excluded_data = split_training_data(train_data, [new_class]) 
train_fol_loader = DataLoader(excluded_data, batch_size=batch_size, shuffle=True, num_workers=2)

# but test on the full 9 classes (old classes + new one, still excluding one)
included_data, excluded_data = split_training_data(test_data, [8])
test_fol_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
model_file_fol = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '.pt'

recall_file_fol = log_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_recall.npy'
train_losses_file_fol = log_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_train_loss.txt'
test_losses_file_fol = log_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'fol' + '_test_loss.txt'

### Training Loop

In [None]:
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_fol_loader, fol_model, loss_fn, fol_optimizer, device)
    test_loss, y_pred, y_actual = test(test_fol_loader, fol_model, loss_fn, device, swap=True, swap_labels=[9,8])
    print('y_pred:', y_pred[:2])
    print('y_actual:', y_actual[:2])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    fol_lr_scheduler.step()
    
torch.save(fol_model.state_dict(), model_file_fol)


recalls = get_recall_per_epoch(y_actuals, y_preds, num_classes)
np.save(recall_file_fol, recalls)

with open(train_losses_file_fol, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file_fol, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

# Ok I'm struggling with getting data loading correct for FOL but I'm not sure we even need it

## SWIL

In [8]:
if model_selection == 'linear':
    swil_model = add_output_nodes(ckpt_file, arch='linear')
    swil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    swil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    swil_model.conv1.requires_grad_(False)
    swil_model.conv2.requires_grad_(False)
    swil_model.fc1.requires_grad_(False)
    
swil_model = swil_model.to(device)

TypeError: add_output_nodes() missing 1 required positional argument: 'device'

In [None]:
swil_optimizer = torch.optim.Adam(swil_model.parameters(), lr=initial_learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=swil_optimizer, gamma=decay_rate)

In [None]:
fmnist_classes = list(range(8)) + [9]

# might not need these
FMNIST_trainloader_gen = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
FMNIST_testloader_gen = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

class_subsets, class_idxs, subset_size = generate_dls(train_data, fmnist_classes)

In [None]:
with open(r'./data/fmnist_sim_scores_boot.txt', 'r') as fp:
    sim_scores = [float(i) for i in fp.readlines()]

sim_sum = sum(sim_scores)

sim_norms = [x/sim_sum for x in sim_scores]

boots_sample_size = 75
sim_sample_sizes = [27 if x < 0.2 else int(x * boots_sample_size*3.52) for x in sim_norms] + [75]

In [None]:
from random import sample

sampled_idxs = []

for i in range(len(fmnist_classes)):
    idx_sample = sample(class_idxs[i].tolist(), sim_sample_sizes[i])
    sampled_idxs += idx_sample

swil_train_subset = torch.utils.data.Subset(train_data, sampled_idxs)

swil_train_dl = torch.utils.data.DataLoader(swil_train_subset, batch_size=1, shuffle=True, num_workers=2)

included_data, excluded_data = split_training_data(test_data, [8])
test_swil_loader = DataLoader(included_data, batch_size=batch_size, shuffle=True, num_workers=2)

### Training Loop

In [None]:
weight_dir = './weights/'
model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil.pt'
recall_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_recall.npy'
train_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_train_loss.txt'
test_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'swil_test_loss.txt'
train_losses = []
test_losses = []
#t = trange(num_epochs)
t = range(num_epochs)
y_preds = []
y_actuals = []

for epoch in t:
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(swil_train_dl, swil_model, loss_fn, swil_optimizer, device, swap=True, swap_labels=[9,8])
    test_loss, y_pred, y_actual = test(test_swil_loader, swil_model, loss_fn, device, swap=True, swap_labels=[9,8])
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    y_preds.append(y_pred)
    y_actuals.append(y_actual)
    
    #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
    lr_scheduler.step()
    
torch.save(swil_model.state_dict(), model_file)

recalls = get_recall_per_epoch(y_actuals, y_preds, 9)
np.save(recall_file, recalls)

with open(train_losses_file, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

## G-SWIL

In [12]:
if model_selection == 'linear':
    gswil_model = add_output_nodes(ckpt_file, arch='linear', device=device)
    gswil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn-demo', device=device)
    gswil_model.conv1.requires_grad_(False)
    gswil_model.conv2.requires_grad_(False)
    gswil_model.fc1.requires_grad_(False)
elif model_selection == 'cnn':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn', device=device)
    gswil_model.conv_block1.requires_grad_(False)
    gswil_model.conv_block2.requires_grad_(False)
    gswil_model.conv_block3.Conv5.requires_grad_(False)
    gswil_model.conv_block3.Relu5.requires_grad_(False)
    gswil_model.conv_block3.BN5.requires_grad_(False)
    
gswil_model = gswil_model.to(device)

In [13]:
gswil_optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
gswil_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=gswil_optimizer, gamma=decay_rate)

In [59]:
# this will be different for cnn 

sim_scores = np.load('./data/cifar10_sim_scores_cat.npy').tolist()

sim_sum = sum(sim_scores)

sim_norms = [x/sim_sum for x in sim_scores]

cat_sample_size = 235
sim_sample_sizes = [int(x * cat_sample_size*2.34) for x in sim_norms] + [cat_sample_size]
print(sim_sample_sizes)
print(sum(sim_sample_sizes))

[59, 60, 80, 64, 55, 42, 41, 46, 100, 235]
782


In [40]:
from torch.utils.data import Dataset
import os
from PIL import Image

import pandas as pd

class GenDataset(Dataset):
    """All-purpose (FMNIST or CIFAR10) generated dataset."""

    def __init__(self, csv_file, root_dir, data_name='img', class_name='label', transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.files = self.landmarks_frame[data_name].values
        self.data = self._load_data()
        self.targets = self.landmarks_frame[class_name].values
        
    def _load_data(self):
        #data = np.empty((len(self.files),28,28))
        data = np.zeros([len(self.files),3,32,32])
        
        i = 0
        
        for f in self.files:
            img_name = os.path.join(self.root_dir, f)
            pil_img = Image.open(img_name)
            img = np.asarray(pil_img)
            #plt.imshow(pil_img)
            #plt.show()
            
            #img = torch.tensor(img)
            #img1 = img / 2 + 0.5     # unnormalize
            #npimg = img1.numpy()
            #plt.imshow(np.transpose(npimg, (1, 2, 0)))
            #plt.show()
            
            if self.transform is not None:
                img = Image.open(img_name)
                img = self.transform(img)
                
            #img2 = img / 2 + 0.5     # unnormalize
            #npimg = img2.numpy()
            #plt.imshow(np.transpose(npimg, (1, 2, 0)))
            #plt.show()
                
            img = np.asarray(img)
            #print(img[0][0])
            
            #np.append(data, img)
            data[i,:,:,:] = img
            i += 1
            #if i == 10:
            #    break
        
        return data

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        #if torch.is_tensor(idx):
        #    idx = idx.tolist()
        
        img, target = self.data[idx], self.targets[idx]
    
        
        return torch.tensor(img, dtype=torch.float32), torch.tensor(target)
    

In [41]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        #transforms.Resize(28),
        #transforms.CenterCrop(28), 
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Normalize((0.5), (0.5))]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

gcifar10_dataset = GenDataset(csv_file=gen_dataset_path,
                                    root_dir=datasets_dir + 'g_cifar10', data_name='filename', class_name='target', transform=transform)


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    #print(nimg.shape)
    #plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.imshow(npimg[0])
    plt.show()
    
i = 0

"""
for item, idx in gcifar10_dataset_test:
    print(item[0][0], idx)
    print("shape", item.shape)
    imshow(item)
    i += 1
    if i == 3:
        break
    
for item1, idx1 in test_data:
    print(item1[0][0])
    imshow(item1)
    break
    """

'\nfor item, idx in gcifar10_dataset_test:\n    print(item[0][0], idx)\n    print("shape", item.shape)\n    imshow(item)\n    i += 1\n    if i == 3:\n        break\n    \nfor item1, idx1 in test_data:\n    print(item1[0][0])\n    imshow(item1)\n    break\n    '

In [42]:
from copy import copy

def reorder_classes(dataset, new_order):
    '''
        new_order: Dict<k:INT, v:(INT, BOOL)>
            - key: Current class number
            - val[0]: New class number
            - val[1]: Swap T/F
                - T: bidirectional update between two classes ex. 1<-->3
                - F: unidirectional update ex. 1->3 3->4 4->1

        IMPORTANT: Assumes class numbers are 0-indexed
    '''

    old_indices = dict()
    unique_targets = np.unique(dataset.targets).tolist()
    all_targets = np.array(dataset.targets)
    labels = dataset.classes
    new_labels = copy(labels)

    # Gather initial target mappings
    for target in unique_targets:
        old_indices[target] = np.where(all_targets == target)[0]

    for old_target, new_target in new_order.items():
        # Assign new target value to datapoints
        np.put(all_targets, old_indices[old_target], new_target[0])
        # Update label list
        new_labels[new_target[0]] = labels[old_target]

        if new_target[1]: # Swap = True
            # Target and label update if swapping with another class
            np.put(all_targets, old_indices[new_target[0]], old_target)
            new_labels[old_target] = labels[new_target[0]]

    return all_targets.tolist(), new_labels

if dataset_selection == 'fashionmnist':
    # FashionMNIST match torchvision with paper
    ordering = {
        5: (6, True),
        8: (9, True),
    }
elif dataset_selection == 'cifar10':
    # CIFAR10 match torchvision with paper
    ordering = {
        0:(5, False),
        1:(9, False),
        2: (0, False),
        3: (8, False),
        4: (1, False),
        5: (2, False),
        6: (3, False),
        7: (4, False),
        8: (6, False),
        9: (7, False),
    }
    
reordered_targets, reordered_class = reorder_classes(train_data, ordering)
reordered_targets_t, reordered_class_t = reorder_classes(test_data, ordering)
train_data.targets = reordered_targets
test_data.targets = reordered_targets_t
train_data.classes = reordered_class
test_data.classes = reordered_class_t

In [61]:
from random import sample

cifar10_classes = list(range(8)) # just get old classes from generated data

class_subsets, old_class_idxs, subset_size = generate_dls(gcifar10_dataset, cifar10_classes)

sampled_idxs = []
for i in range(len(cifar10_classes)):
    idx_sample = sample(old_class_idxs[i].tolist(), sim_sample_sizes[i])
    sampled_idxs += idx_sample
    
gswil_train_old_subset = torch.utils.data.Subset(gcifar10_dataset, sampled_idxs)

gswil_new_subset, new_class_idx, subset_size = generate_dls(train_data, [8])
new_class_idx_sample = sample(new_class_idx[0].tolist(), sim_sample_sizes[-1])
gswil_train_new_subset = torch.utils.data.Subset(train_data, new_class_idx_sample)


gswil_combined_subset = torch.utils.data.ConcatDataset([gswil_train_old_subset, gswil_train_new_subset])

gswil_train_dl = torch.utils.data.DataLoader(gswil_combined_subset, batch_size=1, shuffle=True, num_workers=2)

test_idx = np.where((np.array(test_data.targets) != 9))[0]
test_subset = torch.utils.data.Subset(test_data, test_idx)
gswil_test_dl = torch.utils.data.DataLoader(test_subset, batch_size=1, shuffle=True, num_workers=2)


100
447
235
682


In [62]:
def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(npimg) #np.transpose(npimg, (1, 2, 0)))
    plt.show()

i = 0

j = 0
#for img, label in 

for img, label in gswil_train_dl:
    #imshow(img)
    print(img.shape)
    print(img.dtype)
    print(label)
    #print(img, label)
    i += 1
    if i == 20:
        break

torch.Size([1, 3, 32, 32])
torch.float32
tensor([6])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([2])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([7])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([0])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([2])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([8])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([8])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([8])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([8])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([7])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([1])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([6])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([4])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([0])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([8])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([8])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([0])
torch.Size([1, 3, 32, 32])
torch.float32
tensor([6])
torch.Size([1, 3, 32, 32])
torch.float32
tenso

### Training Loop

In [66]:
def train(dataloader, model, loss_fn, optimizer, device, swap=False, swap_labels=[]) -> float:
    '''
        Model training loop. Performs a single epoch of model updates.
        
        * USAGE *
        Within a training loop of range(num_epochs).

        * PARAMETERS *
        dataloader: A torch.utils.data.DataLoader object
        model: A torch model which subclasses torch.nn.Module
        loss_fn: A torch loss function, such as torch.nn.CrossEntropyLoss
        optimizer: A torch.optim optimizer
        device: 'cuda' or 'cpu'

        * RETURNS *
        float: The model's average epoch loss 
    '''

    size = len(dataloader.dataset)
    train_loss = 0
    
    model.conv_block1.requires_grad_(False)
    model.conv_block2.requires_grad_(False)
    model.conv_block3.Conv5.requires_grad_(False)
    model.conv_block3.Relu5.requires_grad_(False)
    model.conv_block3.BN5.requires_grad_(False)

    #model.train()
    for batch, (X, y) in enumerate(dataloader):
        #print(X, y)
        #print(X.shape)
        if swap:
            for i in range(len(y)):
                if y[i] == swap_labels[0]:
                    y[i] = swap_labels[1]
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        
        # Compute prediction error
        pred = model(X)

        # Backpropagation
        #import pdb; pdb.set_trace()
        #print('pred', pred)
        #print('y', y)
        loss = loss_fn(pred, y)
        
        loss.backward()
        optimizer.step()

        # Append lists
        train_loss += loss.item()

        if batch % 1000 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return train_loss/len(dataloader)



In [67]:
def test(dataloader, model, loss_fn, device, swap=False, swap_labels=[], classes = 9) -> float:
    '''
        Model test loop. Performs a single epoch of model updates.

        * USAGE *
        Within a training loop of range(num_epochs) to perform epoch validation, or after training to perform testing.

        * PARAMETERS *
        dataloader: A torch.utils.data.DataLoader object
        model: A torch model which subclasses torch.nn.Module
        loss_fn: A torch loss function, such as torch.nn.CrossEntropyLoss
        optimizer: A torch.optim optimizer
        device: 'cuda' or 'cpu'

        * RETURNS *
        float: The average test loss
    '''

    # TODO: can the swap stuff be removed now?
    
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    y_pred_list, targets = [], []

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            if swap:
                for i in range(len(y)):
                    if y[i] == swap_labels[0]:
                        y[i] = swap_labels[1]
            X, y = X.to(device), y.to(device)
            pred = model(X)
            targets.append(y.tolist())
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
            _, y_pred_tags = torch.max(pred, dim=1)
            y_pred_list.append(y_pred_tags.cpu().numpy())
            
    y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
    #y_pred_list = [item for sublist in y_pred_list for item in sublist]
    
    targets = [item for sublist in targets for item in sublist]

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return test_loss, np.asarray(y_pred_list), np.asarray(targets), 100*correct

In [72]:
for j in range(4):
    if model_selection == 'linear':
        gswil_model = add_output_nodes(ckpt_file, arch='linear', device=device)
        gswil_model.input_layer.requires_grad_(False)
    elif model_selection == 'cnn-demo':
        gswil_model = add_output_nodes(ckpt_file, arch='cnn-demo', device=device)
        gswil_model.conv1.requires_grad_(False)
        gswil_model.conv2.requires_grad_(False)
        gswil_model.fc1.requires_grad_(False)
    elif model_selection == 'cnn':
        gswil_model = add_output_nodes(ckpt_file, arch='cnn', device=device)
        gswil_model.conv_block1.requires_grad_(False)
        gswil_model.conv_block2.requires_grad_(False)
        gswil_model.conv_block3.Conv5.requires_grad_(False)
        gswil_model.conv_block3.Relu5.requires_grad_(False)
        gswil_model.conv_block3.BN5.requires_grad_(False)

    gswil_model = gswil_model.to(device)

    gswil_optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
    gswil_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=gswil_optimizer, gamma=decay_rate)

    cifar10_classes = list(range(8)) # just get old classes from generated data

    class_subsets, old_class_idxs, subset_size = generate_dls(gcifar10_dataset, cifar10_classes)

    sampled_idxs = []
    for i in range(len(cifar10_classes)):
        idx_sample = sample(old_class_idxs[i].tolist(), sim_sample_sizes[i])
        sampled_idxs += idx_sample

    gswil_train_old_subset = torch.utils.data.Subset(gcifar10_dataset, sampled_idxs)

    gswil_new_subset, new_class_idx, subset_size = generate_dls(train_data, [8])
    new_class_idx_sample = sample(new_class_idx[0].tolist(), sim_sample_sizes[-1])
    gswil_train_new_subset = torch.utils.data.Subset(train_data, new_class_idx_sample)


    gswil_combined_subset = torch.utils.data.ConcatDataset([gswil_train_old_subset, gswil_train_new_subset])

    gswil_train_dl = torch.utils.data.DataLoader(gswil_combined_subset, batch_size=1, shuffle=True, num_workers=2)

    test_idx = np.where((np.array(test_data.targets) != 9))[0]
    test_subset = torch.utils.data.Subset(test_data, test_idx)
    gswil_test_dl = torch.utils.data.DataLoader(test_subset, batch_size=1, shuffle=True, num_workers=2)

    weight_dir = './weights/'
    model_file = weight_dir + model_selection + '_' + dataset_selection + '_gswil' + str(j) + '.pt'
    recall_file = weight_dir + model_selection + '_' + dataset_selection +  '_gswil_recall'+ str(j) +'.npy'
    train_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_gswil_train_loss'+ str(j) +'.txt'
    test_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_gswil_test_loss'+ str(j) +'.txt'
    accuracies_file = weight_dir + model_selection + '_' + dataset_selection + '_gswil_accuracy' + str(j) +'.txt'

    train_losses = []
    test_losses = []
    accuracies = []
    t = range(num_epochs)
    y_preds = []
    y_actuals = []

    for epoch in t:
        print(f"Epoch {epoch+1}\n-------------------------------")
        train_loss = train(gswil_train_dl, gswil_model, loss_fn, gswil_optimizer, device, swap=True, swap_labels=[9,8])
        test_loss, y_pred, y_actual, accuracy = test(gswil_test_dl, gswil_model, loss_fn, device, swap=True, swap_labels=[9,8])
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        accuracies.append(accuracy)
        y_preds.append(y_pred)
        y_actuals.append(y_actual)

        gswil_scheduler.step()

    torch.save(gswil_model.state_dict(), model_file)

    recalls = get_recall_per_epoch(y_actuals, y_preds, 9)
    np.save(recall_file, recalls)

    with open(train_losses_file, 'w') as fp:
        for s in train_losses:
            fp.write("%s\n" % s)

    with open(test_losses_file, 'w') as fp:
        for x in test_losses:
            fp.write("%s\n" % x)

    with open(accuracies_file, 'w') as fp:
        for x in accuracies:
            fp.write("%s\n" % x)

    print("Done!")

Epoch 1
-------------------------------
loss: 7.005485  [    0/  682]
Test Error: 
 Accuracy: 13.0%, Avg loss: 26.706254 

Epoch 2
-------------------------------
loss: 7.142556  [    0/  682]
Test Error: 
 Accuracy: 15.1%, Avg loss: 20.696906 

Epoch 3
-------------------------------
loss: 0.000000  [    0/  682]
Test Error: 
 Accuracy: 15.2%, Avg loss: 23.521334 

Epoch 4
-------------------------------
loss: 0.000000  [    0/  682]
Test Error: 
 Accuracy: 15.1%, Avg loss: 29.701057 

Epoch 5
-------------------------------
loss: 0.000220  [    0/  682]
Test Error: 
 Accuracy: 15.2%, Avg loss: 29.894157 

Epoch 6
-------------------------------
loss: 0.008674  [    0/  682]
Test Error: 
 Accuracy: 15.2%, Avg loss: 29.939462 

Epoch 7
-------------------------------
loss: 0.023498  [    0/  682]
Test Error: 
 Accuracy: 15.2%, Avg loss: 30.004676 

Epoch 8
-------------------------------
loss: 0.009207  [    0/  682]
Test Error: 
 Accuracy: 15.2%, Avg loss: 30.024080 

Epoch 9
--------

KeyboardInterrupt: 