In [1]:
import os
import shutil
import sys
import time
import warnings
import json
import shutil
from random import sample

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR, StepLR
from torch.utils.data.sampler import SubsetRandomSampler, Sampler, SequentialSampler

from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score, f1_score
from sklearn.metrics import mean_absolute_error, mean_squared_error, matthews_corrcoef
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, train_test_split

import pytorch_lightning as L
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer
from torchmetrics.functional import mean_squared_error, mean_absolute_error
from pymatgen.core.composition import Composition
from pymatgen.core.structure import Structure
from torch.utils.data import DataLoader,random_split
from torch.nn import L1Loss, MSELoss, HuberLoss

data_type_np = np.float32
data_type_torch = torch.float32

# import wandb

import dgl
from jarvis.db.jsonutils import loadjson
import pickle as pk
import pandas as pd
import lmdb
from dgl.dataloading import GraphDataLoader
from jarvis.core.atoms import Atoms
from jarvis.core.graphs import Graph

## Data Formatting

The data should be formated similar to CGCNN: individual structures are represented by CIF files in the forlder data_dir="alignn/examples/sample_data/", also there should be id_prop.csv with two columns: first is the name of the file (not just id like in CGCNN), the second column is property (klength in our case). Also the data_dir should contain config_example.json compatible with ALIGNNConfig (if you download alignn repository from github, it will contain alignn_atomwise config file which is different, so take care about it)

In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
from torch.utils.data import Dataset
from typing import List, Tuple
from alignn.lmdb_dataset import prepare_line_graph_batch

class TorchLMDBDataset(Dataset):
    """Dataset of crystal DGLGraphs using LMDB."""

    def __init__(self, lmdb_path="", line_graph=True):
        """Intitialize with path and ids array."""
        super(TorchLMDBDataset, self).__init__()
        self.lmdb_path = lmdb_path
        self.line_graph = line_graph
        self.env = lmdb.open(self.lmdb_path, readonly=True, lock=False)
        with self.env.begin() as txn:
            self.length = txn.stat()["entries"]
        self.prepare_batch = prepare_line_graph_batch

    def __len__(self):
        """Get length."""
        return self.length

    def __getitem__(self, idx):
        """Get sample."""
        with self.env.begin() as txn:
            serialized_data = txn.get(f"{idx}".encode())
        if self.line_graph:
            graph, line_graph, label, ids = pk.loads(serialized_data)
            return graph, line_graph, label, ids

    def close(self):
        """Close connection."""
        self.env.close()

    def __del__(self):
        """Delete connection."""
        self.close()

    @staticmethod
    def collate_line_graph(
        samples: List[Tuple[dgl.DGLGraph, dgl.DGLGraph, torch.Tensor, torch.Tensor]]
    ):
        """Dataloader helper to batch graphs cross `samples`."""
        graphs, line_graphs, labels, ids = map(list, zip(*samples))
        batched_graph = dgl.batch(graphs)
        batched_line_graph = dgl.batch(line_graphs)
        if len(labels[0].size()) > 0:
            return batched_graph, batched_line_graph, torch.stack(labels), torch.stack(ids)
        else:
            return batched_graph, batched_line_graph, torch.tensor(labels), torch.tensor(ids)

In [4]:
data_dir="alignn/examples/sample_data/"
config_name=os.path.join(data_dir,"config_example.json")
id_prop_csv=os.path.join(data_dir,"id_prop.csv")
output_dir="alignn/output/"

In [5]:
def create_lmdb_database(data, file_path, data_dir):
    env = lmdb.open(os.path.join(data_dir,file_path), map_size=int(1e12))
    with env.begin(write=True) as txn:
        for idx in range(len(data)):
            structure = Structure.from_file(os.path.join(data_dir, data.iloc[idx][0]))
            structure_dict={
                    'lattice_mat': structure.lattice.matrix,
                    'coords': [site.frac_coords for site in structure.sites],
                    'elements': [str(site.specie) for site in structure.sites],
                    'abc': structure.lattice.abc,
                    'angles': structure.lattice.angles,
                    'cartesian': False,
                    'props': ['' for site in structure.sites]
                   }
            atoms=Atoms.from_dict(structure_dict)
            g, lg = Graph.atom_dgl_multigraph(
                        atoms,
                        cutoff=float(15),
                        max_neighbors=12,
                        atom_features="cgcnn",
                        compute_line_graph=True,
                        use_canonize=False,
                        neighbor_strategy="k-nearest",
                    )
            label = torch.tensor(data.iloc[idx][1]).type(torch.get_default_dtype())
            ids=torch.tensor(int(data.iloc[idx][0][:-4]))
            serialized_data = pk.dumps((g, lg, label, ids))
            txn.put(f"{idx}".encode(), serialized_data)
    env.close()
    return 

In [6]:
class ALIGNNDataModule(L.LightningDataModule):
    def __init__(self, root_dir: str,
                 id_prop_csv: str,
                 train_ratio = 0.8,
                 val_ratio = 0.1, 
                 test_ratio = 0.1,
                 lmdb_exist = False,
                 batch_size = 64,
                 num_workers = 0,
                 pin_memory = True,
                 random_seed = 123):
        super().__init__()
        
        self.random_seed=random_seed
        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.num_workers=num_workers
        
        data=pd.read_csv(id_prop_csv,header=None)
        train_idx, test_idx = train_test_split(data.index.values, test_size=test_ratio, random_state=123,)
        train_idx, val_idx = train_test_split(train_idx, train_size=train_ratio/(1-test_ratio), random_state=123)
        
        train=data.iloc[train_idx]
        train.reset_index(inplace=True, drop=True)
        train.to_csv(os.path.join(root_dir,'train.csv'))
        
        val=data.iloc[val_idx]
        val.reset_index(inplace=True, drop=True)
        val.to_csv(os.path.join(root_dir,'val.csv'))
        
        test=data.iloc[test_idx]
        test.reset_index(inplace=True, drop=True)
        test.to_csv(os.path.join(root_dir,'test.csv'))

        list_of_paths=['train_data.lmdb', 'test_data.lmdb', 'val_data.lmdb']
        
        if(lmdb_exist == False):
            if os.path.exists(os.path.join(root_dir,'train_data.lmdb')):
                shutil.rmtree(os.path.join(root_dir,'train_data.lmdb'))
            if os.path.exists(os.path.join(root_dir,'val_data.lmdb')):
                shutil.rmtree(os.path.join(root_dir,'val_data.lmdb'))
            if os.path.exists(os.path.join(root_dir,'test_data.lmdb')):
                shutil.rmtree(os.path.join(root_dir,'test_data.lmdb'))
            create_lmdb_database(train,'train_data.lmdb', root_dir)
            create_lmdb_database(val,'val_data.lmdb', root_dir)
            create_lmdb_database(test,'test_data.lmdb', root_dir)  
        elif not all(os.path.exists(os.path.join(root_dir,var)) for var in list_of_paths):
            print("Put lmdb_exist to False or provide train/val/test lmdb files.")
            
        self.train_dataset=TorchLMDBDataset(lmdb_path=os.path.join(root_dir,'train_data.lmdb'), line_graph=True)
        self.val_dataset=TorchLMDBDataset(lmdb_path=os.path.join(root_dir,'val_data.lmdb'), line_graph=True)
        self.test_dataset=TorchLMDBDataset(lmdb_path=os.path.join(root_dir,'test_data.lmdb'), line_graph=True)

        self.train_collate=self.train_dataset.collate_line_graph
        self.val_collate=self.val_dataset.collate_line_graph
        self.test_collate=self.test_dataset.collate_line_graph
  
    def train_dataloader(self,shuffle=True):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers,collate_fn=self.train_collate, 
                          pin_memory=self.pin_memory, shuffle=shuffle)
    def val_dataloader(self,shuffle=False):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=self.val_collate, 
                          pin_memory=self.pin_memory, shuffle=shuffle)
    def test_dataloader(self,shuffle=False):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=self.test_collate, 
                          pin_memory=self.pin_memory, shuffle=shuffle)
    def predict_dataloader(self,shuffle=False):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=self.test_collate, 
                          pin_memory=self.pin_memory, shuffle=shuffle)

In [7]:
config=loadjson(config_name)

In [8]:
from alignn.models.alignn import ALIGNN, ALIGNNConfig

class ALIGNNLightning(L.LightningModule):
    def __init__(self, **config):
        super().__init__()
        # Saving hyperparameters
        L.seed_everything(config['random_seed'])
        self.save_hyperparameters()
        
        tmp=ALIGNNConfig(**config['model'])
        self.model=ALIGNN(tmp)
        print(f'Model size: {count_parameters(self.model)} parameters\n')

        ### here we define some important parameters
        self.batch_size = config['batch_size']
        self.learning_rate = config['learning_rate']
        self.decay = config['weight_decay']
        
        ### we also define loss function based on task
        self.criterion = HuberLoss()

    def forward(self, graph_list):
        out=self.model(graph_list)
        return out

    def configure_optimizers(self):
        optimizer = optim.AdamW(model.parameters(), self.learning_rate,
                              weight_decay=self.decay)
        # lr_scheduler=StepLR(optimizer,
        #                     step_size=1,
        #                     gamma=0.5)
        
        # return [optimizer], [lr_scheduler]
        return [optimizer]

    def training_step(self, batch, batch_idx):
        g, lg, labels, ids = batch
        output = self([g,lg])
       
        loss = self.criterion(output, labels)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        mse = mean_squared_error(labels.cpu(), output.cpu())
        mae = mean_absolute_error(labels.cpu(), output.cpu())
        self.log("train_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("train_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        g, lg, labels, ids = batch
        output = self([g,lg])
        
        loss = self.criterion(output, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        
        mse = mean_squared_error(labels.cpu(), output.cpu())
        mae = mean_absolute_error(labels.cpu(), output.cpu())
        self.log("val_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("val_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        return loss
     
    def test_step(self, batch, batch_idx):
        g, lg, labels, ids = batch
        output = self([g,lg])
        
        loss = self.criterion(output, labels)
        
        mse = mean_squared_error(labels.cpu(), output.cpu())
        mae = mean_absolute_error(labels.cpu(), output.cpu())
        self.log("test_mse", float(mse), on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("test_mae", float(mae), on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        return 
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        g, lg, labels, ids = batch
        output = self([g,lg])
        
        mse = mean_squared_error(labels.cpu(), output.cpu())
        mae = mean_absolute_error(labels.cpu(), output.cpu())
        return output.cpu(), labels.cpu(), ids.cpu()

In [9]:
data = ALIGNNDataModule(root_dir = data_dir, 
                        id_prop_csv = id_prop_csv, 
                        lmdb_exist = True)

In [10]:
model = ALIGNNLightning(**config)

Seed set to 123


Model size: 4026753 parameters



In [11]:
trainer = Trainer(max_epochs=1000,accelerator='gpu', devices=1, 
                  callbacks=[EarlyStopping(monitor='val_loss', patience=50), 
                             ModelCheckpoint(monitor='val_mae', mode="min", 
                                dirpath='alignn_models/alignn_trained_models/', filename='alignn-{epoch:02d}-{val_acc:.2f}')])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/xyq44482/micromamba/envs/alignn/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
trainer.fit(model, datamodule=data)

You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/xyq44482/micromamba/envs/alignn/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/xyq44482/Documents/Uncertainty-quntification/alignn_models/alignn_trained_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params | Mode 
------------------------------------------------
0 | model     | ALIGNN    | 4.0 M  | train
1 | criterion | HuberLoss | 0      | train
------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.107    Total estimated model pa

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/xyq44482/micromamba/envs/alignn/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.
  return torch.load(io.BytesIO(b))
/home/xyq44482/micromamba/envs/alignn/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

In [29]:
trainer = Trainer(max_epochs=1000,accelerator='cpu', devices=1, 
                  callbacks=[EarlyStopping(monitor='val_loss', patience=50), 
                             ModelCheckpoint(monitor='val_mae', mode="min", 
                                dirpath='alignn_models/alignn_trained_models/', filename='alignn-{epoch:02d}-{val_acc:.2f}')])

pred = trainer.predict(model, ckpt_path="alignn_models/alignn_trained_models/alignn.ckpt", datamodule=data)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/xyq44482/micromamba/envs/llm/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Restoring states from the checkpoint path at alignn_models/alignn_trained_models/alignn.ckpt
Loaded model weights from the checkpoint at alignn_models/alignn_trained_models/alignn.ckpt
/home/xyq44482/micromamba/envs/llm/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


Predicting: |                                             | 0/? [00:00<?, ?it/s]

In [None]:
pred_train = trainer.predict(model, data.train_dataloader(),ckpt_path="cgcnn_models/cgcnn_trained_models/k_number.ckpt") 

In [None]:
pred_val = trainer.predict(model, data.val_dataloader(),ckpt_path="cgcnn_models/cgcnn_trained_models/k_number.ckpt") 

In [41]:
num_batches=0
for batch in data.train_dataloader():
    print(batch)

[Graph(num_nodes=61, num_edges=1508,
      ndata_schemes={'atom_features': Scheme(shape=(92,), dtype=torch.float32)}
      edata_schemes={'r': Scheme(shape=(3,), dtype=torch.float32)}), Graph(num_nodes=1508, num_edges=37506,
      ndata_schemes={'r': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'h': Scheme(shape=(), dtype=torch.float32)}), tensor([30.0000, 45.0000, 57.5000, 71.2500, 60.0000, 45.0000, 32.5000, 25.0000,
        65.0000, 20.0000]), tensor([54784, 41010, 57391, 59819, 60833, 38078, 60814, 44557, 20725, 15453])]


In [40]:
num_batches

1