In [None]:
import torch
import numpy as np
import sys
import os
sys.path.append('..')

from mll import hierrnn_lm, nospaces_dataset, run_mem_recv, run_mem_send

# model_file = '../models/hierrnn_lm_mlb9_hughtok1_20190120_173425.dat.22k'
model_file = '../tmp/hierrnn_lm_mlb10_hughtok1_20190120_202735.dat'

So... we need to:
- load the model file
- reconstruct the model
- sample some words from the dataset
- pass through the model
- look at the resulting stopness

In [None]:
from collections import defaultdict

with open(model_file, 'rb') as f:
    state_dict = torch.load(f)
print('state_dict.keys()', state_dict.keys())
print('episode', state_dict['episode'])

p = params = state_dict['params']
print('params', params)

dataset = nospaces_dataset.Dataset(
    in_textfile=p.in_textfile
)

hier_enc_dec = hierrnn_lm.HierarchicalEncoderDecoder(
    vocab_size=dataset.vocab_size,
    embedding_size=p.embedding_size,
    rnn_type=p.rnn_type,
    dropout=0
)
hier_enc_dec.load_state_dict(state_dict['hier_enc_dec_state'])
print('loaded model, and dataset')

encode_len = 50
decode_len = 15

num_its = 16

encoder_sum_by_word_length = {}
encoder_count_by_word_length = defaultdict(int)

decoder_sum_by_word_length = {}
decoder_count_by_word_length = defaultdict(int)

for it in range(num_its):
    print('.', end='', flush=True)
    sample = dataset.sample(batch_size=p.batch_size, encode_len=encode_len, decode_len=decode_len)
    encode_chars_t, decode_chars_t, encode_chars_l, decode_chars_l = map(sample.__getitem__, [
        'encode_chars_t', 'decode_chars_t', 'encode_chars_l', 'decode_chars_l',
    ])
    encode_words_l, decode_words_l, encode_lens_t, decode_lens_t = map(sample.__getitem__, [
        'encode_words_l', 'decode_words_l', 'encode_lens_t', 'decode_lens_t'
    ])
    print(encode_chars_l[0])

    utts_out_logits, enc_stopness, dec_stopness = hier_enc_dec(sample['encode_chars_t'], decoder_utt_len=15)

    def inc_sums(sum_by_word_length, count_by_word_length, words_l, chars_l, lens_t, stopness, justify):
        batch_size = len(words_l)
        max_len = stopness.size(0)
        for n in range(batch_size):
            if justify == 'right':
                padding = max_len - lens_t[n].item()
                pos = padding
            elif justify == 'left':
                pos = 0
            else:
                raise Exception('justify ' + justify + ' not recognized')
            for i, word in enumerate(words_l[n]):
                length = len(word)
#                 assert word == chars_l[n][pos:pos + length]
                stopness_t = stopness[pos:pos + length, n]
                count_by_word_length[length] += 1
                if length in sum_by_word_length:
                    sum_by_word_length[length] += stopness_t
                else:
                    sum_by_word_length[length] = stopness_t
                pos += length

    inc_sums(
        sum_by_word_length=encoder_sum_by_word_length,
        count_by_word_length=encoder_count_by_word_length,
        words_l=encode_words_l,
        chars_l=encode_chars_l,
        lens_t=encode_lens_t,
        stopness=enc_stopness,
        justify='right'
    )
    inc_sums(
        sum_by_word_length=decoder_sum_by_word_length,
        count_by_word_length=decoder_count_by_word_length,
        words_l=decode_words_l,
        chars_l=decode_chars_l,
        lens_t=decode_lens_t,
        stopness=dec_stopness,
        justify='left'
    )

# print(encoder_count_by_word_length)

print('')
print('encoder')
for length in sorted(encoder_count_by_word_length.keys()):
    this_stopness = encoder_sum_by_word_length[length] / encoder_count_by_word_length[length]
    this_stopness = (this_stopness * 100).int()
    print('%i samples' % encoder_count_by_word_length[length], this_stopness.tolist())

print('')
print('decoder')
for length in sorted(decoder_count_by_word_length.keys()):
    this_stopness = decoder_sum_by_word_length[length] / decoder_count_by_word_length[length]
    this_stopness = (this_stopness * 100).int()
    print('%i samples' % decoder_count_by_word_length[length], this_stopness.tolist())