A simple notebook for testing a SR model.

# Setup

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import random
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

In [None]:
from facedataset import FaceDataset
from metrics import mse_loss, psnr
from utils import inner_pad, clip

In [None]:
from model import Model

In [None]:
import PIL
from classic_model import Classic

Make torch as determenistic as possible:

In [None]:
random.seed(714)
np.random.seed(714)
torch.manual_seed(714)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Hyperparameters:

In [None]:
batch_size = 64
input_size = 32
output_size = 64
use_gpu = torch.cuda.is_available()

name = "convolutional" # for weights file
num_examples = 10 # number of examples printed

In [None]:
train_set = FaceDataset("data/thumbnails128x128", 0, 55000, input_size, output_size)
train_loader = DataLoader(train_set, batch_size = batch_size, shuffle=True)
train2_loader = DataLoader(train_set, batch_size = 2*batch_size, shuffle=True, drop_last=True) # double batch size

In [None]:
val_set = FaceDataset("data/thumbnails128x128", 55000, 65000, input_size, output_size)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle=True)

Ignore test set (65 000 -- 70 000) for now

## Example Batch


In [None]:
fig, ax = plt.subplots(num_examples, 2)
for lower, higher in train_loader:
    for i in range(num_examples):
        ax[i][0].imshow(lower[i].detach().permute(1,2,0))
        ax[i][1].imshow(higher[i].detach().permute(1,2,0))
        ax[i][0].axis('off')
        ax[i][1].axis('off')
    break
fig.set_size_inches(6, 3*num_examples)
plt.tight_layout()
plt.show()

# Training

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        
        self.input_size = input_size
        
        self.convs = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
#             nn.Conv2d(64, 32, 3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(32, 32, 3, padding=1),
#             nn.ReLU(),
        )
        self.fcs = nn.Sequential(
            nn.Linear(32*input_size**2, 100),
            nn.ReLU(),
#             nn.Linear(100, 100),
#             nn.ReLU(),
            nn.Linear(100, 1),
        )
        # -> 0 for real data
        # -> 1 for upscaled data
    
    def forward(self, x):
        out = self.convs(x)
        out = out.reshape(-1, 32*self.input_size**2)
        out = self.fcs(out)
        
        return out

In [None]:
def train_model(model, optimizer, freq):
    model.train()

    lloss = []

    for k, (img, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        if use_gpu:
            img = img.cuda()
            target = target.cuda()
        
        optimizer.zero_grad()
        out = model(img)
        loss = mse_loss(out, target)
        loss.backward()
        optimizer.step()
        lloss.append(loss.detach())
        if k % freq == 0:
            lloss = torch.tensor(lloss)
            print(f"{torch.mean(lloss):.6f}")
            lloss = []

In [None]:
def train_discriminator(model, discriminator, optimizer, freq):
    model.eval()
    discriminator.train()
    
    criterion = nn.BCELoss()
    
    accs = []
    
    for k, (img, target) in tqdm(enumerate(train2_loader), total=len(train2_loader)):
        if use_gpu:
            img = img.cuda()
            target = target.cuda()
        
        with torch.no_grad():
            upscaled = model(img[:batch_size])
        orig = target[batch_size:]
        
        optimizer.zero_grad()
        
        pred_upscaled = discriminator(upscaled)
        l1 = criterion(torch.sigmoid(pred_upscaled), torch.ones(batch_size, 1))
        
        pred_orig = discriminator(orig)
        l2 = criterion(torch.sigmoid(pred_orig), torch.zeros(batch_size, 1))
        
        loss = l1 + l2
        loss.backward()
        optimizer.step()
        
        accs.append((torch.sum(pred_upscaled > 0.5) + torch.sum(pred_orig <= 0.5)).float()/(2*batch_size))
        if k % freq == 0:
            accs = torch.tensor(accs)
            print(f"{torch.mean(accs):.6f}")
            print(f"Upscale acc: {torch.sum(pred_upscaled > 0.5).float()/batch_size}", end=" ")
            print(f"Orig acc: {torch.sum(pred_orig <= 0.5).float()/batch_size}")
            print(f"l1: {l1}, l2: {l2}")
            accs = []

In [None]:
def save_model(model):
    torch.save(model.state_dict(), 'weights/{name}.ckpt')

In [None]:
model = Model()
if use_gpu:
    model = model.cuda()

In [None]:
discriminator = Discriminator(output_size)
if use_gpu:
    model = model.cuda()

## Trainings programm

In [None]:
train_model(model, torch.optim.Adam(model.parameters(), 1e-4), 10)

In [None]:
train_discriminator(model, discriminator, torch.optim.Adam(discriminator.parameters(), 1e-4), 15)

In [None]:
opt = torch.optim.Adam(model.parameters(), 1e-4)
train_model(model, opt, 80)
train_model(model, opt, 80)
train_model(model, opt, 80)
train_model(model, opt, 80)
opt = torch.optim.Adam(model.parameters(), 1e-5)
train_model(model, opt, 80)
train_model(model, opt, 80)

# Evaluation

In [None]:
model.eval()

with torch.no_grad():
    val_psnr = []

    for img, target in tqdm(val_loader):
        if use_gpu:
            img = img.cuda()
            target = target.cuda()
        out = model(img)
        val_psnr.append(psnr(out, target))
        

    val_psnr = torch.cat(val_psnr)
    print(f"Mean PSNR {torch.mean(val_psnr):.2f} ± {torch.std(val_psnr):.2f}")

Upscaling side to side with input and original:

In [None]:
model.eval()

with torch.no_grad():
    it = iter(val_loader)
    img, target = next(it)
    if use_gpu:
        img, target = img.cuda(), target.cuda()
    out = clip(model(img)).cpu()
    img, target = img.cpu(), target.cpu()
    
    fig, ax = plt.subplots(num_examples, 3)
    for i in range(num_examples):
        ax[i][0].imshow(img[i].permute(1,2,0))
        ax[i][1].imshow(out[i].permute(1,2,0))
        ax[i][2].imshow(target[i].permute(1,2,0))
        ax[i][0].axis('off')
        ax[i][1].axis('off')
        ax[i][2].axis('off')
    fig.set_size_inches(12, 4*num_examples+1)
    
    ax[0][0].set_title("Input")
    ax[0][1].set_title("Upscaled")
    ax[0][2].set_title("Original")

    plt.show()

Multiple upscalings in succession:

In [None]:
model.eval()

with torch.no_grad():
    steps = 4
    fig, ax = plt.subplots(2, steps)
    img, target = next(iter(val_loader))
    img, target = img[0], target[0]
    ax[0][0].imshow(img.permute(1,2,0))
    ax[0][-1].imshow(target.permute(1,2,0))
    ax[0][0].axis('off')
    ax[0][-1].axis('off')
    for i in range(1,steps-1):
        fig.delaxes(ax[0][i])
    
    if use_gpu:
        img, target = img.cuda(), target.cuda()
    out = img.unsqueeze(0)
    for k in range(steps):
        out = clip(model(out))
        ax[1][k].imshow(out[0].permute(1,2,0).cpu())
        ax[1][k].axis('off')
    
    fig.set_size_inches(4*steps, 8)
    plt.show()