In [15]:
import os
import nbimporter
import tqdm
import time
from dataset import *
from model import *
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms


In [None]:
# Hyperparameters
debug = False
if debug:
    device = "cpu"
else:
    device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
lr = 0.001
batch_size = 16
epochs = 50


# Train/Test split
dataset = Contrastive_Dataset("../data/cryptopunks/")
train_set, test_set = random_split(dataset, [0.9, 0.1])

transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


# Dataloader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

# Model 
model = SuperResolution().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss(reduction="sum").to(device)

# Tensorboard
ifx = "256x"
writer = SummaryWriter(f"../runs/super_resolution_{ifx}/")


In [None]:
# Training loop
for epoch in range(epochs):
    loop = tqdm.tqdm(train_loader, total=len(train_loader), leave=False)
    for idx, batch in enumerate(loop):
        img1, img2, label = batch
        img1, img2 = img1.to(device), img2.to(device)


        sr_image = model(img2)
        loss = criterion(sr_image, img1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Resize the low-res originals to match the super-resolved images
        img2_resized = F.interpolate(img2[:8,...], size=sr_image.shape[2:])

        combined_grid = torch.cat((sr_image[:8,...], img2_resized), dim=0)
        combined_grid = torchvision.utils.make_grid(combined_grid, padding=5)

        # Add images to tensorboard
        writer.add_image('Batch LQ/SR Images', combined_grid, idx)

    
    # Add epoch loss
    writer.add_scalar('Loss/train', loss.item(), epoch)

    # Testing
    with torch.no_grad():
        min_loss = np.inf
        model.eval()
        for idx, batch in enumerate(test_loader):
            img1, img2, label = batch
            img1, img2 = img1.to(device), img2.to(device)
            sr_image = model(img2)
            loss = criterion(sr_image, img1)
        writer.add_scalar('Loss/test', loss.item(), epoch)

        # Add images to tensorboard
        # Convert images to grid format
        img1_grid = torchvision.utils.make_grid(img1[:4,...], padding=5)
        img2_grid = torchvision.utils.make_grid(img2[:4,...], padding=5)
        sr_image_grid = torchvision.utils.make_grid(sr_image[:4,...], padding=5)

        writer.add_image('Original HQ Images', img1_grid, epoch)
        writer.add_image('Original LQ Images', img2_grid, epoch)
        writer.add_image('Super Resolved Images', sr_image_grid, epoch)


        # Save the weights of the network if the current training loss is smaller than the lowest recorded training loss
        if loss.item() < min_loss:
            torch.save(model.state_dict(), './params/best_params.pth')
            min_loss = loss.item()
        model.train()
        model.train()

            