In [None]:
import time
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.utils.data import SubsetRandomSampler

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import itertools

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

from utils import *

In [None]:
def get_dataloaders_cifar10_selfsupervised(batch_size, num_workers=0,
                            validation_fraction=None,
                            train_transforms=None,
                            test_transforms=None,
                            mode = 'supervised'):

    if train_transforms is None:
        train_transforms = transforms.ToTensor()
    if test_transforms is None:
        test_transforms = transforms.ToTensor()

    train_dataset = datasets.CIFAR10(root='data',
                                    train=True,
                                    transform=train_transforms,
                                    download=True)
    valid_dataset = datasets.CIFAR10(root='data',
                                    train=True,
                                    transform=test_transforms)
    test_dataset = datasets.CIFAR10(root='data',
                                    train=False,
                                    transform=test_transforms)

    if mode == 'supervised':
        num = int(validation_fraction * 50000)
        train_indices = torch.arange(0, 50000 - num)
        valid_indices = torch.arange(50000 - num, 50000)
        train_sampler = SubsetRandomSampler(train_indices)
        valid_sampler = SubsetRandomSampler(valid_indices)
        valid_loader = DataLoader(dataset=valid_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                sampler=valid_sampler)
        train_loader = DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                drop_last=True,
                                sampler=train_sampler)
        test_loader = DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False)
        return train_loader, valid_loader, test_loader

    elif mode == 'self-supervised':
        num = int(validation_fraction * 50000)
        valid_indices = torch.arange(50000 - num, 50000)
        train_indices = torch.arange(0, 50000 - num)
        
        train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
        valid_dataset = torch.utils.data.Subset(valid_dataset, valid_indices)

        train_unlabelled_indices, train_labelled_indices = train_test_split(list(range(len(train_dataset))), test_size=5000)
        train_unlabelled_dataset = torch.utils.data.Subset(train_dataset, train_unlabelled_indices)
        train_labelled_dataset = torch.utils.data.Subset(train_dataset, train_labelled_indices)
        valid_loader = DataLoader(dataset=valid_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers)

        train_labelled_loader = DataLoader(dataset=train_labelled_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                drop_last=True)
        train_unlabelled_loader = DataLoader(dataset=train_unlabelled_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                drop_last=True)
        test_loader = DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                num_workers=num_workers,
                                shuffle=False)

        return train_labelled_loader, train_unlabelled_loader, valid_loader, test_loader

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

BATCH_SIZE = 256

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# For Supervised
train_transforms, test_transforms = None, None
train_labelled_loader, train_unlabelled_loader, valid_loader, test_loader = get_dataloaders_cifar10_selfsupervised(
    batch_size=BATCH_SIZE,
    validation_fraction=0.1,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    num_workers=2, mode = 'self-supervised')

train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: torch.flatten(x))])
test_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: torch.flatten(x))])
train_labelled_loader_flatten, train_unlabelled_loader_flatten, valid_loader_flatten, test_loader_flatten = get_dataloaders_cifar10_selfsupervised(
    batch_size=BATCH_SIZE,
    validation_fraction=0.1,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    num_workers=2, mode = 'self-supervised')

for images, labels in train_labelled_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print('Class labels of 10 examples:', labels[:10])
    print()
    break

for images, labels in train_labelled_loader_flatten:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print('Class labels of 10 examples:', labels[:10])
    print()
    break

In [None]:
def train_model_selfsupervised(model, num_epochs, train_loader_lst,
                valid_loader, test_loader, optimizer,
                device, logging_interval=50,
                scheduler=None,
                scheduler_on='valid_acc'):
    train_loader, train_unlabelled_loader = train_loader_lst
    start_time = time.time()
    minibatch_loss_list, train_acc_list, valid_acc_list = [], [], []
    cnt = 0
    predicted_labels_dict = {}
    for epoch in range(num_epochs):

        model.train()
        
        if cnt > 0:
            for batch_idx, (features, targets) in itertools.islice(enumerate(train_unlabelled_loader), 0, (cnt)*BATCH_SIZE, 1):
                features = features.to(device)
                targets = targets.float().to(device)

                logits = model(features)
                _, predicted_labels = torch.max(logits, 1)

                targets = predicted_labels_dict[batch_idx].to(device)

                # ## FORWARD AND BACK PROP
                logits = model(features)
                loss = torch.nn.functional.cross_entropy(logits, targets)
                optimizer.zero_grad()

                loss.backward()

                # ## UPDATE MODEL PARAMETERS
                optimizer.step()

                # ## LOGGING
                minibatch_loss_list.append(loss.item())
                if not batch_idx % logging_interval:
                    print(f'Self-Supervised '
                        f'| Batch {batch_idx:04d}/{len(train_unlabelled_loader):04d} '
                        f'| Loss: {loss:.4f}')
            predicted_labels_dict = {}

        for batch_idx, (features, targets) in enumerate(train_loader):

            features = features.to(device)
            targets = targets.to(device)

            # ## FORWARD AND BACK PROP
            logits = model(features)
            loss = torch.nn.functional.cross_entropy(logits, targets)
            optimizer.zero_grad()

            loss.backward()

            # ## UPDATE MODEL PARAMETERS
            optimizer.step()

            # ## LOGGING
            minibatch_loss_list.append(loss.item())
            if not batch_idx % logging_interval:
                print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                      f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
                      f'| Loss: {loss:.4f}')

        for batch_idx, (features, targets) in itertools.islice(enumerate(train_unlabelled_loader), 0, (cnt+1)*BATCH_SIZE, 1):
            features = features.to(device)
            targets = targets.float().to(device)

            logits = model(features)
            _, predicted_labels = torch.max(logits, 1)

            targets = predicted_labels.to(device)
            predicted_labels_dict[batch_idx] = targets

        cnt+=1

        model.eval()
        with torch.no_grad():  # save memory during inference
            train_acc = compute_accuracy(model, train_loader, device=device)
            valid_acc = compute_accuracy(model, valid_loader, device=device)
            print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} '
                  f'| Train: {train_acc :.2f}% '
                  f'| Validation: {valid_acc :.2f}%')
            train_acc_list.append(train_acc.item())
            valid_acc_list.append(valid_acc.item())

        elapsed = (time.time() - start_time)/60
        print(f'Time elapsed: {elapsed:.2f} min')
        
        if scheduler is not None:

            if scheduler_on == 'valid_acc':
                scheduler.step(valid_acc_list[-1])
            elif scheduler_on == 'minibatch_loss':
                scheduler.step(minibatch_loss_list[-1])
            else:
                raise ValueError(f'Invalid `scheduler_on` choice.')
        

    elapsed = (time.time() - start_time)/60
    print(f'Total Training Time: {elapsed:.2f} min')

    test_acc = compute_accuracy(model, test_loader, device=device)
    print(f'Test accuracy {test_acc :.2f}%')

    return minibatch_loss_list, train_acc_list, valid_acc_list

In [None]:
def start_train_selfsupervised(model, device, NUM_EPOCHS, data_loader, lr = 0.001, FileName = 'model'):
    train_labelled_loader, train_unlabelled_loader, valid_loader, test_loader = data_loader
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                        factor=0.1,
                                                        mode='max',
                                                        verbose=True)

    minibatch_loss_list, train_acc_list, valid_acc_list = train_model_selfsupervised(
        model=model,
        num_epochs=NUM_EPOCHS,
        train_loader_lst=[train_labelled_loader, train_unlabelled_loader],
        valid_loader=valid_loader,
        test_loader=test_loader,
        optimizer=optimizer,
        device=device,
        scheduler=scheduler,
        scheduler_on='valid_acc',
        logging_interval=100)

    plot_training_loss(minibatch_loss_list=minibatch_loss_list,
                    num_epochs=NUM_EPOCHS,
                    iter_per_epoch=len(train_labelled_loader),
                    results_dir=None,
                    averaging_iterations=200)
    plt.show()

    plot_accuracy(train_acc_list=train_acc_list,
                valid_acc_list=valid_acc_list,
                results_dir=None)
    # plt.ylim([60, 100])
    plt.show()

    class_dict = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog',
              7: 'horse',
              8: 'ship',
              9: 'truck'}
    mat = compute_confusion_matrix(model=model, data_loader=test_loader, device=device)
    plot_confusion_matrix(mat, class_names=class_dict.values())
    plt.show()

    torch.save(model.state_dict(), f'{FileName}_selfsupervised.ckpt')
    return compute_accuracy(model, test_loader, device=device).to('cpu').numpy()

In [None]:
model_vgg = VGG16(num_classes=10)
model_resnet = ResNet(ResidualBlock, [2, 2, 2]).to(device)
model_mlp = MultilayerPerceptron()

In [None]:
test_acc_lst = {'vgg':[],'resnet':[],'mlp':[]}

In [None]:
NUM_EPOCHS = 10
for model_name, model in zip(['vgg', 'resnet', 'mlp'],[model_vgg, model_resnet, model_mlp]):
    for lr in [0.0001,0.0003,0.001,0.003,0.01]:
        print("*************************************************************************************")
        print(f"Model: {model_name}, lr: {lr}")
        print("*************************************************************************************")
        if model_name == 'mlp':
            data_loader = [train_labelled_loader_flatten, train_unlabelled_loader_flatten, valid_loader_flatten, test_loader_flatten]
        else:
            data_loader = [train_labelled_loader, train_unlabelled_loader, valid_loader, test_loader]
        val = start_train_selfsupervised(
            model, device, NUM_EPOCHS, data_loader, lr = lr, 
            FileName = f"model_{model_name}_{str(lr).split('.')[-1]}")
        test_acc_lst[model_name].append(val)
        print("*************************************************************************************")
        print("*************************************************************************************")
        print()

In [None]:
fig, ax = plt.subplots(3, 1, sharex = True)
ax[0].plot(test_acc_lst['vgg'])
ax[1].plot(test_acc_lst['resnet'])
ax[2].plot(test_acc_lst['mlp'])
ax[0].set_xticks(np.arange(5))
ax[0].set_xticklabels([0.0001,0.0003,0.001,0.003,0.01])
ax[0].set_xlabel('Learning Rate')
ax[0].set_ylabel('VGG')
ax[1].set_ylabel('ResNet')
ax[2].set_ylabel('N-layer NN')
ax[0].set_title('Test Accuracy')
plt.show()