In [21]:
from Neural_Nets.ThermoNet.Development.ThermoNetTorch import ThermoNet, ThermoLossFunc, ThermoDataset
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.optim import Rprop
from Data_Handling.SGTEHandler.Development.SGTEHandler import SGTEHandler
import numpy as np
import matplotlib.pyplot as plt

In [17]:
def epoch(net: ThermoNet, dataloader, loss_func, optimizer):
    epoch_losses = np.zeros([len(dataloader), ])

    for i, (temp, g, s, h, c) in enumerate(dataloader):
        temp = temp.unsqueeze(-1)

        # Forward pass
        gibbs_energy, entropy, enthalpy, heat_cap = net(temp, temp, temp, temp)

        scale = 10000
        gibbs_energy, entropy, enthalpy, heat_cap = gibbs_energy/scale, entropy/scale, enthalpy/scale, heat_cap/scale
        g, s, h, c = g/scale, s/scale, h/scale, c/scale

        # Get the loss
        loss = loss_func(gibbs_energy.float(), g.float(), entropy.float(), s.float(), enthalpy.float(), h.float(),
                         heat_cap.float(), c.float())

        # Backward pass
        net.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 100)
        optimizer.step()
        epoch_losses[i] = loss

    mean_epoch_loss = epoch_losses.mean()
    print('Mean epoch loss: ', mean_epoch_loss)
    return mean_epoch_loss

In [50]:
def train(net, element, phase):
    # Hyperparameters
    n_epochs = 500
    lr = 0.05
    batch_size = 16
    std_thresh = 0.05

    # Data
    dataset = ThermoDataset(element, phase)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer
    optimizer = Rprop(net.parameters(), lr=lr)
    loss_func = ThermoLossFunc()

    losses = []
    
    # Keep track of epoch where learning rate was reduced last
    lr_reduced_last = 0

    for i in range(n_epochs):
        print('-----\nEpoch %i:\n' % i)
        loss = epoch(net, dataloader, loss_func, optimizer)
        losses.append(loss)
        
        # Adapt learning rate if standard deviation over the last 10 epochs is below a threshold
        if np.array(losses[-10:]).std() < std_thresh and (i - lr_reduced_last) >= 10:
            print('Learning rate halfed! \n')
            lr_reduced_last = i
            lr /= 2

In [51]:
def graphic_evaluation(net, element, phase):
    start_temp = 200
    end_temp = 2000
    temp_range = torch.tensor(list(range(start_temp, end_temp)), dtype=torch.float64).unsqueeze(-1)

    # Get true data
    sgte_handler = SGTEHandler(element)
    sgte_handler.evaluate_equations(start_temp, end_temp, 1e5, plot=False, phases=phase, entropy=True, enthalpy=True,
                                    heat_capacity=True)
    data = sgte_handler.equation_result_data

    # Get values
    temp = torch.tensor(data['Temperature'], dtype=torch.float64)
    gibbs = torch.tensor(data.iloc[:, 1])
    entropy = torch.tensor(data.iloc[:, 2])
    enthalpy = torch.tensor(data.iloc[:, 3])
    heat_cap = torch.tensor(data.iloc[:, 4])
    
    if start_temp < temp.min():
        temp_range = torch.tensor(list(range(int(temp.min()), end_temp)), dtype=torch.float64).unsqueeze(-1)

    gibbs_pred, entropy_pred, enthalpy_pred, heat_cap_pred = net(temp_range, temp_range, temp_range, temp_range)

    gibbs_pred = gibbs_pred.detach()
    entropy_pred = entropy_pred.detach()
    enthalpy_pred = enthalpy_pred.detach()
    heat_cap_pred = heat_cap_pred.detach()
    
    plt.figure()
    plt.plot(temp, gibbs)
    plt.grid()
    plt.scatter(temp_range, gibbs_pred, s=0.3, c='red')

    plt.figure()
    plt.plot(temp, entropy)
    plt.grid()
    plt.scatter(temp_range, entropy_pred, s=0.3, c='red')

    plt.figure()
    plt.plot(temp, enthalpy)
    plt.grid()
    plt.scatter(temp_range, enthalpy_pred, s=0.3, c='red')

    plt.figure()
    plt.plot(temp, heat_cap)
    plt.grid()
    plt.scatter(temp_range, heat_cap_pred, s=0.3, c='red')
    plt.show()

In [52]:
net = ThermoNet()

element = 'Fe'
phase = ['BCC_A2']
train(net, element, phase)

Fe successfully selected!

-----
Epoch 0:

Mean epoch loss:  7.957581000907399
-----
Epoch 1:

Mean epoch loss:  4.749310217171072
-----
Epoch 2:

Mean epoch loss:  4.545609405107587
-----
Epoch 3:

Mean epoch loss:  4.517665435220594
-----
Epoch 4:

Mean epoch loss:  4.553784858400577
-----
Epoch 5:

Mean epoch loss:  4.463663602543768
-----
Epoch 6:

Mean epoch loss:  4.427689371822036
-----
Epoch 7:

Mean epoch loss:  4.383941699411268
-----
Epoch 8:

Mean epoch loss:  4.286846684518261
-----
Epoch 9:

Mean epoch loss:  4.281068288277242
-----
Epoch 10:

Mean epoch loss:  4.287426429374196
-----
Epoch 11:

Mean epoch loss:  4.321660741467342
-----
Epoch 12:

Mean epoch loss:  4.27499183093276
-----
Epoch 13:

Mean epoch loss:  4.241568217767734
-----
Epoch 14:

Mean epoch loss:  4.3110162833026635
-----
Epoch 15:

Mean epoch loss:  4.218645505816023
-----
Epoch 16:

Mean epoch loss:  4.246716746660036
Learning rate halfed! 

-----
Epoch 17:

Mean epoch loss:  4.278587428208824
-----

KeyboardInterrupt: 

In [None]:
graphic_evaluation(net, element, phase)

In [8]:
torch.save(net, 'ThermoNet/Models/model_12_01_22_1535')