In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils import data
import torchvision

import matplotlib.pyplot as plt
import numpy as np

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        self.relu = nn.ReLU(inplace=True)
        
        # Encoder layers
        self.conv_1 = nn.Conv2d(in_channels=3,    out_channels=64,   kernel_size=3, padding='same')
        self.conv_2 = nn.Conv2d(in_channels=64,   out_channels=64,   kernel_size=3, padding='same')
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_3 = nn.Conv2d(in_channels=64,   out_channels=128,   kernel_size=3, padding='same')
        self.conv_4 = nn.Conv2d(in_channels=128,  out_channels=128,   kernel_size=3, padding='same')
        self.max_pool_4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_5 = nn.Conv2d(in_channels=128,   out_channels=256,  kernel_size=3, padding='same')
        self.conv_6 = nn.Conv2d(in_channels=256,   out_channels=256,  kernel_size=3, padding='same')
        self.max_pool_6 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_7 = nn.Conv2d(in_channels=256,   out_channels=512,  kernel_size=3, padding='same')
        self.conv_8 = nn.Conv2d(in_channels=512,   out_channels=512,  kernel_size=3, padding='same')
        self.max_pool_8 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_9 = nn.Conv2d(in_channels=512,   out_channels=1024, kernel_size=3, padding='same')
        self.conv_10 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding='same')
        
        # Decoder layers
        self.upsample_11 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_11 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding='same')
        self.conv_12 = nn.Conv2d(in_channels=512,  out_channels=512, kernel_size=3, padding='same')
        self.conv_13 = nn.Conv2d(in_channels=512,  out_channels=512, kernel_size=3, padding='same')
        
        self.upsample_14 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_14 = nn.Conv2d(in_channels=512,  out_channels=256, kernel_size=3, padding='same')
        self.conv_15 = nn.Conv2d(in_channels=256,  out_channels=256, kernel_size=3, padding='same')
        self.conv_16 = nn.Conv2d(in_channels=256,  out_channels=256, kernel_size=3, padding='same')
        
        self.upsample_17 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_17 = nn.Conv2d(in_channels=256,  out_channels=128, kernel_size=3, padding='same')
        self.conv_18 = nn.Conv2d(in_channels=128,  out_channels=128, kernel_size=3, padding='same')
        self.conv_19 = nn.Conv2d(in_channels=128,  out_channels=128, kernel_size=3, padding='same')
        
        self.upsample_20 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_20 = nn.Conv2d(in_channels=128,  out_channels=64,  kernel_size=3, padding='same')
        self.conv_21 = nn.Conv2d(in_channels=64,   out_channels=64,  kernel_size=3, padding='same')
        self.conv_22 = nn.Conv2d(in_channels=64,   out_channels=64,  kernel_size=3, padding='same')

        self.conv_23 = nn.Conv2d(in_channels=64,   out_channels=1,   kernel_size=3, padding='same')
        

    def forward(self, x):

        encode_0 = x
        print(encode_0.shape)
        
        # Encoder
        print("Encoding 1...")
        encode_1_convolved = self.relu(self.conv_2(self.relu(self.conv_1(encode_0)))) # Channels 3 -> 64
        encode_1_downsampled = self.max_pool_2(encode_1_convolved)
        print(encode_1_convolved.shape)
        print(encode_1_downsampled.shape)
        
        print("Encoding 2...")
        encode_2_convolved = self.relu(self.conv_4(self.relu(self.conv_3(encode_1_downsampled)))) # Channels 64 -> 128
        encode_2_downsampled = self.max_pool_4(encode_2_convolved)
        print(encode_2_convolved.shape)
        print(encode_2_downsampled.shape)
        
        print("Encoding 3...")
        encode_3_convolved = self.relu(self.conv_6(self.relu(self.conv_5(encode_2_downsampled)))) # Channels 128 -> 256
        encode_3_downsampled = self.max_pool_6(encode_3_convolved)
        print(encode_3_convolved.shape)
        print(encode_3_downsampled.shape)
        
        print("Encoding 4...")
        encode_4_convolved = self.relu(self.conv_8(self.relu(self.conv_7(encode_3_downsampled)))) # Channels 256 -> 512
        encode_4_downsampled = self.max_pool_8(encode_4_convolved)
        print(encode_4_convolved.shape)
        print(encode_4_downsampled.shape)
        
        print("Encoding 5...")
        encode_5 = self.relu(self.conv_10(self.relu(self.conv_9(encode_4_downsampled)))) # Channels 512 -> 1024
        print(encode_5.shape)

        # Decoder
        print('Decoding 1...')
        decode_1 = self.conv_13(self.conv_12(torch.cat([encode_4_convolved, self.conv_11(self.upsample_11(encode_5))]))) # Channels 1024 -> 512
        
        print('Decoding 2...')
        decode_2 = self.conv_16(self.conv_15(torch.cat([encode_3_convolved, self.conv_14(self.upsample_14(decode_1))]))) # Channels 512 -> 256
        
        print('Decoding 3...')
        decode_3 = self.conv_19(self.conv_18(torch.cat([encode_2_convolved, self.conv_17(self.upsample_17(decode_2))]))) # Channels 256 -> 128
        
        print('Decoding 4...')
        decode_3 = self.conv_22(self.conv_21(torch.cat([encode_1_convolved, self.conv_20(self.upsample_20(decode_3))]))) # Channels 128 -> 64
        
        print('Decoding 5...')
        out = self.conv_23(decode_4) ## Channels 64 -> 1
        
        return out

class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator,self).__init__()
        self.sequence = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.sequence(x)
        return out

def run():

    print('Loading training data...')
    training_data = np.load('training_data.npz')
    input_dataset = training_data['x']
    output_dataset = training_data['y']

    generator = Generator()
    generator.to(torch.device('cuda'))
    discriminator = Discriminator(input_dataset[0].shape[0])

    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=learning_rate)
    optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()

    for index in range(len(input_dataset)):
        for index, epoch in enumerate(range(num_epochs)):

            print(f'Starting epoch {index+1}...')

            input_sample = input_dataset[index]
            print(input_sample.shape)
            input_sample = torch.from_numpy(input_sample).float().to(torch.device('cuda'))
            input_sample = torch.permute(input_sample, (2, 0, 1)) # Channels first
            input_sample = torch.stack([input_sample], 0) # Mini-batch of size 1

            print(input_sample.shape)

            # Generate a fake sample
            print('Generating fake sample...')
            generated_fake_sample = generator(input_sample)

            # Run the real and fake samples through the discriminator
            print('Discriminating...')
            discriminator_real_result = discriminator(input_sample)
            discriminator_fake_sample = discriminator(fake_sample)

            # Calculate the loss for each sample, and then the average
            loss_discriminator_real_sample = criterion(discriminator_real_result, torch.ones_like(discriminator_real_result))
            loss_discriminator_fake_sample = criterion(discriminator_fake_sample, torch.zeros_like(discriminator_fake_sample)) 
            loss_discriminator = (loss_discriminator_real_sample + loss_discriminator_fake_sample) / 2

            # Backpropagate to train the discriminator
            discriminator.zero_grad()
            loss_discriminator.backward(retain_graph=True)
            optimizer_discriminator.step()
        
            ### Calculate the loss of the fake sample, according to the discriminator
            loss_discriminator_fake_sample = discriminator(fake_sample).view(-1)
            loss_generator = criterion(loss_discriminator_fake_sample, torch.ones_like(loss_discriminator_fake_sample))
            
            # Backpropagate to train the generator
            generator.zero_grad()
            loss_generator.backward()
            optimizer_generator.step()
        
            if id == 0:
                print( "Epoch: {epoch} \t Discriminator Loss: {lossD} Generator Loss: {lossG}".format( epoch=epoch, lossD=lossD, lossG=lossG))

learning_rate = 2e-4
batch_size = 32
num_epochs = 25

run()
