In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNN,self).__init__()
        self.hd_sz  = hidden_size
        self.in_sz  = input_size
        
        combined = input_size+hidden_size
        
        self.h1      = nn.Linear(combined, hidden_size)  

        self.o1      = nn.Linear(combined, combined)
        self.relu    = nn.ReLU(combined)

        self.o2      = nn.Linear(combined, combined)
        self.relu2   = nn.ReLU(combined)
        
        self.o3      = nn.Linear(combined, combined)
        self.relu3   = nn.ReLU(combined)
        
        self.o4      = nn.Linear(combined, combined)
        self.relu4   = nn.Linear(combined, combined)
        
        self.o5      = nn.Linear(combined, input_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def batch_forward(self,xb,yb,hidden,loss_fn):
        self.train()
        if xb[0,0,1].item() == 1: hidden = self.initHidden(xb.shape[0])                   
        loss = 0 
        for char in range(xb.shape[1]):
            x,y           = xb[:,char],yb[:,char]
            x,y,hidden    = unpad(x,y,hidden)
            if x.shape[0] == 0: break
            output,hidden = self.forward(x,hidden)
            loss += loss_fn(output,y)    
        return output,hidden.detach(),loss/(char+1)        
        
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)    
        
        hidden   = self.h1(combined)
        hidden   = torch.tanh(hidden)
        
        output   = self.o1(combined)
        output   = self.relu(output)
        
        output   = self.o2(output)
        output   = self.relu2(output)
        
        output   = self.o3(output)
        output   = self.relu3(output)
        
        output   = self.o4(output)
        output   = self.relu4(output)
        
        output   = self.o5(output)
        output   = self.softmax(output)
        return output, hidden

    def initHidden(self,bs):
        return cuda(torch.zeros(bs,self.hd_sz))    

class GRU(nn.Module):
    def __init__(self, in_sz, hd_sz):
        super(GRU,self).__init__()
        self.in_sz = in_sz
        self.hd_sz = hd_sz
        
        self.h_lin = nn.Linear(self.hd_sz,3*self.hd_sz)
        self.x_lin = nn.Linear(self.in_sz,3*self.hd_sz)        
        
        self.up_sig = nn.Sigmoid()
        self.re_sig = nn.Sigmoid()
            
        self.o1      = nn.Linear(self.hd_sz+self.in_sz,self.in_sz)  

        self.softmax = nn.LogSoftmax(dim=1)   
        self.loss    = 0 
            
    def forward(self,input,hidden):        
        x = self.x_lin(input)        
        h = self.h_lin(hidden)        
        x_u,x_r,x_n = x.chunk(3,1)
        h_u,h_r,h_n = h.chunk(3,1)
        update_gate = self.up_sig(x_u+h_u)        
        reset_gate  = self.re_sig(x_r+h_r)
        new_gate    = torch.tanh(x_n + reset_gate * h_n)         
        h_new       = update_gate * hidden + (1 - update_gate) * new_gate 
        
        combined   = torch.cat((input,h_new),1)
        combined   = self.o1(combined)

        prediction = self.softmax(combined)
        
        return prediction, h_new
    
    def batch_forward(self,xb,yb,hidden,loss_fn):
        self.train()
        if xb[0,0,1].item() == 1: hidden = self.initHidden(xb.shape[0])                   
        loss = 0 
        for char in range(xb.shape[1]):
            x,y           = xb[:,char],yb[:,char]
            x,y,hidden    = unpad(x,y,hidden)
            if x.shape[0] == 0: break
            output,hidden = self.forward(x,hidden)
            loss += loss_fn(output,y)    
        return output,hidden.detach(),loss/(char+1)

    
    def initHidden(self, bs):
        return cuda(torch.zeros(bs,self.hd_sz))