<a href="https://colab.research.google.com/github/jdspell/Art-Gallery/blob/master/charLSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

In [0]:
with open('anna.txt', 'r') as f:
  text = f.read()

In [0]:
#create two dictionaries
#1. int2char, maps integers to characters
#2. char2int, maps characters to integers

chars = tuple(set(text))

int2char = dict(enumerate(chars))

char2int = { ch: ii for ii, ch in int2char.items() }

encoded = np.array([char2int[ch] for ch in text])

In [6]:
text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [7]:
encoded[:100]

array([31, 29, 79, 53, 17,  0, 77, 30, 74, 21, 21, 21, 60, 79, 53, 53, 58,
       30, 50, 79, 42, 62, 40, 62,  0, 41, 30, 79, 77,  0, 30, 79, 40, 40,
       30, 79, 40, 62, 22,  0,  5, 30,  0, 47,  0, 77, 58, 30, 68, 23, 29,
       79, 53, 53, 58, 30, 50, 79, 42, 62, 40, 58, 30, 62, 41, 30, 68, 23,
       29, 79, 53, 53, 58, 30, 62, 23, 30, 62, 17, 41, 30, 78, 39, 23, 21,
       39, 79, 58, 49, 21, 21,  4, 47,  0, 77, 58, 17, 29, 62, 23])

In [0]:
def one_hot_encode(arr, n_labels):
  #initialize the encoded array
  one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)

  #fill the appropriate elements with ones
  one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1

  #reshape to retrieve the original array
  one_hot = one_hot.reshape((*arr.shape, n_labels))

  return one_hot

In [0]:
def get_batches(arr, n_seqs, n_steps):
  '''
  generator that returns batches of size n_seqs * n_steps from arr
  '''

  batch_size = n_seqs * n_steps
  n_batches = len(arr)//batch_size

  #keep enough char for full batches
  arr = arr[:n_batches * batch_size]

  #reshape into n_seqs rows
  arr = arr.reshape((n_seqs, -1))

  for n in range(0, arr.shape[1], n_steps):

    #the features
    x = arr[:, n:n+n_steps]

    #the targets shifted by one
    y = np.zeros_like(x)

    try:
      y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+n_steps]
    except IndexError:
      y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]
    yield x, y

In [0]:
batches = get_batches(encoded, 10, 50)
x, y = next(batches)

In [14]:
print('x\n', x[:10, :10])
print('\ny\n', y[:10, :10])

x
 [[31 29 79 53 17  0 77 30 74 21]
 [30 79 42 30 23 78 17 30 43 78]
 [47 62 23 49 21 21  1 13  0 41]
 [23 30 76 68 77 62 23 43 30 29]
 [30 62 17 30 62 41 24 30 41 62]
 [30 73 17 30 39 79 41 21 78 23]
 [29  0 23 30 59 78 42  0 30 50]
 [ 5 30 44 68 17 30 23 78 39 30]
 [17 30 62 41 23 10 17 49 30 51]
 [30 41 79 62 76 30 17 78 30 29]]

y
 [[29 79 53 17  0 77 30 74 21 21]
 [79 42 30 23 78 17 30 43 78 62]
 [62 23 49 21 21  1 13  0 41 24]
 [30 76 68 77 62 23 43 30 29 62]
 [62 17 30 62 41 24 30 41 62 77]
 [73 17 30 39 79 41 21 78 23 40]
 [ 0 23 30 59 78 42  0 30 50 78]
 [30 44 68 17 30 23 78 39 30 41]
 [30 62 41 23 10 17 49 30 51 29]
 [41 79 62 76 30 17 78 30 29  0]]
