In [41]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10, FashionMNIST

from tqdm.autonotebook import tqdm, trange

from utils.nets import *
from utils.model_tools import *
from utils.dataset_tools import split_training_data
from utils.feature_extractor import *
from utils.cosine_similarity import *

In [42]:
model_dir = './models/'
log_dir = './logs'

model_selection = 'linear' # linear | cnn | vgg
dataset_selection = 'fashionmnist' # cifar10 | fashionmnist

ckpt_file = model_dir+'linear_fashionmnist_holdout_[8, 9].pt'

holdout_classes = [8, 9]
new_class = 8

batch_size = 32
num_classes = 9

In [43]:
num_epochs = 15

initial_learning_rate = 0.001
final_learning_rate = 0.0001

# initial_lr * decay_rate^num_epochs = final_lr
decay_rate = (final_learning_rate/initial_learning_rate)**(1/num_epochs)

loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=initial_learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)

In [44]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [45]:
if dataset_selection == 'cifar10':
    train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_selection == 'fashionmnist':
    train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_data = FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [46]:
total_classes = len(np.unique(train_data.targets))

In [47]:
device = 'cpu'
if model_selection == 'linear':
    gswil_model = add_output_nodes(ckpt_file, arch='linear', device=device)
    gswil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn-demo', device=device)
    gswil_model.conv1.requires_grad_(False)
    gswil_model.conv2.requires_grad_(False)
    gswil_model.fc1.requires_grad_(False)
elif model_selection == 'cnn':
    gswil_model = add_output_nodes(ckpt_file, arch='cnn', device=device)
    gswil_model.conv_block1.requires_grad_(False)
    gswil_model.conv_block2.requires_grad_(False)
    gswil_model.conv_block3.Conv5.requires_grad_(False)
    gswil_model.conv_block3.Relu5.requires_grad_(False)
    gswil_model.conv_block3.BN5.requires_grad_(False)
    
gswil_model = gswil_model.to(device)


input_size 784
num_outputs 9


In [48]:
gswil_optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
gswil_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=gswil_optimizer, gamma=decay_rate)


In [49]:
sim_sample_sizes = np.load('./data/fmnist_sim_scores_boot2.npy')
sim_sample_sizes = np.array(sim_sample_sizes.tolist(),dtype='uint8')

In [50]:
from torch.utils.data import Dataset
import os
from PIL import Image

import pandas as pd

class GenDataset(Dataset):
    """All-purpose (FMNIST or CIFAR10) generated dataset."""

    def __init__(self, csv_file, root_dir, data_name='img', class_name='label', transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.files = self.landmarks_frame[data_name].values
        self.data = self._load_data()
        self.targets = self.landmarks_frame[class_name].values
        
    def _load_data(self):
        #data = np.empty((len(self.files),28,28))
        data = np.zeros([len(self.files),1,28,28])
        
        i = 0
        
        for f in self.files:
            img_name = os.path.join(self.root_dir, f)
            pil_img = Image.open(img_name)
            img = np.asarray(pil_img)
            #plt.imshow(pil_img)
            #plt.show()
            
            #img = torch.tensor(img)
            #img1 = img / 2 + 0.5     # unnormalize
            #npimg = img1.numpy()
            #plt.imshow(np.transpose(npimg, (1, 2, 0)))
            #plt.show()
            
            if self.transform is not None:
                img = Image.open(img_name)
                img = self.transform(img)
                
            #img2 = img / 2 + 0.5     # unnormalize
            #npimg = img2.numpy()
            #plt.imshow(np.transpose(npimg, (1, 2, 0)))
            #plt.show()
                
            img = np.asarray(img)
            #print(img[0][0])
            
            #np.append(data, img)
            data[i,:,:,:] = img
            i += 1
            #if i == 10:
            #    break
        
        return data

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        #if torch.is_tensor(idx):
        #    idx = idx.tolist()
        
        img, target = self.data[idx], self.targets[idx]
        
        # this is done for FMNIST, need for CIFAR10?
        #img = Image.fromarray((img[:,:,:3] * [0.2989, 0.5870, 0.1140]).sum(axis=2), mode="L")
        
        return torch.tensor(img, dtype=torch.float32), torch.tensor(target)

In [51]:
if dataset_selection == 'fashionmnist':
    transform = transforms.Compose([
        #transforms.Resize(28),
        #transforms.CenterCrop(28), 
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Normalize((0.5), (0.5))]) # Images are grayscale -> 1 channel
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

gfmnist_dataset_test = GenDataset(csv_file='data2/g_fashionmnist/annotations.csv',
                                    root_dir='data2/g_fashionmnist/', data_name='filename', class_name='target', transform=transform)


In [52]:
from random import sample

gfmnist_dataset = gfmnist_dataset_test
fmnist_classes = list(range(8)) # just get old classes from generated data

class_subsets, old_class_idxs, subset_size = generate_dls(gfmnist_dataset, fmnist_classes)

sampled_idxs = []
for i in range(len(fmnist_classes)):
    idx_sample = sample(old_class_idxs[i].tolist(), sim_sample_sizes[i])
    sampled_idxs += idx_sample
    
#print("sampled_idxs", sampled_idxs)
    
gswil_train_old_subset = torch.utils.data.Subset(gfmnist_dataset, sampled_idxs)

gswil_new_subset, new_class_idx, subset_size = generate_dls(train_data, [9])
new_class_idx_sample = sample(new_class_idx[0].tolist(), sim_sample_sizes[-1])
gswil_train_new_subset = torch.utils.data.Subset(train_data, new_class_idx_sample)


gswil_combined_subset = torch.utils.data.ConcatDataset([gswil_train_old_subset, gswil_train_new_subset])
#gswil_combined_subset_test = torch.utils.data.ConcatDataset([gfmnist_dataset_test, gswil_train_new_subset])
print(len(gswil_train_old_subset))
print(len(gswil_train_new_subset))
print(len(gswil_combined_subset))

gswil_train_dl = torch.utils.data.DataLoader(gswil_combined_subset, batch_size=1, shuffle=True, num_workers=2)

no_bag_test_idx = np.where((np.array(test_data.targets) != 8))[0]
no_bag_test_subset = torch.utils.data.Subset(test_data, no_bag_test_idx)
gswil_test_dl = torch.utils.data.DataLoader(no_bag_test_subset, batch_size=1, shuffle=True, num_workers=2)

252
73
325


In [53]:
def test(dataloader, model, loss_fn, device, swap=False, swap_labels=[], classes = 9) -> float:
    '''
        Model test loop. Performs a single epoch of model updates.

        * USAGE *
        Within a training loop of range(num_epochs) to perform epoch validation, or after training to perform testing.

        * PARAMETERS *
        dataloader: A torch.utils.data.DataLoader object
        model: A torch model which subclasses torch.nn.Module
        loss_fn: A torch loss function, such as torch.nn.CrossEntropyLoss
        optimizer: A torch.optim optimizer
        device: 'cuda' or 'cpu'

        * RETURNS *
        float: The average test loss
    '''

    # TODO: can the swap stuff be removed now?
    
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    y_pred_list, targets = [], []

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            if swap:
                for i in range(len(y)):
                    if y[i] == swap_labels[0]:
                        y[i] = swap_labels[1]
            X, y = X.to(device), y.to(device)
            pred = model(X)
            targets.append(y.tolist())
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
            _, y_pred_tags = torch.max(pred, dim=1)
            y_pred_list.append(y_pred_tags.cpu().numpy())
            
    y_pred_list = [a.squeeze().tolist() for a in y_pred_list]
    y_pred_list = [item for sublist in y_pred_list for item in sublist]
    
    targets = [item for sublist in targets for item in sublist]

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    return test_loss, np.asarray(y_pred_list), np.asarray(targets), 100*correct

In [54]:
if model_selection == 'linear':
    swil_model = add_output_nodes(ckpt_file, arch='linear', device=device)
    swil_model.input_layer.requires_grad_(False)
elif model_selection == 'cnn-demo':
    swil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
    swil_model.conv1.requires_grad_(False)
    swil_model.conv2.requires_grad_(False)
    swil_model.fc1.requires_grad_(False)
elif model_selection == 'cnn6':
    swil_model = add_output_nodes(ckpt_file, device, arch='cnn')
    swil_model.conv_block1.requires_grad_(False)
    swil_model.conv_block2.requires_grad_(False)
    swil_model.conv_block3.Conv5.requires_grad_(False)
    swil_model.conv_block3.Relu5.requires_grad_(False)
    swil_model.conv_block3.BN5.requires_grad_(False)    
    
swil_model = swil_model.to(device)

input_size 784
num_outputs 9


In [56]:
ntrials = 10
#swil_optimizer = torch.optim.Adam(swil_model.parameters(), lr=initial_learning_rate)
#lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=swil_optimizer, gamma=decay_rate)

#fmnist_classes = list(range(8)) + [9]

#with open(r'./data/fmnist_sim_scores_boot.txt', 'r') as fp:
#    sim_scores = [float(i) for i in fp.readlines()]

#sim_sum = sum(sim_scores)
#sim_norms = [x/sim_sum for x in sim_scores]

#boots_sample_size = 75
#sim_sample_sizes = [27 if x < 0.2 else int(x * boots_sample_size*3.52) for x in sim_norms] + [75]

#sim_sample_sizes = np.load('./data/cifar10_sim_scores_cat.npy')
#sim_sample_sizes = np.array(sim_sample_sizes.tolist(),dtype='uint8')
sim_sample_sizes = np.load('./data/fmnist_sim_scores_boot2.npy')
sim_sample_sizes = np.array(sim_sample_sizes.tolist(),dtype='uint8')
#sim_sample_sizes = np.zeros(9); sim_sample_sizes[-1] = 6000;
#sim_sample_sizes = sim_sample_sizes.astype(int)
#sim_sample_sizes = np.array(sim_sample_sizes.tolist(),dtype='uint8')
print(sim_sample_sizes)

from random import sample
num_epochs = 15
for nt in range(ntrials):
    print(nt)
    
    if model_selection == 'linear':
        gswil_model = add_output_nodes(ckpt_file, arch='linear', device=device)
        gswil_model.input_layer.requires_grad_(False)
    elif model_selection == 'cnn-demo':
        swil_model = add_output_nodes(ckpt_file, arch='cnn-demo')
        swil_model.conv1.requires_grad_(False)
        swil_model.conv2.requires_grad_(False)
        swil_model.fc1.requires_grad_(False)
    elif model_selection == 'cnn6':
        swil_model = add_output_nodes(ckpt_file, device, arch='cnn')
        swil_model.conv_block1.requires_grad_(False)
        swil_model.conv_block2.requires_grad_(False)
        swil_model.conv_block3.Conv5.requires_grad_(False)
        swil_model.conv_block3.Relu5.requires_grad_(False)
        swil_model.conv_block3.BN5.requires_grad_(False) 

    gswil_model = swil_model.to(device)
    gswil_optimizer = torch.optim.Adam(gswil_model.parameters(), lr=initial_learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=gswil_optimizer, gamma=decay_rate)

    gfmnist_dataset = gfmnist_dataset_test
    fmnist_classes = list(range(8)) # just get old classes from generated data

    class_subsets, old_class_idxs, subset_size = generate_dls(gfmnist_dataset, fmnist_classes)

    sampled_idxs = []
    for i in range(len(fmnist_classes)):
        idx_sample = sample(old_class_idxs[i].tolist(), sim_sample_sizes[i])
        sampled_idxs += idx_sample

    gswil_train_old_subset = torch.utils.data.Subset(gfmnist_dataset, sampled_idxs)

    gswil_new_subset, new_class_idx, subset_size = generate_dls(train_data, [9])
    new_class_idx_sample = sample(new_class_idx[0].tolist(), sim_sample_sizes[-1])
    gswil_train_new_subset = torch.utils.data.Subset(train_data, new_class_idx_sample)

    gswil_combined_subset = torch.utils.data.ConcatDataset([gswil_train_old_subset, gswil_train_new_subset])
    gswil_train_dl = torch.utils.data.DataLoader(gswil_combined_subset, batch_size=1, shuffle=True, num_workers=2)

    no_bag_test_idx = np.where((np.array(test_data.targets) != 8))[0]
    no_bag_test_subset = torch.utils.data.Subset(test_data, no_bag_test_idx)
    gswil_test_dl = torch.utils.data.DataLoader(no_bag_test_subset, batch_size=1, shuffle=True, num_workers=2)
        
    weight_dir = './logs/'
    model_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_gswil' + str(nt) + '.pt'
    recall_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_gswil_recall'+ str(nt) + '.npy'
    #train_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + '_swil_train_loss0.txt'
    test_losses_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'gswil_test_loss' + str(nt) + '.txt'
    accuracies_file = weight_dir + model_selection + '_' + dataset_selection + '_' + 'holdout' + '_' + '[8]' + 'gswil_accuracies' + str(nt) + '.txt'

    train_losses = []
    test_losses = []
    #t = trange(num_epochs)
    t = range(num_epochs)
    y_preds = []
    y_actuals = []
    accuracies = []

    for epoch in t:
        print(f"Epoch {epoch+1}\n-------------------------------")
        train_loss = train(gswil_train_dl, gswil_model, loss_fn, gswil_optimizer, device, swap=True, swap_labels=[9,8])
        test_loss, y_pred, y_actual, acc = test(test_gswil_dl, gswil_model, loss_fn, device, swap=True, swap_labels=[9,8])
        #train_losses.append(train_loss)
        test_losses.append(test_loss)
        y_preds.append(y_pred)
        y_actuals.append(y_actual)
        accuracies.append(acc)

        #t.set_description(f"Epoch {epoch} train loss: {epoch_loss_train[-1]:.3f}")
        lr_scheduler.step()

    torch.save(swil_model.state_dict(), model_file)

    recalls = get_recall_per_epoch(y_actuals, y_preds, 9)
    np.save(recall_file, recalls)

    #with open(train_losses_file, 'w') as fp:
    #    for s in train_losses:
    #        fp.write("%s\n" % s)

    with open(test_losses_file, 'w') as fp:
        for x in test_losses:
            fp.write("%s\n" % x)

    with open(accuracies_file, 'w') as fp:
        for x in accuracies:
            fp.write("%s\n" % x)

    print("Done!")

[23 28 21 24 21 20 56 59 73]
0
input_size 784
num_outputs 9
Epoch 1
-------------------------------


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/loganbecker/opt/anaconda3/envs/dalleenv/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/loganbecker/opt/anaconda3/envs/dalleenv/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GenDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 