In [1]:
from RNN_QSR import *

cuda:0


In [16]:
Nh=16


trainrnn,optimizer=new_rnn_with_optim("LSTM",Nh,lr=1e-3)
samplernn = RNN(rnntype="LSTM",Nh=Nh)
Lx=Ly=8
NQ=32
# Chose M st M^NQ = 1/e
M=1-1/NQ
K=2
if len(sys.argv)==4:
    NQ=1
    


In [17]:
DEBUG,_,_=train(samplernn,trainrnn,optimizer,K,NQ,M,Lx,Ly,steps=3000,mydir="hiddentests/Hidden16")

8 8 64 0.96875
2.7110414505004883 0.9
0,2.55|15,-0.31|30,-0.32|45,-0.35|60,-0.34|75,-0.34|89.91402173042297 3000
-0.342318594455719 -0.367777943611145 -0.3584502339363098


# TorchScript RNN
[https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/]

In [2]:
from torch import jit
from torch.nn.parameter import Parameter
import math
from torch.nn import init

In [3]:
class LSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.empty(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.empty(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.empty(4 * hidden_size))
        self.bias_hh = Parameter(torch.empty(4 * hidden_size))
        self.reset_parameters()
        
    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)
    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)
    
    
class LSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        #this is batch first so we unbind dimension 1
        inputs = input.unbind(1)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs,dim=1), state

In [4]:
class LSTM_TS(Sampler):
    def __init__(self,Nh=128,device=device, **kwargs):
        super(LSTM_TS, self).__init__(device=device)
        #rnn takes input shape [B,L,1]
        self.rnn = LSTMLayer(LSTMCell,1,Nh)
        self.cell=self.rnn.cell
        
        self.lin = nn.Sequential(
                nn.Linear(Nh,128),
                nn.ReLU(),
                nn.Linear(128,1),
                nn.Sigmoid()
            )
        self.Nh=Nh
        self.to(device)
    def forward(self, input):
        # input is shape [B,L,1]
        # h0 has shape [B,H]
        
        h0=[torch.zeros([input.shape[0],self.Nh],device=self.device),
           torch.zeros([input.shape[0],self.Nh],device=self.device)]

        out,h=self.rnn(input,h0)
        return self.lin(out)
    
    def logprobability(self,input):
        """Compute the logscale probability of a given state
            Inputs:
                input - [B,L,1] matrix of zeros and ones for ground/excited states
            Returns:
                logp - [B] size vector of logscale probability labels
        """
        
        #Input should have shape [B,L,1]
        B,L,one=input.shape
        
        #first prediction is with the zero input vector
        data=torch.zeros([B,L,one],device=self.device)
        #data is the input vector shifted one to the right, with the very first entry set to zero instead of using pbc
        data[:,1:,:]=input[:,:-1,:]
        
        #real is going to be a set of actual values
        real=input
        #and pred is going to be a set of probabilities
        
        
        
        #probability predictions may be done WITH gradients
        #with torch.no_grad():
        
        pred = self.forward(data)
        #if real[i]=1 than you multiply your conditional probability by pred[i]
        ones = real*pred
        #if real[i]=0 than you multiply by 1-pred[i]
        zeros=(1-real)*(1-pred)
        total = ones+zeros
        #this is the sum you see in the cell above
        #add 1e-10 to the prediction to avoid nans when total=0
        logp=torch.sum(torch.log(total+1e-10),dim=1).squeeze(1)
        return logp
    def sample(self,B,L):
        """ Generates a set states
        Inputs:
            B (int)            - The number of states to generate in parallel
            L (int)            - The length of generated vectors
        Returns:
            samples - [B,L,1] matrix of zeros and ones for ground/excited states
        """
        h=[torch.zeros([B,self.Nh],device=self.device),
            torch.zeros([B,self.Nh],device=self.device)]

        #Sample set will have shape [N,L,1]
        #need one extra zero batch at the start for first pred hence input is [N,L+1,1] 
        input = torch.zeros([B,L+1,1],device=self.device)
        #sampling can be done without gradients
        with torch.no_grad():
          for idx in range(1,L+1):
            #run the rnn on shape [B,1,1]
            
            out,h=self.rnn(input[:,idx-1:idx,:],h)
            out=out[:,0,:]
            #if probs[i]=1 then there should be a 100% chance that sample[i]=1
            #if probs[i]=0 then there should be a 0% chance that sample[i]=1
            #stands that we generate a random uniform u and take int(u<probs) as our sample
            probs=self.lin(out)
            sample = (torch.rand([B,1],device=device)<probs).to(torch.float32)
            input[:,idx,:]=sample
        #input's first entry is zero to get a predction for the first atom
        return input[:,1:,:]

In [51]:
trainrnn = LSTM_TS(Nh=32)
samplernn= LSTM_TS(Nh=32)
beta1=0.9;beta2=0.999
optimizer = torch.optim.Adam(
trainrnn.parameters(), 
lr=1e-3, 
betas=(beta1,beta2)
)

In [52]:
try:
    train(samplernn,trainrnn,optimizer,K,NQ,M,Lx,Ly,steps=3000,mydir="speedtests")
except:
    pass

8 8 64 0.96875
2.618889808654785 1.0
0,2.62|25,-0.30|51,-0.34|76,-0.36|102,-0.37|127,-0.38|153.31546330451965 3000
-0.38360506296157837 -0.3863223195075989 -0.38439613580703735


In [8]:
x = torch.zeros([10,16,1],device=device)
trainrnn(x).shape

torch.Size([10, 16, 1])

In [5]:

class LSTM2DCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(LSTM2DCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        #input matmul weights
        self.weight_ih = Parameter(torch.empty(5 * hidden_size, input_size))
        self.bias_ih = Parameter(torch.empty(5 * hidden_size))
        
        #x direction hidden state weights
        self.weight_hh_x = Parameter(torch.empty(5 * hidden_size, hidden_size))
        
        #y direction hidden state weights
        self.weight_hh_y = Parameter(torch.empty(5 * hidden_size, hidden_size))
        
        self.bias_hh = Parameter(torch.empty(5 * hidden_size))
        
        
        self.reset_parameters()
        
    def reset_parameters(self) -> None:
        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)

    @jit.script_method
    def forward(self, input, state_x,state_y):
        # type: (Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        hx, cx = state_x
        hy,cy = state_y
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh_x.t()) +
                 torch.mm(hy, self.weight_hh_y.t()) 
                 + self.bias_hh)
        
        ingate, forgetgate, cellgate, outgate, l = gates.chunk(5, 1)

        ingate = torch.sigmoid(ingate)#i
        forgetgate = torch.sigmoid(forgetgate)#f
        cellgate = torch.tanh(cellgate)#c_candidate
        outgate = torch.sigmoid(outgate)#o
        l=torch.sigmoid(l)#l
        
        cz = forgetgate * (l * cx + (1 - l) * cy) + (cellgate * ingate)
        #cz = (forgetgate * cx) + (ingate * cellgate)
        
        hz = outgate * torch.tanh(cz)

        return hz, (hz, cz)

In [6]:
class LSTM2DLayer(jit.ScriptModule):
    def __init__(self,N, cell, *cell_args):
        super(LSTM2DLayer, self).__init__()
        self.cell = cell(*cell_args)
        self.N=N
        self.setup()
        #self.to(device)
    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        
        self.X,self.Y
        #this is batch first so we unbind dimension 1
        inputs = input.unbind(1)
        outputs = torch.jit.annotate(List[Tensor], [])
        S = torch.jit.annotate(List[Tuple[Tensor, Tensor]], [])
        #set the starting state
        S +=[state]
        for idx in range(len(inputs)):
            #get horizontal input states
            state_x=S[self.X[idx]]
            #get vertical input states
            state_y=S[self.Y[idx]]
            #get output states
            out,state=self.cell(inputs[idx],state_x,state_y)
            #store output states for later use
            S +=[state]
            #apply linear layer
            outputs+=[out]
            
        return torch.stack(outputs,dim=1), state
    
    def setup(self):
        """Sets up the structure that tells the rnn how to propogate the vertical and horizontal states.
            3x3: X = (0,1,2,3,4,5,6,7,8)
                 Y = (0,0,0,3,2,1,6,5,4)
        """
        N=self.N
        self.X=[]
        self.Y=[]
        
        for idx in range(N**2):
            #self.Y+=[max(N*(idx//N)-(idx%N),0)]
            self.Y+=[max(idx-N+1,0)]
            self.X+=[idx]


In [7]:
rnn = LSTM2DLayer(4,LSTM2DCell,1,128)
rnn.to(device)

state=(torch.zeros([10,128],device=device),torch.zeros([10,128],device=device))

x = torch.zeros([10,16,1],device=device)
out,state = rnn(x,state)

rnn.Y

[0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

In [8]:
class LSTM2D_S(Sampler):
    def __init__(self,Lx,Nh=128,device=device, **kwargs):
        super(LSTM2D_S, self).__init__(device=device)
        #rnn takes input shape [B,L,1]
        self.rnn = LSTM2DLayer(Lx,LSTM2DCell,1,Nh)
        self.Y=self.rnn.Y
        self.X=self.rnn.X
        self.cell=self.rnn.cell
        
        self.lin = nn.Sequential(
                nn.Linear(Nh,128),
                nn.ReLU(),
                nn.Linear(128,1),
                nn.Sigmoid()
            )
        self.Nh=Nh
        self.to(device)
    def forward(self, input):
        # input is shape [B,L,1]
        # h0 has shape [B,H]
        
        h0=[torch.zeros([input.shape[0],self.Nh],device=self.device),
           torch.zeros([input.shape[0],self.Nh],device=self.device)]

        out,h=self.rnn(input,h0)
        return self.lin(out)
    
    def logprobability(self,input):
        """Compute the logscale probability of a given state
            Inputs:
                input - [B,L,1] matrix of zeros and ones for ground/excited states
            Returns:
                logp - [B] size vector of logscale probability labels
        """
        
        #Input should have shape [B,L,1]
        B,L,one=input.shape
        
        #first prediction is with the zero input vector
        data=torch.zeros([B,L,one],device=self.device)
        #data is the input vector shifted one to the right, with the very first entry set to zero instead of using pbc
        data[:,1:,:]=input[:,:-1,:]
        
        #real is going to be a set of actual values
        real=input
        #and pred is going to be a set of probabilities
        #if real[i]=1 than you multiply your conditional probability by pred[i]
        #if real[i]=0 than you multiply by 1-pred[i]
        
        #probability predictions may be done WITH gradients
        #with torch.no_grad():
        
        pred = self.forward(data)
        ones = real*pred
        zeros=(1-real)*(1-pred)
        total = ones+zeros
        #this is the sum you see in the cell above
        #add 1e-10 to the prediction to avoid nans when total=0
        logp=torch.sum(torch.log(total+1e-10),dim=1).squeeze(1)
        return logp
    def sample(self,B,L):
        """ Generates a set states
        Inputs:
            B (int)            - The number of states to generate in parallel
            L (int)            - The length of generated vectors
        Returns:
            samples - [B,L,1] matrix of zeros and ones for ground/excited states
        """
        h=(torch.zeros([B,self.Nh],device=self.device),
            torch.zeros([B,self.Nh],device=self.device))

        S=[h]
        #Sample set will have shape [N,L,1]
        #need one extra zero batch at the start for first pred hence input is [N,L+1,1] 
        input = torch.zeros([B,L+1,1],device=self.device)
        #sampling can be done without gradients
        with torch.no_grad():
          for idx in range(1,L+1):
            #run the rnn on shape [B,1,1]
            #get horizontal input states
            state_x=S[self.X[idx-1]]
            #get vertical input states
            state_y=S[self.Y[idx-1]]
            
            out,h=self.cell(input[:,idx-1,:],state_x,state_y)
            
            S+=[h]
            #if probs[i]=1 then there should be a 100% chance that sample[i]=1
            #if probs[i]=0 then there should be a 0% chance that sample[i]=1
            #stands that we generate a random uniform u and take int(u<probs) as our sample
            probs=self.lin(out)
            sample = (torch.rand([B,1],device=device)<probs).to(torch.float32)
            input[:,idx,:]=sample
        #input's first entry is zero to get a predction for the first atom
        return input[:,1:,:]

In [9]:
Lx=Ly=32
NQ=32
# Chose M st M^NQ = 1/e
M=1-1/NQ
K=2
if len(sys.argv)==4:
    NQ=1
    
    
trainrnn = LSTM2D_S(Lx,Nh=256)
samplernn= LSTM2D_S(Lx,Nh=256)
beta1=0.9;beta2=0.999
optimizer = torch.optim.Adam(
trainrnn.parameters(), 
lr=1e-3, 
betas=(beta1,beta2)
)

In [10]:
DEBUG,_,_ = train(samplernn,trainrnn,optimizer,K,NQ,M,Lx,Ly,steps=12000,mydir="hiddentests/Hidden256")


32 32 1024 0.96875
3.0072531700134277 17.3
2,2.91|486,1.55|970,0.78|1454,0.27|1937,0.01|2421,-0.13|2906,-0.19|3388,-0.23|3872,-0.25|4355,-0.64|4837,-0.28|5322,-0.27|5806,-0.28|6291,-0.28|6774,-0.29|7258,-0.30|7743,-0.29|8227,-0.28|8712,-0.29|9196,-0.29|9678,-0.29|10162,-0.29|10646,-0.30|11132,-0.29|11625.24712395668 12000
