In [None]:
!pip3 install torch torchvision



In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
import numpy as np
import cv2
import matplotlib.pyplot as plt

import os
import time


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
IMG_DIMS_X = 256
IMG_DIMS_Y = 256

In [None]:
args = { # adapted from tutorial
    'gpu': True,
    'valid':False, 
    'checkpoint':"", 
    'model':"CNN", 
    'kernel':3, 
    'num_filters':32, # tweak.
    'learn_rate':0.001, # tweak
    'batch_size':32, # <-- tweak
    'epochs':25, # tweak
    'seed':0, 
    'plot':True, 
    'experiment_name': 'cnn_autoencoder',
    'visualize': True,
    'downsize_input':False,
}

In [None]:
data_root = '/content/gdrive/MyDrive/segmented data'
data_root_val = '/content/gdrive/MyDrive/further segments'

In [None]:
def plot(masked, inpainted, truth, fpath, peek_idx=0):

  """ this function takes a lot of time, aim to plot less """
  if peek_idx is None:
    peek_idx = np.random.randint(0,masked.size()[0])

  
  masked = masked.permute(2,3,1,0)
  masked = masked[:,:,:,peek_idx].reshape(-1,IMG_DIMS_Y,3)
  inpainted = inpainted.permute(2,3,1,0)
  inpainted = inpainted[:,:,:,peek_idx].reshape(-1,IMG_DIMS_Y,3)
  truth = truth.permute(2,3,1,0)
  truth = truth[:,:,:,peek_idx].reshape(-1,IMG_DIMS_Y,3)

  print(masked.size(), inpainted.size(), truth.size())


  canvas = torch.cat((masked, inpainted, truth), 1)
  canvas = (canvas* 255).clamp(0, 255)

  canvas = canvas.cpu().numpy().astype(np.uint8)
  plt.imshow(canvas)

  # save here if neeeded
  if fpath is not None:
    pass

  plt.show()




In [None]:
def inpaint(img_masked, prediction, mask):

  """Returns the image, but with the neural net's prediction filled into the holes.
  This represents the final and actual inpainted result of the net."""

  holes = mask == 0
  holes = holes.permute((0,3,2,1))

  img_masked_copy = img_masked.clone()

  img_masked_copy[holes] = prediction[holes]
  
  return img_masked_copy

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor()])

In [None]:
class Modified_MSE(nn.MSELoss):

  """Loss function to score the accuracy of an image's reprouction.
      Considers how faithfully the reconstruction mimics the masked region."""

  def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
    super(Modified_MSE, self).__init__(size_average, reduce, reduction)

  def forward(self, output, img, img_masked, mask):
        # two separate losses, calculate individually then avg
        
        zeros_mask = mask == 0
        zeros_mask = zeros_mask.permute((0,3,1,2))

        whites_mask = mask == 255
        whites_mask = whites_mask.permute((0,3,1,2))
        
        loss_zeros = F.mse_loss(output[zeros_mask], img[zeros_mask], reduction=self.reduction)

        return loss_zeros
        


In [None]:
class Places(ImageFolder):

  """ A subsample of the Places Small Dataset. Includes only Nature images."""

  def __init__(self, root, transform=None, fixed_mask=False):
    super(Places, self).__init__(root, transform)
    self.fixed_mask = fixed_mask
    if self.fixed_mask:
      self.masks = []
      for i in range(0,self.__len__()):
        # print(i)
        if i < 1600:
          tup = self._create_mask(self._getimg(i))
          self.masks += [(tup[0], tup[1])]
        else:
          tup = self._mask_with_exisiting(self._getimg(i), self.masks[i % 1600][1])
          self.masks += [(tup[0], tup[1])]

  def _create_mask(self, img):
    """returns masked img from dataset along with mask
    
    Try with both random masks per batch load and fixed masks across all batch loads."""

    mask = np.full((IMG_DIMS_X, IMG_DIMS_Y, 3), 255, np.uint8) 
    for _ in range(np.random.randint(3, 5)):
      x1, x2 = np.random.randint(1, IMG_DIMS_X), np.random.randint(1, IMG_DIMS_X)
      y1, y2 = np.random.randint(1, IMG_DIMS_Y), np.random.randint(1, IMG_DIMS_Y)
      thickness = np.random.randint(5,10)

      cv2.line(mask,(x1,y1),(x2,y2),(0,0,0),thickness)

    masked = img.clone()
    masked = (masked* 255).clamp(0, 255)
    masked = np.array(masked, np.uint8)

    mask2 = np.transpose(mask, (2,1,0))
    
    masked[mask2 == 0] = 255
    # masked is an np.array when returned

    return masked, mask
  
  def _mask_with_exisiting(self, img, mask):

    "assumes mask in tensor dim order"

    masked = img.clone()
    masked = (masked* 255).clamp(0, 255)
    masked = np.array(masked, np.uint8)
    mask2 = np.transpose(mask, (2,1,0))
    masked[mask2 == 0] = 255

    return masked, mask


  def _getimg(self, index):
    img, label = super().__getitem__(index)
    return img

  def __getitem__(self, index):
    img, label = super().__getitem__(index)

    if self.fixed_mask:
      masked_img, mask = self.masks[index]
      masked_img = np.transpose(masked_img, (1, 2, 0))
    else:
      masked_img, mask = self._create_mask(img)
      masked_img = np.transpose(masked_img, (1, 2, 0))

    masked_img = self.transform(masked_img)

    return img, masked_img, mask

  def __len__(self):
    return super().__len__()


In [None]:
num_in_channels = 3

In [None]:
class ConvAutoencoder(nn.Module):

  def __init__(self, kernel, num_filters, num_in_channels):
    # can experiment with different architectures in the future

    super(ConvAutoencoder,self).__init__()

    padding = kernel // 2
    self.down1 = nn.Sequential(
        nn.Conv2d(num_in_channels, num_filters, kernel_size=kernel, padding=padding),
        nn.BatchNorm2d(num_filters),
        nn.ReLU(),
        nn.MaxPool2d(2),)
    self.down2 = nn.Sequential(
        nn.Conv2d(num_filters*2, num_filters, kernel_size=kernel, padding=padding),
        nn.BatchNorm2d(num_filters*2),
        nn.ReLU(),
        nn.MaxPool2d(2),)

    self.rfconv = nn.Sequential(
        nn.Conv2d(num_filters, num_filters, kernel_size=kernel, padding=padding),
        nn.BatchNorm2d(num_filters*2),
        nn.ReLU())

    self.up1 = nn.Sequential(
        nn.Conv2d(num_filters, num_filters*2, kernel_size=kernel, padding=padding),
        nn.BatchNorm2d(num_filters),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),)
    self.up2 = nn.Sequential(
        nn.Conv2d(num_filters*2, 3, kernel_size=kernel, padding=padding),
        nn.BatchNorm2d(3),
        nn.ReLU(),
        nn.Upsample(scale_factor=2),)
  


  def forward(self, x):
    self.out1 = self.down1(x)
    self.out2 = self.down2(self.out1)
    self.out3 = self.rfconv(self.out2)
    self.out4 = self.up1(self.out3)
    self.out5 = self.up2(self.out4)
    self.out_final = self.out5
    return self.out_final

In [None]:
def run_validation_step(cnn, criterion, val_dloader, batch_size, plotpath=None, visualize=True, downsize_input=False):
    correct = 0.0
    total = 0.0
    losses = []
    num_colours = 3
    
    for i, (x_gr_truth, x_masked, mask) in enumerate(train_loader):
        x_gr_truth, x_masked, mask = x_gr_truth.cuda(), x_masked.cuda(), mask.cuda()
        outputs = cnn(x_masked)

        val_loss = criterion(outputs, x_gr_truth, x_masked, mask)
        losses.append(val_loss.data.item())

    result = inpaint(x_masked.detach(), outputs.detach(), mask.detach())
 
    if plotpath: 
        plot(x_masked.detach(), result.detach(), x_gr_truth.detach(), plotpath, None)

    val_loss = np.mean(losses)
    return val_loss

In [None]:
def train(args, cnn=None):

  torch.set_num_threads(5)
  save_dir = 'outputs/' + args['experiment_name']

  num_in_channels = 3 

  if cnn is None:
    cnn = ConvAutoencoder(args['kernel'], args['num_filters'], num_in_channels)

  optimizer = torch.optim.Adam(cnn.parameters(), lr=args['learn_rate'])

  # load data

  train_set = Places(data_root, transform=transform, fixed_mask=True)
  train_loader = torch.utils.data.DataLoader(train_set, batch_size=args['batch_size'])

  val_set = Places(data_root_val, transform=transform, fixed_mask=True) 
  val_loader = torch.utils.data.DataLoader(val_set, batch_size=args['batch_size'])

  if not os.path.exists(save_dir):
        os.makedirs(save_dir)

  # start training 

  print("Training...")
  if args['gpu']: 
    cnn.cuda()
  start = time.time()

  train_losses = []
  val_losses = []

  for epoch in range(args['epochs']):
        cnn.train() # Change model to 'train' mode
        losses = []
        for i, (x_gr_truth, x_masked, mask) in enumerate(train_loader):
            x_gr_truth, x_masked, mask = x_gr_truth.cuda(), x_masked.cuda(), mask.cuda()
            # NOTE: tensors are passed by reference

            # Forward + Backward + Optimize
            optimizer.zero_grad()

            outputs = cnn(x_masked) 
            criterion = Modified_MSE() 

            loss = criterion(outputs, x_gr_truth, x_masked, mask)

            loss.backward()
            optimizer.step()
            losses.append(loss.data.item())
            result = inpaint(x_masked.detach(), outputs.detach(), mask.detach())
        # plot training images
        if args['plot'] and epoch % 3 == 0:
            
            plot(x_masked.detach(), result.detach(), x_gr_truth.detach(), save_dir+'/train_%d.png' % epoch)

        # log training losses
        avg_loss = np.mean(losses)
        train_losses.append(avg_loss)
        time_elapsed = time.time() - start
        print('Epoch [%d/%d], Loss: %.4f, Time (s): %d' % (
            epoch+1, args['epochs'], avg_loss, time_elapsed))

        # Evaluate the model
        cnn.eval()  # Change model to 'eval' mode.
        val_loss = run_validation_step(cnn,
                                                Modified_MSE(),
                                                val_loader,
                                                args['batch_size'],
                                                save_dir+'/test_%d.png' % epoch,
                                                args['visualize'],
                                                args['downsize_input'])

        time_elapsed = time.time() - start
        val_losses.append(val_loss)
        print('Epoch [%d/%d], Val Loss: %.4f, Time(s): %d' % (
            epoch+1, args['epochs'], val_loss, time_elapsed))

  if args['checkpoint']:
    print('Saving model...')
    torch.save(cnn.state_dict(), args['checkpoint'])
    
  return cnn

In [None]:
cnn = train(args) 

KeyboardInterrupt: ignored