Load packages

In [1]:
import torch
import numpy as np
import torch.nn as nn
import pickle
import random
from torch.utils.data import DataLoader
import datetime
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from torch_scatter import scatter
warnings.filterwarnings("ignore")

In [2]:
random.seed(0)
pi=3.14159265

The definition of Autoencoder

In [3]:
class Autoencoder_GAP(nn.Module):
    def __init__(self,input_dim,GAP_2_body_dim,GAP_3_body_dim,cutoff,atom_types,
                num_hidden_layers=3,num_hidden_dimensions=256):
        super(Autoencoder_GAP,self).__init__()
        self.atom_num=int(input_dim/4)
        self.cutoff=cutoff
        self.atom_types=atom_types
        self.GAP_2_body_dim=GAP_2_body_dim
        self.GAP_3_body_dim=GAP_3_body_dim
        
        self.GAP_2_body_k=torch.nn.Parameter(torch.rand(int(atom_types*(atom_types+1)/2),GAP_2_body_dim))
        self.GAP_2_body_eta=torch.nn.Parameter(torch.rand(int(atom_types*(atom_types+1)/2),GAP_2_body_dim))
        
        '''  
        # Comment by YY
        self.GAP_3_body_gamma_func1=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_eta_func1=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        
        self.GAP_3_body_gamma_func2=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_eta_func2=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
                                               
        self.GAP_3_body_gamma_func3=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_eta_func3=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        '''
        self.GAP_3_body_alpha_func1=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_eta_func1=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        #self.GAP_3_body_beta_func1=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_alpha_func2=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_eta_func2=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_alpha_func3=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        self.GAP_3_body_eta_func3=torch.nn.Parameter(torch.rand(atom_types**3,GAP_3_body_dim))
        
        #self.decoder=nn.ModuleList([nn.Linear(GAP_2_body_dim*int(atom_types*(atom_types+1)/2)+3*GAP_3_body_dim*atom_types**3,num_hidden_dimensions),nn.ReLU()])
        self.decoder=nn.ModuleList([nn.Linear(GAP_2_body_dim*int(atom_types*(atom_types+1)/2)+3*GAP_3_body_dim*atom_types**3,num_hidden_dimensions),nn.ReLU()])
        for i in range(num_hidden_layers-1):
            self.decoder.append(nn.Linear(num_hidden_dimensions,num_hidden_dimensions))
            self.decoder.append(nn.ReLU())
            self.decoder.append(nn.Dropout())
                                               
        self.decoder.append(nn.Linear(num_hidden_dimensions,input_dim))
        self.decoder.append(nn.ReLU())

    def forward(self,configs,edges_2_body,period_2_body,edges_3_body,period_3_body):
        # calculate two-body fingerprints
        batch_size=int(configs.shape[0]/self.atom_num)
        vec_2_body=torch.index_select(configs[:,1:],0,edges_2_body[:,1])-\
                            torch.index_select(configs[:,1:],0,edges_2_body[:,2])-period_2_body #check
        radius_2_body=torch.norm(vec_2_body,dim=1,p=2,keepdim=True)
        fingerprint_2_body=torch.cos(torch.index_select(self.GAP_2_body_k,0,edges_2_body[:,0])*radius_2_body)*\
                            torch.exp(-radius_2_body/torch.index_select(self.GAP_2_body_eta,0,edges_2_body[:,0]))*\
                            (1+torch.cos(pi*radius_2_body/self.cutoff))/2
        
        # calculate three-body fingerprints
        vec_3_body_ij=torch.index_select(configs[:,1:],0,edges_3_body[:,1])-\
                                torch.index_select(configs[:,1:],0,edges_3_body[:,2])-period_3_body[:,:3] #check
        vec_3_body_ik=torch.index_select(configs[:,1:],0,edges_3_body[:,1])-\
                                torch.index_select(configs[:,1:],0,edges_3_body[:,3])-period_3_body[:,3:] #check
        radius_3_body_ij=torch.norm(vec_3_body_ij,dim=1,p=2,keepdim=True)
        radius_3_body_ik=torch.norm(vec_3_body_ik,dim=1,p=2,keepdim=True)
        
        
        cos_ijk=(torch.sum(vec_3_body_ij*vec_3_body_ik,dim=1,keepdim=True)/radius_3_body_ij/radius_3_body_ik)
        
        '''
        # Comment by YY
        
        fingerprint_3_body_func1=torch.exp(-(radius_3_body_ij**2+radius_3_body_ik**2)/(torch.index_select(self.GAP_3_body_eta_func1,0,edges_3_body[:,0])**2))*\
                                (1+(torch.index_select(self.GAP_3_body_gamma_func1,0,edges_3_body[:,0])*radius_3_body_ij/self.cutoff-torch.index_select(self.GAP_3_body_gamma_func1,0,edges_3_body[:,0])-1)*((radius_3_body_ij/self.cutoff)**torch.index_select(self.GAP_3_body_gamma_func1,0,edges_3_body[:,0])))*\
                                (1+(torch.index_select(self.GAP_3_body_gamma_func1,0,edges_3_body[:,0])*radius_3_body_ik/self.cutoff-torch.index_select(self.GAP_3_body_gamma_func1,0,edges_3_body[:,0])-1)*((radius_3_body_ik/self.cutoff)**torch.index_select(self.GAP_3_body_gamma_func1,0,edges_3_body[:,0])))*\
                                (cos_ijk**3)
        fingerprint_3_body_func2=torch.exp(-(radius_3_body_ij**2+radius_3_body_ik**2)/(torch.index_select(self.GAP_3_body_eta_func2,0,edges_3_body[:,0])**2))*\
                                (1+(torch.index_select(self.GAP_3_body_gamma_func2,0,edges_3_body[:,0])*radius_3_body_ij/self.cutoff-torch.index_select(self.GAP_3_body_gamma_func2,0,edges_3_body[:,0])-1)*((radius_3_body_ij/self.cutoff)**torch.index_select(self.GAP_3_body_gamma_func2,0,edges_3_body[:,0])))*\
                                (1+(torch.index_select(self.GAP_3_body_gamma_func2,0,edges_3_body[:,0])*radius_3_body_ik/self.cutoff-torch.index_select(self.GAP_3_body_gamma_func2,0,edges_3_body[:,0])-1)*((radius_3_body_ik/self.cutoff)**torch.index_select(self.GAP_3_body_gamma_func2,0,edges_3_body[:,0])))*\
                                (1-4/3*cos_ijk**2)
        fingerprint_3_body_func3=torch.exp(-(radius_3_body_ij**2+radius_3_body_ik**2)/(torch.index_select(self.GAP_3_body_eta_func3,0,edges_3_body[:,0])**2))*\
                                (1+(torch.index_select(self.GAP_3_body_gamma_func3,0,edges_3_body[:,0])*radius_3_body_ij/self.cutoff-torch.index_select(self.GAP_3_body_gamma_func3,0,edges_3_body[:,0])-1)*((radius_3_body_ij/self.cutoff)**torch.index_select(self.GAP_3_body_gamma_func3,0,edges_3_body[:,0])))*\
                                (1+(torch.index_select(self.GAP_3_body_gamma_func3,0,edges_3_body[:,0])*radius_3_body_ik/self.cutoff-torch.index_select(self.GAP_3_body_gamma_func3,0,edges_3_body[:,0])-1)*((radius_3_body_ik/self.cutoff)**torch.index_select(self.GAP_3_body_gamma_func3,0,edges_3_body[:,0])))*\
                                ((1-4*cos_ijk**2)*cos_ijk**2)
        '''
        
        fingerprint_3_body_func1=torch.exp(-(radius_3_body_ij**2+radius_3_body_ik**2)/(torch.index_select(self.GAP_3_body_eta_func1,0,edges_3_body[:,0])**2))*\
                                (1+(10.0*radius_3_body_ij/self.cutoff-10.0-1.0)*((radius_3_body_ij/self.cutoff)**10.0))*\
                                (1+(10.0*radius_3_body_ik/self.cutoff-10.0-1.0)*((radius_3_body_ik/self.cutoff)**10.0))*\
                                (cos_ijk**2)*torch.index_select(self.GAP_3_body_alpha_func1,0,edges_3_body[:,0])
        
        fingerprint_3_body_func2=torch.exp(-(radius_3_body_ij**2+radius_3_body_ik**2)/(torch.index_select(self.GAP_3_body_eta_func2,0,edges_3_body[:,0])**2))*\
                                (1+(10.0*radius_3_body_ij/self.cutoff-10.0-1.0)*((radius_3_body_ij/self.cutoff)**10.0))*\
                                (1+(10.0*radius_3_body_ik/self.cutoff-10.0-1.0)*((radius_3_body_ik/self.cutoff)**10.0))*\
                                (cos_ijk**3)*torch.index_select(self.GAP_3_body_alpha_func2,0,edges_3_body[:,0])
        
        fingerprint_3_body_func3=torch.exp(-(radius_3_body_ij**2+radius_3_body_ik**2)/(torch.index_select(self.GAP_3_body_eta_func3,0,edges_3_body[:,0])**2))*\
                                (1+(10.0*radius_3_body_ij/self.cutoff-10.0-1.0)*((radius_3_body_ij/self.cutoff)**10.0))*\
                                (1+(10.0*radius_3_body_ik/self.cutoff-10.0-1.0)*((radius_3_body_ik/self.cutoff)**10.0))*\
                                (cos_ijk**4)*torch.index_select(self.GAP_3_body_alpha_func2,0,edges_3_body[:,0])
        
        
        # get the final fingerprints with respect to the pair types *****
        
        out_2_body_fingerprint=scatter(fingerprint_2_body,edges_2_body[:,3],dim=0).reshape(batch_size,-1)
        out_3_body_fingerprint_1=scatter(fingerprint_3_body_func1,edges_3_body[:,4],dim=0).reshape(batch_size,-1)
        out_3_body_fingerprint_2=scatter(fingerprint_3_body_func2,edges_3_body[:,4],dim=0).reshape(batch_size,-1)
        out_3_body_fingerprint_3=scatter(fingerprint_3_body_func3,edges_3_body[:,4],dim=0).reshape(batch_size,-1)
        x=torch.cat((out_2_body_fingerprint,out_3_body_fingerprint_1,out_3_body_fingerprint_2,out_3_body_fingerprint_3),dim=1)

#         The following is another version to aggregate the fingerprint of the mini-batch
#         final_fingerprint=list()
#         for batch_num in range(batch_size):
#             for i in range(int(self.atom_types*(self.atom_types+1)/2)):
#                 mask=((edges_2_body[:,0]==i) & (edges_2_body[:,3]==batch_num)).reshape(-1,1)  
#                 final_fingerprint.append(torch.sum(torch.masked_select(fingerprint_2_body,mask).reshape(-1,self.GAP_2_body_dim),dim=0))
#             for i in range(self.atom_types**3):
#                 mask=((edges_3_body[:,0]==i) & (edges_3_body[:,4]==batch_num)).reshape(-1,1)
#                 final_fingerprint.append(torch.sum(torch.masked_select(fingerprint_3_body_func1,mask).reshape(-1,self.GAP_3_body_dim),dim=0))

#                 #Comment by YY
#                 final_fingerprint.append(torch.sum(torch.masked_select(fingerprint_3_body_func2,mask).reshape(-1,self.GAP_3_body_dim),dim=0))
#                 final_fingerprint.append(torch.sum(torch.masked_select(fingerprint_3_body_func3,mask).reshape(-1,self.GAP_3_body_dim),dim=0))
            
#         x=torch.cat(final_fingerprint).reshape(batch_size,-1)
        
        # use fingerprints to decode atomic configurations
        for model in self.decoder:
            x=model(x)
        return x.reshape(-1,4)

Hyperparameters

In [4]:
NUM_HIDDEN_LAYERS=3
NUM_HIDDEN_DIMENSIONS=512  # what is meanning of 512??
LEARNING_RATE_INIT=0.001
ATOM_TYPES=2
use_device='cuda' if torch.cuda.is_available() else 'cpu'

The definition of dataset class

In [5]:
class GAP_dataset(torch.utils.data.Dataset):
    def __init__(self,datafile,begin,end):
        super(GAP_dataset,self).__init__()
        self.datafile=datafile
        self.begin=begin # include in the dataset
        self.end=end # exclude in the dataset
        
    def __len__(self):
        return self.end-self.begin
    
    def __getitem__(self,index):
        name=str((index+self.begin)*1)
        config=np.loadtxt(self.datafile+'AtomicInfo-'+name+'.txt')
        info_2_body=np.loadtxt(self.datafile+'PairInfo-'+name+'.txt')
        info_3_body=np.loadtxt(self.datafile+'3BodyInfo-'+name+'.txt')
        return torch.tensor(config,dtype=torch.float),\
                torch.tensor(info_2_body[:,:3],dtype=torch.int64),\
                torch.tensor(info_2_body[:,3:],dtype=torch.float),\
                torch.tensor(info_3_body[:,:4],dtype=torch.int64),\
                torch.tensor(info_3_body[:,4:],dtype=torch.float)

In [6]:
def custom_collate_fn(x):
    config,pair_2_body,period_2_body,pair_3_body,period_3_body=zip(*x)
    atom_num=config[0].shape[0]
    batch_size=len(config)
    
    config=torch.cat(config)
    
    temp_list=list()
    for i in range(batch_size):
        pair_2_body[i][:,1:3]+=(atom_num*i)
        #temp_list.append(torch.full((pair_2_body[i].shape[0],1),i,dtype=torch.int64))
        temp_list.append(torch.full((pair_2_body[i].shape[0],1),i*int(ATOM_TYPES*(ATOM_TYPES+1)/2),dtype=torch.int64))
    temp_list=torch.cat(temp_list).reshape(-1,1)
    pair_2_body=torch.cat(pair_2_body)
    pair_2_body=torch.cat((pair_2_body,temp_list+pair_2_body[:,0].reshape(-1,1)),dim=1)
    
    period_2_body=torch.cat(period_2_body)
    
    temp_list=list()
    for i in range(batch_size):
        pair_3_body[i][:,1:4]+=(atom_num*i)
        #temp_list.append(torch.full((pair_3_body[i].shape[0],1),i,dtype=torch.int64))
        temp_list.append(torch.full((pair_3_body[i].shape[0],1),i*ATOM_TYPES**3,dtype=torch.int64))
    temp_list=torch.cat(temp_list).reshape(-1,1)
    pair_3_body=torch.cat(pair_3_body)
    pair_3_body=torch.cat((pair_3_body,temp_list+pair_3_body[:,0].reshape(-1,1)),dim=1)
    
    period_3_body=torch.cat(period_3_body)
    
    return config,pair_2_body,period_2_body,pair_3_body,period_3_body

In [7]:
is_pin_memory=False # if the RAM capacity is large enough, set this as True can accelerate training process

train_size=50
valid_size=25
test_size=25

train_dataset=GAP_dataset('./test-data-all/',0,50)
valid_dataset=GAP_dataset('./test-data-all/',50,75)
test_dataset=GAP_dataset('./test-data-all/',75,100)

train_dataloader=DataLoader(train_dataset,pin_memory=is_pin_memory,collate_fn=custom_collate_fn,batch_size=10) 
valid_dataloader=DataLoader(valid_dataset,pin_memory=is_pin_memory,collate_fn=custom_collate_fn,batch_size=10)
test_dataloader=DataLoader(test_dataset,pin_memory=is_pin_memory,collate_fn=custom_collate_fn,batch_size=10)

In [8]:
# model=Autoencoder_GAP(input_dim=64*4,GAP_2_body_dim=8,GAP_3_body_dim=4,cutoff=5.5,atom_types=2,       #Comment by YY
#                     num_hidden_layers=NUM_HIDDEN_LAYERS,num_hidden_dimensions=NUM_HIDDEN_DIMENSIONS)  #Comment by YY
model=Autoencoder_GAP(input_dim=64*4,GAP_2_body_dim=8,GAP_3_body_dim=4,cutoff=5.5,atom_types=ATOM_TYPES,
                     num_hidden_layers=NUM_HIDDEN_LAYERS,num_hidden_dimensions=NUM_HIDDEN_DIMENSIONS)

model.to(use_device)
optimizer=torch.optim.Adam(model.parameters(),lr=LEARNING_RATE_INIT)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.8)   ### what is the meaning of this line???

In [None]:
'''
# how to load a model as the intial point of the NN model
model=torch.load('best_model.pkl')
model=final_model.to(use_device)
optimizer=torch.optim.Adam(model.parameters(),lr=LEARNING_RATE_INIT)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.8)   ### what is the meaning of this line???
'''
min_valid_error=np.inf
train_errors=list()
valid_errors=list()
cnt=0
for i in range(200):
    time_beg_epoch=datetime.datetime.now()

    #training process
    model.train()
    train_error=0
    for config,info_2_body_list,info_2_body_period,info_3_body_list,info_3_body_period in train_dataloader:
        config=config.to(use_device)
        info_2_body_list=info_2_body_list.to(use_device)
        info_2_body_period=info_2_body_period.to(use_device)
        info_3_body_list=info_3_body_list.to(use_device)
        info_3_body_period=info_3_body_period.to(use_device)
        optimizer.zero_grad()   ##????
        reconstruct_val=model(config,info_2_body_list,info_2_body_period,info_3_body_list,info_3_body_period)  ###????
        loss=torch.nn.functional.mse_loss(reconstruct_val,config,reduction='sum')
        loss.backward(torch.ones_like(loss))
        optimizer.step() ###???
        train_error+=torch.sum(loss).cpu().detach().numpy()
    train_errors.append(train_error/train_size)

    #validation process
    model.eval()
    valid_error=0
    for config,info_2_body_list,info_2_body_period,info_3_body_list,info_3_body_period in valid_dataloader:
        config=config.to(use_device)
        info_2_body_list=info_2_body_list.to(use_device)
        info_2_body_period=info_2_body_period.to(use_device)
        info_3_body_list=info_3_body_list.to(use_device)
        info_3_body_period=info_3_body_period.to(use_device)
        reconstruct_val=model(config,info_2_body_list,info_2_body_period,info_3_body_list,info_3_body_period)
        loss=torch.nn.functional.mse_loss(reconstruct_val,config,reduction='sum')
        valid_error+=torch.sum(loss).cpu().detach().numpy()
    valid_errors.append(valid_error/valid_size)

    #print information & judgement for early stopping
    scheduler.step()
    time_end_epoch=datetime.datetime.now()
    print('Epoch ',i,' training error = ',train_error/train_size,
          ' validation error = ',valid_error/valid_size,
          ' training and validation time = ',time_end_epoch-time_beg_epoch)

    if valid_error<min_valid_error: #judgement for early stopping
        cnt=0
        torch.save(model,'best_model.pkl')
        min_valid_error=valid_error
    else:
        cnt+=1
        if cnt>=10:
            print('Early stopping')
            del(model)
            with open('training_errors.pickle','wb') as f:
                pickle.dump(train_errors,f)
            with open('valid_errors.pickle','wb') as f:
                pickle.dump(valid_errors,f)
            break

Epoch  0  training error =  14747.8421875  validation error =  14038.218125  training and validation time =  0:00:15.275596
Epoch  1  training error =  12332.53484375  validation error =  8541.218125  training and validation time =  0:00:14.778505
Epoch  2  training error =  5385.413125  validation error =  3039.0363671875  training and validation time =  0:00:14.786486
Epoch  3  training error =  3003.281171875  validation error =  795.17173828125  training and validation time =  0:00:15.085688
Epoch  4  training error =  1271.6692578125  validation error =  1201.9875  training and validation time =  0:00:15.421788
Epoch  5  training error =  1466.978203125  validation error =  821.101240234375  training and validation time =  0:00:15.547452
Epoch  6  training error =  1328.0790234375  validation error =  690.47134765625  training and validation time =  0:00:15.665138
Epoch  7  training error =  1100.10755859375  validation error =  564.713994140625  training and validation time =  0:

In [None]:
final_model=torch.load('best_model.pkl')
final_model=final_model.to(use_device)
final_model.eval()
reconstructions=list()
true_configs=list()
test_error=0
for config,info_2_body_list,info_2_body_period,info_3_body_list,info_3_body_period in test_dataloader:
    config=config.to(use_device)
    info_2_body_list=info_2_body_list.to(use_device)
    info_2_body_period=info_2_body_period.to(use_device)
    info_3_body_list=info_3_body_list.to(use_device)
    info_3_body_period=info_3_body_period.to(use_device)
    pred_reconstruct=final_model(config,info_2_body_list,info_2_body_period,info_3_body_list,info_3_body_period)
    reconstructions.append(pred_reconstruct.cpu().detach().numpy())
    true_configs.append(config.cpu().detach().numpy())
    loss=torch.nn.functional.mse_loss(pred_reconstruct,config,reduction='sum')
    test_error+=torch.sum(loss).cpu().detach().numpy()
print(test_error/test_size)

In [None]:
true_configs[0]-reconstructions[0]

In [None]:
sns.heatmap(reconstructions[0])
#reconstructions[995]

In [None]:
sns.heatmap(true_configs[0])
#true_configs-reconstructions

In [None]:
print(final_model.GAP_2_body_eta)
eta_2 = final_model.GAP_2_body_eta.cpu().detach().numpy()
eta_2
eta_2.reshape(-1,1)
np.shape(eta_2.reshape(-1,1))
np.savetxt("Para-eta2.txt",eta_2.reshape(-1,1))

In [None]:
print(final_model.GAP_2_body_k)
k_2 = final_model.GAP_2_body_k.cpu().detach().numpy().reshape(-1,1)
np.shape(k_2)
np.savetxt("Para-k2.txt",k_2)

In [None]:
print(final_model.GAP_3_body_eta_func1)
eta_3_1 = final_model.GAP_3_body_eta_func1.cpu().detach().numpy().reshape(-1,1)
np.shape(eta_3_1)
np.savetxt("Para-eta3_1.txt",eta_3_1)

alpha_3_1 = final_model.GAP_3_body_alpha_func1.cpu().detach().numpy().reshape(-1,1)
np.shape(alpha_3_1)
np.savetxt("Para-alpha3_1.txt",alpha_3_1)


In [None]:
print(final_model.GAP_3_body_alpha_func1)

In [None]:
print(final_model.GAP_3_body_eta_func2)
eta_3_2 = final_model.GAP_3_body_eta_func2.cpu().detach().numpy().reshape(-1,1)
np.shape(eta_3_2)
np.savetxt("Para-eta3_2.txt",eta_3_2)

alpha_3_2 = final_model.GAP_3_body_alpha_func2.cpu().detach().numpy().reshape(-1,1)
np.shape(alpha_3_2)
np.savetxt("Para-alpha3_2.txt",alpha_3_2)

In [None]:
print(final_model.GAP_3_body_eta_func3)
eta_3_3 = final_model.GAP_3_body_eta_func3.cpu().detach().numpy().reshape(-1,1)
np.shape(eta_3_3)
np.savetxt("Para-eta3_3.txt",eta_3_3)

alpha_3_3 = final_model.GAP_3_body_alpha_func3.cpu().detach().numpy().reshape(-1,1)
np.shape(alpha_3_3)
np.savetxt("Para-alpha3_3.txt",alpha_3_3)