In [29]:
import os
import shutil
import sys
import time
import warnings
from random import sample

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR, StepLR
from torch.utils.data.sampler import SubsetRandomSampler

from sklearn.metrics import balanced_accuracy_score, accuracy_score, roc_auc_score, f1_score
from sklearn.metrics import mean_absolute_error, mean_squared_error, matthews_corrcoef
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, train_test_split

import pytorch_lightning as L
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer
from torchmetrics.functional import mean_squared_error, mean_absolute_error

from pymatgen.core.composition import Composition
from pymatgen.core.structure import Structure

import torch
from torch.utils.data import DataLoader
from torch.nn import L1Loss, MSELoss, HuberLoss

data_type_np = np.float32
data_type_torch = torch.float32

import wandb

In [30]:
L.__version__

'2.4.0'

In [42]:
class Normalizer(object):
    """Normalize a Tensor and restore it later. """

    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']

In [43]:
config={
    'root_dir': '/Users/elena.patyukova/Documents/github/qe-input/src/qe_input/cgcnn/cgcnn_data',
    'train_ratio': 0.8,
    'val_ratio':0.1,
    'test_ratio':0.1,
    'atom_fea_len': 64,
    'n_conv': 3,
    'h_fea_len': 128,
    'n_h': 1,
    'classification': False,
    'batch_size': 256,
    'base_lr': 0.01,
    'momentum': 0.9,
    'weight_decay': 0.1,
    'optim': 'SGD',
    'pin_memory': True,
    'patience': 50,
}

In [44]:
structure=Structure.from_file('/Users/elena.patyukova/Documents/github/qe-input/src/qe_input/cgcnn/cgcnn_data/0.cif')
structure

Structure Summary
Lattice
    abc : 4.93725642 4.937256420000001 5.44931282
 angles : 90.0 90.0 120.00000589
 volume : 115.0386328504509
      A : 4.93725642 0.0 3.02319763566136e-16
      B : -2.468628649550688 4.275789230943074 3.02319763566136e-16
      C : 0.0 0.0 5.44931282
    pbc : True True True
PeriodicSite: Si0 (Si) (-0.1869, 3.056, 3.449) [0.3195, 0.7146, 0.6329]
PeriodicSite: Si1 (Si) (0.8728, 1.22, 4.449) [0.3195, 0.2854, 0.8164]
PeriodicSite: Si2 (Si) (1.473, 0.0, 9.019e-17) [0.2983, 0.0, 0.0]
PeriodicSite: O3 (O) (2.714, 2.201, 2.381) [0.807, 0.5146, 0.4369]
PeriodicSite: O4 (O) (-0.2512, 2.512, 1.381) [0.2429, 0.5875, 0.2534]
PeriodicSite: O5 (O) (1.329, 0.7423, 3.381) [0.356, 0.1736, 0.6205]
PeriodicSite: O6 (O) (-0.2823, 3.533, 4.517) [0.356, 0.8264, 0.8288]
PeriodicSite: O7 (O) (0.1809, 1.764, 1.068) [0.2429, 0.4125, 0.196]
PeriodicSite: O8 (O) (2.786, 2.075, 0.06753) [0.807, 0.4854, 0.01239]

In [45]:
normalizer=Normalizer(torch.tensor([1.0]))
normalizer.load_state_dict(state_dict={'mean': 43.78969486328977, 'std':20.44467702225091})

  self.std = torch.std(tensor)


In [46]:
from cgcnn.model import CrystalGraphConvNet
from cgcnn.data import CIFData, collate_pool

dataset = CIFData(root_dir=config['root_dir'], max_num_nbr=12, radius=10, dmin=0, step=0.2, random_seed=123)
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model=CrystalGraphConvNet(orig_atom_fea_len=orig_atom_fea_len,
                          nbr_fea_len=nbr_fea_len,
                          atom_fea_len=config['atom_fea_len'], 
                          n_conv=config['n_conv'], 
                          h_fea_len=config['h_fea_len'], 
                          n_h=config['n_h'],
                          classification=config['classification'])

In [47]:
checkpoint = torch.load('/Users/elena.patyukova/Documents/github/qe-input/src/qe_input/trained_models/kspacing_checkpoint.ckpt')

In [48]:
model_weights = checkpoint["state_dict"]
for key in list(model_weights):
    model_weights[key.replace("model.", "")] = model_weights.pop(key)

model.load_state_dict(model_weights)
model.eval()

CrystalGraphConvNet(
  (embedding): Linear(in_features=92, out_features=64, bias=True)
  (convs): ModuleList(
    (0-2): 3 x ConvLayer(
      (fc_full): Linear(in_features=179, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1.0, threshold=20.0)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (softplus2): Softplus(beta=1.0, threshold=20.0)
    )
  )
  (conv_to_fc): Linear(in_features=64, out_features=128, bias=True)
  (conv_to_fc_softplus): Softplus(beta=1.0, threshold=20.0)
  (fc_out): Linear(in_features=128, out_features=1, bias=True)
)

In [49]:
loader=DataLoader(dataset,batch_size=1,collate_fn=collate_pool)

for batch in loader:
    graph, target, _=batch

In [50]:
graph[3]

[tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])]

In [51]:
out=model.forward(graph[0],graph[1],graph[2],graph[3])

In [52]:
normalizer.denorm(float(out))

44.03903842442581

In [53]:
import numpy as np
import pandas as pd
import math
import json
import re
import os
import shutil
from sklearn.model_selection import train_test_split
from jarvis.db.figshare import data
import matplotlib.pyplot as plt
from pymatgen.core.composition import Composition

In [56]:
import torch
torch.__version__

'2.3.0.post101'

In [54]:
dft_3d = data('dft_3d')
df=pd.DataFrame(dft_3d)

Obtaining 3D dataset 76k ...
Reference:https://www.nature.com/articles/s41524-020-00440-1
Other versions:https://doi.org/10.6084/m9.figshare.6815699
Loading the zipfile...
Loading completed.


Unnamed: 0,jid,spg_number,spg_symbol,formula,formation_energy_peratom,func,optb88vdw_bandgap,atoms,slme,magmom_oszicar,...,density,poisson,raw_files,nat,bulk_modulus_kv,shear_modulus_gv,mbj_bandgap,hse_gap,reference,search
0,JVASP-90856,129,P4/nmm,TiCuSiAs,-0.42762,OptB88vdW,0.000,"{'lattice_mat': [[3.566933224304235, 0.0, -0.0...",na,0.0,...,5.956,na,[],8,na,na,na,na,mp-1080455,-As-Cu-Si-Ti
1,JVASP-86097,221,Pm-3m,DyB6,-0.41596,OptB88vdW,0.000,"{'lattice_mat': [[4.089078911208881, 0.0, 0.0]...",na,0.0,...,5.522,na,"[OPT-LOPTICS,JVASP-86097.zip,https://ndownload...",7,na,na,na,na,mp-568319,-B-Dy
2,JVASP-64906,119,I-4m2,Be2OsRu,0.04847,OptB88vdW,0.000,"{'lattice_mat': [[-1.833590720595598, 1.833590...",na,0.0,...,10.960,na,"[OPT-LOPTICS,JVASP-64906.zip,https://ndownload...",4,na,na,na,na,auid-3eaf68dd483bf4f4,-Be-Os-Ru
3,JVASP-98225,14,P2_1/c,KBi,-0.44140,OptB88vdW,0.472,"{'lattice_mat': [[7.2963518353359165, 0.0, 0.0...",na,0.0,...,5.145,na,[],32,na,na,na,na,mp-31104,-Bi-K
4,JVASP-10,164,P-3m1,VSe2,-0.71026,OptB88vdW,0.000,"{'lattice_mat': [[1.6777483798834445, -2.90594...",na,0.0,...,5.718,0.23,"[FD-ELAST,JVASP-10.zip,https://ndownloader.fig...",3,48.79,33.05,0.0,na,mp-694,-Se-V
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75988,JVASP-156020,139,I4/mmm,AcRh2Pb2,-0.30652,OptB88vdW,0.000,"{'lattice_mat': [[-2.374509402119681, 2.374509...",na,0.0,...,11.194,na,[],5,na,na,na,na,1206834,-Ac-Pb-Rh
75989,JVASP-156398,216,F-43m,PrTlZn,-0.34112,OptB88vdW,0.000,"{'lattice_mat': [[-0.0, 3.4210598347774503, 3....",na,0.0,...,8.517,na,[],3,na,na,na,na,915022,-Pr-Tl-Zn
75990,JVASP-156099,139,I4/mmm,BaIn2Bi2,-0.39352,OptB88vdW,0.000,"{'lattice_mat': [[4.082347574975881, -4.076131...",na,0.0,...,7.460,0.67,[],5,30.67,-11.44,na,na,1214095,-Ba-Bi-In
75991,JVASP-156007,139,I4/mmm,TmSi2Tc2,-0.54853,OptB88vdW,0.000,"{'lattice_mat': [[2.90400678672412, -2.9037689...",na,0.0,...,8.212,na,[],5,na,na,na,na,1206745,-Si-Tc-Tm


In [65]:
checkpoint = torch.load('/Users/elena.patyukova/Documents/github/qe-input/src/qe_input/trained_models/kspacing_checkpoint.ckpt',map_location='mps')

In [66]:
torch.save(checkpoint, '/Users/elena.patyukova/Documents/github/qe-input/src/qe_input/trained_models/kspacing_checkpoint.ckpt')

In [67]:
checkpoint = torch.load('/Users/elena.patyukova/Documents/github/qe-input/src/qe_input/trained_models/kspacing_checkpoint.ckpt', map_location='cpu')