<a href="https://colab.research.google.com/github/logan-cardinal/cs7641_lz/blob/main/Copy_of_SESEMI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Semi-Supervised Learning with SESEMI

This notebook accompanies the Medium article "[Semi-Supervised Learning Demystified with PyTorch](https://medium.com/@masonmcgough/semi-supervised-learning-demystified-with-pytorch-9656c14af031)." Follow along with the post to use this notebook.

In this notebook, I demonstrate the SESEMI technique described in "[Exploring Self-Supervised Regularization for Supervised and Semi-Supervised Learning](https://arxiv.org/pdf/1906.10343.pdf)" on the CIFAR-10 dataset using the handy pretrained ResNet model in `torchvision`. I encourage you to try this notebook with different amounts of labeled data to see the impact that semi-supervised regularization has on the model training.

## Imports

In [None]:
import random
from typing import Optional

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

## Data

Here we define our labeled and unlabeled datasets. Since we are using CIFAR-10, a fully labeled dataset with 50000 images, we first need to subsample the labels in this dataset.

The two classes `LabeledDataset` and `UnlabeledDataset` are subclasses of `torch.utils.data.Dataset` that provide useful iterators over labeled and unlabeled subsets of CIFAR-10, respectively. We do not instantiate these classes directly, instead using the `create_label_drop_datasets` function to do so. As the authors discuss in the paper, the `LabeledDataset` class repeats the subsampled labels so that its length and the length of `UnlabeledDataset` are the same. This allows the model to be trained jointly on both datasets.

We also create a `SesemiTransform` class to apply a random augmentation to unlabeled examples and generate an auxiliary label.

In [None]:
class SesemiTransform:
    """
    Torchvision-style transform to apply SESEMI augmentation to image.
    """

    classes = ('0', '90', '180', '270', 'hflip', 'vflip')

    def __call__(self, x):
        tf_type = random.randint(0, len(self.classes) - 1)
        if tf_type == 0:
            x = x
        elif tf_type == 1:
            x = transforms.functional.rotate(x, 90)
        elif tf_type == 2:
            x = transforms.functional.rotate(x, 180)
        elif tf_type == 3:
            x = transforms.functional.rotate(x, 270)
        elif tf_type == 4:
            x = transforms.functional.hflip(x)
        elif tf_type == 5:
            x = transforms.functional.rotate(x, 180)
            x = transforms.functional.hflip(x)
        return x, tf_type

class LabeledDataset(Dataset):
    """
    Labeled training Dataset class.
    """

    def __init__(self, data: np.ndarray, labels: list,
            dataset_min_size: int = 0,
            transform: Optional[transforms.Compose] = None):
        self.data = data
        self.labels = labels
        self.min_size = dataset_min_size
        self.transform = transform

    def __len__(self) -> int:
        return max(len(self.labels), self.min_size)

    def __getitem__(self, idx: int) -> tuple:
        if idx >= len(self):
            raise IndexError(f'{idx} is out-of-bounds for dataset (length: {len(self)})')
        s_idx = idx % len(self.labels)
        
        data = self.data[s_idx]
        labels = self.labels[s_idx]
        if self.transform is not None:
            data = self.transform(data)
        return data, labels

class UnlabeledDataset(Dataset):
    """
    Unlabeled training Dataset class.
    """

    def __init__(self, data: np.ndarray,
            transform: Optional[transforms.Compose] = None):
        self.data = data
        self.transform = transform
        self.sesemi_transform = SesemiTransform()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> np.ndarray:
        data = self.data[idx]
        if self.transform is not None:
            data = self.transform(data)
        data, label = self.sesemi_transform(data)
        return data, label

def create_label_drop_datasets(dataset: Dataset, n_labels_to_keep: int,
        seed: int = -1, transform: Optional[transforms.Compose] = None,
        split: bool = False):
    """
    Create labeled and unlabeled subsets from a given Dataset instance.
    """

    data = dataset.data
    labels = dataset.targets
    n_data = len(labels)
    assert n_data >= n_labels_to_keep, f'n_labels_to_keep ({n_labels_to_keep}) exceeds number of labels ({n_data})'

    idxs = list(range(n_data))
    if seed >= 0:
        random.seed(seed)
    random.shuffle(idxs)
    selected_idxs = idxs[:n_labels_to_keep]
    unselected_idxs = idxs[n_labels_to_keep:]
    labeled_data = data[selected_idxs]
    labels = [labels[_i] for _i in selected_idxs]
    if split:
        unlabeled_data = data[unselected_idxs]
    else:
        unlabeled_data = data
    
    labeled_dataset = LabeledDataset(labeled_data, labels,
        dataset_min_size=len(unlabeled_data), transform=transform)
    unlabeled_dataset = UnlabeledDataset(unlabeled_data, transform=transform)
    return labeled_dataset, unlabeled_dataset

In [None]:
n_labels_to_keep = 5000
batch_size = 64
seed = 231
n_epochs = 80
n_batches_print = 50

In [None]:
sup_classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
labeled_trainset, unlabeled_trainset = create_label_drop_datasets(trainset,
    n_labels_to_keep=n_labels_to_keep, seed=seed, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

print(f'# labaled training:   {len(labeled_trainset):6d}')
print(f'# unlabeled training: {len(unlabeled_trainset):6d}')
print(f'#  testing:           {len(testset):6d}')

In [None]:
labeled_trainloader = DataLoader(labeled_trainset, batch_size=batch_size,
    shuffle=True, num_workers=2)
unlabeled_trainloader = DataLoader(unlabeled_trainset, batch_size=batch_size,
    shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False,
    num_workers=2)

print(f'# labeled batches:   {len(labeled_trainloader):5d}')
print(f'# unlabeled batches: {len(unlabeled_trainloader):5d}')
print(f'#  testing batches:  {len(testloader):5d}')

## Display Images

With our new dataloaders, let's generate a few examples to make sure they are doing what we think they should. First, we look at the labeled dataloader.

In [None]:
n_show_images = 4

def imshow(img: torch.Tensor):
    """
    Display a single image.
    """

    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(labeled_trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images[:n_show_images]))
print(' '.join(f'{sup_classes[labels[j]]:5s}' for j in range(n_show_images)))

Now we make sure the unlabeled dataloader is producing the labels we want. Note that the rotated images seem to match their labels.

In [None]:
unsup_classes = unlabeled_trainset.sesemi_transform.classes
dataiter = iter(unlabeled_trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images[:n_show_images]))
print(' '.join(f'{unsup_classes[labels[j]]:5s}' for j in range(n_show_images)))

## Model

The SESEMI technique is model-agnostic and can be applied to any supervised learning task. The authors of [the paper](https://arxiv.org/pdf/1906.10343.pdf) experiment with three CNN architectures: Network-in-Network, a max-pooling ConvNet, and a wide residual network. We opt to go even simpler and use the standard ResNet you can import in `torchvision`.

This model modifies the ResNet to accommodate two output layers, one for the supervised objective and another for the semi-supervised objective. In the `forward` method, we add an optional input argument `x_selfsup` so that we can accumulate gradients for labeled and unlabeled batches simultaneously during training. During inference, it is not necessary to provide a `x_selfsup` batch.

In [None]:
class SesemiNet(nn.Module):
    """
    ResNet backbone with two heads for SESEMI training.
    """

    def __init__(
        self,
        n_sup_classes: int,
        n_unsup_classes: int,
        pretrained: bool = True
    ):
        super().__init__()
        self.stem = torchvision.models.resnet18(pretrained=pretrained)
        self.fc_out = 256
        self.stem.fc = nn.Linear(self.stem.fc.in_features, self.fc_out)
        self.sup_fc = nn.Linear(self.fc_out, n_sup_classes)
        self.selfsup_fc = nn.Linear(self.fc_out, n_unsup_classes)

    def forward(self, x: torch.Tensor, x_selfsup: Optional[torch.Tensor] = None):
        x = self.stem(x)
        x = self.sup_fc(x)
        if x_selfsup is not None:
            x_selfsup = self.stem(x_selfsup)
            x_selfsup = self.selfsup_fc(x_selfsup)
            return x, x_selfsup
        else:
            return x

## Train

Now we are ready to define our training loop. The `train` function is mostly boilerplate but note the lines that depict the forward pass and loss functions. 

In [None]:
def train(
    model: nn.Module,
    labeled_trainloader: DataLoader,
    unlabeled_trainloader: Optional[DataLoader] = None,
    valloader: Optional[DataLoader] = None,
    n_epochs: int = 2,
    n_batches_print: int = 1000,
    device: Optional[str] = None,
    unsup_wt: float = 1.0
):
    """
    Execute SESEMI training loop.
    """

    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    train_losses = []
    val_losses = []
    val_accs = []
    model.to(device)
    for epoch in range(n_epochs):
        model.train()
        train_loss_sum = 0.0
        running_loss_sup = 0.0
        running_loss_unsup = 0.0
        dataloaders = zip(labeled_trainloader, [None] * len(labeled_trainloader))\
            if unlabeled_trainloader is None else zip(labeled_trainloader, unlabeled_trainloader)
        for i, (sup_data, unsup_data) in enumerate(dataloaders, start=0):
            optimizer.zero_grad()
            sup_inputs, sup_labels = sup_data
            sup_inputs = sup_inputs.to(device)
            sup_labels = sup_labels.to(device)

            if unsup_data is None:
                sup_outputs = model(sup_inputs)
                loss_unsup = 0.0
            else:
                unsup_inputs, unsup_labels = unsup_data
                unsup_inputs = unsup_inputs.to(device)
                unsup_labels = unsup_labels.to(device)
                # forward pass through model with both data subsets
                sup_outputs, unsup_outputs = model(sup_inputs, x_selfsup=unsup_inputs)
                loss_unsup = unsup_criterion(unsup_outputs, unsup_labels)
                running_loss_unsup += loss_unsup.item()
            loss_sup = sup_criterion(sup_outputs, sup_labels)
            # evaluate loss function
            loss = loss_sup + unsup_wt * loss_unsup
            loss.backward()
            optimizer.step()
            iter_loss = loss_sup.item()
            running_loss_sup += iter_loss
            train_loss_sum += iter_loss

            # print statistics
            if i % n_batches_print == n_batches_print - 1:
                if unsup_data is None:
                    print(f'[{epoch + 1}, {i + 1:5d}] loss (sup): {running_loss_sup / n_batches_print:.4f}')
                else:
                    print(f'[{epoch + 1}, {i + 1:5d}] loss (sup): {running_loss_sup / n_batches_print:.4f} '\
                        f'loss (unsup): {running_loss_unsup / n_batches_print:.4f}')
                running_loss_sup = 0.0
                running_loss_unsup = 0.0
        train_losses.append(train_loss_sum / len(labeled_trainloader))

        # validation
        if valloader is not None:
            model.eval()
            acc_values = []
            acc_batchsize = []
            running_loss = 0.0
            for i, val_data in enumerate(valloader, start=0):
                # forward + backward + optimize
                optimizer.zero_grad()
                val_inputs, val_labels = val_data
                val_inputs = val_inputs.to(device)
                val_labels = val_labels.to(device)
                # forward pass of labeled data only for validation
                val_outputs = model(val_inputs)
                loss = sup_criterion(val_outputs, val_labels)
                running_loss += loss.item()

                acc, bsize = accuracy(val_outputs.data, val_labels, topk=(1,))
                acc_values.append(acc[0].numpy())
                acc_batchsize.append(bsize)
            total_loss = running_loss / len(valloader)
            total_acc = np.sum(np.array(acc_values) * np.array(acc_batchsize)) / np.sum(acc_batchsize)
            print(f'Epoch: {epoch + 1}, loss (sup): {running_loss / len(valloader):.4f}, acc: {total_acc:.2f}')
            val_losses.append(total_loss)
            val_accs.append(total_acc)
    print('Training Finished')
    if valloader is None:
        return train_losses
    else:
        return train_losses, val_losses, val_accs

def accuracy(output, target, topk=(1,)):
    """
    Calculate top-k accuracy for the given batch and its targets.
    """

    output = output.cpu()
    target = target.cpu()
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return (res, batch_size)

def plot_losses(trainloss: list, valloss: Optional[list] = None,
        title: str = 'Losses'):
    """
    Display plot of training loss against validation loss.
    """

    x = list(range(1, len(trainloss) + 1))
    plt.plot(x, trainloss, 'b')
    if valloss is not None:
        plt.plot(x, valloss, 'r')
    plt.title(title)
    plt.grid(True)
    plt.show()

### Supervised Training

As a point of comparison, train the model first on the labeled subset only to get a baseline of performance without applying the SESEMI algorithm.

In [None]:
model = SesemiNet(len(sup_classes), len(unsup_classes), pretrained=True)

sup_criterion = nn.CrossEntropyLoss()
unsup_criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

sup_train_losses, sup_val_losses, sup_val_accs = train(
    model,
    labeled_trainloader,
    valloader=testloader,
    n_epochs=n_epochs,
    n_batches_print=n_batches_print)

plot_losses(sup_train_losses, sup_val_losses, 'Loss without Self-supervision')
print(f'Max accuracy: {np.max(sup_val_accs):.2f}')

### Semi-Supervised Training

Noting the max accuracy resulting from the training procedure above, let us reinitialize the model and train again. The accuracy should be a considerable improvement over the previous, fully supervised attempt.

In [None]:
model = SesemiNet(len(sup_classes), len(unsup_classes), pretrained=True)

sup_criterion = nn.CrossEntropyLoss()
unsup_criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

unsup_train_losses, unsup_val_losses, unsup_val_accs = train(
    model,
    labeled_trainloader,
    unlabeled_trainloader,
    valloader=testloader,
    n_epochs=n_epochs,
    n_batches_print=n_batches_print)

plot_losses(unsup_train_losses, unsup_val_losses, title='Loss with Self-supervision')
print(f'Max accuracy: {np.max(unsup_val_accs):.2f}')

Although these results are far from state-of-the-art, they are impressive considering how little labeled data we used and how little we had to change to use the SESEMI algorithm. Feel free to try this using different proportions of labeled and unlabeled data, using different models, learning rate schedulers, different hyperparameters, and see if you can improve the results even more! 

In [None]:
# Create training and validation datasets
image_datasets = {x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}

The data in data_dir, must have folders named train and val, each with subfolder for each class