In [9]:
# pip install torch torchvision
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn as nn
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from loss_functions import CombinedLoss
from discriminator_model import DiscriminatorModel
from upscaling_model import UpscalingModel
from data_sets_loaders import get_train_val_test_dataloaders

%matplotlib inline

In [20]:
def train_prototype(train_dl, val_dl, batch_size= 64, num_epochs=5, learning_rate=1e-4):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    discriminator_model = DiscriminatorModel(3, 8, 0.3).to(device).double()
    upscaling_model = UpscalingModel(1, 3, 3, 32, 10).to(device).double()



    upscaling_model.to(device)
    discriminator_model.to(device)
    upscaling_model.train()
    discriminator_model.train()

    torch.manual_seed(42)
    criterion = CombinedLoss(0.0, 1.0, 1.0, 0.0, 4.0).double()
    optimizer_upscale = torch.optim.Adam(upscaling_model.parameters(), lr=learning_rate) 
    optimizer_discriminator = torch.optim.Adam(discriminator_model.parameters(), lr=learning_rate) 
    
    iters, train_losses, val_losses = [], [], []
    n = 0
    for epoch in range(num_epochs):

        train_count    = 0
        val_count      = 0
        train_loss_sum = 0
        val_loss_sum   = 0

        for inputs_lr, targets_hr in train_dl:

            inputs_lr = inputs_lr.to(device).double()

            targets_hr = targets_hr.to(device)

            real_labels = torch.ones(inputs_lr.shape[0], 1).to(device)
            generated_labels = torch.zeros(inputs_lr.shape[0], 1).to(device)

            # Forward Pass
            generated_hr = upscaling_model(inputs_lr)
            discriminator_gen = discriminator_model(generated_hr.detach())
            discriminator_real = discriminator_model(targets_hr)


            # Compute Loss
            real_loss = criterion(None, real_labels, discriminator_real, is_discriminator=True)
            generated_loss = criterion(None, generated_labels, discriminator_gen, is_discriminator=True)


            single_train_loss = real_loss + generated_loss

            gen_loss = criterion(generated_hr, targets_hr, discriminator_gen)

            train_loss_sum += gen_loss.item()

            # Backpropagation
            optimizer_upscale.zero_grad()
            optimizer_discriminator.zero_grad()
            single_train_loss.backward()
            #gen_loss.backward()
            optimizer_upscale.step()   
            optimizer_discriminator.step()
            train_count += 1
        

        #For Validation Set 
        with torch.no_grad():
            for inputs_lr_val, targets_hr_val in val_dl: 
                inputs_lr_val = inputs_lr_val.to(device)
                targets_hr_val = targets_hr_val.to(device)

                real_labels_val = torch.ones(inputs_lr_val.shape[0], 1).to(device)
                generated_labels_val = torch.zeros(inputs_lr_val.shape[0], 1).to(device)

                # Forward Pass
                generated_hr_val = upscaling_model(inputs_lr_val)
                discriminator_gen_val = discriminator_model(generated_hr_val)
                discriminator_real_val = discriminator_model(targets_hr_val)

                # Compute Loss
                real_loss_val = criterion(None, real_labels_val, discriminator_real_val, is_discriminator=True)
                generated_loss_val = criterion(None, generated_labels_val, discriminator_gen_val, is_discriminator=True)


                single_val_loss = real_loss_val + generated_loss_val

                val_loss_sum += single_val_loss
                val_count += 1

        iters.append(n)

        t_loss = train_loss_sum/train_count
        v_loss = val_loss_sum/val_count
        train_losses.append(t_loss)
        val_losses.append(v_loss)


        n += 1
        print("#"+str(n)+": training loss value = "+str(t_loss)+" validation loss value = "+str(v_loss))

    return(iters, train_losses, val_losses)


In [21]:
cropped_dir = "./cropped_images"
train_dl, val_dl, test_dl = get_train_val_test_dataloaders(8)

iters, train_losses, val_losses = train_prototype(train_dl, val_dl)

plt.title("Losses Curve")
plt.plot(iters, train_losses, label="Train")
plt.plot(iters, val_losses, label="Validation")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.show()