In [1]:
%load_ext watermark
%watermark -p tensorflow,numpy,pandas -v -m

CPython 3.6.5
IPython 6.1.0

tensorflow 1.8.0
numpy 1.14.3
pandas 0.22.0

compiler   : GCC 7.2.0
system     : Linux
release    : 4.10.0-32-generic
machine    : x86_64
processor  : x86_64
CPU cores  : 8
interpreter: 64bit


In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd

from inspect_lm import build_graph, transform_texts, TEXT_ENCODER

## Load Model

In [3]:
graph = tf.Graph()
sess =  tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=graph)
with graph.as_default():
    X, M, lm_logits, lm_losses = build_graph(sess)

## Prepare and Examine Test Data

In [4]:
list_of_texts =[
    "Karen was assigned a roommate her first year of college. Her roommate asked her to go to a nearby city for a concert. Karen agreed happily. The show was absolutely exhilarating.",
    "Jim got his first credit card in college. He didn’t have a job so he bought everything on his card. After he graduated he amounted a $10,000 debt. Jim realized that he was foolish to spend so much money.",
    "Gina misplaced her phone at her grandparents. It wasn’t anywhere in the living room. She realized she was in the car before. She grabbed her dad’s keys and ran outside."
]

In [5]:
x, m = transform_texts(list_of_texts)
x.shape

length: 35
length: 47
length: 36


(3, 128, 2)

In [6]:
DECODER = TEXT_ENCODER.decoder
restored = []
for token, mask in zip(x[0, :, 0], m[0, :]):
    # if DECODER[token] != "<unk>":
    if mask:
        restored.append(DECODER[token].replace("</w>", ""))
" ".join(restored)

'karen was assigned a roommate her first year of college . her roommate asked her to go to a nearby city for a concert . karen agreed happily . the show was absolutely exhilarating .'

## Model Prediction

In [7]:
batch_lm_logits, batch_lm_losses = sess.run([lm_logits, lm_losses], {X: x, M: m})

In [8]:
first_choices = np.argmax(batch_lm_logits, axis=-1)
first_choices.shape

(3, 128)

In [9]:
DECODER = TEXT_ENCODER.decoder
restored = []
preds = []
for token, mask, pred in zip(x[0, :, 0], m[0, :], first_choices[0, :]):
    # if DECODER[token] != "<unk>":
    if mask:
        restored.append(DECODER[token].replace("</w>", ""))
        preds.append(DECODER.get(pred, "<ctx_token>").replace("</w>", ""))
print(" ".join(restored))
print(" ".join(preds))

karen was assigned a roommate her first year of college . her roommate asked her to go to a nearby city for a concert . karen agreed happily . the show was absolutely exhilarating .
's n't to room . freshman semester of college . she roommate was her to stay to the party college and a few , karen agreed and . 
 concert was a amazing . she


In [10]:
def collect_predictions(idx, topk=3):
    topk_preds = [["<start>"] for _ in range(topk)]
    original = []
    for token, mask, logits in zip(x[idx, :, 0], m[idx, :], batch_lm_logits[idx, :, :]):
        top_tokens = np.argsort(logits, axis=-1)[::-1][::topk]
        if mask:
            original.append(DECODER[token].replace("</w>", ""))
            for i in range(topk):
                topk_preds[i].append(DECODER.get(top_tokens[i], "<ctx_token>").replace("</w>", ""))
    original.append("<end>")
    df = pd.DataFrame({"original": original})
    for i in range(topk):
        df[f"pred_{i+1}"] = topk_preds[i]
    return df
collect_predictions(0).transpose()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,26,27,28,29,30,31,32,33,34,35
original,karen,was,assigned,a,roommate,her,first,year,of,college,...,agreed,happily,.,the,show,was,absolutely,exhilarating,.,<end>
pred_1,<start>,'s,n't,to,room,.,freshman,semester,of,college,...,agreed,and,.,\n,concert,was,a,amazing,.,she
pred_2,<start>,said,still,as,seat,at,own,night,in,school,...,had,to,because,"""",city,had,to,wonderful,for,it
pred_3,<start>,had,not,",",small,because,parents,weekend,here,grad,...,told,without,to,karen,two,she,on,stunning,-,when


In [11]:
collect_predictions(1).transpose()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,38,39,40,41,42,43,44,45,46,47
original,jim,got,his,first,credit,card,in,college,.,he,...,he,was,foolish,to,spend,so,much,money,.,<end>
pred_1,<start>,'s,up,hands,look,card,.,the,.,he,...,he,had,n't,to,think,that,much,money,on,he
pred_2,<start>,said,to,ass,glimpse,",",and,two,for,it,...,if,needed,a,.,buy,it,long,.,and,the
pred_3,<start>,had,off,feet,kiss,and,bill,about,because,his,...,there,'d,no,in,get,money,large,cash,just,after


In [12]:
collect_predictions(2).transpose()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,27,28,29,30,31,32,33,34,35,36
original,gina,misplaced,her,phone,at,her,grandparents,.,it,was,...,grabbed,her,dad,'s,keys,and,ran,outside,.,<end>
pred_1,<start>,'s,her,keys,.,the,house,',"""",was,...,had,her,purse,'s,phone,and,ran,to,.,\n
pred_2,<start>,said,a,phone,in,her,car,and,i,fell,...,looked,a,bag,",",hand,",",went,back,and,her
pred_3,<start>,had,that,shoes,number,one,mother,when,her,went,...,pulled,hold,handbag,as,jacket,as,left,.,for,he


In [13]:
batch_lm_losses

array([3.4980624, 3.1991193, 3.2979164], dtype=float32)