# $\frac{dE}{d\eta}$ Model

## Training

In [None]:
import sys
sys.path.append('../hydroml')

from hydroml.model import DELinearModel
from hydroml.dataset import EnergyDensityDataset
from hydroml.utils import trim
from hydroml.plot import plot_cc_graph
from torch import nn
from IPython import display

import numpy as np

import torch

In [None]:
dataset = EnergyDensityDataset('../Datasets/dE_data-5.02tev/dE_detas_initial', '../Datasets/dE_data-5.02tev/dET_deta_final')

In [None]:
batch_size = 64
ngpu = 0
epochs = 500
learning_rate = 1e-3
beta1 = 0.9

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
dE_model = DELinearModel()

In [None]:
optimizer = torch.optim.Adamax(dE_model.parameters(), lr=learning_rate)

In [None]:
loss_func = nn.MSELoss()

In [None]:
for epoch in range(epochs):
    for i, data in enumerate(data_loader):
        actual_batch_size = data[1].shape[0]

        etas_start = dataset.start_eta
        etas_final = dataset.final_eta
        dE_deta_initial = data[0].reshape(actual_batch_size,1, len(dataset.start_eta))
        dNch_deta_final = data[1].reshape(actual_batch_size, 1, len(dataset.final_eta))

        optimizer.zero_grad()

        output = dE_model(dE_deta_initial.float())

        loss = loss_func(output, dNch_deta_final.float())

        loss.backward()
        optimizer.step()

        display.clear_output(wait=True)
        #plot.plot_telemetry(output.detach().numpy(), dNch_deta_final, etas)

        print(f"Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(dataset) // batch_size}], loss: {loss:.4f}")

In [None]:
torch.save(dE_model, '../Trained Models/dE_model.pt')

# Inference

In [None]:
dE_detas_model = torch.load('../Trained Models/dE_model.pt')
dE_detas_model.eval()

In [None]:
bound_1 = -4.9
bound_2 = -4.0

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    shuffle=True,
)

In [None]:
finals = []
models = []

for i, data in enumerate(data_loader):
    dE_detas_initial = data[0].flatten()
    dNch_detas_final = data[1].flatten()

    output = dE_detas_model(dE_detas_initial.float())

    _, dNch_detas_final_trim = trim(dataset.final_eta, dNch_detas_final.numpy(), bound_1, bound_2)
    x_axis, output_trim = trim(dataset.final_eta, output.detach().numpy(), bound_1, bound_2)

    integrated_final = np.trapz(dNch_detas_final_trim, x_axis)
    integrated_output = np.trapz(output_trim, x_axis)

    finals.append( integrated_final )
    models.append( integrated_output )

plot_cc_graph(finals, models)