In [None]:
%matplotlib inline

from model import Simple
from data import Data, RealData
from utils import ImageChannel, Mode

from torch import nn

import torch
import plot

In [None]:
batch_size = 64
ngpu = 0
# Epochs set to '1' for testing
epochs = 500
learning_rate = 1e-4
beta1 = 0.9
n_samples = 8192

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
dataset = Data( n_samples )

data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
# This is used to pass data to and from the hooks. Communicates state and layer output images.
d_image_channel = ImageChannel()

def hook_to_simple(d):
    handlers = []
    handlers.append(d.input_layer.register_forward_hook(plot.simple_layer_hook("Input Layer", d_image_channel)))
    handlers.append(d.second_layer.register_forward_hook(plot.simple_layer_hook("Second Layer", d_image_channel)))
    handlers.append(d.third_layers.register_forward_hook(plot.simple_layer_hook("Third Layer", d_image_channel)))
    handlers.append(d.fourth_layers.register_forward_hook(plot.simple_layer_hook("Fourth Layer", d_image_channel)))

    return handlers

In [None]:
simple = Simple()

# You can toggle these on or off as you like by commenting them out.

# handlers = hook_to_simple(simple)

In [None]:
optimizer = torch.optim.Adam(simple.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
loss_func = nn.MSELoss()

In [None]:
disc_loss_total = []

for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        baryons = data[0]
        protons = data[1]

        optimizer.zero_grad()

        #d_image_channel.set_mode(Mode.GENERATED)
        output = simple(baryons)

        loss = loss_func(output, protons)

        loss.backward()
        optimizer.step()

        #disc_loss_total.append(loss.detach().numpy())

        if i % 128 == 0:
            #plot.plot_telemetry(d_image_channel, disc_loss_total, output.detach().numpy(), protons)

            print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss: {loss:.4f}")

        #d_image_channel.reset()

# for handle in handlers:
#     handle.remove()

In [None]:
torch.save(simple, './simple_model.pt')