### Necessary imports

In [518]:
from __future__ import print_function, division
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

### Hyperparameters

In [519]:
sequence_length = 5
total_digits = 5
input_size = 10*total_digits
output_size = 11*(total_digits+1)
batch = 32
hidden_size = 512
num_layers = 1
num_epochs = 1000
learning_rate = 0.0001

### Generate samples

In [547]:
def add_vec(x,y):
    n = max(len(x),len(y))
    return num2vec(vec2num(x)+vec2num(y),n)

def vec2num(x):
    s = 0
    for i in range(len(x)):
        if x[i] == 10:
            break    
        s *= 10
        s += x[i]
    return s

def num2vec(x, n):
    y = np.zeros(n) + 10
    digits = len(str(int(x)))
    for i in range(digits):
        y[i] = (x//10**(digits-i-1))%10
    return y

def encode_in(x):
    y = np.zeros(len(x)*10)
    for i in range(len(x)):
        if x[i] == 10:
            break
        else:
            y[10*i+int(x[i])] = 1
    return y

def encode_out(x):
    y = np.zeros(len(x)*11)
    for i in range(len(x)):
        if x[i] == 10:
            y[11*i+10] = 1
        else:
            y[11*i+int(x[i])] = 1
    return y

def decode_out(x):
    y = np.zeros(len(x)//11, dtype=int)
    for i in range(len(y)):
        y[i] = np.argmax(x[i*11:(i+1)*11])
    return y

def generate():
    input_dec = np.zeros((batch,sequence_length,total_digits),dtype=int) + 10
    input_enc = np.zeros((batch,sequence_length,input_size),dtype=int)
    output_dec = np.zeros((batch,sequence_length,total_digits+1),dtype=int)
    output_enc = np.zeros((batch,sequence_length,output_size),dtype=int)
    for i in range(batch):
        for j in range(sequence_length):
            digits = np.random.randint(5) + 1    
            for k in range(digits):
                d = np.random.randint(10)
                input_dec[i,j,k] = d
            if j == 0:
                output_dec[i,j,:-1] = input_dec[i,j,:]
                output_dec[i,j,-1] = 10
            elif j > 0:
                output_dec[i,j,:] = add_vec(output_dec[i,j-1,:], input_dec[i,j,:])
            input_enc[i,j,:] = encode_in(input_dec[i,j,:])
            output_enc[i,j,:] = encode_out(output_dec[i,j,:])
    x = Variable(torch.from_numpy(input_enc)).float()
    y = Variable(torch.from_numpy(output_dec)).long()
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
    return x, y
    

### Neural Network

In [548]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(sequence_length*hidden_size, sequence_length*output_size)
    
    def forward(self, x):
        (h0, c0) = self.init_hidden_state()
        out, (h0, c0) = self.lstm(x, (h0, c0))
        out = out.contiguous()
        out = self.fc(out.view(batch,-1))
        for j in range(sequence_length):
            for i in range(total_digits+1):
                out[:,output_size*j + i*11:output_size*j + (i+1)*11] = \
                F.softmax(out[:,output_size*j + i*11:output_size*j + (i+1)*11].clone())
        return out
    
    def init_hidden_state(self):
        if torch.cuda.is_available():
            h0 = Variable(torch.zeros(self.num_layers, batch, self.hidden_size).cuda()) 
            c0 = Variable(torch.zeros(self.num_layers, batch, self.hidden_size).cuda())
        else:
            h0 = Variable(torch.zeros(self.num_layers, batch, self.hidden_size)) 
            c0 = Variable(torch.zeros(self.num_layers, batch, self.hidden_size))
        return h0, c0

rnn = RNN(input_size, hidden_size, num_layers)
if torch.cuda.is_available():
    rnn.cuda()

### Loss and optimizer

In [549]:
def CustomLoss(x,y,criterion):
    s = Variable(torch.zeros(1), requires_grad=True)
    y = y.view(batch,-1)
    for j in range(1,sequence_length):
        for i in range(total_digits+1):
            s = s + criterion(x[:,output_size*j + i*11:output_size*j + (i+1)*11],\
                              y[:,(total_digits+1)*j + i])
    return s

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

In [554]:
def accuracy(out,y):
    out = out.view(batch,sequence_length,-1).data.numpy()
    
    y = y.view(batch,sequence_length,-1).data.numpy()
    s = 0.
    for k in range(batch):
        for j in range(1,sequence_length):
            dec = decode_out(out[k,j,:])
            print(out[k,j,:], dec)
            if (dec - y[k,j,:]).all():
                s += 1
    return s/((sequence_length-1)*batch)
            
    

### Train

In [555]:
losses = np.zeros(num_epochs)
acc = np.zeros(num_epochs)
last_time = time.time()
for i in range(num_epochs):
    rnn.zero_grad()
    x,y = generate()
    out = rnn(x)
    acc[i] = accuracy(out,y)
    loss = CustomLoss(out,y,criterion)
    losses[i] = loss.data[0]
    loss.backward()
    optimizer.step()
    if (i+1)%(num_epochs//20) == 0:
        print('Step '+ str(i+1) + '/' +str(num_epochs)+ ' done. Loss = ' + \
              str(losses[i])+ '. Accuracy = ' + str(acc[i]) + '. Time elapsed: ' + str(time.time()-last_time))
        last_time = time.time()
        

[  3.46802735e-05   6.64687276e-01   1.67376758e-03   1.75809767e-03
   5.51241152e-02   1.36164771e-02   9.35393292e-03   1.41886957e-02
   1.32961318e-01   1.06568515e-01   3.31054689e-05   3.46206993e-01
   6.96878284e-02   1.90722719e-01   3.63921709e-02   9.32877362e-02
   1.00929089e-01   3.22992429e-02   5.39877824e-02   1.63413659e-02
   6.00025766e-02   1.42517456e-04   1.15432462e-03   1.40565690e-02
   4.47768718e-04   7.20434298e-04   1.17662572e-03   1.20123342e-01
   1.45506218e-01   1.34600804e-03   5.51508390e-04   2.65538431e-04
   7.14651644e-01   3.52330767e-06   4.93228845e-06   4.24543214e-06
   6.85666555e-06   3.44624891e-06   5.42178304e-06   3.41873238e-06
   4.73992168e-06   4.80134531e-06   2.45441629e-06   9.99956131e-01
   1.48698882e-06   2.11800557e-06   3.62044216e-06   2.36668484e-06
   2.39996416e-06   2.87752459e-06   3.20509093e-06   2.19453591e-06
   1.56886813e-06   1.93555047e-06   9.99976218e-01   1.35814230e-06
   1.10133180e-06   1.39177325e-06

   3.14227714e-06   9.99967515e-01] [ 1  0  0  1 10 10]
[  1.63982459e-05   7.02399313e-01   1.09437120e-03   1.15418667e-03
   4.67940867e-02   1.05319712e-02   6.85018394e-03   1.12662828e-02
   1.22793756e-01   9.70840678e-02   1.53951805e-05   3.70562851e-01
   6.48247451e-02   1.95087954e-01   3.38630080e-02   9.29653347e-02
   9.88566875e-02   2.84394752e-02   4.93026040e-02   1.34369899e-02
   5.25805913e-02   7.97481480e-05   8.22771573e-04   1.22548956e-02
   2.84368551e-04   4.75987996e-04   8.30564590e-04   1.30711004e-01
   1.61895812e-01   9.73087328e-04   3.63240979e-04   1.60999698e-04
   6.91227257e-01   1.23817961e-06   1.78656319e-06   1.51647362e-06
   2.58905789e-06   1.22188146e-06   2.01251032e-06   1.20599339e-06
   1.73686567e-06   1.73248679e-06   8.53612960e-07   9.99984086e-01
   4.88293551e-07   7.13181009e-07   1.25748647e-06   8.15023782e-07
   8.00501709e-07   9.76844717e-07   1.09798805e-06   7.37913922e-07
   5.18858201e-07   6.46034266e-07   9.99991953

   4.12109472e-07   9.99994218e-01] [ 1  0 10 10 10 10]
[  1.52519647e-06   9.98930097e-01   4.44437355e-05   5.33268394e-05
   9.47322333e-05   1.83727301e-04   3.98846132e-05   2.03853226e-04
   2.69360229e-04   1.77879585e-04   1.19246192e-06   5.37195683e-01
   1.60825327e-01   5.81966937e-02   3.70529555e-02   8.53844434e-02
   6.80695400e-02   2.14002002e-02   1.83143485e-02   1.06825475e-02
   2.85893632e-03   1.93119686e-05   1.33282274e-01   1.36299983e-01
   1.17366485e-01   1.62999943e-01   1.43722996e-01   1.06329970e-01
   9.25492048e-02   4.79694195e-02   4.81286757e-02   1.02762263e-02
   1.07482844e-03   7.60564126e-06   1.01459482e-05   1.62581910e-05
   2.45072151e-05   1.10222745e-05   6.96389725e-06   5.08316043e-06
   1.17160871e-05   1.03345701e-05   7.16777686e-06   9.99889195e-01
   1.05624520e-06   1.00491923e-06   1.36408175e-06   1.92360108e-06
   9.67098003e-07   9.73772558e-07   1.21097366e-06   1.44702699e-06
   1.28387535e-06   1.49140510e-06   9.99987304

   1.02515662e-06   9.99990046e-01] [ 1  0  0 10 10 10]
[  1.25859492e-06   9.99843061e-01   2.21072896e-05   7.62569107e-06
   1.04314868e-05   1.02932845e-05   1.50754877e-05   2.65963081e-05
   1.11895697e-05   5.14519961e-05   9.02513818e-07   6.47548556e-01
   1.18385933e-01   9.39586386e-02   8.77628326e-02   2.19864324e-02
   1.44109130e-02   4.48655849e-03   9.53394081e-03   9.95648443e-04
   9.11136332e-04   1.94182103e-05   1.53138876e-01   1.19387917e-01
   1.29781455e-01   6.74166828e-02   8.83478150e-02   8.59006941e-02
   1.25395060e-01   1.23949246e-02   1.31757960e-01   8.64438489e-02
   3.47565037e-05   5.68266399e-02   2.21648276e-01   8.42094496e-02
   9.57381204e-02   1.57679930e-01   8.17666054e-02   9.17814299e-02
   1.14961460e-01   4.22815233e-02   5.25354370e-02   5.71120996e-04
   1.53271903e-05   8.21853428e-06   1.21652402e-05   1.22136344e-05
   6.07644233e-06   9.59107183e-06   9.09419850e-06   6.41401266e-06
   1.02617596e-05   7.81471317e-06   9.99902844

[  1.57131331e-06   9.99812186e-01   2.67283431e-05   9.30568967e-06
   1.27804087e-05   1.25393699e-05   1.82098411e-05   3.11545991e-05
   1.36052577e-05   6.07493348e-05   1.14710929e-06   6.58141434e-01
   1.17212459e-01   8.94028097e-02   8.27765688e-02   2.22401824e-02
   1.39846141e-02   4.55474015e-03   9.62173939e-03   1.05676206e-03
   9.85808787e-04   2.28907702e-05   1.46432117e-01   1.19084276e-01
   1.20348714e-01   7.00453594e-02   8.92212540e-02   8.84414837e-02
   1.28774941e-01   1.32946968e-02   1.35174826e-01   8.91426057e-02
   3.97469375e-05   5.79582527e-02   2.11506903e-01   8.19965750e-02
   9.83986929e-02   1.48516759e-01   8.51189494e-02   9.53555703e-02
   1.22650459e-01   4.22549322e-02   5.55926822e-02   6.50227652e-04
   1.86497055e-05   9.98445739e-06   1.43485258e-05   1.45419108e-05
   7.42950533e-06   1.15052326e-05   1.06301341e-05   7.88211037e-06
   1.22480624e-05   9.41880899e-06   9.99883354e-01   1.22182973e-06
   1.09181008e-06   1.39088797e-06

[  1.28851298e-06   9.99036729e-01   3.93540577e-05   4.85366763e-05
   8.37119514e-05   1.62905839e-04   3.50552400e-05   1.83915399e-04
   2.46142386e-04   1.61327931e-04   1.01124976e-06   5.48242271e-01
   1.69147372e-01   5.26877902e-02   3.64159495e-02   8.45743492e-02
   5.87333068e-02   2.02007908e-02   1.75797846e-02   9.68746096e-03
   2.71393894e-03   1.69719970e-05   1.48148105e-01   1.39537066e-01
   1.23300210e-01   1.40705854e-01   1.41602695e-01   1.05757430e-01
   8.49467441e-02   5.10703325e-02   5.33927865e-02   1.04527865e-02
   1.08601328e-03   6.43737030e-06   8.93581000e-06   1.39437079e-05
   2.10630660e-05   9.61017668e-06   6.06782305e-06   4.33436935e-06
   1.01892510e-05   8.97175050e-06   6.14531200e-06   9.99904275e-01
   9.39592269e-07   8.81471578e-07   1.21606081e-06   1.71923045e-06
   8.49062189e-07   8.50027959e-07   1.07852554e-06   1.25803899e-06
   1.13291617e-06   1.32600201e-06   9.99988735e-01   4.57707785e-07
   6.59706700e-07   3.55950903e-07

KeyboardInterrupt: 

### Plots

In [None]:
%matplotlib inline
plt.plot(losses)
plt.title('Loss')
plt.figure()
plt.plot(acc)
plt.title('Accuracy')

In [None]:
x = np.array(range(5))
print(x)

print(x[1:2])