In [20]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [21]:
from fastai.learner import *

import torchtext
from torchtext import vocab, data
from torchtext.datasets import language_modeling

from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *

import dill as pickle

from spooky import *

In [22]:
PATH = 'data/spooky'

os.makedirs(f'{PATH}/models', exist_ok=True)
os.makedirs(f'{PATH}/tmp', exist_ok=True)

bs = 64
bptt = 70

emb_sz = 400       # size of each embedding vector
nh = 1024           # of hidden activations per layer
nl = 3             # of layers

# for NLP, configure Adam to use less momentum than the defaul of 0.9
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

In [23]:
# get raw training and test datasets
train_raw_df = pd.read_csv(f'{PATH}/train.csv')
test_df = pd.read_csv(f'{PATH}/test.csv')

len(train_raw_df), len(test_df)

(19579, 8392)

In [24]:
# build train and validation datasets
val_idxs = get_cv_idxs(len(train_raw_df), val_pct=0.05)

train_df =  train_raw_df.drop(val_idxs)
val_df = train_raw_df.iloc[val_idxs]

len(train_df), len(val_df), len(test_df)

(18601, 978, 8392)

In [25]:
# tokenize = split each sentence into a list of words
' '.join(spacy_tok(train_df.text.iloc[0]))

'This process , however , afforded me no means of ascertaining the dimensions of my dungeon ; as I might make its circuit , and return to the point whence I set out , without being aware of the fact ; so perfectly uniform seemed the wall .'

In [26]:
#createa torchtext field = describes how to preprocess a piece of text
TEXT = data.Field(lower=True, tokenize=spacy_tok)

In [27]:
FILES = dict(train_df=train_df, val_df=val_df, test_df=test_df)

# min_freq = 10 says, "treat any word that appears less than 10 times as the word <unk>"
md = LanguageModelData.from_dataframes(PATH, TEXT, 'text', **FILES, bs=bs, bptt=bptt, min_freq=10)

In [28]:
# after building the ModelData object, TEXT.vocab is set.  because this will be needed again, save it
pickle.dump(TEXT, open(f'{PATH}/models/TEXT.pkl', 'wb'))

In [29]:
# batches, # of unique tokens in vocab, # of items in ds, # of words in ds
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)

(125, 4755, 1, 565792)

In [30]:
# int to string mapping
TEXT.vocab.itos[:12]

['<unk>', '<pad>', ',', 'the', 'of', '.', 'and', 'to', 'i', 'a', 'in', 'was']

In [31]:
# string to int mapping
TEXT.vocab.stoi['the']

3

In [32]:
# in a LanguageModelData object there is only one item in each dataset: all the words joined together
md.trn_ds[0].text[:12]

['this',
 'process',
 ',',
 'however',
 ',',
 'afforded',
 'me',
 'no',
 'means',
 'of',
 'ascertaining',
 'the']

In [33]:
# torchtext will handle turning this words into integer Ids
TEXT.numericalize([md.trn_ds[0].text[:12]])

Variable containing:
   31
 2949
    2
  151
    2
 1431
   27
   42
  301
    4
    0
    3
[torch.cuda.LongTensor of size 12x1 (GPU 0)]

In [34]:
batch = next(iter(md.trn_dl))
print(batch[0].size()), print(batch[1].size())

batch

torch.Size([76, 64])
torch.Size([4864])


(Variable containing:
    31     6    52  ...      2     3     2
  2949     0     0  ...      3     0     0
     2     7     2  ...      0     7     2
        ...          ⋱          ...       
    28     2     4  ...      7    20   320
    24  3289     3  ...   1074     0     2
     2   216     0  ...     26     5    10
 [torch.cuda.LongTensor of size 76x64 (GPU 0)], Variable containing:
  2949
     0
     0
   ⋮  
   338
   325
     3
 [torch.cuda.LongTensor of size 4864 (GPU 0)])

In [35]:
learner = md.get_model(opt_fn, emb_sz, nh, nl,
                      dropouti=0.1, dropout=0.1, wdrop=0.2, dropoute=0.04, dropouth=0.1)

learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learner.clip = 0.3

In [None]:
lrf = learner.lr_find() 

In [None]:
learner.sched.plot()

In [36]:
lr = 1e-3

In [37]:
learner.fit(lr, 4, wds=1e-6, cycle_len=1, cycle_mult=2)

  6%|▌         | 7/125 [00:02<00:42,  2.80it/s, loss=8.46]
  6%|▋         | 8/125 [00:02<00:39,  2.97it/s, loss=8.46]

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration



[ 0.       5.76916  5.64622]                                
[ 1.       5.41764  5.26515]                                
[ 2.       5.23646  5.1827 ]                                
[ 3.       5.11374  4.99851]                                
[ 4.       4.98249  4.88693]                                
[ 5.       4.89676  4.84356]                                
[ 6.       4.85656  4.83048]                                
[ 7.       4.84061  4.74867]                                
[ 8.       4.74853  4.70073]                                
[ 9.       4.66462  4.63213]                                
[ 10.        4.58969   4.57953]                             
[ 11.        4.52939   4.55109]                             
[ 12.        4.49072   4.5317 ]                             
[ 13.        4.45834   4.52138]                             
[ 14.        4.45779   4.51699]                             



In [38]:
learner.save_encoder('spooky_adam_enc1')
# learner.load_encoder('spooky_adam1_enc')

In [None]:
learner.fit(lr, 1, wds=1e-6, cycle_len=10, cycle_save_name='spooky_adam_enc2_c1_cl10')

In [41]:
# learner.load_cycle('spooky_adam_enc2_c2_cl10',0)

In [42]:
learner.save_encoder('spooky_adam_enc2')
# learner.load_encoder('spooky_adam2_enc')

In [43]:
learner.fit(lr/2, 1, wds=1e-6, cycle_len=10, cycle_save_name='spooky_adam_enc3_c1_cl10')


  0%|          | 0/125 [00:00<?, ?it/s][A

  0%|          | 0/125 [00:00<?, ?it/s, loss=3.9][A
  1%|          | 1/125 [00:00<00:53,  2.33it/s, loss=3.9][A

Exception in thread Thread-21:
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration




  1%|          | 1/125 [00:00<01:35,  1.30it/s, loss=3.98][A
  2%|▏         | 2/125 [00:00<00:47,  2.60it/s, loss=3.98][A
  2%|▏         | 2/125 [00:01<01:09,  1.76it/s, loss=4.06][A
  2%|▏         | 3/125 [00:01<00:46,  2.64it/s, loss=4.06][A
  2%|▏         | 3/125 [00:01<01:01,  1.98it/s, loss=4.04][A
  3%|▎         | 4/125 [00:01<00:45,  2.63it/s, loss=4.04][A
  3%|▎         | 4/125 [00:01<00:55,  2.18it/s, loss=4.06][A
  4%|▍         | 5/125 [00:01<00:44,  2.72it/s, loss=4.06][A
  4%|▍         | 5/125 [00:02<00:52,  2.27it/s, loss=4.04][A
  5%|▍         | 6/125 [00:02<00:43,  2.72it/s, loss=4.04][A
  5%|▍         | 6/125 [00:02<00:51,  2.32it/s, loss=4.03][A
  6%|▌         | 7/125 [00:02<00:43,  2.71it/s, loss=4.03][A
  6%|▌         | 7/125 [00:02<00:50,  2.34it/s, loss=4.03][A
  6%|▋         | 8/125 [00:02<00:43,  2.67it/s, loss=4.03][A
  6%|▋         | 8/125 [00:03<00:49,  2.37it/s, loss=4.03][A
  7%|▋         | 9/125 [00:03<00:43,  2.67it/s, loss=4.03][A
  7%|▋ 

 53%|█████▎    | 66/125 [00:24<00:21,  2.74it/s, loss=4.05][A
 54%|█████▎    | 67/125 [00:24<00:20,  2.78it/s, loss=4.05][A
 54%|█████▎    | 67/125 [00:24<00:21,  2.74it/s, loss=4.05][A
 54%|█████▍    | 68/125 [00:24<00:20,  2.78it/s, loss=4.05][A
 54%|█████▍    | 68/125 [00:24<00:20,  2.73it/s, loss=4.05][A
 55%|█████▌    | 69/125 [00:24<00:20,  2.77it/s, loss=4.05][A
 55%|█████▌    | 69/125 [00:25<00:20,  2.73it/s, loss=4.05][A
 56%|█████▌    | 70/125 [00:25<00:19,  2.77it/s, loss=4.05][A
 56%|█████▌    | 70/125 [00:25<00:20,  2.73it/s, loss=4.04][A
 57%|█████▋    | 71/125 [00:25<00:19,  2.77it/s, loss=4.04][A
 57%|█████▋    | 71/125 [00:26<00:19,  2.73it/s, loss=4.04][A
 58%|█████▊    | 72/125 [00:26<00:19,  2.77it/s, loss=4.04][A
 58%|█████▊    | 72/125 [00:26<00:19,  2.73it/s, loss=4.04][A
 58%|█████▊    | 73/125 [00:26<00:18,  2.77it/s, loss=4.04][A
 58%|█████▊    | 73/125 [00:26<00:19,  2.73it/s, loss=4.04][A
 59%|█████▉    | 74/125 [00:26<00:18,  2.76it/s, loss=4

  4%|▍         | 5/125 [00:01<00:41,  2.89it/s, loss=4.07][A
  4%|▍         | 5/125 [00:02<00:49,  2.40it/s, loss=4.07][A
  5%|▍         | 6/125 [00:02<00:41,  2.88it/s, loss=4.07][A
  5%|▍         | 6/125 [00:02<00:48,  2.44it/s, loss=4.07][A
  6%|▌         | 7/125 [00:02<00:41,  2.84it/s, loss=4.07][A
  6%|▌         | 7/125 [00:02<00:47,  2.46it/s, loss=4.07][A
  6%|▋         | 8/125 [00:02<00:41,  2.81it/s, loss=4.07][A
  6%|▋         | 8/125 [00:03<00:47,  2.46it/s, loss=4.07][A
  7%|▋         | 9/125 [00:03<00:42,  2.76it/s, loss=4.07][A
  7%|▋         | 9/125 [00:03<00:46,  2.50it/s, loss=4.07][A
  8%|▊         | 10/125 [00:03<00:41,  2.77it/s, loss=4.07][A
  8%|▊         | 10/125 [00:04<00:46,  2.50it/s, loss=4.07][A
  9%|▉         | 11/125 [00:04<00:41,  2.75it/s, loss=4.07][A
  9%|▉         | 11/125 [00:04<00:45,  2.52it/s, loss=4.07][A
 10%|▉         | 12/125 [00:04<00:41,  2.74it/s, loss=4.07][A
 10%|▉         | 12/125 [00:04<00:44,  2.52it/s, loss=4.06][A
 1

 56%|█████▌    | 70/125 [00:25<00:20,  2.70it/s, loss=4.08][A
 56%|█████▌    | 70/125 [00:26<00:20,  2.66it/s, loss=4.08][A
 57%|█████▋    | 71/125 [00:26<00:20,  2.70it/s, loss=4.08][A
 57%|█████▋    | 71/125 [00:26<00:20,  2.68it/s, loss=4.08][A
 58%|█████▊    | 72/125 [00:26<00:19,  2.72it/s, loss=4.08][A
 58%|█████▊    | 72/125 [00:26<00:19,  2.70it/s, loss=4.08][A
 58%|█████▊    | 73/125 [00:26<00:18,  2.74it/s, loss=4.08][A
 58%|█████▊    | 73/125 [00:27<00:19,  2.70it/s, loss=4.08][A
 59%|█████▉    | 74/125 [00:27<00:18,  2.74it/s, loss=4.08][A
 59%|█████▉    | 74/125 [00:27<00:18,  2.70it/s, loss=4.08][A
 60%|██████    | 75/125 [00:27<00:18,  2.74it/s, loss=4.08][A
 60%|██████    | 75/125 [00:27<00:18,  2.70it/s, loss=4.08][A
 61%|██████    | 76/125 [00:27<00:17,  2.74it/s, loss=4.08][A
 61%|██████    | 76/125 [00:28<00:18,  2.70it/s, loss=4.08][A
 62%|██████▏   | 77/125 [00:28<00:17,  2.73it/s, loss=4.08][A
 62%|██████▏   | 77/125 [00:28<00:17,  2.69it/s, loss=4

  6%|▋         | 8/125 [00:03<00:50,  2.30it/s, loss=4.04][A
  7%|▋         | 9/125 [00:03<00:44,  2.58it/s, loss=4.04][A
  7%|▋         | 9/125 [00:03<00:49,  2.33it/s, loss=4.04][A
  8%|▊         | 10/125 [00:03<00:44,  2.59it/s, loss=4.04][A
  8%|▊         | 10/125 [00:04<00:49,  2.34it/s, loss=4.03][A
  9%|▉         | 11/125 [00:04<00:44,  2.57it/s, loss=4.03][A
  9%|▉         | 11/125 [00:04<00:46,  2.47it/s, loss=4.04][A
 10%|▉         | 12/125 [00:04<00:42,  2.69it/s, loss=4.04][A
 10%|▉         | 12/125 [00:04<00:45,  2.50it/s, loss=4.03][A
 10%|█         | 13/125 [00:04<00:41,  2.71it/s, loss=4.03][A
 10%|█         | 13/125 [00:05<00:44,  2.54it/s, loss=4.04][A
 11%|█         | 14/125 [00:05<00:40,  2.73it/s, loss=4.04][A
 11%|█         | 14/125 [00:05<00:43,  2.53it/s, loss=4.04][A
 12%|█▏        | 15/125 [00:05<00:40,  2.71it/s, loss=4.04][A
 12%|█▏        | 15/125 [00:05<00:43,  2.54it/s, loss=4.04][A
 13%|█▎        | 16/125 [00:05<00:40,  2.70it/s, loss=4.04

 58%|█████▊    | 73/125 [00:27<00:19,  2.66it/s, loss=4.03][A
 59%|█████▉    | 74/125 [00:27<00:18,  2.70it/s, loss=4.03][A
 59%|█████▉    | 74/125 [00:27<00:19,  2.66it/s, loss=4.03][A
 60%|██████    | 75/125 [00:27<00:18,  2.70it/s, loss=4.03][A
 60%|██████    | 75/125 [00:28<00:18,  2.66it/s, loss=4.03][A
 61%|██████    | 76/125 [00:28<00:18,  2.69it/s, loss=4.03][A
 61%|██████    | 76/125 [00:28<00:18,  2.65it/s, loss=4.02][A
 62%|██████▏   | 77/125 [00:28<00:17,  2.69it/s, loss=4.02][A
 62%|██████▏   | 77/125 [00:29<00:18,  2.65it/s, loss=4.02][A
 62%|██████▏   | 78/125 [00:29<00:17,  2.69it/s, loss=4.02][A
 62%|██████▏   | 78/125 [00:29<00:17,  2.66it/s, loss=4.02][A
 63%|██████▎   | 79/125 [00:29<00:17,  2.69it/s, loss=4.02][A
 63%|██████▎   | 79/125 [00:29<00:17,  2.65it/s, loss=4.02][A
 64%|██████▍   | 80/125 [00:29<00:16,  2.69it/s, loss=4.02][A
 64%|██████▍   | 80/125 [00:30<00:16,  2.65it/s, loss=4.02][A
 65%|██████▍   | 81/125 [00:30<00:16,  2.69it/s, loss=4

 10%|█         | 13/125 [00:05<00:43,  2.58it/s, loss=3.99][A
 10%|█         | 13/125 [00:05<00:46,  2.40it/s, loss=3.99][A
 11%|█         | 14/125 [00:05<00:42,  2.58it/s, loss=3.99][A
 11%|█         | 14/125 [00:05<00:45,  2.42it/s, loss=3.99][A
 12%|█▏        | 15/125 [00:05<00:42,  2.59it/s, loss=3.99][A
 12%|█▏        | 15/125 [00:06<00:45,  2.44it/s, loss=3.98][A
 13%|█▎        | 16/125 [00:06<00:41,  2.60it/s, loss=3.98][A
 13%|█▎        | 16/125 [00:06<00:44,  2.45it/s, loss=3.99][A
 14%|█▎        | 17/125 [00:06<00:41,  2.60it/s, loss=3.99][A
 14%|█▎        | 17/125 [00:06<00:43,  2.46it/s, loss=3.99][A
 14%|█▍        | 18/125 [00:06<00:41,  2.61it/s, loss=3.99][A
 14%|█▍        | 18/125 [00:07<00:42,  2.49it/s, loss=3.99][A
 15%|█▌        | 19/125 [00:07<00:40,  2.63it/s, loss=3.99][A
 15%|█▌        | 19/125 [00:07<00:42,  2.49it/s, loss=3.98][A
 16%|█▌        | 20/125 [00:07<00:40,  2.62it/s, loss=3.98][A
 16%|█▌        | 20/125 [00:07<00:41,  2.51it/s, loss=3

 63%|██████▎   | 79/125 [00:29<00:17,  2.68it/s, loss=3.98][A
 63%|██████▎   | 79/125 [00:29<00:17,  2.65it/s, loss=3.99][A
 64%|██████▍   | 80/125 [00:29<00:16,  2.68it/s, loss=3.99][A
 64%|██████▍   | 80/125 [00:30<00:16,  2.65it/s, loss=3.99][A
 65%|██████▍   | 81/125 [00:30<00:16,  2.68it/s, loss=3.99][A
 65%|██████▍   | 81/125 [00:30<00:16,  2.67it/s, loss=3.99][A
 66%|██████▌   | 82/125 [00:30<00:15,  2.70it/s, loss=3.99][A
 66%|██████▌   | 82/125 [00:30<00:16,  2.67it/s, loss=3.99][A
 66%|██████▋   | 83/125 [00:30<00:15,  2.70it/s, loss=3.99][A
 66%|██████▋   | 83/125 [00:31<00:15,  2.67it/s, loss=3.98][A
 67%|██████▋   | 84/125 [00:31<00:15,  2.70it/s, loss=3.98][A
 67%|██████▋   | 84/125 [00:31<00:15,  2.67it/s, loss=3.98][A
 68%|██████▊   | 85/125 [00:31<00:14,  2.70it/s, loss=3.98][A
 68%|██████▊   | 85/125 [00:31<00:14,  2.67it/s, loss=3.98][A
 69%|██████▉   | 86/125 [00:31<00:14,  2.70it/s, loss=3.98][A
 69%|██████▉   | 86/125 [00:32<00:14,  2.67it/s, loss=3

 14%|█▎        | 17/125 [00:06<00:42,  2.53it/s, loss=3.96][A
 14%|█▍        | 18/125 [00:06<00:39,  2.68it/s, loss=3.96][A
 14%|█▍        | 18/125 [00:07<00:42,  2.54it/s, loss=3.96][A
 15%|█▌        | 19/125 [00:07<00:39,  2.68it/s, loss=3.96][A
 15%|█▌        | 19/125 [00:07<00:41,  2.55it/s, loss=3.96][A
 16%|█▌        | 20/125 [00:07<00:39,  2.69it/s, loss=3.96][A
 16%|█▌        | 20/125 [00:07<00:41,  2.55it/s, loss=3.96][A
 17%|█▋        | 21/125 [00:07<00:38,  2.68it/s, loss=3.96][A
 17%|█▋        | 21/125 [00:08<00:40,  2.55it/s, loss=3.96][A
 18%|█▊        | 22/125 [00:08<00:38,  2.67it/s, loss=3.96][A
 18%|█▊        | 22/125 [00:08<00:40,  2.55it/s, loss=3.96][A
 18%|█▊        | 23/125 [00:08<00:38,  2.67it/s, loss=3.96][A
 18%|█▊        | 23/125 [00:08<00:39,  2.56it/s, loss=3.96][A
 19%|█▉        | 24/125 [00:08<00:37,  2.67it/s, loss=3.96][A
 19%|█▉        | 24/125 [00:09<00:39,  2.57it/s, loss=3.96][A
 20%|██        | 25/125 [00:09<00:37,  2.67it/s, loss=3

 66%|██████▌   | 82/125 [00:30<00:16,  2.65it/s, loss=3.93][A
 66%|██████▋   | 83/125 [00:30<00:15,  2.69it/s, loss=3.93][A
 66%|██████▋   | 83/125 [00:31<00:15,  2.65it/s, loss=3.93][A
 67%|██████▋   | 84/125 [00:31<00:15,  2.69it/s, loss=3.93][A
 67%|██████▋   | 84/125 [00:31<00:15,  2.66it/s, loss=3.93][A
 68%|██████▊   | 85/125 [00:31<00:14,  2.69it/s, loss=3.93][A
 68%|██████▊   | 85/125 [00:31<00:14,  2.67it/s, loss=3.93][A
 69%|██████▉   | 86/125 [00:31<00:14,  2.70it/s, loss=3.93][A
 69%|██████▉   | 86/125 [00:31<00:14,  2.69it/s, loss=3.93][A
 70%|██████▉   | 87/125 [00:31<00:13,  2.72it/s, loss=3.93][A
 70%|██████▉   | 87/125 [00:32<00:14,  2.68it/s, loss=3.93][A
 70%|███████   | 88/125 [00:32<00:13,  2.72it/s, loss=3.93][A
 70%|███████   | 88/125 [00:32<00:13,  2.68it/s, loss=3.93][A
 71%|███████   | 89/125 [00:32<00:13,  2.72it/s, loss=3.93][A
 71%|███████   | 89/125 [00:33<00:13,  2.69it/s, loss=3.92][A
 72%|███████▏  | 90/125 [00:33<00:12,  2.72it/s, loss=3

 17%|█▋        | 21/125 [00:07<00:38,  2.68it/s, loss=3.9][A
 17%|█▋        | 21/125 [00:08<00:40,  2.56it/s, loss=3.9][A
 18%|█▊        | 22/125 [00:08<00:38,  2.68it/s, loss=3.9][A
 18%|█▊        | 22/125 [00:08<00:39,  2.62it/s, loss=3.9][A
 18%|█▊        | 23/125 [00:08<00:37,  2.73it/s, loss=3.9][A
 18%|█▊        | 23/125 [00:08<00:38,  2.62it/s, loss=3.9][A
 19%|█▉        | 24/125 [00:08<00:36,  2.74it/s, loss=3.9][A
 19%|█▉        | 24/125 [00:09<00:38,  2.62it/s, loss=3.9][A
 20%|██        | 25/125 [00:09<00:36,  2.72it/s, loss=3.9][A
 20%|██        | 25/125 [00:09<00:38,  2.61it/s, loss=3.9][A
 21%|██        | 26/125 [00:09<00:36,  2.72it/s, loss=3.9][A
 21%|██        | 26/125 [00:09<00:38,  2.60it/s, loss=3.9][A
 22%|██▏       | 27/125 [00:09<00:36,  2.70it/s, loss=3.9][A
 22%|██▏       | 27/125 [00:10<00:37,  2.61it/s, loss=3.9][A
 22%|██▏       | 28/125 [00:10<00:35,  2.70it/s, loss=3.9][A
 22%|██▏       | 28/125 [00:10<00:37,  2.61it/s, loss=3.9][A
 23%|██▎

 69%|██████▉   | 86/125 [00:32<00:14,  2.66it/s, loss=3.89][A
 69%|██████▉   | 86/125 [00:32<00:14,  2.62it/s, loss=3.89][A
 70%|██████▉   | 87/125 [00:32<00:14,  2.65it/s, loss=3.89][A
 70%|██████▉   | 87/125 [00:33<00:14,  2.63it/s, loss=3.89][A
 70%|███████   | 88/125 [00:33<00:13,  2.66it/s, loss=3.89][A
 70%|███████   | 88/125 [00:33<00:14,  2.63it/s, loss=3.88][A
 71%|███████   | 89/125 [00:33<00:13,  2.66it/s, loss=3.88][A
 71%|███████   | 89/125 [00:33<00:13,  2.63it/s, loss=3.88][A
 72%|███████▏  | 90/125 [00:33<00:13,  2.65it/s, loss=3.88][A
 72%|███████▏  | 90/125 [00:34<00:13,  2.63it/s, loss=3.88][A
 73%|███████▎  | 91/125 [00:34<00:12,  2.65it/s, loss=3.88][A
 73%|███████▎  | 91/125 [00:34<00:12,  2.63it/s, loss=3.88][A
 74%|███████▎  | 92/125 [00:34<00:12,  2.65it/s, loss=3.88][A
 74%|███████▎  | 92/125 [00:35<00:12,  2.63it/s, loss=3.88][A
 74%|███████▍  | 93/125 [00:35<00:12,  2.66it/s, loss=3.88][A
 74%|███████▍  | 93/125 [00:35<00:12,  2.63it/s, loss=3

 19%|█▉        | 24/125 [00:09<00:38,  2.63it/s, loss=3.88][A
 20%|██        | 25/125 [00:09<00:36,  2.73it/s, loss=3.88][A
 20%|██        | 25/125 [00:09<00:37,  2.66it/s, loss=3.88][A
 21%|██        | 26/125 [00:09<00:35,  2.77it/s, loss=3.88][A
 21%|██        | 26/125 [00:09<00:37,  2.66it/s, loss=3.87][A
 22%|██▏       | 27/125 [00:09<00:35,  2.76it/s, loss=3.87][A
 22%|██▏       | 27/125 [00:10<00:36,  2.66it/s, loss=3.88][A
 22%|██▏       | 28/125 [00:10<00:35,  2.76it/s, loss=3.88][A
 22%|██▏       | 28/125 [00:10<00:36,  2.66it/s, loss=3.87][A
 23%|██▎       | 29/125 [00:10<00:34,  2.75it/s, loss=3.87][A
 23%|██▎       | 29/125 [00:10<00:35,  2.67it/s, loss=3.87][A
 24%|██▍       | 30/125 [00:10<00:34,  2.76it/s, loss=3.87][A
 24%|██▍       | 30/125 [00:11<00:35,  2.67it/s, loss=3.87][A
 25%|██▍       | 31/125 [00:11<00:34,  2.76it/s, loss=3.87][A
 25%|██▍       | 31/125 [00:11<00:34,  2.71it/s, loss=3.88][A
 26%|██▌       | 32/125 [00:11<00:33,  2.79it/s, loss=3

 71%|███████   | 89/125 [00:32<00:13,  2.71it/s, loss=3.85][A
 72%|███████▏  | 90/125 [00:32<00:12,  2.74it/s, loss=3.85][A
 72%|███████▏  | 90/125 [00:33<00:12,  2.71it/s, loss=3.85][A
 73%|███████▎  | 91/125 [00:33<00:12,  2.74it/s, loss=3.85][A
 73%|███████▎  | 91/125 [00:33<00:12,  2.71it/s, loss=3.85][A
 74%|███████▎  | 92/125 [00:33<00:12,  2.74it/s, loss=3.85][A
 74%|███████▎  | 92/125 [00:33<00:12,  2.71it/s, loss=3.85][A
 74%|███████▍  | 93/125 [00:33<00:11,  2.74it/s, loss=3.85][A
 74%|███████▍  | 93/125 [00:34<00:11,  2.72it/s, loss=3.84][A
 75%|███████▌  | 94/125 [00:34<00:11,  2.75it/s, loss=3.84][A
 75%|███████▌  | 94/125 [00:34<00:11,  2.72it/s, loss=3.85][A
 76%|███████▌  | 95/125 [00:34<00:10,  2.75it/s, loss=3.85][A
 76%|███████▌  | 95/125 [00:34<00:11,  2.72it/s, loss=3.85][A
 77%|███████▋  | 96/125 [00:34<00:10,  2.75it/s, loss=3.85][A
 77%|███████▋  | 96/125 [00:35<00:10,  2.72it/s, loss=3.85][A
 78%|███████▊  | 97/125 [00:35<00:10,  2.74it/s, loss=3

 22%|██▏       | 28/125 [00:10<00:35,  2.73it/s, loss=3.85][A
 22%|██▏       | 28/125 [00:10<00:36,  2.63it/s, loss=3.85][A
 23%|██▎       | 29/125 [00:10<00:35,  2.72it/s, loss=3.85][A
 23%|██▎       | 29/125 [00:11<00:36,  2.62it/s, loss=3.84][A
 24%|██▍       | 30/125 [00:11<00:35,  2.71it/s, loss=3.84][A
 24%|██▍       | 30/125 [00:11<00:36,  2.63it/s, loss=3.84][A
 25%|██▍       | 31/125 [00:11<00:34,  2.71it/s, loss=3.84][A
 25%|██▍       | 31/125 [00:11<00:35,  2.62it/s, loss=3.84][A
 26%|██▌       | 32/125 [00:11<00:34,  2.71it/s, loss=3.84][A
 26%|██▌       | 32/125 [00:12<00:35,  2.62it/s, loss=3.84][A
 26%|██▋       | 33/125 [00:12<00:34,  2.70it/s, loss=3.84][A
 26%|██▋       | 33/125 [00:12<00:35,  2.63it/s, loss=3.84][A
 27%|██▋       | 34/125 [00:12<00:33,  2.71it/s, loss=3.84][A
 27%|██▋       | 34/125 [00:12<00:34,  2.64it/s, loss=3.84][A
 28%|██▊       | 35/125 [00:12<00:33,  2.71it/s, loss=3.84][A
 28%|██▊       | 35/125 [00:13<00:34,  2.65it/s, loss=3

 74%|███████▍  | 93/125 [00:34<00:11,  2.72it/s, loss=3.82][A
 74%|███████▍  | 93/125 [00:34<00:11,  2.70it/s, loss=3.82][A
 75%|███████▌  | 94/125 [00:34<00:11,  2.73it/s, loss=3.82][A
 75%|███████▌  | 94/125 [00:34<00:11,  2.70it/s, loss=3.82][A
 76%|███████▌  | 95/125 [00:34<00:11,  2.72it/s, loss=3.82][A
 76%|███████▌  | 95/125 [00:35<00:11,  2.70it/s, loss=3.82][A
 77%|███████▋  | 96/125 [00:35<00:10,  2.72it/s, loss=3.82][A
 77%|███████▋  | 96/125 [00:35<00:10,  2.70it/s, loss=3.81][A
 78%|███████▊  | 97/125 [00:35<00:10,  2.73it/s, loss=3.81][A
 78%|███████▊  | 97/125 [00:35<00:10,  2.70it/s, loss=3.81][A
 78%|███████▊  | 98/125 [00:35<00:09,  2.73it/s, loss=3.81][A
 78%|███████▊  | 98/125 [00:36<00:10,  2.70it/s, loss=3.81][A
 79%|███████▉  | 99/125 [00:36<00:09,  2.73it/s, loss=3.81][A
 79%|███████▉  | 99/125 [00:36<00:09,  2.70it/s, loss=3.81][A
 80%|████████  | 100/125 [00:36<00:09,  2.72it/s, loss=3.81][A
 80%|████████  | 100/125 [00:37<00:09,  2.69it/s, loss

 25%|██▍       | 31/125 [00:11<00:34,  2.71it/s, loss=3.81][A
 26%|██▌       | 32/125 [00:11<00:33,  2.80it/s, loss=3.81][A
 26%|██▌       | 32/125 [00:11<00:34,  2.69it/s, loss=3.81][A
 26%|██▋       | 33/125 [00:11<00:33,  2.78it/s, loss=3.81][A
 26%|██▋       | 33/125 [00:12<00:34,  2.69it/s, loss=3.81][A
 27%|██▋       | 34/125 [00:12<00:32,  2.77it/s, loss=3.81][A
 27%|██▋       | 34/125 [00:12<00:33,  2.69it/s, loss=3.81][A
 28%|██▊       | 35/125 [00:12<00:32,  2.76it/s, loss=3.81][A
 28%|██▊       | 35/125 [00:13<00:33,  2.67it/s, loss=3.81][A
 29%|██▉       | 36/125 [00:13<00:32,  2.75it/s, loss=3.81][A
 29%|██▉       | 36/125 [00:13<00:33,  2.68it/s, loss=3.81][A
 30%|██▉       | 37/125 [00:13<00:32,  2.75it/s, loss=3.81][A
 30%|██▉       | 37/125 [00:13<00:32,  2.72it/s, loss=3.81][A
 30%|███       | 38/125 [00:13<00:31,  2.79it/s, loss=3.81][A
 30%|███       | 38/125 [00:13<00:32,  2.71it/s, loss=3.81][A
 31%|███       | 39/125 [00:14<00:30,  2.79it/s, loss=3

 77%|███████▋  | 96/125 [00:35<00:10,  2.74it/s, loss=3.79][A
 78%|███████▊  | 97/125 [00:35<00:10,  2.76it/s, loss=3.79][A
 78%|███████▊  | 97/125 [00:35<00:10,  2.74it/s, loss=3.79][A
 78%|███████▊  | 98/125 [00:35<00:09,  2.77it/s, loss=3.79][A
 78%|███████▊  | 98/125 [00:35<00:09,  2.74it/s, loss=3.79][A
 79%|███████▉  | 99/125 [00:35<00:09,  2.77it/s, loss=3.79][A
 79%|███████▉  | 99/125 [00:36<00:09,  2.75it/s, loss=3.79][A
 80%|████████  | 100/125 [00:36<00:09,  2.78it/s, loss=3.79][A
 80%|████████  | 100/125 [00:36<00:09,  2.75it/s, loss=3.79][A
 81%|████████  | 101/125 [00:36<00:08,  2.78it/s, loss=3.79][A
 81%|████████  | 101/125 [00:36<00:08,  2.75it/s, loss=3.79][A
 82%|████████▏ | 102/125 [00:36<00:08,  2.78it/s, loss=3.79][A
 82%|████████▏ | 102/125 [00:37<00:08,  2.75it/s, loss=3.79][A
 82%|████████▏ | 103/125 [00:37<00:07,  2.77it/s, loss=3.79][A
 82%|████████▏ | 103/125 [00:37<00:08,  2.75it/s, loss=3.79][A
 83%|████████▎ | 104/125 [00:37<00:07,  2.77it/

 28%|██▊       | 35/125 [00:12<00:31,  2.83it/s, loss=3.8][A
 28%|██▊       | 35/125 [00:12<00:32,  2.75it/s, loss=3.8][A
 29%|██▉       | 36/125 [00:12<00:31,  2.82it/s, loss=3.8][A
 29%|██▉       | 36/125 [00:13<00:32,  2.75it/s, loss=3.8][A
 30%|██▉       | 37/125 [00:13<00:31,  2.82it/s, loss=3.8][A
 30%|██▉       | 37/125 [00:13<00:31,  2.75it/s, loss=3.8][A
 30%|███       | 38/125 [00:13<00:30,  2.82it/s, loss=3.8][A
 30%|███       | 38/125 [00:13<00:31,  2.75it/s, loss=3.8][A
 31%|███       | 39/125 [00:13<00:30,  2.82it/s, loss=3.8][A
 31%|███       | 39/125 [00:14<00:31,  2.75it/s, loss=3.8][A
 32%|███▏      | 40/125 [00:14<00:30,  2.82it/s, loss=3.8][A
 32%|███▏      | 40/125 [00:14<00:30,  2.75it/s, loss=3.81][A
 33%|███▎      | 41/125 [00:14<00:29,  2.81it/s, loss=3.81][A
 33%|███▎      | 41/125 [00:14<00:30,  2.75it/s, loss=3.81][A
 34%|███▎      | 42/125 [00:14<00:29,  2.81it/s, loss=3.81][A
 34%|███▎      | 42/125 [00:15<00:30,  2.74it/s, loss=3.81][A
 34

 80%|████████  | 100/125 [00:36<00:09,  2.76it/s, loss=3.79][A
 80%|████████  | 100/125 [00:36<00:09,  2.73it/s, loss=3.79][A
 81%|████████  | 101/125 [00:36<00:08,  2.75it/s, loss=3.79][A
 81%|████████  | 101/125 [00:37<00:08,  2.72it/s, loss=3.79][A
 82%|████████▏ | 102/125 [00:37<00:08,  2.75it/s, loss=3.79][A
 82%|████████▏ | 102/125 [00:37<00:08,  2.72it/s, loss=3.79][A
 82%|████████▏ | 103/125 [00:37<00:08,  2.75it/s, loss=3.79][A
 82%|████████▏ | 103/125 [00:37<00:08,  2.72it/s, loss=3.79][A
 83%|████████▎ | 104/125 [00:37<00:07,  2.75it/s, loss=3.79][A
 83%|████████▎ | 104/125 [00:38<00:07,  2.72it/s, loss=3.79][A
 84%|████████▍ | 105/125 [00:38<00:07,  2.75it/s, loss=3.79][A
 84%|████████▍ | 105/125 [00:38<00:07,  2.72it/s, loss=3.79][A
 85%|████████▍ | 106/125 [00:38<00:06,  2.75it/s, loss=3.79][A
 85%|████████▍ | 106/125 [00:38<00:06,  2.72it/s, loss=3.79][A
 86%|████████▌ | 107/125 [00:38<00:06,  2.75it/s, loss=3.79][A
 86%|████████▌ | 107/125 [00:39<00:06,  

In [44]:
learner.save_encoder('spooky_adam_enc3')

In [46]:
learner.load_encoder('spooky_adam_enc2')

In [47]:
# metric perplexity (how language model accuracy generally measured) = exp() of loss functino
np.exp(4.33935)

76.657695638682966

In [48]:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True)

# get folds
kfolds = [ (train_idxs, val_idxs) for train_idxs, val_idxs in skf.split(train_df.id, train_df.author) ]
print(len(kfolds))

5


## Test

In [49]:
# create a short bit of text to "prime" the precitions, then use torchtext to numericalize it
# so we can feed it into our language model
m = learner.model
ss = """. It was a dark and scary night. The old"""
s = [spacy_tok(ss)]
t = TEXT.numericalize(s)
' '.join(s[0])

'. It was a dark and scary night . The old'

In [50]:
m[0].bs = 1      # set batch size = 1
m.eval()         # turn-off dropout
m.reset()        # reset hidden state
res, *_ = m(t)   # get predictions from model
m[0].bs = bs     # put batch size back to what it was

In [51]:
# top 10 predictions for next word
nexts = torch.topk(res[-1], 10)[1]
[TEXT.vocab.itos[o] for o in to_np(nexts)]

['<unk>',
 'man',
 'things',
 'men',
 'old',
 'and',
 'whateley',
 ',',
 'people',
 'houses']

In [52]:
# try to generate more text
print(ss, "\n")

for i in range(50):
    n = res[-1].topk(2)[1]
    n = n[1] if n.data[0] == 0 else n[0]
    print(TEXT.vocab.itos[n.data[0]], end=' ')
    res, *_ = m(n[0].unsqueeze(0))
    
print('...')

. It was a dark and scary night. The old 

man , who had been a little man , had been the most remarkable and most remarkable of the world . i had been a little , and had not been able to get out of the world . i was not , however , to be sure , but ...


## Predict the author

In [53]:
bs = 64
bptt = 70

emb_sz = 400       # size of each embedding vector
nh = 1024           # of hidden activations per layer
nl = 3             # of layers

# for NLP, configure Adam to use less momentum than the defaul of 0.9
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))

In [54]:
# use the same vocab built from the language model so as to ensure words map to same Ids
TEXT = pickle.load(open(f'{PATH}/models/TEXT.pkl', 'rb'))

In [55]:
AUTHOR_LABEL = data.Field(sequential=False)
splits = SpookyDataset.splits(TEXT, AUTHOR_LABEL, train_df, val_df, test_df)

In [56]:
t = splits[0].examples[0]

In [57]:
t.label, ' '.join(t.text[:10])

('EAP', 'this process , however , afforded me no means of')

In [58]:
# fastai can create a ModelData object directly from torchtext splits
md2 = TextData.from_splits(PATH, splits, bs)

In [59]:
m3 = md2.get_model(opt_fn, 1500, bptt, emb_sz=emb_sz, n_hid=nh, n_layers=nl,
                      dropout=0.1, dropouti=0.4, wdrop=0.5, dropoute=0.05, dropouth=0.3)

m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
m3.load_encoder(f'spooky_adam_enc2')

In [60]:
m3.clip = 25.
lrs = np.array([1e-4, 1e-3, 1e-2])

In [61]:
m3.freeze_to(-1) # freeze everything except last layer
m3.fit(lrs/2, 2, metrics=[accuracy])


  0%|          | 0/290 [00:00<?, ?it/s][A
  0%|          | 0/290 [00:00<?, ?it/s, loss=1.58][A
  0%|          | 1/290 [00:00<01:05,  4.45it/s, loss=1.58][A
  0%|          | 1/290 [00:00<01:16,  3.77it/s, loss=1.4] [A
  0%|          | 1/290 [00:00<01:30,  3.20it/s, loss=1.29][A
  0%|          | 1/290 [00:00<01:55,  2.49it/s, loss=1.27][A
  1%|▏         | 4/290 [00:00<00:28,  9.93it/s, loss=1.27][A
  1%|▏         | 4/290 [00:00<00:33,  8.47it/s, loss=1.23][A
  1%|▏         | 4/290 [00:00<00:35,  7.98it/s, loss=1.23][A
  2%|▏         | 6/290 [00:00<00:23, 11.92it/s, loss=1.23][A
  2%|▏         | 6/290 [00:00<00:26, 10.74it/s, loss=1.2] [A
  2%|▏         | 6/290 [00:00<00:33,  8.50it/s, loss=1.17][A
  3%|▎         | 8/290 [00:00<00:24, 11.31it/s, loss=1.17][A
  3%|▎         | 8/290 [00:00<00:27, 10.26it/s, loss=1.17][A
  3%|▎         | 8/290 [00:00<00:29,  9.67it/s, loss=1.16][A
  3%|▎         | 10/290 [00:00<00:23, 12.06it/s, loss=1.16][A
  3%|▎         | 10/290 [00:00<00

 32%|███▏      | 92/290 [00:06<00:13, 14.31it/s, loss=0.905][A
 32%|███▏      | 92/290 [00:06<00:13, 14.18it/s, loss=0.901][A
 32%|███▏      | 92/290 [00:06<00:14, 14.11it/s, loss=0.903][A
 32%|███▏      | 92/290 [00:06<00:14, 14.00it/s, loss=0.899][A
 33%|███▎      | 95/290 [00:06<00:13, 14.45it/s, loss=0.899][A
 33%|███▎      | 95/290 [00:06<00:13, 14.22it/s, loss=0.897][A
 33%|███▎      | 95/290 [00:06<00:13, 14.16it/s, loss=0.898][A
 33%|███▎      | 95/290 [00:06<00:13, 14.05it/s, loss=0.898][A
 35%|███▍      | 101/290 [00:06<00:13, 14.45it/s, loss=0.894][A

Exception in thread Thread-38:
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration



[ 0.       0.87596  0.74771  0.67917]                        
[ 1.       0.80553  0.74185  0.68854]                        



In [62]:
m3.unfreeze()
m3.fit(lrs, 2, metrics=[accuracy], cycle_len=1)

[ 0.       0.64033  0.57021  0.76771]                        
[ 1.       0.54842  0.51538  0.79167]                        



In [63]:
m3.fit(lrs/2, 4, metrics=[accuracy], cycle_len=1, cycle_mult=2, cycle_save_name='spooky_sent1_c4_cl1x2')

[ 0.       0.48122  0.47413  0.80833]                        
[ 1.       0.46605  0.46353  0.81771]                        
[ 2.       0.38426  0.45132  0.82083]                        
[ 3.       0.42502  0.46889  0.80729]                        
[ 4.       0.35398  0.46011  0.81667]                        
[ 5.       0.28249  0.45191  0.82292]                        
[ 6.       0.26402  0.46145  0.82188]                        
[ 7.       0.31668  0.4834   0.82292]                        
[ 8.       0.29387  0.4616   0.82396]                        
[ 9.       0.25343  0.48163  0.82604]                        
[ 10.        0.20527   0.51006   0.82396]                    
[ 11.        0.17881   0.49798   0.83958]                    
[ 12.        0.15088   0.53388   0.83125]                    
[ 13.        0.12904   0.53637   0.83229]                    
[ 14.        0.13446   0.53578   0.83542]                    



In [64]:
m3.fit(lrs/4, 3, metrics=[accuracy], cycle_len=3, cycle_save_name='spooky_sent2_c3_cl3')

[ 0.       0.16886  0.51518  0.83229]                        
[ 1.       0.11597  0.54311  0.83542]                        
[ 2.       0.10753  0.54276  0.84062]                        
[ 3.       0.14199  0.56554  0.83229]                        
[ 4.       0.11287  0.57523  0.83646]                        
[ 5.       0.10518  0.58223  0.8375 ]                         
[ 6.       0.13167  0.60857  0.83958]                        
[ 7.       0.10147  0.61336  0.83958]                         
[ 8.       0.08056  0.62719  0.84062]                         



In [83]:
m3.load_cycle('spooky_sent1_c4_cl1x2', 1) # NOTE: using model with lower val loss is better
# m3

In [84]:
classes = AUTHOR_LABEL.vocab.itos
classes

['<unk>', 'EAP', 'MWS', 'HPL']

In [85]:
preds = []

m = m3.model 
m[0].bs = 1
for index, row in test_df.iterrows():
    ss = row['text']
    s = [spacy_tok(ss)]
    t = TEXT.numericalize(s)
   
    m.eval()
    m.reset()
    res,*_ = m(t)
    preds.append(to_np(res).squeeze()[1:])
#     preds.append(to_np(res).squeeze())
    
preds = np.array(preds)
preds.shape

(8392, 3)

In [86]:
probs = to_np(F.softmax(torch.from_numpy(preds)))

In [87]:
probs.shape

(8392, 3)

In [88]:
def do_clip(arr, mx):
    clipped = np.clip(arr, (1-mx)/1, mx)
    return clipped/clipped.sum(axis=1)[:, np.newaxis]

In [89]:
# probs = do_clip(probs, 0.98)

In [90]:
preds_test_df = test_df.copy()
preds_test_df['EAP'] = probs[:,0]
preds_test_df['MWS'] = probs[:,1]
preds_test_df['HPL'] = probs[:,2]

preds_test_df.drop('text', axis=1, inplace=True)
preds_test_df.head()

Unnamed: 0,id,EAP,MWS,HPL
0,id02310,0.027351,0.2754847,0.6971645
1,id24541,1.0,1.474728e-07,4.589602e-09
2,id00134,0.000164,0.001170791,0.9986647
3,id27757,0.698621,0.0007681423,0.3006105
4,id04081,0.695955,0.2278781,0.07616729


In [91]:
preds_test_df.to_csv(f'{PATH}/subm_wg_20171127_4.csv', index=None)

In [None]:
preds_test_df = pd.read_csv(f'{PATH}/subm_wg_20171126_3.csv', index_col=None)

In [None]:
preds_test_df.head()

In [None]:
from IPython.display import FileLink

In [None]:
FileLink(f'{PATH}/subm_wg_20171126_3.csv')