A simple notebook for testing a SR model.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

In [None]:
from facedataset import FaceDataset
from metrics import mse_loss
from utils import inner_pad

In [None]:
batch_size = 32

In [None]:
train_set = FaceDataset("data/thumbnails128x128", 0, 55000, 32, 64)
train_loader = DataLoader(train_set, batch_size = batch_size, shuffle=True)

In [None]:
fig, ax = plt.subplots(batch_size, 2)
for lower, higher in train_loader:
    for i in range(batch_size):
        ax[i][0].imshow(lower[i].detach().permute(1,2,0))
        ax[i][1].imshow(higher[i].detach().permute(1,2,0))
        ax[i][0].axis('off')
        ax[i][1].axis('off')
    break
fig.set_size_inches(6, 3*batch_size)
plt.tight_layout()
plt.show()

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 64, 9, padding = 4)
        self.conv2 = nn.Conv2d(64, 32, 5, padding = 2)
        self.conv3 = nn.Conv2d(32, 32, 5, padding = 2)
        self.conv4 = nn.Conv2d(32, 3, 5, padding = 2)
        
    
    def forward(self, x):
        out = inner_pad(x - 0.5, 2)
        out = self.conv1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.relu(out)
        return self.conv4(out)

In [None]:
model = Model()
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

In [None]:
lloss = []

for k, (img, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
    optimizer.zero_grad()
    out = model(img)
    loss = mse_loss(out, target)
    loss.backward()
    optimizer.step()
    lloss.append(loss.detach())
    if k % 25 == 0:
        print(f"{np.mean(lloss):.2f}")
        lloss = []

In [None]:
val_set = FaceDataset("data/thumbnails128x128", 55000, 65000, 32, 64)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle=True)

In [None]:
it = iter(val_loader)
img, target = next(it)
out = model(img).detach()
out[out > 1] = 1
out[out < 0] = 0

In [None]:
fig, ax = plt.subplots(batch_size, 3)
for i in range(batch_size):
    ax[i][0].imshow(img[i].permute(1,2,0))
    ax[i][1].imshow(out[i].permute(1,2,0))
    ax[i][2].imshow(target[i].permute(1,2,0))
    ax[i][0].axis('off')
    ax[i][1].axis('off')
    ax[i][2].axis('off')
fig.set_size_inches(9, 3*batch_size)

plt.show()

In [None]:
# Did crash notebook before
fig, ax = plt.subplots(1, 8)
img, target = next(it)
img, target = img[0], target[0]
ax[0].imshow(img.permute(1,2,0))
ax[-1].imshow(target.permute(1,2,0))
out = img.unsqueeze(0)
for k in range(1, 7):
    out = model(out)
    out[out < 0] = 0
    out[out > 1] = 1
    ax[k].imshow(out[0].detach().permute(1,2,0))