In [1]:
%matplotlib inline

from model import DEModel
from data import DEData
from torch import nn
from IPython import display
import numpy as np

import torch
import plot

In [2]:
batch_size = 64
ngpu = 0
# Epochs set to '1' for testing
epochs = 1000
learning_rate = 1e-4
beta1 = 0.9

In [3]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [4]:
dataset = DEData('dE_data') + DEData('dE_data_2')

data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [5]:
dE_model = DEModel()

In [6]:
optimizer = torch.optim.Adam(dE_model.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [7]:
loss_func = nn.MSELoss()

In [8]:
loss_total = []

for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        actual_batch_size = data[1].shape[0]
        etas = dataset.data_axis
        dE_deta_initial = data[0].reshape(actual_batch_size,1,64)
        dNch_deta_final = data[1].reshape(actual_batch_size,1,141)

        if dE_deta_initial.max() < 1.:
            continue

        optimizer.zero_grad()

        output = dE_model(dE_deta_initial.float())

        loss = loss_func(output, dNch_deta_final.float())

        loss.backward()
        optimizer.step()

        loss_total.append(loss.detach().numpy())

        if i % 128 == 0:
            display.clear_output(wait=True)
            #plot.plot_telemetry( loss_total, output.detach().numpy(), dNch_deta_final, etas)

            print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss: {loss:.4f}")

print(np.array(loss_total).mean())

Epoch [1000/1000], Batch [129/171], loss: 0.8355
1.5869615


In [9]:
torch.save(dE_model, './dE_model.pt')