# Unsupervised Visual Representation Learning by Context Prediction
### Carl Doersch, Abhinav Gupta, and Alexei A. Efros.
### ICCV, 2015
<a href="https://arxiv.org/pdf/1505.05192.pdf">[Paper]</a>

<img src="images/doersch_1.png" width = 400>

There are millions of unannotated data available on the web. Can we use these data to effectively learn a useful representation? One such way is unsupervised learning with denoising autoencoder. There are many other such tasks that could lead to better feature learning while incurring no annotation cost. We will look into Context Prediction in this notebook.

The idea in this paper is simple. Given two neighboring tiles (indicated with red and blue squares) from an image, the model tries to predict their relative positions. In order to do this task effectively, the model needs to learns the discriminative representations of the patches that constitutes the object. 

The (self-) supervision in the form of relative position is obtained with no cost and is effective in learning useful representations.

We pass each patch through an encoder network (AlexNet, VGG, ResNets, etc.) and get their representations. We then use the concatenated representation of these patches to classify their relative positions.


We will be using <a href="http://vis-www.cs.umass.edu/lfw/part_labels/">Part Labels dataset</a> in this experiment. The task is to label each pixel in the image into one of three classes: Background (blue), Hair (red), and skin (green).
<img src="http://vis-www.cs.umass.edu/lfw/part_labels/images/img_funneled.jpg" width=100><img src="http://vis-www.cs.umass.edu/lfw/part_labels/images/img_ground_truth.png" width=100> <br/>
There are 13,233 images in total, out of which 2,927 have been labeled. There are 1,500 train, 500 val, and 927 test images. We will be using only 10% of the training set in our experiments. For self-supervised pre-training we will use 5,000 images (available splits are: $ \tt train\_unlabeled\_2k.txt, train\_unlabeled\_5k.txt, train\_unlabeled\_10k.txt$).

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn

### import other stuffs
from enc_dec import encoder
from utils import *
from relative_utils import *
import matplotlib.pyplot as plt
DATA_ROOT = '/tmp/school/data/beyond_supervised/'

In [None]:
### define dataset paths
train_img_root = DATA_ROOT + 'part_labels/data/all/'
train_image_list = DATA_ROOT + 'part_labels/splits/train_unlabeled_5k.txt'

val_img_root = DATA_ROOT + 'part_labels/data/all/'
val_image_list = DATA_ROOT + 'part_labels/splits/val_unlabeled_500.txt'

In [None]:
crop_shape = (64,64)
train_loader = torch.utils.data.DataLoader(RelativeTileDataLoader(img_root = train_img_root,
                                                                  image_list = train_image_list, crop_shape = crop_shape,mirror = True),
                                           batch_size=128, num_workers=2, shuffle = True, pin_memory=False)

val_loader = torch.utils.data.DataLoader(RelativeTileDataLoader(img_root = val_img_root,
                                                                  image_list = val_image_list, crop_shape = crop_shape, mirror = True),
                                           batch_size=32, num_workers=2, shuffle = False, pin_memory=False)

We define an encoder architecture with 4 convolution layers. We will use context prediction technique to pre-train the encoder in self-supervised way and later use it for face parsing in 3rd notebook.

<img src="https://docs.google.com/drawings/d/e/2PACX-1vQ8zrtcyVOGwxvd8HgccmSWQad_WKefGT_KDQIu61IcAgzYw-MxfYWgwPKI25mu7etpm2b09jBwoqgj/pub?w=1413&h=360" width = 1200>

In [None]:
net = encoder().cuda()
experiment = 'self_supervised_pre_train_relative_tile'

In [None]:
print('Net params count (M): ', param_counts(net)/(1000000.0))

In [None]:
"""simple mlp"""
mlp = nn.Sequential(nn.Linear(2048,16),nn.ReLU(),nn.Dropout(0.5),nn.Linear(16,8)).cuda()

In [None]:
use_cuda = torch.cuda.is_available()
best_loss = 9999  # best val loss

In [None]:
loss_fn = nn.CrossEntropyLoss()
def train(epoch):
    print('\nTrain epoch: %d' % epoch)
    net.train()
    mlp.train()
    train_loss = 0

    for batch_idx, (center_crops, random_crops, class_idxs, class_locs) in enumerate(train_loader):

        if use_cuda:
            center_crops, random_crops, class_idxs = center_crops.cuda(), random_crops.cuda(), class_idxs.cuda()
        optimizer.zero_grad()
        
        center_crops = Variable(center_crops,requires_grad = True)
        random_crops = Variable(random_crops,requires_grad = True)
        class_idxs = Variable(class_idxs,requires_grad = False)
        
        v = torch.cat( (net(center_crops).view(center_crops.size()[0],-1),net(random_crops).view(center_crops.size()[0],-1)),1 )

        outputs = mlp(v)
        loss = loss_fn(outputs,class_idxs)
        

        loss.backward()
        optimizer.step()
        train_loss += loss.data[0]
        
    print('Loss: %f '% (train_loss/(batch_idx+1)))

In [None]:
def val(epoch):
    print('\nVal epoch: %d' % epoch)
    global best_loss
    net.eval()
    mlp.eval()
    val_loss = 0
    for batch_idx, (center_crops, random_crops, class_idxs, class_locs) in enumerate(val_loader):
        if use_cuda:
            center_crops,random_crops,class_idxs = center_crops.cuda(),random_crops.cuda(),class_idxs.cuda()
        center_crops = Variable(center_crops,requires_grad=True)
        random_crops = Variable(random_crops,requires_grad=True)
        class_idxs = Variable(class_idxs,requires_grad=False)
        v = torch.cat((net(center_crops).view(center_crops.size()[0],-1),net(random_crops).view(random_crops.size()[0],-1)),1)
        outputs = mlp(v)
        loss = loss_fn(outputs,class_idxs)
        val_loss += loss.data[0]
        
    print('Loss: %f '% (val_loss/(batch_idx+1)))
    # Save checkpoint.
    if val_loss < best_loss:
        print('Saving..')
        state = {'net': net}
        if not os.path.isdir(DATA_ROOT + 'checkpoint'):
            os.mkdir(DATA_ROOT + 'checkpoint')
        torch.save(state, DATA_ROOT + 'checkpoint/'+experiment+'ckpt.t7')
        best_loss = val_loss

In [None]:
optimizer = optim.SGD(list(net.parameters()) + list(mlp.parameters()), lr=0.01, momentum=0.9, weight_decay=0.0005)
for epoch in range(0, 100):
    if epoch == 80:
        optimizer = optim.SGD(list(net.parameters()) + list(mlp.parameters()), lr=0.0001, momentum=0.9, weight_decay=0.0005)
    if epoch == 60:
        optimizer = optim.SGD(list(net.parameters()) + list(mlp.parameters()), lr=0.001, momentum=0.9, weight_decay=0.0005)
    train(epoch)
    val(epoch)