In [1]:
import torch

# Check if CUDA is available
print(torch.cuda.is_available())

# Print CUDA version
print(torch.version.cuda)

False
12.1


  return torch._C._cuda_getDeviceCount() > 0


In [None]:
import os
import shutil
import psutil 
import argparse
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

from config import *
import dataProcess
import CGCNN

"""
Training, testing, and validation of CGCNN model
"""

# Get system memory information
mem = psutil.virtual_memory()
print(f"Total Memory: {mem.total / (1024 ** 3):.2f} GB", flush=True)
print(f"Available Memory: {mem.available / (1024 ** 3):.2f} GB", flush=True)

class TrainCGCNN():
    """
    Train crystal graph convolutional neural network (CGCNN)
    """

    def __init__(self):
        
        print("Initializing...", flush=True)
        
        # Initialize general parameters
        job_path = os.path.join(jobPath, jobName)         # log directory
        if not os.path.exists(jobPath):
            os.makedirs(jobPath)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")          # device
        print(self.device, flush=True)
        
        self.dtype = torch.set_default_dtype(torch.float64 if fp64 else torch.float32)        # dtype

        # Initialize dataset
        L_y_isotherm = 3                     # number of parameters to represent isotherm curve
        L_y_enthalpy = 3                     # number of parameters to represent enthalpy curve (average)
        L_y_enthalpy_LB = L_y_enthalpy                     # number of parameters to represent enthalpy curve (lower bound)
        L_y_enthalpy_UB = L_y_enthalpy                     # number of parameters to represent enthalpy curve (upper bound)
        
        self.L_y_isotherm = L_y_isotherm
        self.L_y_enthalpy = L_y_enthalpy
        self.L_y_enthalpy_LB = L_y_enthalpy_LB
        self.L_y_enthalpy_UB = L_y_enthalpy_UB
        
        if train:
            if run_dataProcess:
                _, structureInputs = dataProcess.structure_inputs(dataDir=dataPath)
            else:
                print("Loading structural data...", flush=True)
                structureInputs = torch.load(dataPath+"X_dataset.pth")
                print("Done loading.", flush=True)
                
            print("Building train, validation, and test data...", flush=True)
            nodeFeat, bondFeat, connectivityFeat = structureInputs["nodeFeat"], structureInputs["bondFeat"], structureInputs["connectivityFeat"]
            bondFeat = bondFeat[:, :nodeFeat.size(1), :nodeFeat.size(1)]
            connectivityFeat = connectivityFeat[:, :nodeFeat.size(1), :nodeFeat.size(1)]

            print("Getting x_data, y_data...", flush=True)
            x_data = torch.cat((nodeFeat, bondFeat, connectivityFeat), dim=2).double()
            y_data = dataProcess.structure_outputs(dataDir=dataPath)

            print("Splitting...", flush=True)
            x_train, x_val, y_train, y_val = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
            x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

            print("Train y_data...", flush=True)
            y_train_isotherm = y_train[:, 0:self.L_y_isotherm].double()
            y_train_enthalpy = y_train[:, self.L_y_isotherm:self.L_y_isotherm+self.L_y_enthalpy].double()    
            y_train_enthalpy_LB = y_train[:, self.L_y_isotherm+self.L_y_enthalpy:self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB].double()     
            y_train_enthalpy_UB = y_train[:, self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB:self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB+self.L_y_enthalpy_UB].double()  
            
            print("Val y_data...", flush=True)
            y_val_isotherm = y_val[:, 0:self.L_y_isotherm].double()  
            y_val_enthalpy = y_val[:, self.L_y_isotherm:self.L_y_isotherm+self.L_y_enthalpy].double()    
            y_val_enthalpy_LB = y_val[:, self.L_y_isotherm+self.L_y_enthalpy:self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB].double()     
            y_val_enthalpy_UB = y_val[:, self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB:self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB+self.L_y_enthalpy_UB].double()
            
            print("Test y_data...", flush=True)
            y_test_isotherm = y_test[:, 0:self.L_y_isotherm].double()  
            y_test_enthalpy = y_test[:, self.L_y_isotherm:self.L_y_isotherm+self.L_y_enthalpy].double()    
            y_test_enthalpy_LB = y_test[:, self.L_y_isotherm+self.L_y_enthalpy:self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB].double()     
            y_test_enthalpy_UB = y_test[:, self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB:self.L_y_isotherm+self.L_y_enthalpy+self.L_y_enthalpy_LB+self.L_y_enthalpy_UB].double()  
            
            print("TensorDataset...", flush=True)
            train_dataset = TensorDataset(x_train, y_train_isotherm, y_train_enthalpy, y_train_enthalpy_LB, y_train_enthalpy_UB)
            val_dataset = TensorDataset(x_val, y_val_isotherm, y_val_enthalpy, y_val_enthalpy_LB, y_val_enthalpy_UB)
            self.train_DataLoader = DataLoader(train_dataset, batch_size=train_batchSize, shuffle=True, pin_memory=True)
            self.val_DataLoader = DataLoader(val_dataset, batch_size=val_batchSize, shuffle=True, pin_memory=True)

        if test:
            test_dataset = TensorDataset(x_test, y_test_isotherm, y_test_enthalpy, y_test_enthalpy_LB, y_test_enthalpy_UB)
            self.test_DataLoader = DataLoader(test_dataset, batch_size=test_batchSize, shuffle=True, pin_memory=True)
            
        print("Done building train, validation, and test data.", flush=True)

        # Initialize model
        structureParams = {
            "dim_in": nodeFeat.size(2) + bondFeat.size(2),    # nodeFeat.size(1) + bondFeat.size(1) number of features you input (node + bond features + TEXTURAL FEATURES) -STILL NEED TO ADD TEXTURAL FEATURES!!!!!!!!!
            
            "n_convLayer": 3,
            "dim_out": [128, 64, 32],                  
            
            "n_hidLayer_pool": 2,
            "dim_hidFeat": [32, 16],
            
            "dim_fc_out": [L_y_isotherm, L_y_enthalpy, L_y_enthalpy_LB, L_y_enthalpy_UB]
        }

        self.model = CGCNN.CGCNNModel(structureParams)
        self.model.to(self.device)
        
        self.N = nodeFeat.size(1)            # max number of nodes across all crystal structures

        # Initialize optimizer and scheduler
        if optimizer in ["sgd", "SGD"]:
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
        elif optimizer in ["Adam", "adam"]:
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        elif optimizer in ['Adamax', 'adamax']:
            self.optimizer = torch.optim.Adamax(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=scheduler_gamma)

        # Load checkpoint
        self.logPath = os.path.join(jobPath, "train_log.txt")
        self.start_epoch = 0
        if not disable_checkpt:
            self.statePath = os.path.join(jobPath, "state_dicts")
            if os.path.exists(self.statePath):
                shutil.rmtree(self.statePath)
            if os.path.exists(self.statePath) and len(os.listdir(self.statePath)) > 0:
                for i in range(num_epoch, 0, -1):
                    fileName = os.path.join(self.statePath, f"epoch_{i}_sd.pt")
                    if os.path.isfile(fileName):
                        checkpoint = torch.load(fileName)
                        self.model.load_state_dict(checkpoint["model_state_dict"])
                        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
                        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
                        self.start_epoch = checkpoint["epoch"]
                        self.model.eval()
                        break
            elif not os.path.exists(self.statePath):
                os.mkdir(self.statePath)

    def calcLoss(self, y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB, y_target_isotherm, y_target_enthalpy, y_target_enthalpy_LB, y_target_enthalpy_UB):
        # convert all to double
        y_pred_isotherm = y_pred_isotherm.double()
        y_pred_enthalpy = y_pred_enthalpy.double()
        y_pred_enthalpy_LB = y_pred_enthalpy_LB.double()
        y_pred_enthalpy_UB = y_pred_enthalpy_UB.double()
        
        # mse
        mse_isotherm = nn.MSELoss()(y_pred_isotherm, y_target_isotherm)     
        mse_enthalpy = nn.MSELoss()(y_pred_enthalpy, y_target_enthalpy)    
        mse_enthalpy_LB = nn.MSELoss()(y_pred_enthalpy_LB, y_target_enthalpy_LB)     
        mse_enthalpy_UB = nn.MSELoss()(y_pred_enthalpy_UB, y_target_enthalpy_UB)  
        
        # mae
        mae_isotherm = nn.L1Loss()(y_pred_isotherm, y_target_isotherm)     
        mae_enthalpy = nn.L1Loss()(y_pred_enthalpy, y_target_enthalpy) 
        mae_enthalpy_LB = nn.L1Loss()(y_pred_enthalpy_LB, y_target_enthalpy_LB)   
        mae_enthalpy_UB = nn.L1Loss()(y_pred_enthalpy_UB, y_target_enthalpy_UB)
        
        # NOT SURE???
        mse = mse_isotherm + mse_enthalpy + mse_enthalpy_LB + mse_enthalpy_UB
        mae = mae_isotherm + mae_enthalpy + mae_enthalpy_LB + mae_enthalpy_UB
        
        
        # print(mse_isotherm, mse_enthalpy, mse_enthalpy_LB, mse_enthalpy_UB)
        # print(mae_isotherm, mae_enthalpy, mae_enthalpy_LB, mae_enthalpy_UB)
        
        return mse, mae

    def train(self):
        N = self.N  # max number of nodes across all crystal structures
        best_mae = 1e10
        best_mae_epoch = 0
        for epoch in range(self.start_epoch, num_epoch):
            # Train
            train_mse = []
            train_mae = []

            self.model.train()            
            for batch, (x_data_conn, y_data_isotherm, y_data_enthalpy, y_data_enthalpy_LB, y_data_enthalpy_UB) in tqdm(enumerate(self.train_DataLoader), total=len(self.train_DataLoader)):                
                # Separate x_data (node features) from connectivity matrix
                batch_size_train = x_data_conn.size(0)  # batch size (in training loop - number of crystal structures in the batch)
                
                x_node_train = x_data_conn[:, :, 0:-(2 * N)]
                x_bond_train = x_data_conn[:, :, -(2 * N):-N]
                x_connectivity_train = x_data_conn[:, :, -N:]

                x_data_train = [x_node_train, x_bond_train, x_connectivity_train]

                # all crystal structures are padded so max. num. of nodes are the same
                batchAssign_train = torch.tensor([b for b in range(batch_size_train) for n in range(N)])

                y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB = self.model(x_data_train, batchAssign_train)
                mse, mae = self.calcLoss(y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB, y_data_isotherm, y_data_enthalpy, y_data_enthalpy_LB, y_data_enthalpy_UB)
                
                self.optimizer.zero_grad()
                mse.backward()
                self.optimizer.step()
                train_mse.append(mse.item())
                train_mae.append(mae.item())
            self.scheduler.step()

            # Validation
            val_mse = []
            val_mae = []
            self.model.eval()
            for batch, (x_data_conn, y_data_isotherm, y_data_enthalpy, y_data_enthalpy_LB, y_data_enthalpy_UB) in enumerate(self.val_DataLoader):
                batch_size_val = x_data_conn.size(0)  # batch size (in training loop - number of crystal structures in the batch)

                x_node_val = x_data_conn[:, :, 0:-(2 * N)]
                x_bond_val = x_data_conn[:, :, -(2 * N):-N]
                x_connectivity_val = x_data_conn[:, :, -N:]

                x_data_val = [x_node_val, x_bond_val, x_connectivity_val]

                # all crystal structures are padded so max. num. of nodes are the same
                batchAssign_val = torch.tensor([b for b in range(batch_size_val) for n in range(N)])

                y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB = self.model(x_data_val, batchAssign_val)
                mse, mae = self.calcLoss(y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB, y_data_isotherm, y_data_enthalpy, y_data_enthalpy_LB, y_data_enthalpy_UB)
                val_mse.append(mse.item())
                val_mae.append(mae.item())

            if np.mean(val_mae) < best_mae:
                best_mae = np.mean(val_mae)
                best_mae_epoch = epoch
                print(f'epoch {epoch + 1}: weighted MAE {np.mean(val_mae):.4f} -> (best MAE, stored)', flush=True)
                print(f'epoch {epoch + 1}: weighted MAE {np.mean(val_mae):.4f} -> (best MAE, stored)', file=open(self.logPath, 'a'))                
            else:
                print(f'epoch {epoch + 1}: weighted MAE {np.mean(val_mae):.4f}', flush=True)

            # Save checkpoint every epoch
            if not disable_checkpt:
                statedict_filename = os.path.join(self.statedict_path, f"epoch_{epoch + 1}_sd.pt")
                torch.save({
                    "epoch": epoch + 1,
                    "model_state_dict": self.model.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "scheduler_state_dict": self.scheduler.state_dict()
                }, statedict_filename)
        print(f"Epoch {epoch + 1}/{num_epoch}, Average mse: {np.mean(val_mse):.4f}", flush=True)

        return best_mae_epoch + 1


    def test(self, best_epoch=0):
        N = self.N  # max number of nodes across all crystal structures
        if not best_epoch:
            best_epoch = num_epoch
        # fileName = os.path.join(self.statePath, f"epoch_{best_epoch}_sd.pt")
        # checkpoint = torch.load(fileName)
        # self.model.load_state_dict(checkpoint["model_state_dict"])
        # self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        # self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        # self.start_epoch = checkpoint["epoch"]

        # Testing / Validation
        val_mse = []
        val_mae = []
        self.model.eval()
        for batch, (x_data_conn, y_data_isotherm, y_data_enthalpy, y_data_enthalpy_LB, y_data_enthalpy_UB) in enumerate(self.test_DataLoader):
            batch_size_val = x_data_conn.size(0)  # batch size (in training loop - number of crystal structures in the batch)

            x_node_val = x_data_conn[:, :, 0:-(2 * N)]
            x_bond_val = x_data_conn[:, :, -(2 * N):-N]
            x_connectivity_val = x_data_conn[:, :, -N:]

            x_data_val = [x_node_val, x_bond_val, x_connectivity_val]

            # all crystal structures are padded so max. num. of nodes are the same - tells us which nodes correspond to which crystal in the batch
            batchAssign_val = torch.tensor([b for b in range(batch_size_val) for n in range(N)])        # needs to be repeated batch_size times (0,0,0, ..., batch_size-1, batch_size-1, batch_size-1)

            y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB = self.model(x_data_val, batchAssign_val)        
            mse, mae = self.calcLoss(y_pred_isotherm, y_pred_enthalpy, y_pred_enthalpy_LB, y_pred_enthalpy_UB, y_data_isotherm, y_data_enthalpy, y_data_enthalpy_LB, y_data_enthalpy_UB)
            val_mse.append(mse.item())
            val_mae.append(mae.item())

            print(f"Average mse {np.mean(val_mse):.4f}, Average MAE: {np.mean(val_mae):.4f}", flush=True)


if __name__ == "__main__":
    trainer = TrainCGCNN()
    best_epoch = trainer.train()
    trainer.test(best_epoch=best_epoch)

Total Memory: 31.67 GB
Available Memory: 26.82 GB
Initializing...
cpu
Loading structural data...
Done loading.
Building train, validation, and test data...
Getting x_data, y_data...
Splitting...
Train y_data...
Val y_data...
Test y_data...
TensorDataset...
Done building train, validation, and test data.


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:27<00:00,  1.95s/it]


epoch 1: weighted MAE 44.1807 -> (best MAE, stored)


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:35<00:00,  2.51s/it]


epoch 2: weighted MAE 42.8658 -> (best MAE, stored)


100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [00:34<00:00,  2.49s/it]


epoch 3: weighted MAE 41.6072 -> (best MAE, stored)


 57%|███████████████████████████████████████████████▍                                   | 8/14 [00:20<00:14,  2.39s/it]