In [1]:
import os
import shutil
import sys
import time
import warnings
from random import sample

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR, StepLR
from torch.utils.data.sampler import SubsetRandomSampler

from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score, f1_score
from sklearn.metrics import mean_absolute_error, mean_squared_error, matthews_corrcoef
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, train_test_split

import pytorch_lightning as L
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from pymatgen.core.composition import Composition

import torch
from torch.utils.data import DataLoader
from torch.nn import L1Loss, MSELoss, HuberLoss

data_type_np = np.float32
data_type_torch = torch.float32

from pytorch_forecasting.metrics import QuantileLoss
quantiles = [0.1, 0.5, 0.9]

import wandb

In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters()if p.requires_grad)

In [3]:
from cgcnn.data import CIFData, collate_pool

## The data should be prepared as discribed in CIFdata doc-string

class CGCNNDataModule(L.LightningDataModule):
    def __init__(self, root_dir: str,
                 train_ratio: 0.8,
                 val_ratio: 0.1, 
                 test_ratio: 0.1,
                 collate_fn=collate_pool,
                 classification = False,
                 batch_size = 256,
                 num_workers=0,
                 pin_memory = True):
        super().__init__()
        self.dataset = CIFData(root_dir=root_dir, max_num_nbr=12, radius=10, dmin=0, step=0.2, random_seed=123)
        total_size = len(self.dataset)
        indices = list(range(total_size))
        
        self.collate_fn=collate_pool
        
        train_size = int(train_ratio * total_size)
        val_size = int(val_ratio * total_size)
        test_size = total_size - train_size - val_size
        self.train_sampler = SubsetRandomSampler(indices[:train_size])
        self.val_sampler = SubsetRandomSampler(indices[train_size:train_size+val_size])
        self.test_sampler = SubsetRandomSampler(indices[train_size+val_size:])
        
        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.classification = classification
        self.num_workers=num_workers
  
    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size,
                          sampler=self.train_sampler, num_workers=self.num_workers,
                          collate_fn=self.collate_fn, pin_memory=self.pin_memory)
    def val_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size,
                          sampler=self.val_sampler, num_workers=self.num_workers,
                          collate_fn=self.collate_fn, pin_memory=self.pin_memory,shuffle=False)
    def test_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size,
                          sampler=self.test_sampler, num_workers=self.num_workers,
                          collate_fn=self.collate_fn, pin_memory=self.pin_memory,shuffle=False)
    def predict_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size,
                          sampler=self.test_sampler, num_workers=self.num_workers,
                          collate_fn=self.collate_fn, pin_memory=self.pin_memory,shuffle=False)

In [4]:
config={
    'root_dir': '/Users/elena.patyukova/Documents/github/Uncertainty-quntification/data/cgcnn_data',
    'train_ratio': 0.8,
    'val_ratio':0.1,
    'test_ratio':0.1,
    'atom_fea_len': 64,
    'n_conv': 3,
    'h_fea_len': 128,
    'n_h': 1,
    'classification': False,
    'robust_regression': False,
    'quantile_regression': True,
    'batch_size': 128,
    'base_lr': 0.01,
    'momentum': 0.9,
    'weight_decay': 0,
    'optim': 'AdamW',
    'pin_memory': True,
    'patience': 100,
    'dropout_fr': 0.1,
}

In [5]:
from cgcnn.model import CrystalGraphConvNet

dataset = CIFData(root_dir='/Users/elena.patyukova/Documents/github/Uncertainty-quntification/data/cgcnn_data', max_num_nbr=12, radius=10, dmin=0, step=0.2, random_seed=123)
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model=CrystalGraphConvNet(orig_atom_fea_len=orig_atom_fea_len,
                                       nbr_fea_len=nbr_fea_len,
                                       atom_fea_len=config['atom_fea_len'], 
                                       n_conv=config['n_conv'], 
                                       h_fea_len=config['h_fea_len'], 
                                       n_h=config['n_h'],
                                       robust_regression=config['robust_regression'],
                                       classification=config['classification'],
                                       quantile_regression=config['quantile_regression'])

In [11]:
data=CGCNNDataModule(root_dir='/Users/elena.patyukova/Documents/github/Uncertainty-quntification/data/cgcnn_data', train_ratio=0.8, val_ratio=0.1, test_ratio=0.1)

In [12]:
train_loader=data.train_dataloader()

In [13]:
for batch in train_loader:
    graph, target, idx = batch
    break



In [14]:
input_var=(graph[0],graph[1],graph[2],graph[3])

In [15]:
output=model(*input_var)

In [16]:
output.shape

torch.Size([256, 3])

In [17]:
criterion=QuantileLoss(quantiles=quantiles)

In [18]:
criterion(output,target)

tensor(51.2804, grad_fn=<CloneBackward0>)

In [19]:
a=output[:,1]
a.shape

torch.Size([256])

In [21]:
target[:,0].shape

torch.Size([256])

In [22]:
from cgcnn.model import CrystalGraphConvNet

class CGCNNLightning(L.LightningModule):
    def __init__(self, **config):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        
        self.dataset = CIFData(root_dir=config['root_dir'], max_num_nbr=12, radius=10, dmin=0, step=0.2, random_seed=123)
        structures, _, _ = self.dataset[0]
        orig_atom_fea_len = structures[0].shape[-1]
        nbr_fea_len = structures[1].shape[-1]
        
        self.model=CrystalGraphConvNet(orig_atom_fea_len=orig_atom_fea_len,
                                       nbr_fea_len=nbr_fea_len,
                                       atom_fea_len=config['atom_fea_len'], 
                                       n_conv=config['n_conv'], 
                                       h_fea_len=config['h_fea_len'], 
                                       n_h=config['n_h'],
                                       robust_regression=config['robust_regression'],
                                       classification=config['classification'],
                                       quantile_regression=config['quantile_regression'])

        print(f'Model size: {count_parameters(self.model)} parameters\n')

        ### here we define some important parameters
        self.batch_size=config['batch_size']
        self.classification = config['classification']
        self.robust_regression = config['robust_regression']
        self.quantile_regression=config['quantile_regression']
        self.base_lr=config['base_lr']
        self.momentum=config['momentum']
        self.decay=config['weight_decay']
        
        ### we also define loss function based on task
        if self.classification:
            print("Using BCE loss for classification task")
            self.criterion = BCEWithLogitsLoss
        elif self.robust_regression:
            print('Using RobustL2Loss for regression task')
            self.criterion = RobustL2Loss
        elif self.quantile_regression:
            self.criterion=QuantileLoss(quantiles=quantiles)
        else:
            # print("Using L1Loss loss for regression task")
            # self.criterion = L1Loss()
            # print('Using MSELoss loss for regression task')
            # self.criterion = MSELoss()
            print('Using HuberLoss for regression task')
            self.criterion = HuberLoss()

    def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        out=self.model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)
        return out

    def configure_optimizers(self):
#         optimizer = optim.SGD(model.parameters(), self.base_lr,
#                               momentum=self.momentum,
#                               weight_decay=self.decay)
        optimizer = optim.AdamW(model.parameters(), self.base_lr,
                              weight_decay=self.decay)
        # lr_scheduler=StepLR(optimizer,
        #                     step_size=1,
        #                     gamma=0.5)
        
        # return [optimizer], [lr_scheduler]
        return [optimizer]

    def training_step(self, batch, batch_idx):
        graph, target, _ = batch
        
        input_var=(graph[0],graph[1],graph[2],graph[3])
        output = self(*input_var)
        
        if self.robust_regression:
            prediction, uncertainty = output.chunk(2, dim=-1)
            loss = self.criterion(prediction, uncertainty, target)
        elif self.quantile_regression:
            prediction=output[:,1]
            loss = self.criterion(output, target)
        else:
            loss = self.criterion(output,target)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        if self.classification:
            pass
#             prediction = torch.sigmoid(prediction)
#             y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            
#             acc=balanced_accuracy_score(target_normed,y_pred)
#             f1=f1_score(target_normed,y_pred,average='weighted')
#             mc=matthews_corrcoef(target_normed,y_pred)
            
#             self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
#             self.log("train_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
#             self.log("train_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        elif self.robust_regression:
            mse = mean_squared_error(target.cpu(),prediction.data.cpu())
            mae = mean_absolute_error(target.cpu(),prediction.data.cpu())
            self.log("train_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("train_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        elif self.quantile_regression:
            mse = mean_squared_error(target[:,0].cpu(),prediction.data.cpu())
            mae = mean_absolute_error(target[:,0].cpu(),prediction.data.cpu())
            self.log("train_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("train_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(target.cpu(),self.normalizer.denorm(output.data.cpu()))
            mae = mean_absolute_error(target.cpu(),self.normalizer.denorm(output.data.cpu()))
            self.log("train_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("train_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        graph, target, _ = batch
        
        input_var=(graph[0],graph[1],graph[2],graph[3])
        output = self(*input_var)
        
        if self.robust_regression:
            prediction, uncertainty = output.chunk(2, dim=-1)
            loss = self.criterion(prediction, uncertainty, target)
        elif self.quantile_regression:
            prediction=output[:,1]
            loss = self.criterion(output, target)
        else:
            loss = self.criterion(output,target)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        if self.classification:
            pass
#             prediction = torch.sigmoid(prediction)
#             y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            
#             acc=balanced_accuracy_score(target_normed,y_pred)
#             f1=f1_score(target_normed,y_pred,average='weighted')
#             mc=matthews_corrcoef(target_normed,y_pred)
            
#             self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
#             self.log("train_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
#             self.log("train_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        elif self.robust_regression:
                mse = mean_squared_error(target.cpu(),prediction.data.cpu())
                mae = mean_absolute_error(target.cpu(),prediction.data.cpu())
                self.log("val_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
                self.log("val_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size) 
        elif self.quantile_regression:
                mse = mean_squared_error(target[:,0].cpu(),prediction.data.cpu())
                mae = mean_absolute_error(target[:,0].cpu(),prediction.data.cpu())
                self.log("val_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
                self.log("val_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)    
        else:
                mse = mean_squared_error(target.cpu(),self.normalizer.denorm(output.data.cpu()))
                mae = mean_absolute_error(target.cpu(),self.normalizer.denorm(output.data.cpu()))
                self.log("val_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
                self.log("val_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        return loss
     
    def test_step(self, batch, batch_idx):
        graph, target, _ = batch
        
        input_var=(graph[0],graph[1],graph[2],graph[3])
        output = self(*input_var)
        
        if self.robust_regression:
            prediction, uncertainty = output.chunk(2, dim=-1)
            loss = self.criterion(prediction, uncertainty, target)
        elif self.quantile_regression:
            prediction=output[:,1]
            loss = self.criterion(output, target)
        else:
            loss = self.criterion(output,target)
        
        
        if self.classification:
            pass
        elif self.robust_regression:
            mse = mean_squared_error(target.cpu(),prediction.data.cpu())
            mae = mean_absolute_error(target.cpu(),prediction.data.cpu())
            self.log("test_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        elif self.quantile_regression:
            mse = mean_squared_error(target[:,0].cpu(),prediction.data.cpu())
            mae = mean_absolute_error(target[:,0].cpu(),prediction.data.cpu())
            self.log("test_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(target,self.normalizer.denorm(output.data.cpu()))
            mae = mean_absolute_error(target,self.normalizer.denorm(output.data.cpu()))
            self.log("test_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        return 
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        graph, target, idx = batch
        
        input_var=(graph[0],graph[1],graph[2],graph[3])
        output = self(*input_var)
        
        if self.classification:
            pass
            return
        
        elif self.robust_regression:
            prediction, uncertainty = output.chunk(2, dim=-1)
            return prediction.data.cpu(), uncertainty.data.cpu(), target, idx
        
        elif self.robust_regression:
            prediction=output[:,1]
            return prediction.data.cpu(), output.data.cpu(), target, idx
        
        else:
            return output.data.cpu(), target, idx


In [23]:
data=CGCNNDataModule(root_dir='/Users/elena.patyukova/Documents/github/Uncertainty-quntification/data/cgcnn_data', train_ratio=0.8, val_ratio=0.1, test_ratio=0.1)

In [24]:
model = CGCNNLightning(**config)

Model size: 84931 parameters



In [25]:
trainer = Trainer(max_epochs=5,accelerator='gpu', devices=1, 
                  callbacks=[EarlyStopping(monitor='val_loss', patience=config['patience']), 
                             ModelCheckpoint(monitor='val_mae', mode="min", 
                                dirpath='cgcnn_models/cgcnn_trained_models/', filename='eform-{epoch:02d}-{val_acc:.2f}')])

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [26]:
trainer.fit(model, datamodule=data)


  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | model     | CrystalGraphConvNet | 84.9 K | train
1 | criterion | QuantileLoss        | 0      | train
----------------------------------------------------------
84.9 K    Trainable params
0         Non-trainable params
84.9 K    Total params
0.340     Total estimated model params size (MB)
28        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/opt/miniconda3/envs/llm/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]







Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [28]:
pred = trainer.predict(model, ckpt_path='/Users/elena.patyukova/Documents/github/Uncertainty-quntification/cgcnn_models/cgcnn_trained_models/eform-epoch=03-val_acc=0.00.ckpt', datamodule=data)

Restoring states from the checkpoint path at /Users/elena.patyukova/Documents/github/Uncertainty-quntification/cgcnn_models/cgcnn_trained_models/eform-epoch=03-val_acc=0.00.ckpt
Loaded model weights from the checkpoint at /Users/elena.patyukova/Documents/github/Uncertainty-quntification/cgcnn_models/cgcnn_trained_models/eform-epoch=03-val_acc=0.00.ckpt
/opt/miniconda3/envs/llm/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |                                             | 0/? [00:00<?, ?it/s]



In [31]:
pred[0][0]

tensor([[23.6189, 37.4230, 78.4041],
        [23.4980, 36.8992, 77.4523],
        [23.6946, 37.2615, 78.0076],
        [23.9441, 37.9520, 79.3621],
        [23.2813, 36.6141, 76.8184],
        [23.2784, 36.5483, 76.6488],
        [23.7229, 37.2983, 77.8813],
        [24.3838, 38.6190, 80.6638],
        [24.1054, 38.2186, 79.7583],
        [23.4010, 36.6904, 76.8520],
        [22.9499, 36.3510, 76.3146],
        [23.1680, 36.3532, 76.5240],
        [24.0143, 38.0682, 79.6391],
        [25.3755, 40.2646, 83.7394],
        [23.2480, 36.2718, 76.1073],
        [22.8301, 35.8715, 75.4783],
        [24.2611, 38.4435, 80.1873],
        [23.9254, 37.7452, 78.9689],
        [23.0264, 36.3967, 76.6130],
        [22.9396, 36.0821, 75.9616],
        [23.2353, 36.4042, 76.5087],
        [23.0542, 36.1380, 76.0469],
        [23.7631, 37.5670, 78.6380],
        [23.1254, 36.2901, 76.3020],
        [22.5145, 35.5675, 74.8371],
        [23.3453, 36.5646, 76.9005],
        [22.7576, 36.0445, 75.8845],
 