In [18]:
import glob
import torch
import torch.nn as nn
from torch.optim import Adam
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
HR_train_paths = sorted(glob.glob("./data/DIV2K_train_HR/*.png"))
X2_train_paths = sorted(glob.glob("./data/DIV2K_train_LR_bicubic/X2/*.png"))
X4_train_paths = sorted(glob.glob("./data/DIV2K_train_LR_bicubic/X4/*.png"))
X8_train_paths = sorted(glob.glob("./data/DIV2K_train_LR_bicubic/X8/*.png"))
X16_train_paths = sorted(glob.glob("./data/DIV2K_train_LR_bicubic/X16/*.png"))
X32_train_paths = sorted(glob.glob("./data/DIV2K_train_LR_bicubic/X32/*.png"))
X64_train_paths = sorted(glob.glob("./data/DIV2K_train_LR_bicubic/X64/*.png"))

HR_valid_paths = sorted(glob.glob("./data/DIV2K_valid_HR/*.png"))
X2_valid_paths = sorted(glob.glob("./data/DIV2K_valid_LR_bicubic/X2/*.png"))
X4_valid_paths = sorted(glob.glob("./data/DIV2K_valid_LR_bicubic/X4/*.png"))
X8_valid_paths = sorted(glob.glob("./data/DIV2K_valid_LR_bicubic/X8/*.png"))
X16_valid_paths = sorted(glob.glob("./data/DIV2K_valid_LR_bicubic/X16/*.png"))
X32_valid_paths = sorted(glob.glob("./data/DIV2K_valid_LR_bicubic/X32/*.png"))
X64_valid_paths = sorted(glob.glob("./data/DIV2K_valid_LR_bicubic/X64/*.png"))

### FSRCNN - Fast Super-Resolution Convolutional Neural Network
The architecture below is described in this paper: [Accelerating the Super-Resolution Convolutional Neural Network](https://arxiv.org/abs/1608.00367)

I derived the formula for padding and absolute padding of the transposed convolution from this formula (which can be found in the [pytorch docs](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)):
$$
H_{\text{out}} = (H_{\text{in}} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
$$

Thanks to it the output size image is exactly ```n x larger``` than the input

In [7]:
class FSRCNN(nn.Module):
    def __init__(self, d: int, s: int, m: int, n: int):
        """
        Args:
            d: feature dimension
            s: shrinking dimension
            m: mapping layers
            n: scaling factor
        """
        super().__init__()
        
        self.model = nn.Sequential(
            self._conv(3, d, 5),
            self._conv(d, s, 1)
        )

        for _ in range(m):
            self.model.append(self._conv(s, s, 3))

        self.model.append(self._conv(s, d, 1))

        # Ensure the output image is exactly n times bigger than the input
        if n <= 9:
            padding = (9 - n + 1) // 2
            output_padding = (9 - n) % 2
        else:
            for i in range(n):
                padding = i - n + 9
                if padding % 2 == 0 and padding >= 0:
                    output_padding = i
                    break
        
        self.model.append(nn.ConvTranspose2d(d, 3, 9, stride=n, padding=padding, output_padding=output_padding))

    def forward(self, x):
        return self.model(x)

    def _conv(self, ni, nf, ks):
        return nn.Sequential(
            nn.Conv2d(ni, nf, ks, padding='same'),
            nn.PReLU()
        )

In [8]:
fsr_cnn = FSRCNN(56, 12, 4, 2).to(device)

In [12]:
summary(fsr_cnn, (32, 3, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
FSRCNN                                   [32, 3, 256, 256]         --
├─Sequential: 1-1                        [32, 3, 256, 256]         --
│    └─Sequential: 2-1                   [32, 56, 128, 128]        --
│    │    └─Conv2d: 3-1                  [32, 56, 128, 128]        4,256
│    │    └─PReLU: 3-2                   [32, 56, 128, 128]        1
│    └─Sequential: 2-2                   [32, 12, 128, 128]        --
│    │    └─Conv2d: 3-3                  [32, 12, 128, 128]        684
│    │    └─PReLU: 3-4                   [32, 12, 128, 128]        1
│    └─Sequential: 2-3                   [32, 12, 128, 128]        --
│    │    └─Conv2d: 3-5                  [32, 12, 128, 128]        1,308
│    │    └─PReLU: 3-6                   [32, 12, 128, 128]        1
│    └─Sequential: 2-4                   [32, 12, 128, 128]        --
│    │    └─Conv2d: 3-7                  [32, 12, 128, 128]        1,308
│    │  

In [15]:
class SRDataset(Dataset):
    def __init__(self, target_paths: list[str], scale: int, crop_size: int=128):
        self.target = target_paths
        self.crop_size = crop_size
        self.scale = scale

        # make sure that the scaling is possible
        assert self.scale % 2 == 0 and self.scale <= crop_size
        
        self.transforms = v2.Compose([
            v2.PILToTensor(),
            v2.Lambda(lambda x: x/255.0)
        ])

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        target = Image.open(self.target[idx])
        target = self.random_crop(target, self.crop_size)
        inp = target.resize((target.width//self.scale, target.height//self.scale), Image.BICUBIC)
        return self.transforms(inp), self.transforms(target)

    def random_crop(self, img, size):
        w, h = img.size
        if w < size or h < size:
            img = img.resize((size, size), Image.BICUBIC)  # Resize if image is too small
            
        x = random.randint(0, w - size)
        y = random.randint(0, h - size)
        return img.crop((x, y, x + size, y + size))

In [16]:
scales = [2, 4, 8, 16, 32, 64]

for scale in tqdm(scales):
    train_ds = SRDataset(HR_train_paths, scale)
    valid_ds = SRDataset(HR_valid_paths, scale)

    train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=10)
    valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=10)

    model = FSRCNN(56, 12, 4, scale).to(device)
    loss_fn = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in tqdm(range(10), desc="Epochs", leave=False):
        
        
    

SyntaxError: incomplete input (4224303332.py, line 17)

In [17]:
def train_step(model, dataloader, optimizer, scheduler, loss_fn, device, accuracy):
    avg_accuracy = 0
    avg_loss = 0
    model.train()

    # Necessary for dataloader to work with tqdm without errors, tqdm interferes with dataloader workers shutdown process,
    # therefore I separated them
    dl_iterator = iter(dataloader)
    for _ in tqdm(range(len(dataloader)), desc="Training", leave=False):
        batch, target = next(dl_iterator)
        batch, target = batch.to(device), target.to(device)
        
        logits = model(batch)
        loss = loss_fn(logits, target)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        avg_loss += loss.item()
        avg_accuracy += accuracy(logits, target).item()

    avg_loss /= len(dataloader)
    avg_accuracy /= len(dataloader)
    return avg_loss, avg_accuracy

def valid_step(model, dataloader, loss_fn, device, accuracy):
    avg_accuracy = 0
    avg_loss = 0
    model.eval()

    # Necessary for dataloader to work with tqdm without errors
    dl_iterator = iter(dataloader)
    with torch.inference_mode():
        for _ in tqdm(range(len(dataloader)), desc="Validation", leave=False):
            batch, target = next(dl_iterator)
            batch, target = batch.to(device), target.to(device)
            
            logits = model(batch)
            loss = loss_fn(logits, target)

            avg_loss += loss.item()
            avg_accuracy += accuracy(logits, target).item()

    avg_loss /= len(dataloader)
    avg_accuracy /= len(dataloader)
    return avg_loss, avg_accuracy

def train(model, train_dl, valid_dl, optimizer, loss_fn, epochs):
    psnr = PeakSignalNoiseRatio()
    ssim = StructuralSimilarityIndexMeasure(data_range=2.0)
    lpips = LearnedPerceptualImagePatchSimilarity().to(device)

    for epoch in tqdm(range(epochs), desc="Epochs"):
        train_loss, train_PSNR, SSIM = train_step(
            model,
            train_dl,
            optimizer,
            loss_fn
        )

        valid_loss, valid_PSNR, valid_SSIM = valid_step(
            model,
            valid_dl,
            loss_fn
        )

        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_PSNR: {train_PSNR:.4f} | "
            f"train_SSIM: {train_SSIM:.4f} | "
            f"valid_loss: {valid_loss:.4f} | "
            f"valid_PSNR: {valid_PSNR:.4f} | "
            f"valid_SSIM: {valid_SSIM:.4f}"
        )