# 05 Differentiable Counter

Here is a differentiable counter for the LSTM to learn to use to count numbers. The counter works as follows:


*   Upon seeing a sequence element $x_i$, the LSTM predicts a value $0 \le r_i \le 1$, which corresponds to the probability that the counter should be incremented by 1.
*   The total count for a sequence is then $c=\sum_i r_i$
* The target for the sequence is $\sum_i (x_i = 1)$



In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
%cd drive/My\ Drive/CS281\ Final\ Project

## Package Loading

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from random import shuffle
import numpy as np
from sklearn.model_selection import train_test_split
device = 'cuda'

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle
from random import sample

## Model Definition

In [0]:
class Counter():
    def __init__(self, hidden):
        '''hidden is the number of hidden variables to use per cell'''
        self.hidden = hidden
        #this LSTM goes from input [batch x length x 1] to output [batch x length x hidden]
        self.lstm = nn.LSTM(1, hidden, batch_first=True).double().cuda() 

        #this matrix transforms from [1 x hidden] to [1 x 1]
        self.combine = torch.randn([hidden,1], dtype=float, device=device, requires_grad=True) 

        params = list(self.lstm.parameters())
        params.append(self.combine)
        self.optimizer = optim.Adam(params)

    @staticmethod
    def convert_sequence(seq):
        '''converts a set of sequences with the same length from array or numpy into a correctly formatted tensor.
        Shape: [batch x length x 1]'''
        seq = torch.tensor(seq, device=device).double()
        seq = seq.reshape([len(seq), -1, 1])
        return seq

    def predict(self, sequence):
        '''takes a tensor, predicts the sum of the tensor, and compares to the actual sum of the tensor. 
        Returns the loss and the predicted sum'''

        h = torch.zeros([1,sequence.shape[0],self.hidden]).double().cuda()
        c = torch.zeros([1,sequence.shape[0],self.hidden]).double().cuda()

        o, _ = self.lstm(sequence)
        add = torch.sigmoid(o @ self.combine)
        count = add.sum(1).squeeze(1)
        true = sequence.sum(1).squeeze(1)
        loss = (count - true).pow(2)
        return loss, count

    def predict_multilength(self, sequences):
        '''Takes a list of batches of tensors of different length. Predicts on each batch. Sums up the loss. Reduces to a single mean'''
        loss = torch.tensor(0, device=device).double()
        count= torch.tensor(0, device=device).double()
        for s in sequences:
            res    = self.predict(s)[0]
            count += res.shape[0]
            loss  += res.sum()
        return loss / count

## Data Generation

In [0]:
def generate_data(length, total):
    counts = np.random.dirichlet((np.arange(length)+1)**2) * total * 0.9
    counts = np.round(counts).astype(int)

    train_set = []
    val_set = []
    test_set = []

    for i in range(1,length+1):
        if counts[i-1] == 0:
            continue
        seqs = np.random.randint(0,2, size=[counts[i-1],i])
        seqs = np.unique(seqs, axis=0)
        try:
            train, val = train_test_split(seqs, test_size=2/9, shuffle=True)
            train = Counter.convert_sequence(train)
            val = Counter.convert_sequence(val)
            train_set.append(train)
            val_set.append(val)
        except ValueError:
            continue

    counts = np.random.dirichlet((np.arange(length, 2*length)+1)**2) * total * 0.1
    counts = np.round(counts).astype(int)
    for i in range(length):
        if counts[i] == 0:
            continue
        seqs = np.random.randint(0,2, size=[counts[i],i+length+1])
        seqs = np.unique(seqs, axis=0)
        test = Counter.convert_sequence(seqs)
        test_set.append(test)

    return train_set, val_set, test_set

In [0]:
#generate all the strings and partition into train and test
length = 64
hidden = 10

output_folder = "Part-5-Outputs"

In [0]:

# print(length)
# depth = 100000
# train, val, test = generate_data(length,depth)

# with open("%s/Data.pickle"%output_folder, "wb") as f:
#     pickle.dump([train, val, test], f)

In [0]:
with open("%s/Data.pickle"%output_folder, "rb") as f:
    train, val, test = pickle.load(f)

In [0]:
trainsize = sum([x.shape[0] for x in train])
valsize = sum([x.shape[0] for x in val])
testsize = sum([x.shape[0] for x in test])
print(trainsize, valsize, testsize)

total = trainsize+valsize+testsize
print("Total:",total)
print("Fraction %.3f %.3f %.3f"%(trainsize/total, valsize/total, testsize/total))

print("Train    length range: %d-%d"%(min([x.shape[1] for x in train]), max([x.shape[1] for x in train])))
print("Validate length range: %d-%d"%(min([x.shape[1] for x in val]), max([x.shape[1] for x in val])))
print("Test     length range: %d-%d"%(min([x.shape[1] for x in test]), max([x.shape[1] for x in test])))

69924 20015 10003
Total: 99942
Fraction 0.700 0.200 0.100
Train    length range: 2-64
Validate length range: 2-64
Test     length range: 65-128


## Model Training

In [0]:
#train over all the training data
model = Counter(hidden=hidden)

history = []
best = float('inf')
patience = 10
tol = 0.001
count = 0

for epoch in range(1000000): 
    shuffle(train)
    shuffle(val)
    if epoch % 100 == 0:
        train_loss = model.predict_multilength(train).item()
        with torch.no_grad():
            val_loss = model.predict_multilength(val).item()
        history.append([train_loss, val_loss])
        print("Epoch: %5d. Train Loss: %7.3f. Validation Loss: %7.3f"%(epoch, train_loss, val_loss))

        if val_loss + tol < best:
            best = val_loss
            count = 0
        else:
            count += 1
        if count >= patience:
            break

    #take the average loss over all the train data
    loss = model.predict_multilength(train)   
    #and update
    model.optimizer.zero_grad()
    loss.backward(retain_graph=True)
    model.optimizer.step()

# history = np.array(history)
# np.save("%s/Train-History"%output_folder, history)
# torch.save(model, "%s/Model"%output_folder)

# #display testing results
# loss = model.predict_multilength(test)
# print("Average Test Loss:", loss.item())

## Results Evaluation

Deliverable for this notebook: the main figure to get from this notebook should be a plot of string length versus prediction accuracy. Generate a string of length $n$ where each digit is 1 with a probability of 10% or so, and evaluate the loss of the network on that string. Run $n$ from $65$ to $10,000$ or so (whatever is a reasonable place to stop)

In [0]:
history = np.load("%s/Train-History.npy"%output_folder)
train_loss = history[:,0]
val_loss = history[:,1]
model = torch.load("%s/Model"%output_folder)

In [0]:
loss = model.predict_multilength(test)
print("Average Test Loss:", loss.item())

Average Test Loss: 0.0009120174796806725


In [0]:
#print the true sum and predicted sum for one string of each length
for size in test:
    sample = size[np.random.choice(size.shape[0], 1)]
    print("Length: %3d Sum: %3d Pred: %7.4f"%(sample.shape[1], int(sample.cpu().numpy().sum()), model.predict(sample)[1].item()))

Length:  65 Sum:  31 Pred: 30.9915
Length:  66 Sum:  27 Pred: 26.9762
Length:  67 Sum:  30 Pred: 30.0075
Length:  68 Sum:  30 Pred: 30.0223
Length:  69 Sum:  28 Pred: 28.0328
Length:  70 Sum:  36 Pred: 36.0122
Length:  71 Sum:  39 Pred: 39.0389
Length:  72 Sum:  29 Pred: 28.9874
Length:  73 Sum:  33 Pred: 33.0060
Length:  74 Sum:  38 Pred: 38.0294
Length:  75 Sum:  36 Pred: 35.9969
Length:  76 Sum:  40 Pred: 40.0073
Length:  77 Sum:  34 Pred: 33.9626
Length:  78 Sum:  37 Pred: 36.9973
Length:  79 Sum:  43 Pred: 43.0261
Length:  80 Sum:  37 Pred: 37.0211
Length:  81 Sum:  36 Pred: 35.9805
Length:  82 Sum:  49 Pred: 48.9899
Length:  83 Sum:  34 Pred: 33.9506
Length:  84 Sum:  36 Pred: 36.0293
Length:  85 Sum:  30 Pred: 30.0006
Length:  86 Sum:  43 Pred: 43.0322
Length:  87 Sum:  50 Pred: 49.9885
Length:  88 Sum:  45 Pred: 44.9416
Length:  89 Sum:  46 Pred: 46.0317
Length:  90 Sum:  49 Pred: 48.9772
Length:  91 Sum:  38 Pred: 38.0174
Length:  92 Sum:  47 Pred: 46.9660
Length:  93 Sum:  51