This file is used for converting the result of VASP calculation to pyg graph 

In [1]:
import numpy as np
import pandas as pd
import networkx as nx
from torch_geometric.data import Data 
import torch
import pickle
import math
import os
import random

In [2]:
CUTOFF_DISTANCE=6.0

In [3]:
def generate_reps(cutoff_dist,vec_1_len,vec_2_len,vec_3_len):
    vec_1_num=math.ceil(CUTOFF_DISTANCE/vec_1_len)
    vec_2_num=math.ceil(CUTOFF_DISTANCE/vec_2_len)
    vec_3_num=math.ceil(CUTOFF_DISTANCE/vec_3_len)
    vec_1_iter=[0]
    vec_2_iter=[0]
    vec_3_iter=[0]
    for i in range(vec_1_num):
        vec_1_iter.append(i+1)
        vec_1_iter.append(-i-1)
    for j in range(vec_2_num):
        vec_2_iter.append(j+1)
        vec_2_iter.append(-j-1)
    for k in range(vec_3_num):
        vec_3_iter.append(k+1)
        vec_3_iter.append(-k-1)
    return vec_1_iter,vec_2_iter,vec_3_iter

In [4]:
def generate_graphs(filename,graph_list,CUTOFF_DISTANCE):
    with open(filename,'r') as f:
        total_atoms,atom_num,total_config=f.readline().split()
        total_atoms=int(total_atoms)
        atom_num=int(atom_num)
        total_config=int(total_config)
        for i in range(total_config):
            vec_1_x,vec_1_y,vec_1_z=f.readline().split()
            vec_1_x=float(vec_1_x)
            vec_1_y=float(vec_1_y)
            vec_1_z=float(vec_1_z)
            vec_2_x,vec_2_y,vec_2_z=f.readline().split()
            vec_2_x=float(vec_2_x)
            vec_2_y=float(vec_2_y)
            vec_2_z=float(vec_2_z)
            vec_3_x,vec_3_y,vec_3_z=f.readline().split()
            vec_3_x=float(vec_3_x)
            vec_3_y=float(vec_3_y)
            vec_3_z=float(vec_3_z)
            vec_1_len=math.sqrt(vec_1_x*vec_1_x+vec_1_y*vec_1_y+vec_1_z*vec_1_z)
            vec_2_len=math.sqrt(vec_2_x*vec_2_x+vec_2_y*vec_2_y+vec_2_z*vec_2_z)
            vec_3_len=math.sqrt(vec_3_x*vec_3_x+vec_3_y*vec_3_y+vec_3_z*vec_3_z)
            vec_1_iter,vec_2_iter,vec_3_iter=generate_reps(CUTOFF_DISTANCE,vec_1_len,vec_2_len,vec_3_len)
            len_vec1=len(vec_1_iter)
            len_vec2=len(vec_2_iter)
            len_vec3=len(vec_3_iter)
            
            x_list=list()
            y_list=list()
            z_list=list()
            fx_list=list()
            fy_list=list()
            fz_list=list()
            charge_list=list()
            period_vec_list=list()
            edge_list_u=list()
            edge_list_v=list()
            
            # If the training data is too much, use this method to reduce the size of the data.
            # To load the full data in the files, just comment the following four lines.
            if i%8!=0:
                for j in range(atom_num+1): 
                    f.readline()
                continue
                
            for j in range(atom_num): #get each atom info
                #get atom positions & forces
                atom_type,x,y,z,fx,fy,fz=f.readline().split()
                atom_type=int(atom_type)
                x_list.append(float(x))
                y_list.append(float(y))
                z_list.append(float(z))
                fx_list.append(float(fx))
                fy_list.append(float(fy))
                fz_list.append(float(fz))
                charge_list.append(atom_type-1)
            total_energy=f.readline()
            total_energy=float(total_energy)
            
            if total_energy/atom_num>-6.5:
                continue

            for m in range(atom_num): #construct graph, considering the periodic boundaries at the same time
                for n in range(m):
                    for period_num in range(len_vec1*len_vec2*len_vec3):
                        temp=period_num
                        vec_1_period=temp%len_vec1
                        temp//=len_vec1
                        vec_2_period=temp%len_vec2
                        vec_3_period=temp//len_vec2
              
                        x_dist=x_list[m]-x_list[n]-vec_1_iter[vec_1_period]*vec_1_x-vec_2_iter[vec_2_period]*vec_2_x-vec_3_iter[vec_3_period]*vec_3_x
                        y_dist=y_list[m]-y_list[n]-vec_1_iter[vec_1_period]*vec_1_y-vec_2_iter[vec_2_period]*vec_2_y-vec_3_iter[vec_3_period]*vec_3_y
                        z_dist=z_list[m]-z_list[n]-vec_1_iter[vec_1_period]*vec_1_z-vec_2_iter[vec_2_period]*vec_2_z-vec_3_iter[vec_3_period]*vec_3_z
                        
                        if x_dist*x_dist+y_dist*y_dist+z_dist*z_dist<CUTOFF_DISTANCE*CUTOFF_DISTANCE:
                            
                            edge_list_u.append(n)
                            edge_list_v.append(m)
                            period_vec_list.append([-vec_1_iter[vec_1_period]*vec_1_x-vec_2_iter[vec_2_period]*vec_2_x-vec_3_iter[vec_3_period]*vec_3_x,
                                                    -vec_1_iter[vec_1_period]*vec_1_y-vec_2_iter[vec_2_period]*vec_2_y-vec_3_iter[vec_3_period]*vec_3_y,
                                                    -vec_1_iter[vec_1_period]*vec_1_z-vec_2_iter[vec_2_period]*vec_2_z-vec_3_iter[vec_3_period]*vec_3_z])
                            edge_list_u.append(m)
                            edge_list_v.append(n)
                            period_vec_list.append([vec_1_iter[vec_1_period]*vec_1_x+vec_2_iter[vec_2_period]*vec_2_x+vec_3_iter[vec_3_period]*vec_3_x,
                                                    vec_1_iter[vec_1_period]*vec_1_y+vec_2_iter[vec_2_period]*vec_2_y+vec_3_iter[vec_3_period]*vec_3_y,
                                                    vec_1_iter[vec_1_period]*vec_1_z+vec_2_iter[vec_2_period]*vec_2_z+vec_3_iter[vec_3_period]*vec_3_z])
                
                for period_num in range(len_vec1*len_vec2*len_vec3):
                        if period_num==0:
                            continue
                            
                        temp=period_num
                        vec_1_period=temp%len_vec1
                        temp//=len_vec1
                        vec_2_period=temp%len_vec2
                        vec_3_period=temp//len_vec2
              
                        x_dist=-vec_1_iter[vec_1_period]*vec_1_x-vec_2_iter[vec_2_period]*vec_2_x-vec_3_iter[vec_3_period]*vec_3_x
                        y_dist=-vec_1_iter[vec_1_period]*vec_1_y-vec_2_iter[vec_2_period]*vec_2_y-vec_3_iter[vec_3_period]*vec_3_y
                        z_dist=-vec_1_iter[vec_1_period]*vec_1_z-vec_2_iter[vec_2_period]*vec_2_z-vec_3_iter[vec_3_period]*vec_3_z
                        
                        if x_dist*x_dist+y_dist*y_dist+z_dist*z_dist<CUTOFF_DISTANCE*CUTOFF_DISTANCE:
                            
                            edge_list_u.append(m)
                            edge_list_v.append(m)
                            period_vec_list.append([-vec_1_iter[vec_1_period]*vec_1_x-vec_2_iter[vec_2_period]*vec_2_x-vec_3_iter[vec_3_period]*vec_3_x,
                                                    -vec_1_iter[vec_1_period]*vec_1_y-vec_2_iter[vec_2_period]*vec_2_y-vec_3_iter[vec_3_period]*vec_3_y,
                                                    -vec_1_iter[vec_1_period]*vec_1_z-vec_2_iter[vec_2_period]*vec_2_z-vec_3_iter[vec_3_period]*vec_3_z])
                            
            data_temp=Data(x=torch.tensor(charge_list,dtype=torch.long),
                           pos=torch.tensor(np.array([x_list,y_list,z_list]).transpose(1,0),dtype=torch.float32,requires_grad=True),
                           force=torch.tensor(np.array([fx_list,fy_list,fz_list]).transpose(1,0),dtype=torch.float32,requires_grad=True),
                           y=torch.tensor([[total_energy]]),
                           edge_index=torch.tensor([edge_list_u,edge_list_v],dtype=torch.long),
                           edge_attr=torch.tensor(period_vec_list,dtype=torch.float32)
                          )
            
            graph_list.append(data_temp)

In [5]:
graph_list=list()
file_path='.//HEA_database'
file_names=os.listdir(file_path)
for file in file_names:
    generate_graphs(file_path+'//'+file,graph_list,CUTOFF_DISTANCE)
random.seed(0)
random.shuffle(graph_list)

In [6]:
with open('pyg_graph_hea.pickle','wb') as f:
    pickle.dump(graph_list,f)

In [7]:
print(len(graph_list))

8856


In [8]:
print(graph_list[0])

Data(edge_attr=[1814, 3], edge_index=[2, 1814], force=[32, 3], pos=[32, 3], x=[32], y=[1, 1])
