# 01 LSTM Counter

The first step in this process is to develop an LSTM that is capable of counting the number of 1s in a binary string. 

The code here is capable of overfitting on strings of length 12. The next step is to test how it performs on longer range strings, and then how well it can learn to generalize. The goal is to get it to train on strings of length $\le 64$ and to test on strings of length $>64, < 128$

**Links/References:**
1.   https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html


## Package Loading

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from sklearn.model_selection import train_test_split
device = 'cuda'

## Model Definition

In [0]:
class Counter():
    def __init__(self, hidden):
        '''hidden is the number of hidden variables to use per cell'''

        #this LSTM goes from input [batch x length x 1] to output [batch x length x hidden]
        self.lstm = nn.LSTM(1, hidden, batch_first=True).double().cuda() 

        #this matrix transforms from [1 x hidden] to [1 x 1]
        self.combine = torch.randn([hidden,1], dtype=float, device=device, requires_grad=True) 

        params = list(self.lstm.parameters())
        params.append(self.combine)
        self.optimizer = optim.Adam(params)

    @staticmethod
    def convert_sequence(seq):
        '''converts a set of sequences with the same length from array or numpy into a correctly formatted tensor.
        Shape: [batch x length x 1]'''
        seq = torch.tensor(seq, device=device).double()
        seq = seq.reshape([len(seq), -1, 1])
        return seq

    def predict(self, sequence):
        '''takes a tensor, predicts the sum of the tensor, and compares to the actual sum of the tensor. 
        Returns the loss and the predicted sum'''
        pred, _ = self.lstm(sequence)
        pred = pred[:,-1,:] @ self.combine 
        #the second index is the rolling output. The final output is the last element in that index

        loss = (pred - sequence.sum(1)).pow(2)
        return loss, pred

## Model Testing

Here I generate all length 12 binary strings. I split into train and test, train on the train, and evaluate on the test. 

For training, I predict all the sums of all the strings, average the losses, and step down the loss. 

For testing, I just predict all the sums of all the strings and compare. 

In [0]:
#generate all the strings and partition into train and test
length = 12
seqs = np.arange(2**length)[:,None]
seqs = np.apply_along_axis(lambda x: np.array(list(np.binary_repr(x[0], width=length)), dtype=int), 1, seqs)
train, test = train_test_split(seqs, test_size=0.2, shuffle=True)
train = Counter.convert_sequence(train)
test  = Counter.convert_sequence(test)

print(train.shape)
print(test.shape)


torch.Size([3276, 12, 1])
torch.Size([820, 12, 1])


In [0]:
#train over all the training data
model = Counter(hidden=2)

for epoch in range(10000): 
    #take the average loss over all the train data
    loss = model.predict(train)[0].mean()

    if epoch % 500 == 0:
        print(loss.item())
    
    #and update
    model.optimizer.zero_grad()
    loss.backward(retain_graph=True)
    model.optimizer.step()

37.262636851006334
15.20766886075256
6.652507507842696
4.256048111054809
3.3252727085772573
3.021195004300546
2.325172169588343
0.888399799634254
0.5395329202144713
0.35725916545191466
0.2468087923382904
0.17381460975124424
0.12233646947481083
0.08496430968063116
0.058486344887476206
0.04043293520153698
0.02829033415101975
0.02005198223465755
0.014380488517167559
0.010408532599298762


In [0]:
#display testing results
loss, pred = model.predict(test)
print("Average Loss:", loss.mean().item())


Average Loss: 0.0040787440506473


In [0]:
#string, then actual count, then predicted count
for k in range(len(pred)):
    print(test[k].cpu().numpy().astype(int).squeeze(), "%2d"%test[k].sum().item(), "%6.3f"%pred[k].item())

[0 1 1 0 1 0 1 0 1 0 0 0]  5  5.015
[1 1 0 1 1 0 0 0 0 1 1 0]  6  5.972
[1 1 0 1 1 1 0 1 1 0 1 0]  8  8.011
[1 1 0 1 0 0 1 0 0 1 0 1]  6  5.986
[1 0 0 0 1 0 0 1 0 0 0 1]  4  3.991
[0 0 0 1 1 0 1 1 0 0 1 0]  5  4.994
[0 1 0 1 0 1 0 1 0 0 1 0]  5  4.982
[0 1 1 0 1 0 1 1 0 0 1 1]  7  6.965
[0 0 1 1 0 1 1 0 0 1 1 1]  7  6.984
[0 0 1 1 1 0 1 1 1 0 1 1]  8  8.034
[0 0 0 0 0 1 1 1 1 1 0 0]  5  5.064
[0 0 1 0 1 1 1 0 1 1 0 1]  7  6.990
[1 0 1 0 1 1 0 1 0 1 0 1]  7  6.956
[0 0 0 1 1 1 1 1 1 0 1 1]  8  8.071
[1 0 0 1 0 0 1 1 1 0 1 1]  7  7.009
[0 0 0 1 1 0 1 1 0 1 1 1]  7  7.018
[0 1 0 1 0 1 1 0 0 1 1 0]  6  5.947
[1 0 1 0 1 0 0 0 1 0 1 0]  5  4.987
[1 1 0 0 0 0 1 0 0 0 0 1]  4  4.019
[1 1 0 1 0 0 1 1 1 0 0 0]  6  5.986
[1 1 1 0 0 0 1 0 0 0 1 0]  5  5.044
[1 1 0 0 0 0 0 1 0 1 1 1]  6  6.023
[1 1 1 0 1 0 1 0 1 0 1 0]  7  6.941
[0 1 0 0 1 0 0 1 0 0 1 1]  5  4.987
[0 1 1 0 1 1 0 0 1 1 1 0]  7  6.934
[1 1 0 1 1 1 1 1 0 1 1 1] 10  9.697
[0 1 1 0 0 1 1 0 1 0 0 1]  6  5.977
[0 1 1 1 1 0 1 0 1 1 1 0]  8