<a href="https://colab.research.google.com/github/gan3sh500/mixmatch-pytorch/blob/master/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28. 

It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab. :)

In [0]:
import torch
import numpy as np
import imgaug.augmenters as iaa

Now that we have the basic imports out of the way lets get to it. 
First we shall define the function to get augmented version of a given batch of images. The below function returns the function to do that. 

In [0]:
def get_augmentor():
    seq = iaa.Sequential([
        iaa.Crop(px=(0, 16)),
        iaa.Fliplr(0.5),
        iaa.GaussianBlur(sigma=(0, 3.0))
    ])
    def augment(images):
        return seq.augment(images.transpose(0, 2, 3, 1)).transpose(0, 2, 3, 1)
    return augment

Next we define the sharpening function to sharpen the prediction from the averaged prediction of all the unlabeled augmented images. It does the same thing as applying a temperature within the softmax function but to the probabilities. 

In [0]:
def sharpen(x, T):
    temp = x**(1/T)
    return temp / temp.sum(axis=1, keepdims=True)

A simple implementation of the [paper](https://arxiv.org/pdf/1710.09412.pdf) mixup: Beyond Empirical Risk Minimization used in this paper as well.

In [0]:
def mixup(x1, x2, y1, y2, alpha):
    beta = np.random.beta(alpha, -alpha)
    x = beta * x1 + (1 - beta) * x2
    y = beta * y1 + (1 - beta) * y2
    return x, y

This covers Algorithm 1 from the paper. 

In [0]:
def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):
    xb = augment_fn(x)
    ub = [augment_fn(u) for _ in range(K)]
    qb = sharpen(sum(map(lambda i: model(i), ub)) / K)
    Ux = np.concatenate(ub, axis=0)
    Uy = np.concatenate([qb for _ in range(K)], axis=0)
    indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))
    Wx = np.concatenate([Ux, xb], axis=0)[indices]
    Wy = np.concatenate([qb, y], axis=0)[indices]
    X, p = mixup(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha)
    U, q = mixup(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha)
    return X, p, U, q

The combined loss for training from the paper.

In [0]:
class MixMatchLoss(torch.nn.Module):
    def __init__(self, lambda_u=100):
        self.lambda_u = lambda_u
        self.xent = torch.nn.CrossEntropyLoss()
        self.mse = torch.nn.MSELoss()
        super(MixMatchLoss, self).__init__()
    
    def forward(X, U, p, q):
        X_ = np.concatenate([X, U], axis=1)
        y_ = np.concatenate([p, q], axis=1)
        return self.xent(preds[:len(p)], p) + \
                                    self.lambda_u * self.mse(preds[len(p):], q)

Now that we have the MixMatch stuff done, we have a few things to do. Namely, define the WideResnet28 model, write the data and training code and write testing code. 
Let's start with the model. The below is just a copy paste mostly from the wide-resnet.pytorch repo by meliketoy. 

In [0]:
def conv3x3(in_planes, out_planes, stride=1):
    return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                           bias=True)

Will need the below init function later before training.

In [0]:

def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
        torch.nn.init.constant(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.constant(m.weight, 1)
        torch.nn.init.constant(m.bias, 0)

The basic block for the WideResnet

In [0]:
class WideBasic(torch.nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1):
        super(WideBasic, self).__init__()
        self.bn1 = torch.nn.BatchNorm2d(in_planes)
        self.bn2 = torch.nn.BatchNorm2d(planes)
        self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=3,
                                     padding=1, bias=True)
        self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,
                                     padding=1, bias=True)
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.shortcut = torch.nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_planes, planes, kernel_size=1,
                                stride=stride, bias=True)
            )

    def forward(self, x):
        out = self.dropout(self.conv1(torch.nn.functional.relu(self.bn1(x))))
        out = self.conv2(torch.nn.functional.relu(self.bn2(out)))
        return out + self.shortcut(x)

Aaand the full model with default params set for CIFAR10.

In [0]:
class WideResNet(torch.nn.Module):
    def __init__(self, depth=28, widen_factor=10,
                 dropout_rate=0.3, num_classes=10):
        super(WideResNet, self).__init__()
        self.in_planes = 16
        n = (depth - 4) // 6
        k = widen_factor
        nStages = [16, 16*k, 32*k, 64*k]
        self.conv1 = conv3x3(3, nStages[0])
        self.layer1 = self.wide_layer(WideBasic, nStages[1], n, dropout_rate,
                                      stride=1)
        self.layer2 = self.wide_layer(WideBasic, nStages[2], n, dropout_rate,
                                      stride=2)
        self.layer3 = self.wide_layer(WideBasic, nStages[3], n, dropout_rate,
                                      stride=2)
        self.b1 = torch.nn.BatchNorm2d(nStages[3], momentum=0.9)
        self.linear = torch.nn.Linear(nStages[3], num_classes)
    
    def wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride))
            self.in_planes = planes
        return torch.nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.layer3(self.layer2(self.layer1(out)))
        out = torch.nn.functional.relu(self.bn1(out))
        out = torch.nn.functional.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        return self.linear(out)