In [1]:


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from pathlib import Path
from PIL import Image



class PairedImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_dir = Path(input_dir)
        self.target_dir = Path(target_dir)
        self.transform = transform

        # Match files by name
        self.input_files = sorted(self.input_dir.glob("*"))
        self.target_files = sorted(self.target_dir.glob("*"))
    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_img = Image.open(self.input_files[idx])
        target_img = Image.open(self.target_files[idx])

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)

        return input_img, target_img
    

In [2]:
currentTransforms = transforms.Compose([transforms.ToTensor()])


trainingData = PairedImageDataset(input_dir=r"O:/Data upscale train/Dataset/train/input/_upscaleFactor2/", target_dir=r"O:/Data upscale train/Dataset/train/target/_upscaleFactor2/", transform=currentTransforms)
trainingLoader = DataLoader(trainingData, batch_size=1, shuffle=True)

validationData = PairedImageDataset(input_dir=r"O:/Data upscale train/Dataset/validate/input/_upscaleFactor2/", target_dir=r"O:/Data upscale train/Dataset/validate/target/_upscaleFactor2/", transform=currentTransforms)
validationLoader = DataLoader(validationData, batch_size=1, shuffle=False)


In [3]:
import torch.nn as nn
import torch.nn.functional as F


class upscaleModel(nn.Module):
    def __init__(self, upscale_factor=2):
        super(upscaleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.sigmoid(self.pixel_shuffle(self.conv3(x)))
        return x


training loop

In [None]:
from tqdm.notebook import tqdm
import torch.optim as optim
# Pick GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Move your model to the device
model = upscaleModel(upscale_factor=2).to(device)
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

train_loss, validation_loss = [], []
num_epoch = 300




for epoch in range(num_epoch):
    run_loss = 0.0
    model.train()
    for input, target in trainingLoader:
        input = input.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        forwardOut = model(input)
        #handle dimension mismatch before calculating MSE
        out_h, out_w = forwardOut.shape[2], forwardOut.shape[3]
        tgt_h, tgt_w = target.shape[2], target.shape[3]

        min_h = min(out_h, tgt_h)
        min_w = min(out_w, tgt_w)

        forwardOut = forwardOut[:, :, :min_h, :min_w]
        target     = target[:, :, :min_h, :min_w]
        #now calc loss
        MSE = loss(forwardOut, target)
        #backpropagate
        MSE.backward()
        optimizer.step()
        run_loss += MSE.item()
    
    train_loss.append(run_loss/len(trainingLoader.dataset))
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for input, target in validationLoader:
            input = input.to(device)
            target = target.to(device)
            forwardOut = model(input)
            #handle dimension mismatch
            out_h, out_w = forwardOut.shape[2], forwardOut.shape[3]
            tgt_h, tgt_w = target.shape[2], target.shape[3]

            min_h = min(out_h, tgt_h)
            min_w = min(out_w, tgt_w)

            forwardOut = forwardOut[:, :, :min_h, :min_w]
            target     = target[:, :, :min_h, :min_w]
            MSE = loss(forwardOut, target)
            run_loss += MSE.item()
    validation_loss.append(run_loss/len(validationLoader.dataset))
    scheduler.step(validation_loss[-1])
    print(f"Epoch {epoch+1}/{num_epoch}, Train Loss: {train_loss[epoch]}, Val Loss: {validation_loss[epoch]}")


Epoch 1/300, Train Loss: 0.06598141962879864, Val Loss: 0.2595246464101707


In [6]:
print(forwardOut.shape)
print(target.shape)

torch.Size([1, 1, 850, 846])
torch.Size([1, 1, 850, 845])


In [7]:

out_h, out_w = forwardOut.shape[2], forwardOut.shape[3]
tgt_h, tgt_w = target.shape[2], target.shape[3]

min_h = min(out_h, tgt_h)
min_w = min(out_w, tgt_w)

forwardOut = forwardOut[:, :, :min_h, :min_w]
target     = target[:, :, :min_h, :min_w]

print(forwardOut.shape)
print(target.shape)

torch.Size([1, 1, 850, 845])
torch.Size([1, 1, 850, 845])
