In [None]:
from model import DEModel
from data import DEData
from torch import nn
from IPython import display
from scipy import stats

import numpy as np

import torch
import plot

In [None]:
batch_size = 64
ngpu = 0
# Epochs set to '1' for testing
epochs = 1000
learning_rate = 1e-4
beta1 = 0.9

In [None]:
dataset = DEData('dE_data', standardize=True)

In [None]:
# Take out the outliers for both initial energy distro and final state

import utils
import matplotlib.pyplot as plt

bound_1 = -4.9
bound_2 = -4.

initial_zscore_threshold = 2.25
final_zscore_threshold = 2.25

initial_int = []
final_int = []

for i, data in enumerate(dataset):
    initial = data[0]
    final = data[1]

    initial_eta_trim, initial_trim = utils.trim( dataset.start_eta, initial, bound_1, bound_2 )
    final_eta_trim, final_trim = utils.trim( dataset.final_eta, final, bound_1, bound_2 )

    initial_int.append( np.trapz( initial_trim, initial_eta_trim ) )
    final_int.append( np.trapz( final_trim, final_eta_trim ) )

initial_int = np.array(initial_int)
final_int = np.array(final_int)

initial_zscore = np.abs( stats.zscore( initial_int ) )
final_zscore = np.abs( stats.zscore( final_int ) )

to_remove = []

for i, _ in enumerate(initial_int):
    if initial_zscore[i] > initial_zscore_threshold or final_zscore[i] > final_zscore_threshold:
        to_remove.append(i)

dataset = dataset.delete_elements(to_remove)

print(len(dataset))

# plt.scatter( np.linspace( 0, 1, len(initial_zscore) ), initial_zscore )
# plt.scatter( np.linspace( 0, 1, len(final_zscore) ), final_zscore )
# plt.show()

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
dE_model = DEModel(len(dataset.start_eta), len(dataset.final_eta))

In [None]:
optimizer = torch.optim.Adamax(dE_model.parameters(), lr=learning_rate)

In [None]:
loss_func = nn.MSELoss()

In [None]:
loss_total = []

for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        actual_batch_size = data[1].shape[0]
        etas = dataset.final_eta
        dE_deta_initial = data[0].reshape(actual_batch_size,1, len(dataset.start_eta))
        dNch_deta_final = data[1].reshape(actual_batch_size, 1, len(dataset.final_eta))

        optimizer.zero_grad()

        output = dE_model(dE_deta_initial.float())

        loss = loss_func(output, dNch_deta_final.float())

        loss.backward()
        optimizer.step()

        #loss_total.append(loss.detach().numpy())

        if i % 128 == 0:
            display.clear_output(wait=True)
            #plot.plot_telemetry( loss_total, output.detach().numpy(), dNch_deta_final, etas)

            print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss: {loss:.4f}")

In [None]:
torch.save(dE_model, './dE_model_ztest.pt')