In [None]:
import pickle
import numpy as np
from os import makedirs
from os.path import exists
from tqdm.notebook import tqdm
from os.path import dirname
from random import shuffle
from torch.utils.data import Dataset

class SyntheticSequenceDataset(Dataset):

  N = 12

  def __init__(self, dataset_cache: str = 'data/synthetic_dataset.pickle', 
                force_recompute: bool = False):

    self._data = None
    self._train = True
    self.force_recompute = force_recompute
    self.dataset_cache   = dataset_cache

    self.create_data()

  def eval(self):
    self._train = False
  
  def train(self):
    self._train = True

  def __len__(self):
    return self._data[0].shape[0] if self._train else self._data[2].shape[0]

  def __getitem__(self, item: int):
    return (self._data[0][item], self._data[1][item]) if self._train else \
      (self._data[2][item], self._data[3][item])

  def create_data(self):

    if not self._data:
      if not self.force_recompute and exists(self.dataset_cache):
        print('Loading dataset from cache...')
        with open(self.dataset_cache, 'rb') as dump_file:
          dataset = pickle.load(dump_file)
      else:
        print('Recomputing dataset...')
        dataset = self._compute_dataset()
        if not exists(dirname(self.dataset_cache)):
          makedirs(dirname(self.dataset_cache))
        with open(self.dataset_cache, 'wb') as dump_file:
          pickle.dump(dataset, dump_file)

      # Store data
      self._data = dataset

    return self._data

  def _compute_dataset(self):

    num_examples    = 2 ** SyntheticSequenceDataset.N

    # there are N + 1 = 0, 1, ..., N different classes
    num_classes     = SyntheticSequenceDataset.N + 1

    # How many examples to use for training (others are for test)
    num_train_examples = int(0.8 * num_examples)

    # Generate 2**N binary strings
    data_strings = [('{0:0%db}'%self.N).format(i) for i in range(num_examples)]

    # Shuffle sequences
    shuffle(data_strings)

    # Cast to numeric each generated binary string
    data_x, data_y = [], []
    for i in tqdm(range(num_examples)):
      train_sequence = []
      for binary_char in data_strings[i]:
        value = int(binary_char)
        train_sequence.append([value])
      # examples are binary sequences of int {0, 1}
      data_x.append(train_sequence)           
        # targets are the number of ones in the sequence
      data_y.append(np.sum(train_sequence))  

    # Separate suggested training and test data
    train_data      = np.array(data_x[:num_train_examples], dtype=np.float32)
    train_targets   = np.array(data_y[:num_train_examples])
    test_data       = np.array(data_x[num_train_examples:], dtype=np.float32)
    test_targets    = np.array(data_y[num_train_examples:])

    return train_data, train_targets, test_data, test_targets

In [None]:
from torch.utils.data import DataLoader

synthetic_dataset = SyntheticSequenceDataset(force_recompute=True)

BATCH_SIZE = 32

dl = DataLoader(dataset=synthetic_dataset, batch_size=BATCH_SIZE, 
                num_workers=0, drop_last=True, shuffle=True)

In [None]:
for _ in range(10):
  (x, y) = iter(dl).next()
  x, y = x[0], y[0]
  print('Input: ', ''.join([str(int(xcur[0])) for xcur in x.tolist()]), f'-> y: {y.item()}')

In [None]:
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

class LSTM4Counting(nn.Module):

  def __init__(self, num_features_in: int, hidden_dim: int, 
                num_classes: int):

    super(LSTM4Counting, self).__init__()
    self.hidden_dim = hidden_dim

    self.lstm = nn.LSTM(num_features_in, hidden_dim, batch_first=True)

    self.net = nn.Sequential(
      nn.Linear(hidden_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, num_classes)
    )

  def forward(self, X: torch.Tensor):
    _, (h_n, _) = self.lstm(X)
    h_n = h_n[0]
    output = self.net(h_n)
    return output

In [None]:
def eval_acc(net: nn.Module, data_loader: torch.utils.data.DataLoader, 
             device: torch.device):
  
  correct = 0
  total = 0
  
  with torch.no_grad():
    for x, y in data_loader:
      x, y = x.to(device), y.to(device)
      y_pred = model(x)
      correct += torch.sum((y == y_pred.max(1)[1])).item()
      total += y_pred.size(0)
  
  return correct/total

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
from torch.optim import RMSprop
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

num_hidden      = 50
num_epochs      = 100
learning_rate   = 0.001
num_features_in = 1
num_classes     = SyntheticSequenceDataset.N + 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = LSTM4Counting(num_features_in=num_features_in, hidden_dim=num_hidden, 
                       num_classes=num_classes).to(device)

loss_fun = nn.CrossEntropyLoss().to(device)
opt = RMSprop(model.parameters(), learning_rate)

now = datetime.now()
train_name = f'{now.hour}:{now.minute}:{now.second}/'
writer = SummaryWriter('./logs/' + train_name)
  
for e in tqdm(range(num_epochs)):

  model.eval()
  
  synthetic_dataset.train()
  train_acc = eval_acc(model, dl, device)

  synthetic_dataset.eval()
  test_acc = eval_acc(model, dl, device)

  writer.add_scalar('Acc/train', train_acc, e)
  writer.add_scalar('Acc/test',  test_acc, e)

  model.train()
  synthetic_dataset.train()

  for i, (x, y) in enumerate(dl):
    x, y = x.to(device), y.to(device)

    opt.zero_grad()
    y_pred = model(x)
    loss = loss_fun(y_pred, y)
    writer.add_scalar('Loss/train', loss.cpu().item(), i + e * len(dl))
    loss.backward()
    opt.step()  


In [None]:
print('\n' + 50 * '*' + '\nInteractive Session\n' + 50 * '*')

while True:
  my_sequence = input('Write your own binary sequence of up to %d digits in {0, 1} (write anything else to exit):\n' % synthetic_dataset.N)
  if len(set(my_sequence)) <= 2 and ('0' in my_sequence or '1' in my_sequence):

    # Pad shorter sequences
    if len(my_sequence) < synthetic_dataset.N:
      my_sequence = (synthetic_dataset.N - len(my_sequence))*'0' + my_sequence
      print("Sequence will be padded to", my_sequence)

    # Crop longer sequences
    if len(my_sequence) > synthetic_dataset.N:
      my_sequence = my_sequence[:synthetic_dataset.N]
      print("Sequence will be cropped to", my_sequence)

    # Prepare example
    test_example = []
    for binary_char in my_sequence:
      test_example.append([float(binary_char)])

    test_example = torch.FloatTensor(test_example).unsqueeze(0).to(device)
    y_pred = model(test_example)
    y_pred = torch.argmax(y_pred).item()
    y_real = int(torch.sum(test_example).item())
    print(f'Predicted number of ones: {y_pred} - Real: {y_real}\n')
  else:
    print("Stopping")
    break