In [None]:
%matplotlib inline

from model import DEModel
from data import DEData
from utils import ImageChannel, Mode

from torch import nn

import torch
import plot

In [None]:
batch_size = 64
ngpu = 0
# Epochs set to '1' for testing
epochs = 1000
learning_rate = 1e-4
beta1 = 0.9

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
dataset = DEData()

data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
# This is used to pass data to and from the hooks. Communicates state and layer output images.
image_channel = ImageChannel()

def hook_to_model(d):
    handlers = []
    handlers.append(d.input_layer.register_forward_hook(plot.simple_layer_hook("Input Layer", image_channel)))
    handlers.append(d.second_layer.register_forward_hook(plot.simple_layer_hook("Second Layer", image_channel)))
    handlers.append(d.third_layer.register_forward_hook(plot.simple_layer_hook("Third Layer", image_channel)))
    handlers.append(d.fourth_layer.register_forward_hook(plot.simple_layer_hook("Fourth Layer", image_channel)))

    return handlers

In [None]:
dE_model = DEModel()

# You can toggle these on or off as you like by commenting them out.
#handlers = hook_to_model(dE_model)

In [None]:
optimizer = torch.optim.Adam(dE_model.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
loss_func = nn.MSELoss()

In [None]:
loss_total = []

for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        actual_batch_size = data[1].shape[0]
        etas = dataset.data_axis
        dE_deta_initial = data[0].reshape(actual_batch_size,1,141)
        dNch_deta_final = data[1].reshape(actual_batch_size,1,141)

        optimizer.zero_grad()

        image_channel.set_mode(Mode.GENERATED)
        output = dE_model(dE_deta_initial.float())

        loss = loss_func(output, dNch_deta_final.float())

        loss.backward()
        optimizer.step()

        loss_total.append(loss.detach().numpy())

        if i % 128 == 0:
            #plot.plot_telemetry(image_channel, loss_total, output.detach().numpy(), dNch_deta_final, etas)

            print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss: {loss:.4f}")

        image_channel.reset()

# for handle in handlers:
#     handle.remove()

In [None]:
torch.save(dE_model, './dE_model.pt')