In [19]:


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from pathlib import Path
from PIL import Image



class PairedImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_dir = Path(input_dir)
        self.target_dir = Path(target_dir)
        self.transform = transform

        # Match files by name
        self.input_files = sorted(self.input_dir.glob("*"))
        self.target_files = sorted(self.target_dir.glob("*"))
    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_img = Image.open(self.input_files[idx])
        target_img = Image.open(self.target_files[idx])

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)

        return input_img, target_img
    

In [42]:
currentTransforms = transforms.Compose([transforms.ToTensor()])


trainingData = PairedImageDataset(input_dir=r"O:/Data upscale train/Dataset/train/input/_upscaleFactor2/", target_dir=r"O:/Data upscale train/Dataset/train/target/_upscaleFactor2/", transform=currentTransforms)
trainingLoader = DataLoader(trainingData, batch_size=1, shuffle=True)

validationData = PairedImageDataset(input_dir=r"O:/Data upscale train/Dataset/validate/input/_upscaleFactor2/", target_dir=r"O:/Data upscale train/Dataset/validate/target/_upscaleFactor2/", transform=currentTransforms)
validationLoader = DataLoader(validationData, batch_size=1, shuffle=False)


In [43]:
import torch.nn as nn
import torch.nn.functional as F


class upscaleModel(nn.Module):
    def __init__(self, upscale_factor=2):
        super(upscaleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.sigmoid(self.pixel_shuffle(self.conv3(x)))
        return x


training loop

In [None]:
from tqdm.notebook import tqdm
import cv2 as cv

#create instance
Net = upscaleModel(upscale_factor=2)

loss = nn.MSELoss()
optimizer = optim.Adam(Net.parameters(), lr=0.001)

train_loss, val_loss = [], []
num_epoch = 10

# Pick GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move your model to the device
Net.to(device)

for epoch in range(num_epoch):
    run_loss = 0.0
    Net.train()
    for input, target in trainingLoader:
        optimizer.zero_grad()
        forwardOut = Net(input.to(device))
        #handle dimension mismatch before calculating MSE
        outDim = forwardOut.cpu().detach().numpy().shape[2:] 
        targetDim = target.cpu().detach().numpy().shape[2:]
        diff1 = outDim[0] - targetDim[0]
        diff2 = outDim[1] - targetDim[1]
        if diff1 > 0:
            forwardOut = forwardOut[:, :, :targetDim[0], :]
        if diff2 > 0:
            forwardOut = forwardOut[:, :, :, :targetDim[1]]
        if diff1 < 0:
            target = target[:, :, :outDim[0], :]
        if diff2 < 0:
            target = target[:, :, :, :outDim[1]]

        #now calc loss
        MSE = loss(forwardOut, target.to(device))

        MSE.backward()
        optimizer.step()
        run_loss = MSE.item()

    train_loss.append(run_loss/len(trainingLoader.input_files))
    Net.eval()
    val_loss = 0.0
    with torch.no_grad():
        for input, target in validationLoader:
            forwardOut = Net(input.to(device))
            MSE = loss(forwardOut, target.to(device))
            run_loss += MSE.item()
    val_loss.append(run_loss/len(validationLoader.input_files))
    print(f"Epoch {epoch+1}/{num_epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}")   


In [54]:
print(forwardOut.cpu().detach().numpy().shape )

print(target.cpu().detach().numpy().shape[2:])

(1, 1, 922, 1446)
(921, 1445)
