In [1]:
from data_inference import GaussianDistance, CIFData, data_loader

In [6]:
dataset = CIFData(cif_folder_path = 'cifs_folder') # input the path that contains your cif files

data_list = dataset.load_data()

inference_loader = data_loader(data_list, batch_size=64, shuffle=False)

100%|██████████| 9/9 [00:00<00:00, 14.35sample/s]


In [8]:
from scipy.ndimage import gaussian_filter1d
from model_head4_nonegraphbias_layers3_phys.CGT_phys import CGT
import warnings
import csv
import torch
import pandas as pd
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import os
import json

warnings.filterwarnings("ignore")

name = 'CGT_phys'
model = name.replace('_phys','')
variable = globals()[model]


    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = variable(edge_dim=14, out_dim=201*4, seed=123).to(device)


model_path = f"CGT_phys_weights.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

model.eval() 

if not os.path.exists(f'inference_result_{name}'):
    os.makedirs(f'inference_result_{name}', exist_ok=True)


with torch.no_grad():
    for batch in inference_loader:
        batch = batch.to(device)
        output = model(batch)
        name_id, energies, space_group_number = batch.name_id, batch.energies.cpu().numpy(), batch.space_group_number.cpu().numpy()
        output[output < 0] = 0
        batch_size = batch.batch.max().item() + 1

        for i in range(batch_size):
            y_pred = output[i].cpu().numpy()
            energy = energies[i]
            for orbital in range(0, 4):
                y_pred[orbital] = gaussian_filter1d(y_pred[orbital], sigma=3)

            pdos_dic = {
                'energies': energy.tolist(),  
                's_pre': y_pred[0].tolist(),
                'p_pre': y_pred[1].tolist(),
                'd_pre': y_pred[2].tolist(),
                'f_pre': y_pred[3].tolist(),
            }


            pdos_path = os.path.join(f'inference_result_{name}', f'{name_id[i]}.json')

            with open(pdos_path, 'w') as f:
                json.dump(pdos_dic, f, indent=4)  






