In [1]:
import numpy as np
import pandas as pd
from pymatgen.core.composition import Composition

import torch
import torch.nn as nn
import os
import re
import json
import pytorch_lightning as L
import wandb

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR, StepLR
from torch.nn import CrossEntropyLoss, L1Loss, MSELoss, ReLU, NLLLoss
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, balanced_accuracy_score, accuracy_score, roc_auc_score, matthews_corrcoef
from sklearn.metrics import precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay

from roost.Data import data_from_composition_general
from roost.Model import Roost
from roost.utils import count_parameters, Scaler, DummyScaler, BCEWithLogitsLoss, Lamb, Lookahead, get_compute_device

from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

data_type_np = np.float32
data_type_torch = torch.float32
device=get_compute_device()

In [33]:
class RoostDataModule(L.LightningDataModule):
    def __init__(self, train_file: str , 
                 val_file: str, 
                 test_file: str, 
                 batch_size = 256,
                 features='onehot'):
        super().__init__()
        self.train_path = train_file
        self.val_path = val_file
        self.test_path = test_file
        self.batch_size = batch_size
        self.features=features

    def prepare_data(self):
        path='data/el-embeddings/'
        if(self.features == 'onehot'):
            with open(path+'onehot-embedding.json',"r") as f:
                elem_features=json.load(f)
        elif(self.features == 'matscholar'):
            with open(path+'matscholar-embedding.json',"r") as f:
                elem_features=json.load(f)
        elif(self.features == 'mat2vec'):
            with open(path+'mat2vec.json',"r") as f:
                elem_features=json.load(f)
        elif(self.features == 'cgcnn'):
            with open(path+'cgcnn-embedding.json',"r") as f:
                elem_features=json.load(f)
        
        ### loading and encoding trianing data
        if(re.search('.json', self.train_path )):
            self.data_train=pd.read_json(self.train_path)
        elif(re.search('.csv', self.train_path)):
            self.data_train=pd.read_csv(self.train_path)

        self.train_dataset = data_from_composition_general(self.data_train,elem_features)
        self.train_len = len(self.train_dataset)
        
        ### loading and encoding validation data
        if(re.search('.json', self.val_path )):
            self.data_val=pd.read_json(self.val_path)
        elif(re.search('.csv', self.val_path)):
            self.data_val=pd.read_csv(self.val_path)
        
        self.val_dataset = data_from_composition_general(self.data_val,elem_features)
        self.val_len = len(self.val_dataset)

        ### loading and encoding testing data
        if(re.search('.json', self.test_path )):
            self.data_test=pd.read_json(self.test_path)
        elif(re.search('.csv', self.test_path)):
            self.data_test=pd.read_csv(self.test_path)
        
        self.test_dataset = data_from_composition_general(self.data_test,elem_features)
        self.test_len = len(self.test_dataset)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_len, shuffle=False)
    
    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_len, shuffle=False)
    

class RoostLightningClass(L.LightningModule):
    def __init__(self, **config):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        self.batch_size=config['data_params']['batch_size']
        self.out_dims=config['model_params']['output_dim']
        self.n_graphs=config['model_params']['n_graphs']
        self.comp_heads=config['model_params']['comp_heads']
        self.internal_elem_dim=config['model_params']['internal_elem_dim']
        self.setup=config['setup_params']
        self.model = Roost(**config['model_params'])
        # maybe need to do it, to unify Roost and CrabNet
        print('\n Model architecture: out_dims, n_graphs, heads, internal_elem_dim')
        print(f'{self.out_dims}, {self.n_graphs}, '
                  f'{self.comp_heads}, {self.internal_elem_dim}')
        print(f'Model size: {count_parameters(self.model)} parameters\n')
        
        if(config['setup_params']['loss'] == 'BCEWithLogitsLoss'):
            self.criterion = BCEWithLogitsLoss

        if(re.search('.json', config['data_params']['train_path'] )):
            train_data=pd.read_json(config['data_params']['train_path'])
        elif(re.search('.csv', config['data_params']['train_path'])):
            train_data=pd.read_csv(config['data_params']['train_path'])
        y=train_data['disorder'].values
        self.step_size = len(y)
        if(np.sum(y)>0):
            self.weight=torch.tensor(((len(y)-np.sum(y))/np.sum(y)),dtype=data_type_torch).to(device)   

    def forward(self, batch):
        out = self.model(batch.x, batch.edge_index, batch.pos, batch.batch)
        return out

    def configure_optimizers(self):
        if(self.setup['optim'] == 'AdamW'):
        # We use AdamW optimizer with MultistepLR scheduler as in the original Roost model
            optimizer = torch.optim.AdamW(self.parameters(),lr=self.setup['learning_rate'], 
                                        weight_decay=self.setup['weight_decay']) 
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=self.setup['gamma'])

        elif(self.setup['optim'] == 'Lamb'):
            base_optim = Lamb(params=self.model.parameters(),lr=0.001)
            optimizer = Lookahead(base_optimizer=base_optim)
            scheduler = CyclicLR(optimizer,
                                base_lr=self.setup['base_lr'],
                                max_lr=self.setup['max_lr'],
                                cycle_momentum=False,
                                step_size_up=self.step_size)

        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        logits=self(batch)
        loss=self.criterion(logits, batch.y,self.weight)
        prediction = torch.sigmoid(logits)
        y_pred = prediction.detach().cpu().numpy() > 0.5
        acc=balanced_accuracy_score(batch.y.detach().cpu().numpy().astype(bool),y_pred)
        f1=f1_score(batch.y.detach().cpu().numpy().astype(bool),y_pred)
        mc=matthews_corrcoef(batch.y.detach().cpu().numpy().astype(bool),y_pred)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("train_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("train_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits=self(batch)
        loss=self.criterion(logits, batch.y,self.weight)
        prediction = torch.sigmoid(logits)
        y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
        acc=balanced_accuracy_score(batch.y.detach().cpu().numpy().astype(bool),y_pred)
        f1=f1_score(batch.y.detach().cpu().numpy().astype(bool),y_pred)
        mc=matthews_corrcoef(batch.y.detach().cpu().numpy().astype(bool),y_pred)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("val_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("val_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return loss
    
    def test_step(self, batch, batch_idx):
        logits=self(batch)
        loss=self.criterion(logits, batch.y,self.weight)
        prediction = torch.sigmoid(logits)
        y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
        acc=balanced_accuracy_score(batch.y.detach().cpu().numpy().astype(bool),y_pred)
        f1=f1_score(batch.y.detach().cpu().numpy().astype(bool),y_pred)
        mc=matthews_corrcoef(batch.y.detach().cpu().numpy().astype(bool),y_pred)

        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("test_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("test_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return loss
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        logits=self(batch)
        prediction = torch.sigmoid(logits)
        y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
        
        return batch.y.view(-1).detach().cpu().numpy(), prediction, y_pred
    
def main(**config):
    L.seed_everything(config['seed'])

    data_file = 'data/general_disorder.csv'
    df=pd.read_csv(data_file,usecols=['formula', 'disorder'])
    index=np.linspace(0,len(df)-1,len(df),dtype=int)
    train_idx,test_idx= train_test_split(index, test_size=0.2, random_state=config['seed'])
    train_idx,val_idx= train_test_split(train_idx, test_size=0.1, random_state=config['seed'])
    val_set = df.iloc[val_idx]
    val_set.to_csv('data/roost_data/val.csv',index=False)
    test_set = df.iloc[test_idx]
    test_set.to_csv('data/roost_data/test.csv',index=False)
    train_set = df.iloc[train_idx]
    train_set.to_csv('data/roost_data/train.csv',index=False)
    
    wandb_logger = WandbLogger(project="Roost-global-disorder", config=config, log_model="all")
    model = RoostLightningClass(**config)
    trainer = Trainer(devices=1, accelerator='gpu',max_epochs=config['epochs'], logger=wandb_logger, 
                  callbacks=[StochasticWeightAveraging(swa_epoch_start=config['setup_params']['swa_epoch_start'],swa_lrs=config['setup_params']['swa_lrs']),
                             ModelCheckpoint(monitor='val_acc', mode='max',dirpath='roost_models/trained_models/', filename='disorder-{epoch:02d}-{val_acc:.2f}'),
                             EarlyStopping(monitor='val_loss', mode='max', patience=config['patience']),
                             LearningRateMonitor(logging_interval='step')])
    disorder_data = RoostDataModule(config['data_params']['train_path'],
                                   config['data_params']['val_path'],
                                   config['data_params']['test_path'], features=config['data_params']['embed'])
    trainer.fit(model, datamodule=disorder_data)
    y_true, prediction, y_pred=trainer.predict(ckpt_path='best', datamodule=disorder_data)[0]
    metrics={}
    metrics['acc']=balanced_accuracy_score(y_true,y_pred)
    metrics['f1']=f1_score(y_true,y_pred)
    metrics['precision']=precision_score(y_true,y_pred)
    metrics['recall']=recall_score(y_true,y_pred)
    metrics['mc']=matthews_corrcoef(y_true,y_pred)
    metrics['roc_auc']=roc_auc_score(y_true,prediction)
    pred_matrix={}
    pred_matrix['y_true']=y_true
    pred_matrix['y_score']=prediction.detach().numpy()
    pred_matrix['y_true']=y_pred
   
    wandb.log(metrics)
    wandb.log(pred_matrix)
    return

In [94]:
with open('roost/roost_config.json','r') as f:
    config=json.load(f)
path='data/el-embeddings/'
if(config['data_params']['embed']=='onehot'):
    with open(path+'onehot-embedding.json',"r") as f:
        elem_features=json.load(f)
elif(config['data_params']['embed']=='matscholar'):
    with open(path+'matscholar-embedding.json',"r") as f:
        elem_features=json.load(f)
elif(config['data_params']['embed']=='mat2vec'):
    with open(path+'mat2vec.json',"r") as f:
        elem_features=json.load(f)
elif(config['data_params']['embed']=='cgcnn'):
    with open(path+'cgcnn-embedding.json',"r") as f:
        elem_features=json.load(f)

elem_emb_len=len(elem_features['H'])
config['model_params']['input_dim']=elem_emb_len

In [97]:
config['setup_params']

{'optim': 'AdamW',
 'learning_rate': 0.0001,
 'weight_decay': 1e-06,
 'momentum': 0.9,
 'loss': 'BCEWithLogitsLoss',
 'base_lr': 0.001,
 'max_lr': 0.006,
 'swa_epoch_start': 0.2,
 'swa_lrs': 0.0001,
 'gamma': 0.2}

In [99]:
model = RoostLightningClass.load_from_checkpoint('roost_energy_models/trained_models/energy-epoch=93-val_acc=0.00.ckpt',
                                                classification=True,criterion=BCEWithLogitsLoss,setup=config['setup_params'])


 Model architecture: out_dims, n_graphs, heads, internal_elem_dim
1, 3, 3, 64
Model size: 2352057 parameters



In [105]:
model.setup, model.hparams.classification, model.hparams.setup, model.criterion

AttributeError: 'RoostLightningClass' object has no attribute 'criterion'

In [44]:
disorder_data.prepare_data()
dataloader=disorder_data.train_dataloader()

In [45]:
for ind,batch in enumerate(dataloader):
    print(ind,batch)
    break

0 DataBatch(x=[961, 200], edge_index=[2, 4159], y=[256], pos=[961], batch=[961], ptr=[257])


In [53]:
model.hparams.classification=True

In [73]:
model.hparams.criterion

<function roost.utils.BCEWithLogitsLoss(output, target, weight=None)>

In [58]:
model.hparams

"classification": True
"data_params":    {'embed': 'matscholar', 'batch_size': 128, 'train_path': 'data/energy_data/train_fe.csv', 'val_path': 'data/energy_data/val_fe.csv', 'test_path': 'data/energy_data/test_fe.csv'}
"epochs":         100
"model_name":     roost
"model_params":   {'input_dim': 200, 'output_dim': 1, 'hidden_layer_dims': [1024, 512, 256, 128, 64], 'n_graphs': 3, 'elem_heads': 3, 'internal_elem_dim': 64, 'g_elem_dim': 256, 'f_elem_dim': 256, 'comp_heads': 3, 'g_comp_dim': 128, 'f_comp_dim': 128, 'batchnorm': False, 'negative_slope': 0.2}
"patience":       100
"seed":           42
"setup_params":   {'optim': 'AdamW', 'learning_rate': 0.0003, 'weight_decay': 1e-06, 'momentum': 0.9, 'loss': 'L1Loss', 'base_lr': 0.001, 'max_lr': 0.006, 'swa_epoch_start': 0.2, 'swa_lrs': 0.0003, 'gamma': 0.2}
"test_size":      0.2
"val_size":       0.1

In [43]:
disorder_data = RoostDataModule(config['data_params']['train_path'],
                                   config['data_params']['val_path'],
                                   config['data_params']['test_path'], features=config['data_params']['embed'])

In [59]:
checkpoint = torch.load('roost_energy_models/trained_models/energy-epoch=93-val_acc=0.00.ckpt')

In [60]:
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])

In [63]:
checkpoint['callbacks']

{'StochasticWeightAveraging': {'n_averaged': 75,
  'latest_update_epoch': 93,
  'scheduler_state': {'anneal_func': <function torch.optim.swa_utils.SWALR._cosine_anneal(t)>,
   'anneal_epochs': 10,
   'base_lrs': [0.0003],
   'last_epoch': 176,
   '_step_count': 76,
   'verbose': False,
   '_get_lr_called_within_step': False,
   '_last_lr': [0.0003]},
  'average_model_state': OrderedDict([('model.material_nn.project_fea.weight',
                tensor([[ 0.1814,  0.0198,  0.0808,  ..., -0.0408,  0.1799, -0.1739],
                        [ 0.1793,  0.0948, -0.1398,  ..., -0.2047,  0.0222,  0.1107],
                        [-0.1685, -0.0192,  0.0807,  ...,  0.1732, -0.0293,  0.1981],
                        ...,
                        [ 0.1598,  0.1110,  0.2008,  ..., -0.0157,  0.0770, -0.0737],
                        [ 0.1671,  0.0894,  0.0350,  ...,  0.2250,  0.2377, -0.1426],
                        [ 0.1947,  0.0192,  0.1408,  ..., -0.0987,  0.1253,  0.0840]])),
               ('mod

In [25]:
s='model.material_nn.project_fea.weight'
s[6:]

'material_nn.project_fea.weight'

In [30]:
from collections import OrderedDict

checkpoint1=OrderedDict()
for key,value in checkpoint['state_dict'].items():
    checkpoint1[key[6:]]=value

In [31]:
model.load_state_dict(checkpoint1)

RuntimeError: Error(s) in loading state_dict for Roost:
	Unexpected key(s) in state_dict: "al_nn.project_fea.weight". 

In [74]:
model.hparams.classification

True

In [85]:
model.hparams.criterion(torch.tensor([0.0,1.0]),torch.tensor([0.0,1.0]))

tensor(0.5032)

In [86]:
model.criterion

AttributeError: 'RoostLightningClass' object has no attribute 'criterion'

In [77]:
model.hparams.criterion(torch.tensor([True,False]),torch.tensor([True,False]))

RuntimeError: Negation, the `-` operator, on a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.

In [78]:
model.hparams.criterion

<function roost.utils.BCEWithLogitsLoss(output, target, weight=None)>

In [81]:
loss=BCEWithLogitsLoss
loss(torch.tensor([0.0,1.0]),torch.tensor([0.0,1.0]))

tensor(0.5032)