In [1]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-365c34a9-7a2d-8d61-7c27-a5f55720056c)


In [2]:
import numpy as np
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import glob
import cv2
import random
from google.colab.patches import cv2_imshow
import os
from torch.utils.tensorboard import SummaryWriter
import time

# Deblurring Model

In [17]:
class DeblurringModule(nn.Module):
  def __init__(self, kernel_size=3):
    assert kernel_size%2 == 1, 'kernel size must be odd' 
    super(DeblurringModule, self).__init__()

    padding = kernel_size//2
    relu = nn.ReLU(inplace=True)
    conv_in = nn.Conv2d(3, 64, kernel_size, stride=1, padding=padding,bias=False)
    conv_out = nn.Conv2d(64, 3, kernel_size, stride=1, padding=padding, bias=False)
    conv_mid = nn.Conv2d(64, 64, kernel_size, stride=1, padding=padding, bias=False)

    layers = []
    layers.append(conv_in)
    layers.append(relu)
    for i in range(10):
      layers.append(conv_mid)
      layers.append(relu)
    layers.append(conv_out)
    
    self.model = nn.Sequential(*layers)

  def forward(self, img):
    out = self.model(img)
    final_out = torch.add(out,img)
    return final_out


In [4]:
class DeblurringDataset(Dataset):
  def __init__(self, data_dir='/content/drive/MyDrive/Colab Datasets/MEAD_video/M003',batchsize=32):
    self.files = glob.glob(data_dir+'/**/*.mp4',recursive=True)
    self.batchsize = batchsize
    
  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):

    f = self.files[idx]
    vid = cv2.VideoCapture(f)
    num_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
    start_frame = random.randint(0,num_frames-self.batchsize-1)
    vid.set(1,start_frame)
    frames_blur = np.empty((self.batchsize,256,256,3))
    frames_tgt = np.empty((self.batchsize,256,256,3))
    for i in range(self.batchsize):
      ret, frame = vid.read()
      frame_tgt = self.crop_and_downsample(frame)
      im = Image.fromarray(frame_tgt)
      im = im.filter(ImageFilter.BoxBlur(1.5))
      frame_blur = np.array(im)
      frames_blur[i] = frame_blur
      frames_tgt[i] = frame_tgt

    frames_blur = (np.swapaxes(frames_blur,1,3)/255.).astype('float32')
    frames_tgt = (np.swapaxes(frames_tgt,1,3)/255.).astype('float32')

    return torch.tensor(frames_blur).cuda(), torch.tensor(frames_tgt).cuda()

  def crop_and_downsample(self, img, img_dim=256):
    h, w, c = img.shape
    crop_width = (w - h) // 2
    img_crop = img[0:h,crop_width:w-crop_width]
    img_resize = cv2.resize(img_crop,dsize=(img_dim,img_dim))

    return img_resize

In [5]:
def cat_images(blurry,target,predictions):
  i = random.randint(0,len(blurry)-1)
  grid = torch.cat([blurry[i],predictions[i],target[i]],axis=1)
  grid = grid.detach().cpu().numpy()
  grid = np.swapaxes(grid,0,2)
  grid = grid * 255
  return grid

  
def init_model(model, state_file=None):
  if state_file is None:
    for param in model.parameters():
        if param.dim()>1:
            torch.nn.init.xavier_uniform_(param)
  else:
    ckpt = torch.load(state_file)
    model.load_state_dict(ckpt)

In [32]:
# initialise writer and paths
experiment_name = 'L1_corrected'
writer = SummaryWriter(f'/content/drive/MyDrive/Colab Notebooks/talkingheads/runs/deblurring/{experiment_name}')
PATH_model_base = f'/content/drive/MyDrive/Colab Notebooks/talkingheads/models/deblurring/{experiment_name}/'
PATH_image_base = f'/content/drive/MyDrive/Colab Notebooks/talkingheads/Deblurring Images/{experiment_name}/'
os.mkdir(PATH_model_base)
os.mkdir(PATH_image_base)

In [33]:
# initialise model, data, optimizer and loss
model = DeblurringModule()
init_model(model)
model.to('cuda')
dataset = DeblurringDataset(batchsize=16)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.5, 0.999))

In [None]:
# run dimensions
start_epoch = 1
training_epochs = 50
model_save_freq = 1
img_save_freq = 1
eval_freq = 5
train_split=0.8
random.seed(0)
indices = list(range(len(dataset)))
train_indices = random.sample(indices,int(len(indices)*train_split))
test_indices = [idx for idx in indices if idx not in train_indices]

#epoch loop
for epoch in range(start_epoch, start_epoch + training_epochs):
    
    model.train()
    epoch_start = time.time()
    training_loss = 0.
    test_loss = 0.

    random.shuffle(train_indices)
    
    #training
    for idx in train_indices:

        #get targets and predictions
        blurry, target = dataset[idx]
        predictions = model(blurry)

        # calculate losses
        loss = criterion(predictions, target)

        #training step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #update loss
        training_loss += loss.item()

    #testing
    if epoch % eval_freq == 0:
        with torch.no_grad():
            model.eval()
            for idx in test_indices:
                blurry_test, target_test = dataset[idx]
                predictions_test = model(blurry_test)
                loss = criterion(predictions_test, target_test)
                test_loss += loss.item()
        writer.add_scalar('test loss', test_loss / len(test_indices), epoch)

    #log epoch metrics
    writer.add_scalar('training loss', training_loss / len(train_indices), epoch)
    

    #checkpoint model
    if epoch % model_save_freq == 0:
        PATH_model = PATH_model_base + f'epoch_{epoch}.pth'
        torch.save(model.state_dict(),PATH_model)

    #save image examples
    if epoch % img_save_freq == 0:
        img_grid = cat_images(blurry,target,predictions)
        filename = f'{PATH_image_base}epoch_{epoch:0>3d}.jpg'
        cv2.imwrite(filename,img_grid)

    #print metrics
    print(f"""end of epoch {epoch}: 
    training loss = {training_loss / len(train_indices)}
    test loss = {(test_loss / len(test_indices)) if epoch%eval_freq==0 else "N/A"} 
    epoch time taken = {int(time.time()-epoch_start)} s""")

end of epoch 1: 
    training loss = 0.010642534519672094
    test loss = N/A 
    epoch time taken = 316 s
