# Mathematical Foundations of Computer Graphics and Vision 2022
## EXERCISE 6 - DEEP LEARNING

In [4]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.io import read_image
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm import tqdm

## 1.1. Task 1 - Datasets, Preprocessing and Data loading

In [5]:
class SRDataset(Dataset):
    """Define training/valid dataset loading methods.
    Args:
        image_dir (str): Train/Valid dataset address.
        image_size (int): High resolution image size.
        upscale_factor (int): Image up scale factor.
        mode (str): Data set loading method, the training data set is for data enhancement, and the
            verification dataset is not for data enhancement.
    """

    def __init__(self, image_dir, image_size = 64, upscale_factor = 2, jitter_val = 0.2, mode = 'Train') -> None:
        self.image_file_names = [os.path.join(image_dir, image_file_name) for image_file_name in os.listdir(image_dir)]
        self.image_size = image_size
        self.upscale_factor = upscale_factor
        self.jitter_val = jitter_val
        # Load training dataset or test dataset
        self.mode = mode

    def __getitem__(self, index):
        image = read_image(self.image_file_names[index]).float()/255.
        if self.mode == 'Train':
            hr_transformer = transforms.Compose([
                transforms.RandomCrop(self.image_size),
                transforms.ColorJitter(brightness=self.jitter_val, contrast=self.jitter_val, saturation=self.jitter_val, hue=self.jitter_val),
            ])  
        else:
            hr_transformer = transforms.Compose([
                transforms.CenterCrop(self.image_size),
                transforms.ColorJitter(brightness=self.jitter_val, contrast=self.jitter_val, saturation=self.jitter_val, hue=self.jitter_val),
            ])
        
        lr_transformer = transforms.Compose([
            transforms.Resize(size=(int(self.image_size / self.upscale_factor), int(self.image_size / self.upscale_factor))),
        ])
        hr_image = hr_transformer(image)
        lr_image = lr_transformer(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_file_names)

In [6]:
train_path = './train'
test_path = './eval'
train_dataset = SRDataset(train_path)
test_dataset = SRDataset(test_path)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    drop_last=True,
    pin_memory=True,
)
print(f" * Dataset contains {len(train_dataset)} image(s).")
for _, batch in enumerate(train_dataloader, 0):
    lr_image, hr_image = batch
    torchvision.io.write_png(lr_image[0, ...].mul(255).byte(), "lr_image.png")
    torchvision.io.write_png(hr_image[0, ...].mul(255).byte(), "hr_image.png")
    break # we deliberately break after one batch as this is just a test

 * Dataset contains 301 image(s).


In [7]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

In [8]:
class BasicSRModel(nn.Module):
    def __init__(self, upscale_factor = 2, layers = 10) :
        super(BasicSRModel, self).__init__()
        self.up_sample = nn.Upsample(scale_factor=upscale_factor, mode='bilinear')
        self.conv_first = nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1))
        self.conv_last = nn.Conv2d(64, 3, (3, 3), (1, 1), (1, 1))
        modules = []
        for i in range(layers):
            modules.append(nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)))
            modules.append(nn.LeakyReLU())
        self.conv_middle = nn.Sequential(*modules)

    def forward(self, x):
        x_up = self.up_sample(x)
        out1 = self.conv_first(x_up)
        out2 = self.conv_middle(out1)
        out = self.conv_last(out2)
        return out

In [15]:
# learning_rate [1e-2, 1e-3, 1e-5, 1e-6]
learning_rate = 1e-4
save_iterval = 200
number_of_epochs = 1000
save_dir = './models'

model = BasicSRModel()
loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr=learning_rate)
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(num_params)

372803


In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [17]:
model = model.to(device)
loss_function = loss_function.to(device)

In [18]:
for epoch in tqdm(range(number_of_epochs)):
    for _, batch in enumerate(train_dataloader):
        low_res, high_res = batch
        low_res = low_res.to(device)
        high_res = high_res.to(device)
        optimizer.zero_grad()
        high_res_prediction = model(low_res)
        loss = loss_function(high_res_prediction, high_res)
        loss.backward()
        optimizer.step()
    if (epoch + 1) % save_iterval or (epoch + 1) == number_of_epochs:
        
        

 36%|████████████████████████████                                                  | 36/100 [03:50<06:50,  6.42s/it]


KeyboardInterrupt: 

In [22]:
save_dir = './models'
torch.save(model.state_dict(), '{:s}/model_iter_{:d}_lr_{:f}.pth'.format(save_dir,
                                                           20, learning_rate))

In [23]:
model = BasicSRModel()
self_attn.load_state_dict(torch.load('./model/model_iter_20_lr0.000100.pth'))

NameError: name 'SelfAttention' is not defined