In [1]:
import numpy as np
import pandas as pd
import collections
from collections import OrderedDict
import pytorch_lightning as L
import os
import re
import json
import tqdm

from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score, f1_score
from sklearn.metrics import mean_absolute_error, mean_squared_error, matthews_corrcoef
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, train_test_split

from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
# from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from pymatgen.core.composition import Composition
from crabnet.kingcrab import CrabNet

import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR, StepLR

from crabnet.utils.utils import (Lamb, Lookahead, RobustL1, BCEWithLogitsLoss,
                         EDMDataset, get_edm, Scaler, DummyScaler, count_parameters)
from crabnet.utils.get_compute_device import get_compute_device
# from crabnet.utils.composition import _element_composition, get_sym_dict, parse_formula, CompositionError
#from utils.optim import SWA

data_type_np = np.float32
data_type_torch = torch.float32

import wandb


class CrabNetDataModule(L.LightningDataModule):
    def __init__(self, train_file: str , 
                 val_file: str, 
                 test_file: str,
                 n_elements ='infer', 
                 classification = False,
                 elem_prop='mat2vec',
                 batch_size = 2**10,
                 scale = True,
                 pin_memory = True):
        super().__init__()
        self.train_path = train_file
        self.val_path = val_file
        self.test_path = test_file
        self.batch_size = batch_size
        self.n_elements=n_elements
        self.pin_memory = pin_memory
        self.scale = scale
        self.classification = classification
        self.elem_prop=elem_prop

    def prepare_data(self):
        ### loading and encoding trianing data
        if(re.search('.json', self.train_path )):
            self.data_train=pd.read_json(self.train_path)
        elif(re.search('.csv', self.train_path)):
            self.data_train=pd.read_csv(self.train_path)

        self.train_main_data = list(get_edm(self.data_train, elem_prop=self.elem_prop,
                                      n_elements=self.n_elements,
                                      inference=False,
                                      verbose=True,
                                      drop_unary=False,
                                      scale=self.scale))
        
        self.train_len_data = len(self.train_main_data[0])
        self.train_n_elements = self.train_main_data[0].shape[1]//2

        print(f'loading data with up to {self.train_n_elements:0.0f} '
              f'elements in the formula for training')
        
        ### loading and encoding validation data
        if(re.search('.json', self.val_path )):
            self.data_val=pd.read_json(self.val_path)
        elif(re.search('.csv', self.val_path)):
            self.data_val=pd.read_csv(self.val_path)
        
        self.val_main_data = list(get_edm(self.data_val, elem_prop=self.elem_prop,
                                      n_elements=self.n_elements,
                                      inference=True,
                                      verbose=True,
                                      drop_unary=False,
                                      scale=self.scale))
        
        self.val_len_data = len(self.val_main_data[0])
        self.val_n_elements = self.val_main_data[0].shape[1]//2

        print(f'loading data with up to {self.val_n_elements:0.0f} '
              f'elements in the formula for validation')
        
        ### loading and encoding testing data
        if(re.search('.json', self.test_path )):
            self.data_test=pd.read_json(self.test_path)
        elif(re.search('.csv', self.test_path)):
            self.data_test=pd.read_csv(self.test_path)
        
        self.test_main_data = list(get_edm(self.data_test, elem_prop=self.elem_prop,
                                      n_elements=self.n_elements,
                                      inference=True,
                                      verbose=True,
                                      drop_unary=False,
                                      scale=self.scale))
        
        self.test_len_data = len(self.test_main_data[0])
        self.test_n_elements = self.test_main_data[0].shape[1]//2

        print(f'loading data with up to {self.test_n_elements:0.0f} '
              f'elements in the formula for testing')

        self.train_dataset = EDMDataset(self.train_main_data, self.train_n_elements)
        self.val_dataset = EDMDataset(self.val_main_data, self.val_n_elements)
        self.test_dataset = EDMDataset(self.test_main_data, self.test_n_elements)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, shuffle=True)
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                        pin_memory=self.pin_memory, shuffle=False)
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_len_data,
                        pin_memory=self.pin_memory, shuffle=False)
    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_len_data,
                        pin_memory=self.pin_memory, shuffle=False)


class CrabNetLightning(L.LightningModule):
    def __init__(self, **config):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model = CrabNet(out_dims=config['out_dims'],
                             d_model=config['d_model'],
                             N=config['N'],
                             heads=config['heads'])
        print('\nModel architecture: out_dims, d_model, N, heads')
        print(f'{self.model.out_dims}, {self.model.d_model}, '
                  f'{self.model.N}, {self.model.heads}')
        print(f'Model size: {count_parameters(self.model)} parameters\n')

        ### here we define some important parameters
        self.fudge=config['fudge']
        self.batch_size=config['batch_size']
        self.classification = config['classification']
        self.base_lr=config['base_lr']
        self.max_lr=config['max_lr']
        ### here we also need to initialise scaler based on training data
        if(re.search('.json', config['train_path'] )):
            train_data=pd.read_json(config['train_path'])
        elif(re.search('.csv', config['train_path'])):
            train_data=pd.read_csv(config['train_path'])
        
        y=train_data['target'].values
        self.step_size = len(y)
        if self.classification:
            self.scaler = DummyScaler(y)
        else:
            self.scaler = Scaler(y)
        ### we also define loss function based on task
        if self.classification:
            if(np.sum(y)>0):
                self.weight=torch.tensor(((len(y)-np.sum(y))/np.sum(y))).cuda()
            print("Using BCE loss for classification task")
            self.criterion = BCEWithLogitsLoss
        else:
            print("Using RobustL1 loss for regression task")
            self.criterion = RobustL1

    def forward(self, src, frac):
        out=self.model(src, frac)
        return out

    def configure_optimizers(self):
        base_optim = Lamb(params=self.model.parameters(),lr=0.001)
        optimizer = Lookahead(base_optimizer=base_optim)
        lr_scheduler = CyclicLR(optimizer,
                                base_lr=self.base_lr,
                                max_lr=self.max_lr,
                                cycle_momentum=False,
                                step_size_up=self.step_size)
        # lr_scheduler=StepLR(optimizer,
        #                     step_size=3,
        #                     gamma=0.5)
        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        loss = self.criterion(prediction.view(-1),
                              uncertainty.view(-1),
                              y.view(-1), self.weight)
        
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
            y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            acc=balanced_accuracy_score(y.view(-1).detach().cpu().numpy(),y_pred)
            f1=f1_score(y.view(-1).detach().cpu().numpy(),y_pred,average='weighted')
            mc=matthews_corrcoef(y.view(-1).detach().cpu().numpy(),y_pred)
            
            self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
            self.log("train_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("train_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(prediction.view(-1),y.view(-1))
            mae = mean_absolute_error(prediction.view(-1),y.view(-1))
            self.log("train_mse", mse, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("train_mae", mae, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        val_loss = self.criterion(prediction.view(-1),
                              uncertainty.view(-1),
                              y.view(-1), self.weight)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
            y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            acc=balanced_accuracy_score(y.view(-1).detach().cpu().numpy(),y_pred)
            f1=f1_score(y.view(-1).detach().cpu().numpy(),y_pred,average='weighted')
            mc=matthews_corrcoef(y.view(-1).detach().cpu().numpy(),y_pred)
            
            self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
            self.log("val_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("val_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(prediction.view(-1),y.view(-1))
            mae = mean_absolute_error(prediction.view(-1),y.view(-1))
            self.log("val_mse", mse, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("val_mae", mae, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return val_loss
     
    def test_step(self, batch, batch_idx):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
            y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            acc=balanced_accuracy_score(y.view(-1).detach().cpu().numpy(),y_pred)
            f1=f1_score(y.view(-1).detach().cpu().numpy(),y_pred,average='weighted')
            mc=matthews_corrcoef(y.view(-1).detach().cpu().numpy(),y_pred)
            
            self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
            self.log("test_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(prediction.view(-1),y.view(-1))
            mae = mean_absolute_error(prediction.view(-1),y.view(-1))
            self.log("test_mse", mse, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_mae", mae, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return 
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)

        y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
        acc=balanced_accuracy_score(y.view(-1).detach().cpu().numpy(),y_pred)
        f1=f1_score(y.view(-1).detach().cpu().numpy(),y_pred,average='weighted')
        mc=matthews_corrcoef(y.view(-1).detach().cpu().numpy(),y_pred)
        self.log("predict_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        self.log("predict_f1", f1, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        self.log("predict_mc", mc, on_step=False, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return formula, y_pred, prediction, uncertainty




In [2]:
with open('crabnet/crabnet_config.json','r') as f:
        config=json.load(f)

L.seed_everything(config['random_seed'])

Global seed set to 42


42

In [3]:
model = CrabNetLightning(**config)
# wandb_logger = WandbLogger(project="Crabnet-global-disorder-new", config=config, log_model="all")
trainer = Trainer(max_epochs=10,accelerator='gpu', devices=1, 
                      callbacks=[StochasticWeightAveraging(swa_epoch_start=config['swa_epoch_start'],swa_lrs=config['swa_lrs']),
                                EarlyStopping(monitor='val_loss', patience=config['patience']), ModelCheckpoint(monitor='val_acc', mode="max", 
                                dirpath='crabnet_models/crabnet_trained_models/', filename='disorder-{epoch:02d}-{val_acc:.2f}')])
disorder_data = CrabNetDataModule(config['train_path'],
                                   config['val_path'],
                                   config['test_path'],
                                   classification = config['classification'])
trainer.fit(model, datamodule=disorder_data)
trainer.test(ckpt_path='best',datamodule=disorder_data)



GPU available: True (cuda), used: True



Model architecture: out_dims, d_model, N, heads
3, 512, 3, 4
Model size: 11987206 parameters

Using BCE loss for classification task


TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 207925.06formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 208416.34formulae/s]


loading data with up to 16 elements in the formula for validation


Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 210505.90formulae/s]
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


loading data with up to 16 elements in the formula for testing


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type    | Params
----------------------------------
0 | model | CrabNet | 12.0 M
----------------------------------
12.0 M    Trainable params
23.8 K    Non-trainable params
12.0 M    Total params
48.044    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Swapping scheduler `CyclicLR` for `SWALR`


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 212528.78formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 205074.96formulae/s]


loading data with up to 16 elements in the formula for validation


Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 202331.59formulae/s]
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at crabnet_models/crabnet_trained_models/disorder-epoch=08-val_acc=0.89.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at crabnet_models/crabnet_trained_models/disorder-epoch=08-val_acc=0.89.ckpt


loading data with up to 16 elements in the formula for testing


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8879008745264143
         test_f1            0.8783106564072712
         test_mc            0.7603236724744891
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.8879008745264143,
  'test_f1': 0.8783106564072712,
  'test_mc': 0.7603236724744891}]

In [5]:
for x in disorder_data.predict_dataloader():
    X, y_true, formula = x

In [22]:
trainer.test(ckpt_path='best', datamodule=disorder_data)

Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 204244.68formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 198486.13formulae/s]


loading data with up to 16 elements in the formula for validation


Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 208400.89formulae/s]
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at crabnet_models/crabnet_trained_models/disorder-epoch=08-val_acc=0.89.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


loading data with up to 16 elements in the formula for testing


Loaded model weights from checkpoint at crabnet_models/crabnet_trained_models/disorder-epoch=08-val_acc=0.89.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8879008745264143
         test_f1            0.8783106564072712
         test_mc            0.7603236724744891
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.8879008745264143,
  'test_f1': 0.8783106564072712,
  'test_mc': 0.7603236724744891}]

In [27]:
formula, prediction, uncertainty=trainer.predict(ckpt_path='best', datamodule=disorder_data)[0]

Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 211632.95formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 198487.25formulae/s]


loading data with up to 16 elements in the formula for validation


Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 206336.60formulae/s]
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at crabnet_models/crabnet_trained_models/disorder-epoch=08-val_acc=0.89.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


loading data with up to 16 elements in the formula for testing


Loaded model weights from checkpoint at crabnet_models/crabnet_trained_models/disorder-epoch=08-val_acc=0.89.ckpt
  rank_zero_warn(


Predicting: 74it [00:00, ?it/s]

In [29]:
prediction

tensor([[1.0000],
        [0.9639],
        [0.0169],
        ...,
        [0.9998],
        [0.9998],
        [1.0000]])

In [30]:

y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5

In [31]:
y_pred,y_true,prediction

(array([ True,  True, False, ...,  True,  True,  True]),
 tensor([1., 1., 0.,  ..., 1., 1., 1.]),
 tensor([[1.0000],
         [0.9639],
         [0.0169],
         ...,
         [0.9998],
         [0.9998],
         [1.0000]]))

In [32]:
balanced_accuracy_score(y_true,y_pred)

0.8879008745264143

In [33]:
f1_score(y_true,y_pred,average='weighted')

0.8783106564072712

In [34]:
matthews_corrcoef(y_true,y_pred)

0.7603236724744891

In [20]:
roc_auc_score(y_true,prediction)

0.9507252526634921

In [None]:
def main(**config):
    model = CrabNetLightning(**config)
    wandb_logger = WandbLogger(project="Crabnet-global-disorder-new", config=config, log_model="all")
    trainer = Trainer(max_epochs=100,accelerator='gpu', devices=1, logger=wandb_logger,
                      callbacks=[StochasticWeightAveraging(swa_epoch_start=config['swa_epoch_start'],swa_lrs=config['swa_lrs']),
                                EarlyStopping(monitor='val_loss', patience=config['patience']), ModelCheckpoint(monitor='val_acc', mode="max", 
                                dirpath='crabnet_models/crabnet_trained_models/', filename='disorder-{epoch:02d}-{val_acc:.2f}')])
    disorder_data = CrabNetDataModule(config['train_path'],
                                   config['val_path'],
                                   config['test_path'],
                                   classification = config['classification'])
    trainer.fit(model, datamodule=disorder_data)
    trainer.test(ckpt_path='best',datamodule=disorder_data)

    formula, prediction, uncertainty=trainer.predict(ckpt_path='best', datamodule=disorder_data)
    metrics={}
    metrics['acc']=balanced_accuracy_score(y_true,y_pred)
    metrics['f1']=f1_score(y_true,y_pred,average='weighted')
    metrics['precision']=precision_score(y_true,y_pred)
    metrics['recall']=recall_score(y_true,y_pred)
    metrics['mc']=matthews_corrcoef(y_true,y_pred)
    metrics['roc_auc']=roc_auc_score(y_true,prediction)
    metrics['conf_matrix']=confusion_matrix(y_true,y_pred)
    pred_matrix={}
    pred_matrix['y_true']=y_true
    pred_matrix['y_score']=prediction.detach().numpy()
    pred_matrix['y_true']=y_pred
   
    wandb.log(metrics)
    wandb.log(pred_matrix)


    return

In [None]:

if __name__=='__main__':
    wandb.init(project="Crabnet-global-disorder-ne")
    wandb.login(key='b11d318e434d456c201ef1d3c86a3c1ce31b98d7')

    with open('crabnet/crabnet_config.json','r') as f:
        config=json.load(f)

    L.seed_everything(config['random_seed'])
    main(**config)

    wandb.finish()
    # print('Start sweeping with different parameters for RF...')

    # wandb.login(key='b11d318e434d456c201ef1d3c86a3c1ce31b98d7')

    # sweep_config = {
    # 'method': 'random',
    # 'parameters': {'n_estimators': {'values': [50, 100, 150, 200]},
    #                'class_weight': {'values':['balanced', 'balanced_subsample']},
    #                'criterion': {'values': ['gini', 'entropy', 'log_loss']}
    # }
    # }

    # sweep_id = wandb.sweep(sweep=sweep_config, project="RF-disorder-prediction-global-disorder")

    # wandb.agent(sweep_id, function=main, count=10)

    # wandb.finish()

In [18]:
class CrabNetDataModule(L.LightningDataModule):
    def __init__(self, train_file: str , 
                 val_file: str, 
                 test_file: str,
                 n_elements ='infer', 
                 classification = False,
                 batch_size = 2**10,
                 scale = True,
                 pin_memory = True,
                 num_workers = 1):
        super().__init__()
        self.train_path = train_file
        self.val_path = val_file
        self.test_path = test_file
        self.batch_size = batch_size
        self.n_elements=n_elements
        self.pin_memory = pin_memory
        self.scale = scale
        self.classification = classification
        self.num_workers = num_workers

    def prepare_data(self):
        ### loading and encoding trianing data
        if(re.search('.json', self.train_path )):
            self.data_train=pd.read_json(self.train_path)
        elif(re.search('.csv', self.train_path)):
            self.data_train=pd.read_csv(self.train_path)

        self.train_main_data = list(get_edm(self.data_train, elem_prop='mat2vec',
                                      n_elements=self.n_elements,
                                      inference=False,
                                      verbose=True,
                                      drop_unary=False,
                                      scale=self.scale))
        
        self.train_len_data = len(self.train_main_data[0])
        self.train_n_elements = self.train_main_data[0].shape[1]//2

        print(f'loading data with up to {self.train_n_elements:0.0f} '
              f'elements in the formula for training')
        
        ### loading and encoding validation data
        if(re.search('.json', self.val_path )):
            self.data_val=pd.read_json(self.val_path)
        elif(re.search('.csv', self.val_path)):
            self.data_val=pd.read_csv(self.val_path)
        
        self.val_main_data = list(get_edm(self.data_val, elem_prop='mat2vec',
                                      n_elements=self.n_elements,
                                      inference=True,
                                      verbose=True,
                                      drop_unary=False,
                                      scale=self.scale))
        
        self.val_len_data = len(self.val_main_data[0])
        self.val_n_elements = self.val_main_data[0].shape[1]//2

        print(f'loading data with up to {self.val_n_elements:0.0f} '
              f'elements in the formula for validation')
        
        ### loading and encoding testing data
        if(re.search('.json', self.test_path )):
            self.data_test=pd.read_json(self.test_path)
        elif(re.search('.csv', self.test_path)):
            self.data_test=pd.read_csv(self.test_path)
        
        self.test_main_data = list(get_edm(self.data_test, elem_prop='mat2vec',
                                      n_elements=self.n_elements,
                                      inference=True,
                                      verbose=True,
                                      drop_unary=False,
                                      scale=self.scale))
        
        self.test_len_data = len(self.test_main_data[0])
        self.test_n_elements = self.test_main_data[0].shape[1]//2

        print(f'loading data with up to {self.test_n_elements:0.0f} '
              f'elements in the formula for testing')

    def setup(self, stage: str):
        ### creating dataloaders for training
        if stage == "fit":
            self.train_dataset = EDMDataset(self.train_main_data, self.train_n_elements)
            self.val_dataset = EDMDataset(self.val_main_data, self.val_n_elements)
            
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage == "predict":
            self.test_dataset = EDMDataset(self.test_main_data, self.test_n_elements)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          pin_memory=self.pin_memory, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                        pin_memory=self.pin_memory, shuffle=False,num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                        pin_memory=self.pin_memory, shuffle=False,num_workers=self.num_workers)
    
    def predict_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_len_data,
                        pin_memory=self.pin_memory, shuffle=False,num_workers=self.num_workers)


class CrabNetLightning(L.LightningModule):
    def __init__(self, **config):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model = CrabNet(out_dims=config['out_dims'],
                             d_model=config['d_model'],
                             N=config['N'],
                             heads=config['heads'])
        print('\n Model architecture: out_dims, d_model, N, heads')
        print(f'{self.model.out_dims}, {self.model.d_model}, '
                  f'{self.model.N}, {self.model.heads}')
        print(f'Model size: {count_parameters(self.model)} parameters\n')

        ### here we define some important parameters
        self.fudge=config['fudge']
        self.batch_size=config['batch_size']
        self.classification = config['classification']
        self.base_lr=config['base_lr']
        self.max_lr=config['max_lr']
        ### here we also need to initialise scaler based on training data
        if(re.search('.json', config['train_path'] )):
            train_data=pd.read_json(config['train_path'])
        elif(re.search('.csv', config['train_path'])):
            train_data=pd.read_csv(config['train_path'])
        
        y=train_data['target'].values
        self.step_size = len(y)
        if self.classification:
            self.scaler = DummyScaler(y)
        else:
            self.scaler = Scaler(y)
        ### we also define loss function based on task
        if self.classification:
            if(np.sum(y)>0):
                self.weight=torch.tensor(((len(y)-np.sum(y))/np.sum(y)),dtype=data_type_torch).to(device)
            print("Using BCE loss for classification task")
            self.criterion = BCEWithLogitsLoss
        else:
            print("Using RobustL1 loss for regression task")
            self.criterion = RobustL1


    def forward(self, src, frac):
        out=self.model(src, frac)
        return out

    def configure_optimizers(self):
        base_optim = Lamb(params=self.model.parameters(),lr=0.001)
        optimizer = Lookahead(base_optimizer=base_optim)
        lr_scheduler = CyclicLR(optimizer,
                                base_lr=self.base_lr,
                                max_lr=self.max_lr,
                                cycle_momentum=False,
                                step_size_up=self.step_size)
        # lr_scheduler=StepLR(optimizer,
        #                     step_size=3,
        #                     gamma=0.5)
        return [optimizer], [lr_scheduler]

    def training_step(self, batch, batch_idx):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        loss = self.criterion(prediction.view(-1),
                              uncertainty.view(-1),
                              y.view(-1), self.weight)
        
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
            y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            acc=accuracy_score(y_pred,y.view(-1).detach().cpu().numpy())
            f1=f1_score(y_pred,y.view(-1).detach().cpu().numpy())
            # auc=roc_auc_score(prediction.view(-1).detach().cpu().numpy(),y.view(-1).detach().cpu().numpy())
            self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
            self.log("train_f1", f1, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            # self.log("train_auc", auc, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(prediction.view(-1),y.view(-1))
            mae = mean_absolute_error(prediction.view(-1),y.view(-1))
            self.log("train_mse", mse, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("train_mae", mae, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        val_loss = self.criterion(prediction.view(-1),
                              uncertainty.view(-1),
                              y.view(-1), self.weight)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
        if self.classification:
            y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            acc=accuracy_score(y_pred,y.view(-1).detach().cpu().numpy())
            f1=f1_score(y_pred,y.view(-1).detach().cpu().numpy())
            # auc=roc_auc_score(prediction.view(-1).detach().cpu().numpy(),y.view(-1).detach().cpu().numpy())
            self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.batch_size)
            self.log("val_f1", f1, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            # self.log("val_auc", auc, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(prediction.view(-1),y.view(-1))
            mae = mean_absolute_error(prediction.view(-1),y.view(-1))
            self.log("val_mse", mse, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("val_mae", mae, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return val_loss
     
    def test_step(self, batch, batch_idx):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
        if self.classification:
            y_pred = prediction.view(-1).detach().cpu().numpy() > 0.5
            acc=accuracy_score(y_pred,y.view(-1).detach().cpu().numpy())
            f1=f1_score(y_pred,y.view(-1).detach().cpu().numpy())
            # auc=roc_auc_score(prediction.view(-1).detach().cpu().numpy(),y.view(-1).detach().cpu().numpy())
            self.log("test_acc", acc, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_f1", f1, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            # self.log("test_auc", auc, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        else:
            mse = mean_squared_error(prediction.view(-1),y.view(-1))
            mae = mean_absolute_error(prediction.view(-1),y.view(-1))
            self.log("test_mse", mse, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
            self.log("test_mae", mae, on_step=True, on_epoch=True, prog_bar=False, logger=True, batch_size=self.batch_size)
        return 
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        X, y, formula = batch
        y = self.scaler.scale(y)
        src, frac = X.squeeze(-1).chunk(2, dim=1)
        frac = frac * (1 + (torch.randn_like(frac))*self.fudge)
        frac = torch.clamp(frac, 0, 1)
        frac[src == 0] = 0
        frac = frac / frac.sum(dim=1).unsqueeze(1).repeat(1, frac.shape[-1])
        
        output = self(src, frac)
        prediction, uncertainty = output.chunk(2, dim=-1)
        uncertainty = torch.exp(uncertainty) * self.scaler.std
        prediction = self.scaler.unscale(prediction)
        if self.classification:
            prediction = torch.sigmoid(prediction)
        return formula, prediction, uncertainty

In [3]:
print('Loading the data...')

Loading the data...


In [12]:
disorder_data = CrabNetDataModule(config['train_path'],
                                   config['val_path'],
                                   config['test_path'],
                                   classification = config['classification'])

In [14]:
disorder_data.prepare_data()

Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 230354.83formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 231825.29formulae/s]

loading data with up to 16 elements in the formula for validation



Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 230580.65formulae/s]

loading data with up to 16 elements in the formula for testing





In [15]:
disorder_data.setup(stage='fit')

In [16]:
trainloader=disorder_data.train_dataloader()

In [17]:
for batch in trainloader:
    print(batch)
    
    break

[tensor([[[17.],
         [11.],
         [47.],
         ...,
         [ 0.],
         [ 0.],
         [ 0.]],

        [[ 8.],
         [ 1.],
         [16.],
         ...,
         [ 0.],
         [ 0.],
         [ 0.]],

        [[34.],
         [49.],
         [48.],
         ...,
         [ 0.],
         [ 0.],
         [ 0.]],

        ...,

        [[52.],
         [64.],
         [55.],
         ...,
         [ 0.],
         [ 0.],
         [ 0.]],

        [[ 8.],
         [22.],
         [20.],
         ...,
         [ 0.],
         [ 0.],
         [ 0.]],

        [[ 8.],
         [26.],
         [15.],
         ...,
         [ 0.],
         [ 0.],
         [ 0.]]]), tensor([1., 0., 1.,  ..., 0., 1., 1.]), ('Ag0.201 Cl1 Na0.799', 'H2 Gd1 K1 O9 S2', 'Cd0.95 Cu0.05 In2 Se4', 'Gd5 Ge10 Ir4', 'Nd2 Ni17', 'Cr4 Nb1 Zr1', 'Ba2 Ga1 Li1 S4', 'C1.5 Fe14.6 Ga2.4 Sm2', 'Cu5 O4 Rb3', 'Gd1 Si1 Ti1', 'Fe7 Sm1', 'Co2.73 Cu0.27 O4', 'Cu0.96 Eu1 Sn1.04', 'Ga0.15 La1 Mn0.85 O3', 'Ce1 Mo6 Se8'

In [1]:
config={'train_path': 'data/crabnet_data/train.csv',
            'val_path': 'data/crabnet_data/val.csv',
            'test_path':'data/crabnet_data/test.csv',
            'out_dims': 3,
            'd_model': 512,
            'N': 3,
            'heads': 4,
            'classification': True,
            'batch_size': 2**12,
            'fudge': 0,
            'random_seed': 42,
            'swa_epoch_start' : 0.05,
            'swa_lrs': 1e-2,
            'base_lr': 1e-4,
            'max_lr': 6e-3,
            'schedule': 'CyclicLR',
            'patience': 10,
            'num_workers' : 1 }

In [4]:
import json

with open('crabnet/crabnet_config.json','w') as f:
    json.dump(config,f, indent=4)

In [4]:
data_file='data/general_disorder.csv'
df=pd.read_csv('data/general_disorder.csv',usecols=['formula', 'disorder'])
df=df.rename(columns={'disorder':'target'})

In [5]:
index=np.linspace(0,len(df)-1,len(df),dtype=int)
train_idx,test_idx= train_test_split(index, test_size=0.2, random_state=42)
train_idx,val_idx= train_test_split(train_idx, test_size=0.1, random_state=42)
val_set = df.iloc[val_idx]
val_set.to_csv('data/crabnet_data/val.csv',index=False)
test_set = df.iloc[test_idx]
test_set.to_csv('data/crabnet_data/test.csv',index=False)
train_set = df.iloc[train_idx]
train_set.to_csv('data/crabnet_data/train.csv',index=False)

In [9]:
path=''
config={'train_path': path+'data/crabnet_data/train.csv',
            'val_path': path+'data/crabnet_data/val.csv',
            'test_path': path+'data/crabnet_data/test.csv',
            'out_dims': 3,
            'd_model': 512,
            'N': 3,
            'heads': 4,
            'classification': True,
            'batch_size': 2**12,
            'fudge': 0,
            'random_seed': 42,
            'swa_epoch_start' : 0.05,
            'swa_lrs': 1e-2,
            'base_lr': 1e-4,
            'max_lr': 6e-3,
            'schedule': 'CyclicLR',
            'patience': 10,
            'num_workers' : 1 }

In [7]:
L.seed_everything(config['random_seed'])

Seed set to 42


42

In [8]:
model = CrabNetLightning(**config)


 Model architecture: out_dims, d_model, N, heads
3, 512, 3, 4
Model size: 11987206 parameters

Using BCE loss for classification task




In [9]:
trainer = Trainer(max_epochs=1,accelerator='gpu', devices=1,
                      callbacks=[StochasticWeightAveraging(swa_epoch_start=config['swa_epoch_start'],swa_lrs=config['swa_lrs']),
                                EarlyStopping(monitor='val_loss', patience=config['patience']), ModelCheckpoint(monitor='val_acc', mode="max", 
                                dirpath=path+'crabnet_models/trained_models/', filename='disorder-{epoch:02d}-{val_acc:.2f}')])

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/elenapatyukova/anaconda3/envs/disorder_pred/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [17]:
disorder_data = CrabNetDataModule(config['train_path'],
                                   config['val_path'],
                                   config['test_path'],
                                   classification = config['classification'])

In [11]:
trainer.fit(model, datamodule=disorder_data)

Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 235357.34formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 230421.06formulae/s]

loading data with up to 16 elements in the formula for validation



Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 230603.14formulae/s]

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | CrabNet | 12.0 M | train
------------------------------------------
12.0 M    Trainable params
23.8 K    Non-trainable params
12.0 M    Total params
48.044    Total estimated model params size (MB)


loading data with up to 16 elements in the formula for testing


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/elenapatyukova/anaconda3/envs/disorder_pred/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/Users/elenapatyukova/anaconda3/envs/disorder_pred/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Swapping scheduler `CyclicLR` for `SWALR`


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [18]:
trainer.test(ckpt_path='best',datamodule=disorder_data)

Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 232381.28formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 247921.00formulae/s]

loading data with up to 16 elements in the formula for validation



Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 246702.89formulae/s]
Restoring states from the checkpoint path at /Users/elenapatyukova/Documents/GitHub/Disorder-prediction-new/crabnet_models/trained_models/disorder-epoch=00-val_acc=0.61.ckpt
Loaded model weights from the checkpoint at /Users/elenapatyukova/Documents/GitHub/Disorder-prediction-new/crabnet_models/trained_models/disorder-epoch=00-val_acc=0.61.ckpt


loading data with up to 16 elements in the formula for testing


/Users/elenapatyukova/anaconda3/envs/disorder_pred/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.6029446721076965
      test_f1_epoch         0.7522377967834473
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc_epoch': 0.6029446721076965, 'test_f1_epoch': 0.7522377967834473}]

In [19]:
pred = trainer.predict(ckpt_path='best',datamodule=disorder_data)

Generating EDM: 100%|██████████| 75091/75091 [00:00<00:00, 235414.16formulae/s]


loading data with up to 16 elements in the formula for training


Generating EDM: 100%|██████████| 8344/8344 [00:00<00:00, 239556.39formulae/s]

loading data with up to 16 elements in the formula for validation



Generating EDM: 100%|██████████| 20859/20859 [00:00<00:00, 242176.67formulae/s]
Restoring states from the checkpoint path at /Users/elenapatyukova/Documents/GitHub/Disorder-prediction-new/crabnet_models/trained_models/disorder-epoch=00-val_acc=0.61.ckpt
Loaded model weights from the checkpoint at /Users/elenapatyukova/Documents/GitHub/Disorder-prediction-new/crabnet_models/trained_models/disorder-epoch=00-val_acc=0.61.ckpt


loading data with up to 16 elements in the formula for testing


/Users/elenapatyukova/anaconda3/envs/disorder_pred/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'predict_dataloader' to speed up the dataloader worker initialization.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [None]:
import matplotlib.pyplot as plt

plt.hist(pred[0][2],bins=30)

KeyboardInterrupt: 

Error in callback <function _draw_all_if_interactive at 0x15d89e3e0> (for post_execute), with arguments args (),kwargs {}:


KeyboardInterrupt: 

In [1]:
y=[0.1,0.2,0.5]

In [3]:
from crabnet.utils.utils import (Lamb, Lookahead, RobustL1, BCEWithLogitsLoss,
                         EDMDataset, get_edm, Scaler, DummyScaler, count_parameters)

In [4]:
z=Scaler(y)

In [6]:
z.scale(y)

tensor([-0.8006, -0.3203,  1.1209])