In [1]:
import torch
import torch.nn as nn
from pathlib import Path
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
from math import *
import torch.distributions as tdist
from random import gauss,seed

In [2]:
import sys 
sys.path.insert(0,'/home/mohit/Downloads/code_notebooks/deep_boltzmann')
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

In [3]:
from IPython.display import SVG

In [4]:
from deep_boltzmann.models import ParticleDimer
md = ParticleDimer()

Using TensorFlow backend.


In [None]:
class RNAF(nn.Module):
    def __init__(self,dimer_atoms,solvent_atoms):
        super(RNAF, self).__init__()
        self.dimer_atoms = dimer_atoms
        self.solvent_atoms = solvent_atoms
        self.total_dims = self.dimer_atoms + self.solvent_atoms
        self.D = self.total_dims
        self.H = 64
        self.params = nn.ParameterDict({
            "V" : nn.Parameter(torch.randn(self.D, self.H)),
            "b" : nn.Parameter(torch.zeros(self.D)),
            "V2" : nn.Parameter(torch.randn(self.D, self.H)),
            "b2" : nn.Parameter(torch.zeros(self.D)),
            "W" : nn.Parameter(torch.randn(self.H, self.D)),
            "c" : nn.Parameter(torch.zeros(1, self.H)),
        })
        nn.init.xavier_normal_(self.params["V"])
        nn.init.xavier_normal_(self.params["V2"])
        nn.init.xavier_normal_(self.params["W"])
        
    def forward(self, x):
        ai = self.params["c"].expand(x.size(0), -1)   #B x H
        a1=[]
        m1 = []
        z = []
        for d in range(self.D):
            h_i = torch.relu(ai) #B x H
            #alpha1 = torch.sigmoid( h_i.mm(self.params["V"][d:d+1,:].t() ) + self.params["b"][d:d+1] )*2  + pow(10,-1) + 0.5#  BxH *  Hx1  
            std1 = torch.sigmoid( h_i.mm(self.params["V"][d:d+1,:].t() ) + self.params["b"][d:d+1] ) * 1.60 - 0.7
            #std1 = torch.exp(std1)
            mean1 = h_i.mm(self.params["V2"][d:d+1,:].t() ) + self.params["b2"][d:d+1]
            a1.append(std1)
            m1.append(mean1)
            z.append(torch.randn(x.size(0)))
            ai = x[:, d:d+1].mm(self.params["W"][:, d:d+1].t() ) + ai #Bx1 * 1xH =  BxH
        
        a1 = torch.cat(a1,1)
        m1 = torch.cat(m1,1)
        z = torch.cat(z,1)
        final_prob = torch.stack([m1,a1,z])       
     
        return final_prob
    
    def sample(self,x):
        ai = self.params["c"].expand(x.size(0), -1)
        means = []
        alphas = []
        boj = []
        z = []
        sample = x
        for d in range(self.D):
            h_i = torch.relu(ai)
            if(d<4):
                ai = sample[:, d:d+1].mm(self.params["W"][:,d:d+1].t()) + ai
                continue
            alpha = torch.sigmoid( h_i.mm(self.params["V"][d:d+1,:].t() ) + self.params["b"][d:d+1] )*2  + pow(10,-1) + 0.5#  BxH *  Hx1  
            mean = h_i.mm(self.params["V2"][d:d+1,:].t() ) + self.params["b2"][d:d+1]
            means.append(mean)
            alphas.append(alpha)
            z.append(torch.randn(x.size(0)))
            x1 = mean
            boj.append(x1)
            ai = x1.mm(self.params["W"][:, d:d+1].t() ) + ai
        alphas = torch.cat(alphas,1)
        means = torch.cat(means,1)
        z = torch.cat(z,1)
        boj = torch.cat(boj,1)
        print(boj)
        return boj

In [None]:
def KlDivergence(output, pred):