In [None]:
#hide
from nbdev.showdoc import *

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from htools import add_docstring

In [None]:
# Used in notebook but not needed in package.
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import spacy

from htools import assert_raises, InvalidArgumentError
import pandas_htools

In [None]:
FPATH = Path('../data/warbreaker.txt')

In [None]:
with open(FPATH, 'r') as f:
    text = f.read()
len(text)

18509

In [None]:
len(text)

18509

In [None]:
c2i = {k: i for i, k in enumerate(sorted(set(text.lower())))}
i2c = list(c2i.keys())
print(c2i)
print(i2c)

{'\n': 0, ' ': 1, ',': 2, '-': 3, '.': 4, ':': 5, ';': 6, '?': 7, 'a': 8, 'b': 9, 'c': 10, 'd': 11, 'e': 12, 'f': 13, 'g': 14, 'h': 15, 'i': 16, 'j': 17, 'k': 18, 'l': 19, 'm': 20, 'n': 21, 'o': 22, 'p': 23, 'q': 24, 'r': 25, 's': 26, 't': 27, 'u': 28, 'v': 29, 'w': 30, 'x': 31, 'y': 32, 'z': 33, '—': 34, '’': 35, '“': 36, '”': 37}
['\n', ' ', ',', '-', '.', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '—', '’', '“', '”']


In [None]:
nlp = spacy.load('en_core_web_sm', disable=['ner', 'tagger', 'parser'])

In [None]:
def tokenize_one(text):
    return [t.text for t in nlp(text)]

In [None]:
def tokenize(texts):
    with multiprocessing.Pool() as p:
        tokens = p.map(tokenize_one, texts)
    return tokens

In [None]:
tokens = tokenize_one(text)

In [None]:
len(tokens)

4103

## Issues

- currently assuming all words len >= 4
- haven't used any padding, so inputs are all different lengths
- haven't used padding, so outputs are all different lengths
- character encode? word encode? figure out how to handle

In [None]:
class CharJumbleDS(Dataset):
    
    def __init__(self, tokens, c2i, window=3):
        # TO DO: For now, start by assuming all words have len >= 4. Fix later.
        self.tokens = [t for t in tokens if len(t) >= 4]
        self.c2i = c2i
        self.i2c = list(c2i.keys())
        self.window = window
        self.mid_i = window // 2
        
    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx+self.window]
        label = self.encode(' '.join(chunk))   # Only needed for seq2seq approach in v3
        mid = chunk[self.mid_i]
        mid_len = len(mid)
        order = np.random.permutation(mid_len - 2) + 1
        chunk[self.mid_i] = mid[0]  + ''.join(mid[i] for i in order) + mid[-1]
        # This version returns the order that was used to permute the original indices.
        # Maybe less intuitive but simpler - can always do the conversion in some 
        # prediction wrapper that doesn't add computation during training.
#         return chunk, [0] + list(order) + [mid_len-1]

        # This version returns the order to map from the permuted indices to the original 
        # indices. Intuitive but adds computation and hard-to-read logic.
#         return (chunk, 
#                 [0] 
#               + [k for k, v in sorted(dict(enumerate(order, 1)).items(),key=lambda x: x[1])] 
#               + [mid_len-1])

        # V3: just return whole seq of char indices as input and output.
        # Prob more computationally expensive (seq2seq vs multiclass classification)
        return self.encode(' '.join(chunk)), label
    
    def encode(self, word_str):
        return [self.c2i[char] for char in word_str.lower()]
    
    def decode(self, idx):
        return ''.join(self.i2c[i] for i in idx)
        
    def __len__(self):
        return len(self.tokens)
    
    def __repr__(self):
        return f'CharJumbleDS(len={len(self)})'

In [None]:
ds = CharJumbleDS(tokens, c2i, 4)
ds

CharJumbleDS(len=1953)

In [None]:
for i in range(50):
    x, y = ds[i]
    print(x)
    print(y)
    print(ds.decode(x))
    print(ds.decode(y))
    print()

[13, 28, 21, 21, 32, 1, 29, 8, 26, 15, 12, 25, 1, 27, 22, 14, 28, 15, 15, 27, 1, 20, 8, 21, 32]
[13, 28, 21, 21, 32, 1, 29, 8, 26, 15, 12, 25, 1, 27, 15, 22, 28, 14, 15, 27, 1, 20, 8, 21, 32]
funny vasher toguhht many
funny vasher thought many

[29, 8, 26, 15, 12, 25, 1, 27, 15, 22, 28, 14, 15, 27, 1, 20, 8, 21, 32, 1, 27, 15, 16, 21, 14, 26]
[29, 8, 26, 15, 12, 25, 1, 27, 15, 22, 28, 14, 15, 27, 1, 20, 8, 21, 32, 1, 27, 15, 16, 21, 14, 26]
vasher thought many things
vasher thought many things

[27, 15, 22, 28, 14, 15, 27, 1, 20, 8, 21, 32, 1, 27, 16, 14, 15, 21, 26, 1, 9, 12, 14, 16, 21]
[27, 15, 22, 28, 14, 15, 27, 1, 20, 8, 21, 32, 1, 27, 15, 16, 21, 14, 26, 1, 9, 12, 14, 16, 21]
thought many tighns begin
thought many things begin

[20, 8, 21, 32, 1, 27, 15, 16, 21, 14, 26, 1, 9, 16, 14, 12, 21, 1, 30, 16, 27, 15]
[20, 8, 21, 32, 1, 27, 15, 16, 21, 14, 26, 1, 9, 12, 14, 16, 21, 1, 30, 16, 27, 15]
many things bigen with
many things begin with

[27, 15, 16, 21, 14, 26, 1, 9, 12, 14, 1