In [1]:
import torch
import torch.nn as nn
from pathlib import Path
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
from math import *
from random import gauss,seed

In [2]:
trajdict = np.load('output.npz')
import ast
#params = ast.literal_eval(str(trajdict['params']))
traj_closed_train = trajdict['traj_closed_train_hungarian']
traj_open_train = trajdict['traj_open_train_hungarian']
traj_closed_test = trajdict['traj_closed_test_hungarian']
traj_open_test = trajdict['traj_open_test_hungarian']
x = np.vstack([traj_closed_train, traj_open_train])
xval = np.vstack([traj_closed_test, traj_open_test])

In [3]:
batch_size = 256
train_set = np.vstack([traj_closed_train[:80000]])

In [4]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

In [5]:
#RNADE
class RNADE(nn.Module):
    def __init__(self,dimer_atoms,solvent_atoms):
        super(RNADE, self).__init__()
        self.dimer_atoms = dimer_atoms
        self.solvent_atoms = solvent_atoms
        self.total_dims = self.dimer_atoms + self.solvent_atoms
        self.i2h = nn.Linear(self.total_dims-1, self.total_dims-1)
        self.h2o = nn.Linear(self.total_dims-1, self.total_dims-self.dimer_atoms)
        self.D = self.total_dims
        self.H = 256
        self.params = nn.ParameterDict({
            "V" : nn.Parameter(torch.randn(self.D, self.H)),
            "b" : nn.Parameter(torch.zeros(self.D)),
            "V2" : nn.Parameter(torch.randn(self.D, self.H)),
            "b2" : nn.Parameter(torch.zeros(self.D)),
            "W" : nn.Parameter(torch.randn(self.H, self.D)),
            "c" : nn.Parameter(torch.zeros(1, self.H)),
        })
        nn.init.xavier_normal_(self.params["V"])
        nn.init.xavier_normal_(self.params["V2"])
        nn.init.xavier_normal_(self.params["W"])
        
    def forward(self, x):
        ai = self.params["c"].expand(x.size(0), -1)   #B x H
        print(ai.size())
        a1=[]
        m1 = []
        for d in range(self.D):
            h_i = torch.relu(ai) #B x H
            alpha1 = torch.sigmoid( h_i.mm(self.params["V"][d:d+1,:].t() ) + self.params["b"][d:d+1] )*2  + pow(10,-1) + 0.5#  BxH *  Hx1  
            mean1 = h_i.mm(self.params["V2"][d:d+1,:].t() ) + self.params["b2"][d:d+1]
            a1.append(alpha1)
            m1.append(mean1)

            ai = x[:, d:d+1].mm(self.params["W"][:, d:d+1].t() ) + ai #Bx1 * 1xH =  BxH
        
        a1 = torch.cat(a1,1)
        m1 = torch.cat(m1,1)
        final_prob = torch.stack([m1,a1])       
        return final_prob


In [6]:
def lossFunct(output,pred,ind):
    alpha = output[1, :,4:]
    mean = output[0, :, 4:]
    loss = torch.exp( -0.5 *  ( (pred[:,4:]- mean)/alpha)**2  )/(alpha*sqrt(2*3.14) ) +pow(10,-10)
    print("Loss:")
    return -torch.log( loss ).sum(axis= 0)[ind]/alpha.size()[0]

In [7]:
def otherLoss(output,pred):
    alpha = output[1, :,4:]
    mean = output[0, :, 4:]
    loss = torch.exp( -0.5 *  ( (pred[:,4:]- mean)/alpha)**2  )/(alpha*sqrt(2*3.14) ) +pow(10,-10)
    print("Loss:")
    return -torch.log( loss ).sum(axis= 0).sum()/alpha.size()[0]

In [22]:
def MixtureLoss(output,pred):
    mean= output[0]
    alpha = output[2]
    std = output[1]
    k=2
    ans = None;
    for i in range(k):
        print(alpha[:,:,i].size())
        loss = alpha[ :, :,i]* torch.exp(-0.5 * ( (pred[:,4:] -mean[ :, :,i])/std[ :, :,i])**2 ) / (std[ :, :,i]*sqrt(2*3.14))
        if(ans is not None):
            ans = ans - torch.log(loss).sum(axis=0).sum()/256
        else:
            ans = -torch.log(loss).sum(axis=0).sum()/256
    return ans

In [28]:
class RNADE2(nn.Module):
    def __init__(self,dimer_atoms,solvent_atoms):
        super(RNADE2, self).__init__()
        self.dimer_atoms = dimer_atoms
        self.solvent_atoms = solvent_atoms
        self.D = self.dimer_atoms + self.solvent_atoms
        self.H = 128
        self.K = 2
        self.params = nn.ParameterDict({
            "V" : nn.Parameter(torch.randn(self.D, self.H)),
            "b" : nn.Parameter(torch.zeros(self.D)),
            "V2" : nn.Parameter(torch.randn(self.D, self.H)),
            "b2" : nn.Parameter(torch.zeros(self.D)),
            "Vmean" : nn.Parameter(torch.randn(self.D,self.H, self.K)),
            "Valpha" : nn.Parameter(torch.randn(self.D,self.H, self.K)),
            "Vstd" : nn.Parameter(torch.randn(self.D,self.H, self.K)),
            "bmean" : nn.Parameter(torch.zeros(self.D,self.K)),
            "balpha" : nn.Parameter(torch.zeros(self.D,self.K)),
            "bstd" : nn.Parameter(torch.zeros(self.D,self.K)),
            "W" : nn.Parameter(torch.randn(self.H, self.D)),
            "c" : nn.Parameter(torch.zeros(1, self.H)),
        })
        nn.init.xavier_normal_(self.params["V"])
        nn.init.xavier_normal_(self.params["V2"])
        nn.init.xavier_normal_(self.params["W"])
        nn.init.xavier_normal_(self.params["Vmean"])
        nn.init.xavier_normal_(self.params["Valpha"])
        nn.init.xavier_normal_(self.params["Vstd"])
        
    def forward(self, x):
        ai = self.params["c"].expand(x.size(0), -1)   #B x H
        a= None
        m = None
        s = None
        for d in range(self.D):
            if(d<4):
                ai = x[:, d:d+1].mm(self.params["W"][:,d:d+1].t()) + ai
                continue
            h_i = torch.relu(ai) #B x H
            std = torch.sigmoid( ( h_i.mm(self.params["Vstd"][d,:,] ) + self.params["bstd"][d:d+1,:].expand(x.size(0), -1) ) )*2  + pow(10,-1) + 0.5#  BxH *  HxK = BxK  
            mean = ( h_i.mm(self.params["Vmean"][d,:,] ) + self.params["bmean"][d:d+1,:].expand(x.size(0), -1) ) #B xH  * HxK  = B x K + BxK
            alpha = torch.softmax( (h_i.mm(self.params["Valpha"][d,:,] ) +self.params["balpha"][d:d+1,:].expand(x.size(0), -1) ), dim = 0 )
            if(a is not None):
                a = torch.cat((a, alpha.unsqueeze(dim = 0)),0)
                m = torch.cat((m, mean.unsqueeze(dim = 0)) , 0)
                s = torch.cat((s, std.unsqueeze(dim = 0)) , 0)
            else:
                a = alpha.unsqueeze(dim=0)
                m = mean.unsqueeze(dim=0)
                s = std.unsqueeze(dim=0)
            ai = x[:, d:d+1].mm(self.params["W"][:, d:d+1].t() ) + ai #Bx1 * 1xH =  BxH
        
        m = m.permute(1,0,2 )
        a = a.permute(1,0,2)
        s = s.permute(1,0,2)
        #print(a.size(),m.size(),s.size())
        #final_prob = torch.stack([m,s,a]) 
        #print(final_prob.size())
        return [m,s,a]

In [29]:
device  = "cpu"

In [30]:
model = RNADE2(4,72).to(device)

In [31]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(500):
    print(epoch)
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
            optimizer.zero_grad()
            x_hat = model(data)
            loss = MixtureLoss(x_hat, data)
            loss.backward()
            optimizer.step()
            print(loss.item())
            # print statistics
            running_loss += loss.item()
    torch.save(model.state_dict(), 'RNADEtemp2')
    print("Running Loss is " + str(running_loss) )

0
torch.Size([256, 72])
torch.Size([256, 72])
1139.71337890625
torch.Size([256, 72])
torch.Size([256, 72])
1106.256103515625
torch.Size([256, 72])
torch.Size([256, 72])
1082.85107421875
torch.Size([256, 72])
torch.Size([256, 72])
1064.3212890625
torch.Size([256, 72])
torch.Size([256, 72])
1048.92236328125
torch.Size([256, 72])
torch.Size([256, 72])
1035.8818359375
torch.Size([256, 72])
torch.Size([256, 72])
1026.700439453125
torch.Size([256, 72])
torch.Size([256, 72])
1020.3818359375
torch.Size([256, 72])
torch.Size([256, 72])
1013.9923706054688
torch.Size([256, 72])
torch.Size([256, 72])
1006.089111328125
torch.Size([256, 72])
torch.Size([256, 72])
999.3118286132812
torch.Size([256, 72])
torch.Size([256, 72])
993.8810424804688
torch.Size([256, 72])
torch.Size([256, 72])
987.244384765625
torch.Size([256, 72])
torch.Size([256, 72])
978.217041015625
torch.Size([256, 72])
torch.Size([256, 72])
970.4793701171875
torch.Size([256, 72])
torch.Size([256, 72])
965.0411376953125
torch.Size([256,

871.638427734375
torch.Size([256, 72])
torch.Size([256, 72])
871.3636474609375
torch.Size([256, 72])
torch.Size([256, 72])
871.38330078125
torch.Size([256, 72])
torch.Size([256, 72])
871.21142578125
torch.Size([256, 72])
torch.Size([256, 72])
870.94482421875
torch.Size([256, 72])
torch.Size([256, 72])
870.6710205078125
torch.Size([256, 72])
torch.Size([256, 72])
871.197509765625
torch.Size([256, 72])
torch.Size([256, 72])
870.4716796875
torch.Size([256, 72])
torch.Size([256, 72])
870.8242797851562
torch.Size([256, 72])
torch.Size([256, 72])
871.92822265625
torch.Size([256, 72])
torch.Size([256, 72])
870.561767578125
torch.Size([256, 72])
torch.Size([256, 72])
870.9083862304688
torch.Size([256, 72])
torch.Size([256, 72])
870.2588500976562
torch.Size([256, 72])
torch.Size([256, 72])
870.59375
torch.Size([256, 72])
torch.Size([256, 72])
870.8744506835938
torch.Size([256, 72])
torch.Size([256, 72])
870.5556640625
torch.Size([256, 72])
torch.Size([256, 72])
870.2579345703125
torch.Size([256

torch.Size([256, 72])
torch.Size([256, 72])
867.2306518554688
torch.Size([256, 72])
torch.Size([256, 72])
867.22998046875
torch.Size([256, 72])
torch.Size([256, 72])
867.4710693359375
torch.Size([256, 72])
torch.Size([256, 72])
867.3473510742188
torch.Size([256, 72])
torch.Size([256, 72])
867.5906982421875
torch.Size([256, 72])
torch.Size([256, 72])
867.392578125
torch.Size([256, 72])
torch.Size([256, 72])
867.36865234375
torch.Size([256, 72])
torch.Size([256, 72])
867.3035888671875
torch.Size([256, 72])
torch.Size([256, 72])
867.1874389648438
torch.Size([256, 72])
torch.Size([256, 72])
867.1506958007812
torch.Size([256, 72])
torch.Size([256, 72])
867.1419677734375
torch.Size([256, 72])
torch.Size([256, 72])
867.400390625
torch.Size([256, 72])
torch.Size([256, 72])
867.1756591796875
torch.Size([256, 72])
torch.Size([256, 72])
866.9678955078125
torch.Size([256, 72])
torch.Size([256, 72])
867.2567138671875
torch.Size([256, 72])
torch.Size([256, 72])
867.228515625
torch.Size([256, 72])
to

In [396]:
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for ind in range(72):
        for i, data in enumerate(train_loader, 0):
            optimizer.zero_grad()
            x_hat = model(data)
            print(x_hat[2][0][0])
            loss = lossFunct(x_hat, data,ind)
            loss.backward()
            optimizer.step()

            print(loss.item())
           
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    torch.save(model.state_dict(), 'RNADEtemp1')

torch.Size([72, 256, 2])
torch.Size([256, 72, 2]) torch.Size([256, 72, 2]) torch.Size([256, 72, 2])
tensor([0.4983, 0.5017], grad_fn=<SelectBackward>)


RuntimeError: The size of tensor a (72) must match the size of tensor b (2) at non-singleton dimension 2

In [393]:
model.load_state_dict(torch.load("RNADEOpen2"))
model.eval()

RuntimeError: Error(s) in loading state_dict for RNADE2:
	Missing key(s) in state_dict: "params.Valpha", "params.Vmean", "params.Vstd", "params.balpha", "params.bmean", "params.bstd". 
	Unexpected key(s) in state_dict: "i2h.weight", "i2h.bias", "h2o.weight", "h2o.bias". 

In [386]:
for epoch in range(500):
    print(epoch)
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
            optimizer.zero_grad()
            x_hat = model(data)
            loss = otherLoss(x_hat, data)
            loss.backward()
            optimizer.step()
            print(loss.item())
            # print statistics
            running_loss += loss.item()
    torch.save(model.state_dict(), 'RNADEClosed2')
    print("Running Loss is " + str(running_loss) )

0
torch.Size([72, 256, 2])
torch.Size([256, 72, 2]) torch.Size([256, 72, 2]) torch.Size([256, 72, 2])


RuntimeError: The size of tensor a (72) must match the size of tensor b (2) at non-singleton dimension 2