In [1]:
import torch_geometric
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing,global_add_pool
from torch_geometric.utils import remove_self_loops
import torch
import numpy as np
import matplotlib.pyplot as plt
from molecular_mpns.data import MolGraph
from molecular_mpns.systems import WellSystem
from molecular_mpns.config import fig_dir
from torch.optim.lr_scheduler import ExponentialLR

In [2]:
# system parameters
dim = 2
N = 2
a = 0
b = -5
c = 0.9
d0 = 4
tau = 1
system = WellSystem(dim,N,a,b,c,d0,tau)

# create a langevian trajectory
x = np.array([[0.,0.],[1.1,1.1]])
M = 2000000
beta = 2
dt = 1e-3

traj = np.zeros((M,N,2))
d_traj = np.zeros(M)
pot_traj = np.zeros(M)

np.random.seed(42)
for i in range(M):
    x = x - system._gradient(x)*dt + np.sqrt(2*dt/beta)*np.random.randn(N,2)
    dist = system._r(x)
    pot = system._potential(x)
    
    traj[i] = x
    d_traj[i] = dist[0,1]
    pot_traj[i] = pot
    

In [7]:
# define model

class BLWellMPN(MessagePassing):
    
    def __init__(self,h_dim):
        
        super(BLWellMPN,self).__init__(aggr = 'mean')
        self.h_dim = h_dim
        
        self.lin1 = torch.nn.Linear(1,h_dim)
        self.lin2 = torch.nn.Linear(h_dim,h_dim)
        self.lin3 = torch.nn.Linear(h_dim,h_dim)
        self.lin4 = torch.nn.Linear(h_dim,1)
        
    def forward(self,edge_index,x):
        
        edge_index, _ = remove_self_loops(edge_index)
        return self.propagate(edge_index = edge_index,x = x)
    
    def message(self,x_i,x_j):
        
        # compute distances
        dists = torch.sqrt(((x_i - x_j)**2).sum(dim=1,keepdim = True))
        
        # pass through nn
        h = self.lin1(dists)
        h = torch.nn.functional.silu(h)
        h = self.lin2(h)
        h = torch.nn.functional.silu(h)
        h = self.lin3(h)
        h = torch.nn.functional.silu(h)
        h = self.lin4(h)
        
        return h
    
    def update(self,aggr_out):
        return aggr_out
        
        

In [8]:
# build model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mod = BLWellMPN(256)
mod = mod.double()
mod = mod.to(device)

opt = torch.optim.Adam(mod.parameters(),lr = 1e-4)
sched = ExponentialLR(opt, gamma = 0.995)

# train
epochs,batch_size = 100,32

for ep in range(epochs):
    ep_loss = 0
    
    # bootstrap sample equillibrium distribution
    np.random.seed(42)
    rand_idx = np.random.choice(M,10000)
    traj_train,pot_traj_train = traj[rand_idx],pot_traj[rand_idx]
    
    # compute graphs
    G_traj = [MolGraph(x,V,dV = 0) for x,V in zip(traj_train,pot_traj_train)]
    loader = DataLoader(G_traj,batch_size = batch_size)
    
    for G_batch in loader:
        G_batch = G_batch.to(device)
        
        Vtheta = mod(G_batch.edge_index,G_batch.x)
        
        Vtheta = global_add_pool(Vtheta,G_batch.batch)
        V = G_batch.V.view(Vtheta.shape[0],1)
        
        loss = torch.mean((Vtheta - V)**2)
        ep_loss += loss.item()
        
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    sched.step()
    print('Epoch ' + str(ep+1) + ' Loss: ' + str(ep_loss))
    
    # validate potential
    x = np.array([[0.,0.],[1.1,1.1]])
    res = 40
    grid = np.linspace(0,0.18,res)
    dists = np.zeros(res)
    V,Vhat = np.zeros(res),np.zeros(res)
    

    initial = x[1:,]
    for i,pt in enumerate(grid):
        x[1,:] = initial + pt
        d = system._r(x)
        d = d[0,1]
        true_potential = system._potential(x)
        dists[i],V[i] = d,true_potential
        
        G=MolGraph(x,true_potential,0)
        G = G.to(device)
        with torch.no_grad():
            est_potential = mod(G.edge_index,G.x)
            est_potential = est_potential.sum()
        Vhat[i] = est_potential
    
    fname = str(fig_dir)+'/BLDblWellMPN'+str(ep+1)+'.png'
    plt.plot(dists,V,color = 'black',label = 'True')
    plt.plot(dists,Vhat,color = 'red',label = 'Predicted')
    plt.legend()
    plt.xlabel('$d$')
    plt.ylabel('$V(d)$')
    plt.savefig(fname)
    #plt.show()
    plt.close()
        
        
        

Epoch 1 Loss: 1102.287201824779
Epoch 2 Loss: 193.60539203242607
Epoch 3 Loss: 185.07909353894917
Epoch 4 Loss: 178.01719507127146
Epoch 5 Loss: 171.772936698203
Epoch 6 Loss: 163.81245715120082
Epoch 7 Loss: 150.82274380170125
Epoch 8 Loss: 132.7236685075308
Epoch 9 Loss: 116.06665169675466
Epoch 10 Loss: 107.35993829171358
Epoch 11 Loss: 103.98444621390546
Epoch 12 Loss: 101.87533419102408
Epoch 13 Loss: 99.90023183949094
Epoch 14 Loss: 97.73535102744373
Epoch 15 Loss: 95.08825266719771
Epoch 16 Loss: 91.69577030396866
Epoch 17 Loss: 86.77664100101632
Epoch 18 Loss: 78.93221059633342
Epoch 19 Loss: 65.98726434066545
Epoch 20 Loss: 47.576199059237716
Epoch 21 Loss: 30.500892990795517
Epoch 22 Loss: 20.26540406568925
Epoch 23 Loss: 14.36969436442172
Epoch 24 Loss: 10.668114327254779
Epoch 25 Loss: 8.241767930536794
Epoch 26 Loss: 6.541854804342162
Epoch 27 Loss: 5.368497894141534
Epoch 28 Loss: 4.516714112704437
Epoch 29 Loss: 3.935411900122434
Epoch 30 Loss: 3.445164725915382
Epoch 31