In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import lib
from importlib import reload
reload(lib)

class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator,self).__init__()
        self.sequence = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.sequence(x)
        return out

def run():

    device = torch.device('mps')

    print('Loading training data...')
    training_data = np.load('data/training_data.npz')
    input_dataset = training_data['x']
    output_dataset = training_data['y']

    generator = lib.Generator(3, 1)
    generator.to(device)
    discriminator = Discriminator(input_dataset[0].shape[0])

    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=learning_rate)
    optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()

    for index in range(len(input_dataset)):
        for index, epoch in enumerate(range(num_epochs)):

            print(f'Starting epoch {index+1}...')

            input_image = input_dataset[index]
            input_sample = torch.from_numpy(input_image).float().to(device)
            input_sample = torch.permute(input_sample, (2, 0, 1)) # Channels first
            input_sample = torch.stack([input_sample], 0) # Mini-batch of size 1

            # Generate a fake sample
            print('Generating fake sample...')
            fake_sample = generator(input_sample)
            
            fake_image = fake_sample.cpu().detach().numpy().squeeze()
            plt.imshow(input_image)
            plt.show()
            plt.imshow(fake_image, cmap='viridis')
            plt.show()
            plt.imshow(output_dataset[index], cmap='viridis')
            plt.show()

            # Run the real and fake samples through the discriminator
            print('Discriminating...')
            discriminator_real_result = discriminator(input_sample)
            discriminator_fake_sample = discriminator(fake_sample)

            # Calculate the loss for each sample, and then the average
            loss_discriminator_real_sample = criterion(discriminator_real_result, torch.ones_like(discriminator_real_result))
            loss_discriminator_fake_sample = criterion(discriminator_fake_sample, torch.zeros_like(discriminator_fake_sample)) 
            loss_discriminator = (loss_discriminator_real_sample + loss_discriminator_fake_sample) / 2

            # Backpropagate to train the discriminator
            discriminator.zero_grad()
            loss_discriminator.backward(retain_graph=True)
            optimizer_discriminator.step()
        
            ### Calculate the loss of the fake sample, according to the discriminator
            loss_discriminator_fake_sample = discriminator(fake_sample).view(-1)
            loss_generator = criterion(loss_discriminator_fake_sample, torch.ones_like(loss_discriminator_fake_sample))
            
            # Backpropagate to train the generator
            generator.zero_grad()
            loss_generator.backward()
            optimizer_generator.step()
        
            if id == 0:
                print( "Epoch: {epoch} \t Discriminator Loss: {lossD} Generator Loss: {lossG}".format( epoch=epoch, lossD=lossD, lossG=lossG))

learning_rate = 2e-4
batch_size = 32
num_epochs = 25

run()
