In [None]:
# %matplotlib inline

from IPython import display

import numpy as np
import matplotlib.pyplot as plt

from hydroml import model as m
from hydroml.data import Data
from torch import nn

import torch
import hydroml.plot as plot

In [None]:
batch_size = 64
ngpu = 0
# Epochs set to '1' for testing
epochs = 100
learning_rate = 5e-5
beta1 = 0.9
nz = 100
ndf = 512
ngf = 512
n_samples = batch_size*20

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
dataset = Data( n_samples )
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
discriminator = m.Discriminator(ndf)
print(discriminator)

generator = m.Generator(ngf, nz)
print(generator)

# Hook onto one of the model layers to see the output of the layer.
# discriminator.input_layer.register_forward_hook(plot.layer_hook())
# discriminator.layer1.register_forward_hook(plot.layer_hook())
# discriminator.layer2.register_forward_hook(plot.layer_hook())
# discriminator.layer3.register_forward_hook(plot.layer_hook())
# discriminator.layer4.register_forward_hook(plot.layer_hook())

In [None]:
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

gen_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
loss_func = nn.BCEWithLogitsLoss()

In [None]:
hard_real = torch.full((batch_size,), 1, dtype=torch.float32, device=device, requires_grad=False)
hard_fake = torch.full((batch_size,), 0, dtype=torch.float32, device=device, requires_grad=False)

In [None]:
disc_loss_total = []
gen_loss_total = []
real_total = []
fake_total = []

for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        i_batch_size = data.size()[0]
        i_batch = data.to(device)

        #Discriminator

        # Add noise to labels
        real_labels = torch.randn_like(hard_real) * 0.10 + hard_real
        fake_labels = torch.randn_like(hard_fake) * 0.10 + hard_fake

        disc_optimizer.zero_grad()
        real_output = discriminator(i_batch).view(-1)
        fake_output = discriminator(torch.randn( (batch_size, 1, 128 ) , device=device)).view(-1)

        real_loss = loss_func(real_output, real_labels)
        fake_loss = loss_func(fake_output, fake_labels)

        discriminator_loss = fake_loss + real_loss
        discriminator_loss.backward()
        disc_optimizer.step()

        #Generator

        gen_optimizer.zero_grad()
        generated_data = generator( torch.randn( batch_size, nz, 1 ) )
        print(f'Generated Data Size: {generated_data.size()}')
        predictions = discriminator(generated_data).view(-1)
        print(f'Predictions Size: {predictions.size()}')

        gen_loss = loss_func( torch.ones_like(predictions), predictions )
        gen_loss.backward()
        disc_optimizer.step()

        # Keep statistics
        mean_real_output = real_output.mean().item()
        mean_fake_output = fake_output.mean().item()

        disc_loss_total.append(discriminator_loss.item())
        gen_loss_total.append(gen_loss.item())
        real_total.append(mean_real_output)
        fake_total.append(mean_fake_output)

        display.clear_output(wait=True)

        plot.plot_telemetry(disc_loss_total, gen_loss_total, real_total, fake_total)

        print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss_d: {discriminator_loss:.4f}, mean_real_output: {mean_real_output:.4f}, mean_fake_output: {mean_fake_output:.4f}")