In [1]:
from Neural_Nets.ThermoNet.Development.ThermoNet import ThermoRegressionNet, ThermoDataset
from Neural_Nets.ThermoNetActFuncs.Development.ThermoNetActFuncs import Sigmoid
from Utils.PlotHandler.Development.PlotHandler import PlotHandler 
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torch.optim import Rprop, Adam
from Data_Handling.SGTEHandler.Development.SGTEHandler import SGTEHandler
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def epoch(net: ThermoRegressionNet, dataloader, loss_func, optimizer):
    epoch_losses = np.zeros([len(dataloader), ])

    for i, (temp, g) in enumerate(dataloader):
        temp = temp.unsqueeze(-1)

        # Forward pass
        gibbs_energy = net(temp)
        
        # Get the loss
        loss = loss_func(gibbs_energy.float(), g.float())

        # Backward pass
        net.zero_grad()
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(net.parameters(), 100)
        optimizer.step()
        epoch_losses[i] = loss

    mean_epoch_loss = epoch_losses.mean()
    print('Mean epoch loss: ', mean_epoch_loss)
    return mean_epoch_loss

In [3]:
def train(net, dataset):
    # Hyperparameters
    n_epochs = 100
    lr = 0.01
    batch_size = 128
    std_thresh = 0.05

    # Data
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer
    optimizer = Rprop(net.parameters(), lr=lr)
    loss_func = nn.MSELoss()

    losses = []
    
    # Keep track of epoch where learning rate was reduced last
    lr_reduced_last = 0

    for i in range(n_epochs):
        print('-----\nEpoch %i:\n' % i)
        loss = epoch(net, dataloader, loss_func, optimizer)
        losses.append(loss)
        
        # Adapt learning rate if standard deviation over the last 10 epochs is below a threshold
        if np.array(losses[-10:]).std() < std_thresh and (i - lr_reduced_last) >= 10:
            print('Learning rate halfed! \n')
            lr_reduced_last = i
            lr /= 2

In [4]:
net = ThermoRegressionNet()

element = 'Fe'
phase = ['BCC_A2']
dataset = ThermoDataset(element, phase)

train(net, dataset)

tensor(371.5209, grad_fn=<StdBackward0>)
tensor(402.5326, grad_fn=<StdBackward0>)
tensor([[862.7891]], grad_fn=<AddBackward0>)
tensor([[862.7891]], grad_fn=<AddBackward0>)
Fe successfully selected!

-----
Epoch 0:

tensor(0., grad_fn=<StdBackward0>)
tensor(862.7891, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.7592, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.7231, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.6798, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.6280, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.5659, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.4912, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.4017, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.2941, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(862.1652, grad_fn=<MeanBackward0>)
tensor(0., grad

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


tensor(481.9741, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(405.7809, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(314.3492, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(204.6311, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(72.9695, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-77.0305, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-227.0305, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-377.0305, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-527.0305, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-677.0305, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-827.0304, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-977.0304, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-1127.0305, grad_fn=<MeanBackward0>)
Mean epoch loss:  449897371

tensor(0., grad_fn=<StdBackward0>)
tensor(-19577.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-19727.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-19877.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-20027.0254, grad_fn=<MeanBackward0>)
Mean epoch loss:  2714667830.857143
-----
Epoch 13:

tensor(0., grad_fn=<StdBackward0>)
tensor(-20177.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-20327.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-20477.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-20627.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-20777.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-20927.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-21077.0293, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-21227.0293, grad_fn=<Mea

tensor(-38777.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-38927.0273, grad_fn=<MeanBackward0>)
Mean epoch loss:  1609881892.5714285
-----
Epoch 22:

tensor(0., grad_fn=<StdBackward0>)
tensor(-39077.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-39227.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-39377.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-39527.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-39677.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-39827.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-39977.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-40127.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-40277.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-40427.0312, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<S

tensor(0., grad_fn=<StdBackward0>)
tensor(-54600.4141, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-54670.0703, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-54753.6602, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-54853.9727, grad_fn=<MeanBackward0>)
Mean epoch loss:  1248921197.7142856
-----
Epoch 31:

tensor(0., grad_fn=<StdBackward0>)
tensor(-54974.3398, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55118.7812, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55118.7812, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55191., grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55277.6641, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55381.6602, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55506.4648, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-55656.2188, grad_fn=<MeanBa

tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
Mean epoch loss:  1226120548.5714285
-----
Epoch 40:

tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<S

tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
Mean epoch loss:  1238410345.142857
-----
Epoch 49:

tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<StdBackward0>)
tensor(-56584.6172, grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<St

KeyboardInterrupt: 

In [None]:
ph = PlotHandler()

ph.properties_temp(net, dataset)