<a href="https://colab.research.google.com/github/ckraju/beyond-supervised/blob/main/2-Self-Supervised_Learning_with_Context_Inpainting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Context Encoders: Feature Learning by Inpainting
### Deepak Pathak, Phillip Krähenbühl, Jeff Donahue, Trevor Darrell, and Alexei A. Efros.
### CVPR, 2016
<a href="http://people.eecs.berkeley.edu/~pathak/papers/cvpr16.pdf">[Paper]</a>

<img src="http://people.eecs.berkeley.edu/~pathak/context_encoder/resources/teaser.jpg" width="400"/> <br/>
Given an image with a missing region (a), a human artist has no trouble inpainting it (b). Automatic inpainting using our context encoder trained with L2 reconstruction loss is shown in (c), and using both L2 and adversarial losses in (d).

A CNN (encoder-decoder network) is trained to generate the contents of an arbitrary image region conditioned on its surroundings. In order to succeed at this task, the model needs to both understand the content of the entire image, as well as produce a plausible hypothesis for the missing part(s).

As also seen in previous notebook, the (self-) supervision in the form of semantic inpainting is obtained with no cost and is very effective in learning useful representations.

We will use 5,000 images for pre-training the network for semantic inpainting task and later use this pre-trained model for face parsing.

In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn

### import other stuffs
from enc_dec import encoder_decoder
from utils import *
from inpaint_utils import *
import matplotlib.pyplot as plt
DATA_ROOT = '/tmp/school/data/beyond_supervised/'

In [None]:
### define dataset paths
train_img_root = DATA_ROOT + 'part_labels/data/all/'
train_image_list = DATA_ROOT + 'part_labels/splits/train_unlabeled_5k.txt'

val_img_root = DATA_ROOT + 'part_labels/data/all/'
val_image_list = DATA_ROOT + 'part_labels/splits/val_unlabeled_500.txt'

You can change the amount and size of regions to be erased by passing argument to the data loader. Default is context_shape = [32, 32], context_count = 4

In [None]:
train_loader = torch.utils.data.DataLoader(ContextInpaintingDataLoader(img_root = train_img_root,
                                                                  image_list = train_image_list, mirror = True),
                                           batch_size=16, num_workers=2, shuffle = True, pin_memory=False)

val_loader = torch.utils.data.DataLoader(ContextInpaintingDataLoader(img_root = val_img_root,
                                                                  image_list = val_image_list, mirror = True),
                                           batch_size=16, num_workers=2, shuffle = False, pin_memory=False)

We define an encoder-decoder architecture with 4 convolution layers each. Each convolution layer (except the last layer) is followed by BatchNorm and ReLU (not shown in figure). We will use context inpainting technique to pre-train the encoder as well as the decoder in self-supervised way and later use it for face parsing in 3rd notebook.

<img src="https://docs.google.com/drawings/d/e/2PACX-1vS_yenRY55ol0M6k3aJTh6yVVSYEgcCmqQEFWtkBeCg2tXOtMLTntjWZgwtrGy4xFitUVs3n-W6Ss5Y/pub?w=2373&h=442" width=1400>

In [None]:
net = encoder_decoder().cuda()
tanh = nn.Tanh()
experiment = 'self_supervised_pre_train_semantic_inpainting'

In [None]:
print('Net params count (M): ', param_counts(net)/(1000000.0))

In [None]:
use_cuda = torch.cuda.is_available()
best_loss = 9999  # best test accuracy

We use MSE loss for inpainting task. Higher weight (0.99) is applied to loss correspnding to the missing regions, while 0.01 weight is used at other regions.

In [None]:
def train(epoch):
    print('\nTrain epoch: %d' % epoch)
    net.train()
    train_loss = 0
    for batch_idx, (inputs, masks, contexts) in enumerate(train_loader):
        if use_cuda:
            inputs, masks, contexts = inputs.cuda(), masks.cuda(), contexts.cuda()
        optimizer.zero_grad()
        inputs = Variable(inputs)
        masks = Variable(masks)
        contexts = Variable(contexts)
        outputs = tanh(net(inputs))
        loss = 0.99*torch.mean(torch.mul((outputs - contexts)**2, masks)) + 0.01*torch.mean(torch.mul((outputs - contexts)**2, 1-masks))
        loss.backward()
        optimizer.step()
        train_loss += loss.data[0]
        
    print('Loss: %f '% (train_loss/(batch_idx+1)))

In [None]:
def val(epoch):
    print('\nVal epoch: %d' % epoch)
    global best_loss
    net.eval()
    val_loss = 0
    for batch_idx, (inputs, masks, contexts) in enumerate(val_loader):
        if use_cuda:
            inputs, masks, contexts = inputs.cuda(), masks.cuda(), contexts.cuda()
        inputs = Variable(inputs)
        masks = Variable(masks)
        contexts = Variable(contexts)
        outputs = tanh(net(inputs))
        loss = 0.99*torch.mean(torch.mul((outputs - contexts)**2, masks)) + 0.01*torch.mean(torch.mul((outputs - contexts)**2, 1-masks))
        val_loss += loss.data[0]
        
    print('Loss: %f '% (val_loss/(batch_idx+1)))
    # Save checkpoint.
    if val_loss < best_loss:
        print('Saving..')
        state = {'net': net}
        if not os.path.isdir(DATA_ROOT + 'checkpoint'):
            os.mkdir(DATA_ROOT + 'checkpoint')
        torch.save(state, DATA_ROOT + 'checkpoint/'+experiment+'ckpt.t7')
        best_loss = val_loss

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
for epoch in range(0, 50):
    if epoch == 40:
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
    if epoch == 30:
        optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
    train(epoch)
    val(epoch)

Now let's visualize some semantic inpainting results.

In [None]:
net = torch.load(DATA_ROOT + 'checkpoint/'+experiment+'ckpt.t7')['net'].cuda().eval()

In [None]:
mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
std_bgr = 255*np.array([0.229, 0.224, 0.225])

In [None]:
val_loader = torch.utils.data.DataLoader(ContextInpaintingDataLoader(img_root = val_img_root,
                                                                  image_list = val_image_list, mirror = True),
                                           batch_size=1, num_workers=1, shuffle = True, pin_memory=False)

for batch_idx, (inputs, masks, contexts) in enumerate(val_loader):
    if use_cuda:
        inputs, masks, contexts = inputs.cuda(), masks.cuda(), contexts.cuda()
    inputs = Variable(inputs)
    masks = Variable(masks)
    contexts = Variable(contexts)
    outputs = tanh(net(inputs))
    i = (inputs[0].data.cpu().numpy().transpose(1,2,0) + mean_bgr).astype(np.uint8)[:,:,::-1]
    c = (contexts[0].data.cpu().numpy().transpose(1,2,0)*3*std_bgr + mean_bgr).astype(np.uint8)[:,:,::-1]
    o = (outputs[0].data.cpu().numpy().transpose(1,2,0)*3*std_bgr + mean_bgr).astype(np.uint8)[:,:,::-1]
    vis = np.concatenate((i,c,o), axis = 1)
    plt.imshow(vis)
    plt.show()
    
    if batch_idx == 50:
        break