In this notebook, we will see how useful are the self-supervised pre-training techniques. We pick face parsing task to evaluate the models. Note that this pre-training strategy is also applicable for other tasks such as object recognition, detection, segmentation, etc. This is because most tasks share a similar underlying network architecture.


We will be using <a href="http://vis-www.cs.umass.edu/lfw/part_labels/">Part Labels dataset</a> in this experiment. The task is to label each pixel in the image into one of three classes: Background (blue), Hair (red), and skin (green).
<img src="http://vis-www.cs.umass.edu/lfw/part_labels/images/img_funneled.jpg" width=100><img src="http://vis-www.cs.umass.edu/lfw/part_labels/images/img_ground_truth.png" width=100> <br/>
There are 13,233 images in total, out of which 2,927 have been labeled. There are 1,500 train, 500 val, and 927 test images. We will be using only 10% of the training set in our experiments. For self-supervised pre-training we will use 5,000 images.


We will measure three metrics: mIoU, pixel accuracy and frequency weighted pixel accuracy. These are the popular metrics use in semantic segmentation tasks.

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn

### import other stuffs
from enc_dec import encoder_decoder
from loss import *
from utils import *
from seg_utils import *
import matplotlib.pyplot as plt

DATA_ROOT = '/tmp/school/data/beyond_supervised/'

Available splits for experiments:

10% - train_10p.txt 

50% - train_50p.txt 

100% - train.txt

In [None]:
### define dataset paths
train_img_root = DATA_ROOT + 'part_labels/data/images/'
train_gt_root = DATA_ROOT + 'part_labels/data/gt/'
train_image_list = DATA_ROOT + 'part_labels/splits/train_10p.txt'

val_img_root = DATA_ROOT + 'part_labels/data/images/'
val_gt_root = DATA_ROOT + 'part_labels/data/gt/'
val_image_list = DATA_ROOT + 'part_labels/splits/minival.txt'

test_img_root = DATA_ROOT + 'part_labels/data/images/'
test_gt_root = DATA_ROOT + 'part_labels/data/gt/'
test_image_list = DATA_ROOT + 'part_labels/splits/test.txt'

nClasses = 3
mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])

In [None]:
train_loader = torch.utils.data.DataLoader(SegmentationDataLoader(img_root = train_img_root, gt_root = train_gt_root,
                                                                  image_list = train_image_list, transform=True, mirror = True),
                                           batch_size=16, num_workers=2, shuffle = True, pin_memory=False)

val_loader = torch.utils.data.DataLoader(SegmentationDataLoader(img_root = val_img_root, gt_root = val_gt_root,
                                                                  image_list = val_image_list, transform=True, mirror = False),
                                           batch_size=16, num_workers=2, shuffle = False, pin_memory=False)

test_loader = torch.utils.data.DataLoader(SegmentationDataLoader(img_root = test_img_root, gt_root = test_gt_root,
                                                                  image_list = test_image_list, transform=True, mirror = False),
                                           batch_size=16, num_workers=2, shuffle = False, pin_memory=False)

<img src="https://docs.google.com/drawings/d/e/2PACX-1vT_ZXwfGNnjfS221bBh9HDxGM79aavoLARgwHep4hKvlql1si6qscZ9M4fhXKCWxuXNRy6tgBvj__GD/pub?w=2011&h=331" />

In [None]:
'''Experiment 1: train semantic segmentation network form scratch using 10% of training data'''
'''Initialize model with random weights (He initialization).'''
experiment = 'from_scratch'
net = encoder_decoder().cuda()

# # 55.098452484703344 82.77928629989214 71.4377468041187





In [None]:
'''Experiment 2: Fine-tune semantic segmentation network using pre-trained encoder (context prediction) using 10% of training data'''
'''Load the pre-trained encoder'''

# experiment = 'from_relative_tiles_pre_training'
# net = torch.load(DATA_ROOT + 'checkpoint/self_supervised_pre_train_relative_tileckpt.t7')['net']
# upsample = nn.Upsample(scale_factor=2, mode='bilinear')

'''Add the decoder to the model. Note that the decoder is initialized with random weights'''
# net.decoder = nn.Sequential(upsample, nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64),
#                             upsample, nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(32), 
#                             upsample, nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=0, bias=False), nn.BatchNorm2d(16), 
#                             upsample, nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=0, bias=True))

# net.cuda()

# # 62.867514016826135 86.01381488673138 76.27833472519274

In [None]:
'''Experiment 3: Fine-tune semantic segmentation network using pre-trained encoder-decoder (context inpainting) using 10% of training data'''
'''Load the pre-trained encoder-decoder'''
# net = torch.load(DATA_ROOT + 'checkpoint/self_supervised_pre_train_semantic_inpaintingckpt.t7')['net'].cuda()
# experiment = 'from_semantic_inpainting_pre_training'
# # 65.35361380810471 87.81917583603021 78.72283941521574

In [None]:
print('Net params count (M): ', param_counts(net)/(1000000.0))

In [None]:
use_cuda = torch.cuda.is_available()
best_acc = 0  # best test accuracy

In [None]:
def train(epoch):
    print('\nTrain epoch: %d' % epoch)
    net.train()
    hist = np.zeros((nClasses, nClasses))+1e-12
    train_loss = 0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        inputs = Variable(inputs)
        outputs = net(inputs)
        
        loss = cross_entropy2d(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        correctLabel = labels.view(-1, labels.size()[1], labels.size()[2])

        hist += fast_hist(correctLabel.view(correctLabel.size(0),-1).cpu().numpy(),
                        predicted.view(predicted.size(0),-1).cpu().numpy(),
                        nClasses)
        
        
        miou, p_acc, fwacc = performMetrics(epoch,batch_idx,len(train_loader),hist,train_loss/(batch_idx+1),is_train=True)     
        

    miou, p_acc, fwacc = performMetrics(epoch,batch_idx,len(train_loader),hist,train_loss/(batch_idx+1),is_train=True)
    print('train: mIoU/Accuracy/Freqweighted_Accuracy', miou, p_acc, fwacc)


In [None]:
def val(epoch):
    print('\nVal epoch: %d' % epoch)
    global best_acc
    net.eval()
    val_loss = 0
    hist = np.zeros((nClasses, nClasses))+1e-12
    for batch_idx, (inputs, labels) in enumerate(val_loader):
        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        inputs = Variable(inputs)
        outputs = net(inputs)
        loss = cross_entropy2d(outputs, labels)

        val_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        correctLabel = labels.view(-1, labels.size()[1], labels.size()[2])

        hist += fast_hist(correctLabel.view(correctLabel.size(0),-1).cpu().numpy(),
                        predicted.view(predicted.size(0),-1).cpu().numpy(),
                        nClasses)
        
        
        miou, p_acc, fwacc = performMetrics(epoch,batch_idx,len(val_loader),hist,val_loss/(batch_idx+1),is_train=False)
              
        

    miou, p_acc, fwacc = performMetrics(epoch,batch_idx,len(val_loader),hist,val_loss/(batch_idx+1),is_train=False)
    print('val: mIoU/Accuracy/Freqweighted_Accuracy', miou, p_acc, fwacc)

    # Save checkpoint.
    if p_acc > best_acc:
        print('Saving..')
        state = {'net': net}
        if not os.path.isdir(DATA_ROOT + 'checkpoint'):
            os.mkdir(DATA_ROOT + 'checkpoint')
        torch.save(state, DATA_ROOT + 'checkpoint/'+experiment+'ckpt.t7')
        best_acc = p_acc

In [None]:
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)
for epoch in range(0, 100):
    if epoch == 80:
        optimizer = optim.SGD(net.parameters(), lr=1e-5, momentum=0.9, weight_decay=0.0005)
    if epoch == 60:
        optimizer = optim.SGD(net.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.0005)
    train(epoch)
    val(epoch)

In [None]:
net = torch.load(DATA_ROOT + 'checkpoint/'+experiment+'ckpt.t7')['net'].cuda().eval()

In [None]:
def apply_color_map(x):
    img = np.zeros((x.shape[0], x.shape[1], 3), dtype = np.uint8)
    indices = np.where(x==0)
    
    img[indices[0], indices[1] ,2] = 255
    
    indices = np.where(x==1)
    img[indices[0], indices[1] ,1] = 255
    
    indices = np.where(x==2)
    img[indices[0], indices[1] ,0] = 255
    
    return img

In [None]:
def evaluate(epoch=0):
    net.eval()
    test_loss = 0
    hist = np.zeros((nClasses, nClasses))+1e-12
    for batch_idx, (inputs, labels) in enumerate(test_loader):
        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()
        inputs = Variable(inputs)
        outputs = net(inputs)
        loss = cross_entropy2d(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        
        i = (inputs[0].data.cpu().numpy().transpose(1,2,0) + mean_bgr).astype(np.uint8)[:,:,::-1]
        g = apply_color_map(labels[0].cpu().numpy())
        o = apply_color_map(predicted[0].cpu().numpy())
        vis = np.concatenate((i,g,o), axis = 1)
        plt.imshow(vis)
        plt.show()
    
        correctLabel = labels.view(-1, labels.size()[1], labels.size()[2])

        hist += fast_hist(correctLabel.view(correctLabel.size(0),-1).cpu().numpy(),
                        predicted.view(predicted.size(0),-1).cpu().numpy(),
                        nClasses)
        
        
        miou, p_acc, fwacc = performMetrics(epoch,batch_idx,len(test_loader),hist,test_loss/(batch_idx+1),is_train=False)
              
        

    miou, p_acc, fwacc = performMetrics(epoch,batch_idx,len(test_loader),hist,test_loss/(batch_idx+1),is_train=False)
    print('test: mIoU/Accuracy/Freqweighted_Accuracy', miou, p_acc, fwacc)

In [None]:
evaluate()