In [1]:
import os
from fastai import text
import torch

import pathlib

In [2]:
os.chdir('../data') # had to do this because FastAI has stupid convention of prefixing model name with 'model/'

In [3]:
!ls

file.csv  github_repos.json  lm_data.pkl  models  nlp_github_repos.json


In [4]:
!ls ../data

file.csv  github_repos.json  lm_data.pkl  models  nlp_github_repos.json


In [5]:
config_dict = dict(
    base_lr = 1e-2,
    finetuning_lr = (1e-4, 1e-2),
    drop_mult = 0.5,
    bptt=50,
    bs = 64
)

config = text.awd_lstm_lm_config.copy()
config['qrnn'] = True

DATASET_PATH = 'lm_data.pkl'
MODEL_PATH = 'ft_cleaned_qrnn_bptt50_10'
MODEL_ENCODER_PATH = 'ft_enc_cleaned_qrnn_bptt50_10'

In [6]:
data_lm = text.load_data('', DATASET_PATH, bs=config_dict['bs'], bptt=config_dict['bptt'])

In [7]:
learn = text.language_model_learner(data_lm, text.AWD_LSTM, config=config, drop_mult=config_dict['drop_mult'], pretrained=False).to_fp16()

In [8]:
learn.load(MODEL_PATH)
learn.load_encoder(MODEL_ENCODER_PATH);

In [9]:
sentence_segments = [
    "Machine learning and",
    "Deploying Artificial",
    "Github automation pipeline for"
]

In [10]:
for temp in [0.1, 0.5, 1.0, 2.0]:
    print()
    print('temperature:', temp)
    for segment in sentence_segments:
        
        segment_text = 'segment: {}'.format(segment)
        generated_text = 'generated: {}'.format(learn.predict(segment, n_words=20, temperature=temp))
        print(segment_text)
        print(generated_text)


temperature: 0.1
segment: Machine learning and
generated: Machine learning and machine learning Machine Learning Machine Learning Machine Learning Machine Learning Machine
segment: Deploying Artificial
generated: Deploying Artificial Intelligence This is a Java Application Development Kit for Java . It
segment: Github automation pipeline for
generated: Github automation pipeline for Python This is a Python module for the Python programming language . It is a

temperature: 0.5
segment: Machine learning and
generated: Machine learning and Machine Learning Machine Learning Machine Learning Machine Learning Machine Learning
segment: Deploying Artificial
generated: Deploying Artificial Intelligence This is a repository for the Udacity course Developing Android Apps :
segment: Github automation pipeline for
generated: Github automation pipeline for Django This is a collection of tools to help automate the development of Django Social

temperature: 1.0
segment: Machine learning and
generated

In [16]:
encoder = learn.model[0]

In [72]:
def print_shapes_recursively(tpl, nesting=''):
    if type(tpl) is tuple or type(tpl) is list:
        l = len(tpl)
        print(nesting + 'Collection of {} elements:'.format(l))
        for item in tpl:
            print_shapes_recursively(item, nesting + '\t')
    else:
        print(nesting + str(tpl.shape))

In [None]:
def get_shapes_recursively(tpl, nesting=''):
    if type(tpl) is tuple or type(tpl) is list:
        l = len(tpl)
        print(nesting + 'Collection of {} elements:'.format(l))
        for item in tpl:
            print_shapes_recursively(item, nesting + '\t')
    else:
        print(nesting + str(tpl.shape))

In [68]:
def get_model_states(learner, text):
    input_tensor, __ = learner.data.one_item(text)
    return learner.model[0](input_tensor)

In [69]:
item = learn.data.one_item('foo bar baz')

In [74]:
states = get_model_states(learn, "Github automation pipeline for young warranties user preserved reports couldn ' explore bukkitdev performant · observatory album occasions blob tpg ifftxxnumber redirecturi analyze smxxnumber")

In [81]:
states[0][1] == states[1][1]

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]], device='cuda:0', dtype=torch.uint8)

In [88]:
print_shapes_recursively(states)

Collection of 2 elements:
	Collection of 3 elements:
		torch.Size([1, 26, 1152])
		torch.Size([1, 26, 1152])
		torch.Size([1, 26, 400])
	Collection of 3 elements:
		torch.Size([1, 26, 1152])
		torch.Size([1, 26, 1152])
		torch.Size([1, 26, 400])
