In [None]:
%matplotlib inline

from IPython import display

from hydroml.model import Generator, Discriminator
from hydroml.data import Data
from hydroml.utils import ImageChannel, Mode

from torch import nn

import torch
import hydroml.plot as plot

In [None]:
batch_size = 64
ngpu = 0
# Epochs set to '1' for testing
epochs = 1000
learning_rate = 5e-5
beta1 = 0.9
nz = 128
ndf = 512
ngf = 512
n_samples = batch_size*20

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
dataset = Data( n_samples )
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
# This is used to pass data to and from the hooks. Communicates state and layer output images.
d_image_channel = ImageChannel()
g_image_channel = ImageChannel()

# Hook onto one of the model layers to see the output of the layer.
def hook_to_generator(generator):
    generator.input_layer.register_forward_hook(plot.generator_layer_hook("Input Layer", g_image_channel))
    generator.layer1.register_forward_hook(plot.generator_layer_hook("Layer 1", g_image_channel))
    generator.layer2.register_forward_hook(plot.generator_layer_hook("Layer 2", g_image_channel))
    generator.layer3.register_forward_hook(plot.generator_layer_hook("Layer 3", g_image_channel))
    generator.output_layer.register_forward_hook(plot.generator_layer_hook("Output Layer", g_image_channel))

def hook_to_discriminator(discriminator):
    discriminator.input_layer.register_forward_hook(plot.discriminator_layer_hook("Input Layer", d_image_channel))
    discriminator.layer1.register_forward_hook(plot.discriminator_layer_hook("Layer 1", d_image_channel))
    discriminator.layer2.register_forward_hook(plot.discriminator_layer_hook("Layer 2", d_image_channel))
    discriminator.layer3.register_forward_hook(plot.discriminator_layer_hook("Layer 3", d_image_channel))
    discriminator.layer4.register_forward_hook(plot.discriminator_layer_hook("Layer 3", d_image_channel))

In [None]:
discriminator = Discriminator(ndf)
generator = Generator(ngf, nz)

# You can toggle these on or off as you like by commenting them out.

# hook_to_discriminator(discriminator)
# hook_to_generator(generator)

In [None]:
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

gen_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
loss_func = nn.BCEWithLogitsLoss()

In [None]:
hard_real = torch.full((batch_size,), 1, dtype=torch.float32, device=device, requires_grad=False)
hard_fake = torch.full((batch_size,), 0, dtype=torch.float32, device=device, requires_grad=False)

In [None]:
disc_loss_total = []
gen_loss_total = []
real_total = []
fake_total = []

for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        i_batch_size = data.size()[0]
        i_batch = data.to(device)

        #Discriminator

        # Add noise to labels
        real_labels = torch.randn_like(hard_real) * 0.10 + hard_real
        fake_labels = torch.randn_like(hard_fake) * 0.10 + hard_fake

        disc_optimizer.zero_grad()

        d_image_channel.set_mode(Mode.REAL)
        real_output = discriminator(i_batch).view(-1)

        d_image_channel.set_mode(Mode.NOISE)
        fake_output = discriminator(torch.randn( (batch_size, 1, 128 ) , device=device)).view(-1)

        real_loss = loss_func(real_output, real_labels)
        fake_loss = loss_func(fake_output, fake_labels)

        discriminator_loss = fake_loss + real_loss
        discriminator_loss.backward()
        disc_optimizer.step()

        #Generator

        gen_optimizer.zero_grad()

        g_image_channel.set_mode(Mode.NOISE)
        generated_data = generator( torch.randn( batch_size, nz, 1, device=device) )

        d_image_channel.set_mode(Mode.GENERATED)
        predictions = discriminator(generated_data).view(-1)

        gen_loss = loss_func( torch.ones_like(predictions), predictions )
        gen_loss.backward()
        gen_optimizer.step()

        # Keep statistics
        mean_real_output = real_output.mean().item()
        mean_fake_output = fake_output.mean().item()

        disc_loss_total.append(discriminator_loss.item())
        gen_loss_total.append(gen_loss.item())
        real_total.append(mean_real_output)
        fake_total.append(mean_fake_output)

        display.clear_output(wait=True)

        plot.plot_telemetry(d_image_channel, g_image_channel, disc_loss_total, gen_loss_total, real_total, fake_total, data.detach().numpy(), generated_data.detach().numpy())

        # print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss_d: {discriminator_loss:.4f}, mean_real_output: {mean_real_output:.4f}, mean_fake_output: {mean_fake_output:.4f}")
        d_image_channel.reset()
        g_image_channel.reset()