In [1]:
from ThermoNet.Development.ThermoNetTorch import ThermoNet, ThermoLossFunc, ThermoDataset
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.optim import Rprop
from SGTEHandler.Development.SGTEHandler import SGTEHandler
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def epoch(net: ThermoNet, dataloader, loss_func, optimizer):
    epoch_losses = np.zeros([len(dataloader), ])

    for i, (temp, g, s, h, c) in enumerate(dataloader):
        temp = temp.unsqueeze(-1)

        # Forward pass
        gibbs_energy, entropy, enthalpy, heat_cap = net(temp, temp, temp, temp)

        gibbs_energy, entropy, enthalpy, heat_cap = gibbs_energy, entropy, enthalpy, heat_cap
        g, s, h, c = g, s, h, c

        # Get the loss
        loss = loss_func(gibbs_energy.float(), g.float(), entropy.float(), s.float(), enthalpy.float(), h.float(),
                         heat_cap.float(), c.float())

        # Backward pass
        net.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 100)
        optimizer.step()
        epoch_losses[i] = loss

    mean_epoch_loss = epoch_losses.mean()
    print('Mean epoch loss: ', mean_epoch_loss)
    return mean_epoch_loss

In [3]:
def train(net, element, phase):
    # Hyperparameters
    n_epochs = 100
    lr = 0.01
    batch_size = 16

    # Data
    dataset = ThermoDataset(element, phase)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer
    optimizer = Rprop(net.parameters(), lr=lr)
    loss_func = ThermoLossFunc()

    losses = []

    for i in range(n_epochs):
        print('-----\nEpoch %i:\n' % i)
        loss = epoch(net, dataloader, loss_func, optimizer)
        losses.append(loss)

In [None]:
def graphic_evaluation(net, element, phase):
    start_temp = 500
    end_temp = 1500
    temp_range = torch.tensor(list(range(start_temp, end_temp)), dtype=torch.float64).unsqueeze(-1)

    # Get true data
    sgte_handler = SGTEHandler(element)
    sgte_handler.evaluate_equations(start_temp, end_temp, 1e5, plot=False, phases=phase, entropy=True, enthalpy=True,
                                    heat_capacity=True)
    data = sgte_handler.equation_result_data

    # Get values
    temp = torch.tensor(data['Temperature'], dtype=torch.float64)
    gibbs = torch.tensor(data.iloc[:, 1])
    entropy = torch.tensor(data.iloc[:, 2])
    enthalpy = torch.tensor(data.iloc[:, 3])
    heat_cap = torch.tensor(data.iloc[:, 4])

    gibbs_pred, entropy_pred, enthalpy_pred, heat_cap_pred = net(temp_range, temp_range, temp_range, temp_range)

    gibbs_pred = gibbs_pred.detach()
    entropy_pred = entropy_pred.detach()
    enthalpy_pred = enthalpy_pred.detach()
    heat_cap_pred = heat_cap_pred.detach()

    plt.figure()
    plt.plot(temp, gibbs)
    plt.grid()
    plt.scatter(temp_range, gibbs_pred, s=0.3, c='red')

    plt.figure()
    plt.plot(temp, entropy)
    plt.grid()
    plt.scatter(temp_range, entropy_pred, s=0.3, c='red')

    plt.figure()
    plt.plot(temp, enthalpy)
    plt.grid()
    plt.scatter(temp_range, enthalpy_pred, s=0.3, c='red')

    plt.figure()
    plt.plot(temp, heat_cap)
    plt.grid()
    plt.scatter(temp_range, heat_cap_pred, s=0.3, c='red')
    plt.show()

In [None]:
net = ThermoNet()

element = 'Fe'
phase = ['BCC_A2']
train(net, element, phase)

Fe successfully selected!

-----
Epoch 0:

Mean epoch loss:  37910699621.98131
-----
Epoch 1:

Mean epoch loss:  320673866.07476634
-----
Epoch 2:

Mean epoch loss:  74722809.27102804
-----
Epoch 3:

Mean epoch loss:  65754308.336448595
-----
Epoch 4:

