### Running BERT and GPT2 on Gershman and Tenenbaum's Phrase similarity data

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from transformers import *
import pandas as pd
import numpy as np
from copy import copy
from scipy.stats.mstats import rankdata
from scipy.stats import sem

In [None]:
def check_fragment(myword):
    # make sure this is just a word with no spaces
    assert(myword == myword.strip()) # no beginning or end space
    assert(myword[-1] != '.') # no period
    return myword

def check_sentence(mysentence):
    # make sure we have no extra spaces, and we end with a period
    assert(mysentence == mysentence.strip()) # no beginning or end space
    assert(mysentence[-1] == '.') # must have a period at the end
    return mysentence

def check_no_split(tokens):
    # tokens should be a 1 x 1 tensor 
    if tkn.numel()>1:
        print("Error: word is split in WordPiece")
        assert False
    
def f_extract(X, tkns, mytype, use_special):
    # Input
    #  X : [1 x dim x K tensor] which is embedding at each step we want to extract
    #  tkns : tokens of phrase we are considering (just for printing)
    #  mytype : either "mean" (take the average) or "end" (take the last word that is not a period)
    #  use_speical : are we adding special tokens?
    sz = X.size()
    assert(sz[0]==1)
    if mytype == 'mean':
        return torch.mean(X, dim=1)
    elif mytype == 'end':        
        idx = -3
        if (not use_special) or (tokenizer.__class__.__name__ == 'GPT2Tokenizer'):
            idx += 1 # there is no extract symbol
        if verbose:
            print("extract embedding for: " + tokenizer.decode(tkns[:,idx]))
        return X[:,idx]
    else:
        assert False, "extraction type is undefined"
    
def compute_hidden_cosine(query_phrase, list_phrases, use_special=True, extract='mean'):
    # Compare respresentations using the top level of a transfomer
    # ---
    # query_phrase : string to query
    # list_phrase : [list of strings to compare against]
    # extract : 'mean' for computing mean embedding from top layer, and 'end' for getting embeding from last word
    if use_special:
        query_tkn = encode_sentence(query_phrase)
    else:
        query_tkn = encode_sentence_plain(query_phrase)
    print("Similarity analysis (cosine) from " + model.__class__.__name__ + ": ")
    print("'" + str(query_phrase) + "'")
    f_get = lambda X, tkns, us : f_extract(X, tkns, extract, us)     
    list_S = []
    with torch.no_grad():
        h_query = model(query_tkn)[0] # 1 x ntkns
        h_query = f_get(h_query, query_tkn, use_special)
        for phrase in list_phrases:
            if use_special:
                phrase_tkn = encode_sentence(phrase)
            else:
                phrase_tkn = encode_sentence_plain(phrase)
            h_phrase = model(phrase_tkn)[0] # 1 x ntkns
            h_phrase = f_get(h_phrase, phrase_tkn, use_special)
            S = F.cosine_similarity(h_query, h_phrase, dim=1)
            S = S.item()
            list_S.append(S)
    sim_phrases = copy(list_S) # original order 
            
    Z = [(x,y) for x,y in sorted(zip(list_phrases,list_S), reverse=True, key=lambda pair: pair[1])]
    list_phrases, list_S = zip(*Z) 
    for ii in range(len(list_S)):
        phrase = list_phrases[ii]
        cos = str(round(list_S[ii],2))
        print("  vs. '" + phrase + "' : " + cos)
    print("")
    return sim_phrases
    
def predict_mask(sentence, K=3):
    sentence_tkn = encode_sentence(sentence)
    mask_token_index = torch.nonzero(sentence_tkn == tokenizer.mask_token_id)[:,1]
    if mask_token_index.numel()!=1:
        assert False, "wrong number of mask tokens"
    token_logits = LM(sentence_tkn)[0]
    mask_token_logits = token_logits[0, mask_token_index, :]
    mask_token_logits = mask_token_logits.flatten()
    mask_token_prob = torch.softmax(mask_token_logits,dim=0)
    _, top_K_tokens = torch.topk(mask_token_logits, K)
    top_K_tokens = top_K_tokens.tolist()    
    print('Fill in the blank ("token : score") from ' + LM.__class__.__name__ + ": ")
    print("  " + sentence)
    for ii in range(K):
        my_idx = top_K_tokens[ii]
        score = str(round(mask_token_prob[my_idx].item(),3))
        print("  " + tokenizer.decode([my_idx]) + " : " + score)
    print("")
    for token in top_K_tokens:
        print(sentence.replace(tokenizer.mask_token, tokenizer.decode([token])))

### Load BERT

In [None]:
# Load pretrained BERT model/tokenizer
try:
    model_class
    print('Cannot load BERT model because...')
    print('  A model is already loaded.')
except NameError:
    model_class = BertModel
    LM_class = BertForMaskedLM
    tokenizer_class = BertTokenizer
    pretrained_weights = 'bert-large-uncased'
    tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
    model = model_class.from_pretrained(pretrained_weights)

    E = model.get_input_embeddings()
    NE = E.num_embeddings
    LM = LM_class.from_pretrained(pretrained_weights)
    encode_word = lambda myword : tokenizer.encode(check_fragment(myword), add_special_tokens=False, return_tensors="pt")
    encode_sentence = lambda sentence : tokenizer.encode(check_sentence(sentence), add_special_tokens=True, return_tensors="pt")
    encode_sentence_plain = lambda sentence : tokenizer.encode(sentence, add_special_tokens=False, return_tensors="pt")
    decode_sentence = lambda tkns : tokenizer.decode(torch.squeeze(tkns))

### Or we can load GPT2 (but not both)

In [None]:
# Load GPT2
try:
    model_class
    print('Cannot load GPT2 model because...')
    print('  A model is already loaded.')
except NameError:
    model_class = GPT2Model
    LM_class = GPT2LMHeadModel
    tokenizer_class = GPT2Tokenizer
    pretrained_weights = 'gpt2-xl'
    tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
    model = model_class.from_pretrained(pretrained_weights)

    E = model.get_input_embeddings()
    NE = E.num_embeddings
    LM = LM_class.from_pretrained(pretrained_weights)
    encode_word = lambda myword : tokenizer.encode(check_fragment(myword), add_special_tokens=False, add_prefix_space=True, return_tensors="pt")
    encode_sentence = lambda sentence : tokenizer.encode(check_sentence(sentence), add_special_tokens=True, add_prefix_space=True, return_tensors="pt")
    encode_sentence_plain = lambda sentence : tokenizer.encode(sentence, add_special_tokens=False, add_prefix_space=True, return_tensors="pt")
    decode_sentence = lambda tkns : tokenizer.decode(torch.squeeze(tkns))

## Some simple santify checks

In [None]:
tkns = encode_sentence('To be or not to be.')
decode_sentence(tkns)

In [None]:
# Only relevant if BERT. GPT2 does not have masks
predict_mask('To be or not to [MASK].', K=5)

## Gershman phrase similarity

In [None]:
# load dataset
df = pd.read_csv('data/gershman_phrases.csv')
nset = np.max(df['set'])
ntype = np.max(df['type'])
type_names = ['base','meaning preserve','noun change','preposition change','adjective change']
print(df)

In [None]:
def get_phrase(df,idx_set,idx_type):
    # Grab the phrase for a particular numerical "set" and "type"
    # return the string of that phrase
    df_sel = df.loc[(df['set']==idx_set) & (df['type']==idx_type)]
    assert len(df_sel==1)
    return df_sel['phrase'].values[0]

# get big similarity matix comparing phrases within a set
S = np.zeros((nset,ntype))
for sid in range(1,nset+1):
    base = get_phrase(df,sid,1)
    queries = [get_phrase(df,sid,tid) for tid in range(1,ntype+1)] 
    sims = compute_hidden_cosine(base, queries, use_special=False, extract='mean')
    S[sid-1,:] = np.array(sims)

In [None]:
# Plot the data
R = rankdata(S,axis=1) # replace sim with rank: 1 is lowest sim, 5 is highest sim
R = ntype-R # reverse the ranking to become 1 for highest, 5 for lowest
mean_R = np.mean(R,axis=0)
se_R = sem(R,axis=0)
labels = type_names[1:] # 'meaning preserve','noun change','preposition change','adjective change']
y_pos = [1,4,2,3] # to create order in paper, 'meaning preserve','preposition change','adjective change','noun change']
plt.rcdefaults()
fig, ax = plt.subplots(figsize=(5,4))
ax.barh(y_pos, mean_R[1:], height=.9, xerr=se_R[1:], color='black', align='center')
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.set_xlabel('Average Rank (model)')
plt.tight_layout()
plt.show()