In [1]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-0cd23c16-d7c0-07b4-0275-9235c4199c6d)


In [2]:
import numpy as np
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import glob
import cv2
import random
from google.colab.patches import cv2_imshow
import os
from torch.utils.tensorboard import SummaryWriter
import time

# Deblurring Model

In [3]:

class DeblurringModule(nn.Module):
  def __init__(self):

    super(DeblurringModule, self).__init__()

    relu = nn.ReLU(inplace=True)
    conv_in = nn.Conv2d(3,64,3,1,1,bias=False)
    conv_out = nn.Conv2d(64,3,3,1,1,bias=False)
    conv_mid = nn.Conv2d(64,64,3,1,1,bias=False)

    layers = []
    layers.append(conv_in)
    layers.append(relu)
    for i in range(10):
      layers.append(conv_mid)
      layers.append(relu)
    layers.append(conv_out)
    
    self.model = nn.Sequential(*layers)

  def forward(self, img):
    out = self.model(img)
    final_out = torch.add(out,img)
    return final_out

def init_model(model, state_file=None):
  if state_file is None:
    for param in model.parameters():
        if param.dim()>1:
            torch.nn.init.xavier_uniform_(param)
  else:
    ckpt = torch.load(state_file)
    model.load_state_dict(ckpt)



In [4]:

class DeblurringDataset(Dataset):
  def __init__(self, data_dir='/content/drive/MyDrive/Colab Datasets/MEAD_video/M003',batchsize=32):
    self.files = glob.glob(data_dir+'/**/*.mp4',recursive=True)
    self.batchsize = batchsize
    
  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):

    f = self.files[idx]
    vid = cv2.VideoCapture(f)
    num_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
    start_frame = random.randint(0,num_frames-self.batchsize-1)
    vid.set(1,start_frame)
    frames_blur = np.empty((self.batchsize,256,256,3))
    frames_tgt = np.empty((self.batchsize,256,256,3))
    for i in range(self.batchsize):
      ret, frame = vid.read()
      frame_tgt = self.crop_and_downsample(frame)
      im = Image.fromarray(frame_tgt)
      im = im.filter(ImageFilter.BoxBlur(1.5))
      frame_blur = np.array(im)
      frames_blur[i] = frame_blur
      frames_tgt[i] = frame_tgt

    frames_blur = (np.swapaxes(frames_blur,1,3)/255.).astype('float32')
    frames_tgt = (np.swapaxes(frames_tgt,1,3)/255.).astype('float32')

    return torch.tensor(frames_blur).cuda(), torch.tensor(frames_tgt).cuda()

  def crop_and_downsample(self, img, img_dim=256):
    h, w, c = img.shape
    crop_width = (w - h) // 2
    img_crop = img[0:h,crop_width:w-crop_width]
    img_resize = cv2.resize(img_crop,dsize=(img_dim,img_dim))

    return img_resize

In [None]:
def write_training_images(img_dir, epoch, base_img, tgt_img, full_stack, output):
    images = []
    for baseimg, tgtimg, fullstack, tgtemotion, out in zip(base_img, tgt_img, full_stack, tgt_emotion, output):
        baseimg = baseimg.numpy()
        tgtimg = np.swapaxes(tgtimg.detach().cpu().numpy(),0,2) * 255
        lmks = np.swapaxes(fullstack.detach().cpu().numpy()[8:11],0,2) * 255
        out = np.swapaxes(out.detach().cpu().numpy(),0,2) * 255
        concat = np.concatenate((baseimg,tgtimg,lmks,out),axis=1)
        cv2.putText(concat, tgtemotion, (256*2,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1, cv2.LINE_AA)
        images.append(concat)
    
    img_grid = np.concatenate(images,axis=0)
    filename = f'{img_dir}epoch_{epoch:0>3d}.jpg'
    cv2.imwrite(filename,img_grid)

In [6]:
# initialise writer and paths
experiment_name = 'exp_1_run_1'
writer = SummaryWriter(f'/content/drive/MyDrive/runs/{experiment_name}')
PATH_model_base = f'/content/drive/MyDrive/models/{experiment_name}/'
PATH_image_base = f'/content/drive/MyDrive/images/{experiment_name}/'
os.mkdir(PATH_model_base)
os.mkdir(PATH_image_base)

In [9]:
# initialise model, data, optimizer and loss
model = DeblurringModule()
init_model(model)
model.to('cuda')
dataset = DeblurringDataset('/content/drive/MyDrive/MEAD_video/M003', batchsize=16)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [None]:
# run dimensions
start_epoch = 0
training_epochs = 100
model_save_freq = 10
img_save_freq = 5
train_split=0.8
random.seed(0)
indices = list(range(len(dataset)))
train_indices = random.sample(indices,int(len(indices)*train_split))
test_indices = [idx for idx in indices if idx not in train_indices]

#epoch loop
for epoch in range(start_epoch, start_epoch + training_epochs):
    
    model.train()
    epoch_start = time.time()
    training_loss = 0.
    test_loss = 0.

    random.shuffle(train_indices)
    
    #training
    for idx in train_indices:

        #get targets and predictions
        blurry, target = dataset[idx]
        predictions = model(blurry)

        # calculate losses
        loss = criterion(predictions, blurry)

        #training step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #update loss
        training_loss += loss.item()

    #testing
    for idx in test_indices:
      with torch.no_grad():

        model.eval()
        blurry_test, target_test = dataset[idx]
        predictions_test = model(blurry_test)
        loss = criterion(predictions_test, blurry_test)
        test_loss += loss.item()

    #log epoch metrics
    writer.add_scalar('training loss', training_loss / len(dataset), epoch)
    writer.add_scalar('test loss', test_loss / len(dataset), epoch)

    #checkpoint model
    if epoch % model_save_freq == 0:
        PATH_model = PATH_model_base + f'epoch_{epoch}.pth'
        torch.save(model.state_dict(),PATH_model)

    #save image examples
    #if epoch % img_save_freq == 0:
    #    write_training_images(PATH_image_base, epoch, base_img, tgt_img, full_stack, tgt_emotion, output)

    #print metrics
    print(f"""end of epoch {epoch}: 
    training loss = {training_loss / len(dataset)}
    test loss = {test_loss / len(dataset)} 
    epoch time taken = {int(time.time()-epoch_start)} s""")

end of epoch 0: 
    training loss = 1.7999643450204008e-09
    test loss = 1.1541443158164122e-10 
    epoch time taken = 189 s
end of epoch 1: 
    training loss = 3.388697979072773e-10
    test loss = 6.487672759567511e-11 
    epoch time taken = 195 s
end of epoch 2: 
    training loss = 2.1241406798370805e-10
    test loss = 4.490683784223809e-11 
    epoch time taken = 198 s
end of epoch 3: 
    training loss = 1.5381860995765702e-10
    test loss = 3.406206441401097e-11 
    epoch time taken = 194 s
end of epoch 4: 
    training loss = 1.2121819328089454e-10
    test loss = 2.7976612938389083e-11 
    epoch time taken = 194 s
end of epoch 5: 
    training loss = 1.0107799425653237e-10
    test loss = 2.3717383562459644e-11 
    epoch time taken = 195 s
end of epoch 6: 
    training loss = 8.649032850915448e-11
    test loss = 2.0545216353719963e-11 
    epoch time taken = 196 s
end of epoch 7: 
    training loss = 7.597631083503813e-11
    test loss = 1.823407765139452e-11 
    