In [1]:
import torch
import hypertools
%matplotlib widget
import matplotlib.pyplot as plt

# Initialize BERT and some helper functions
BERT_MODEL = "bert-base-cased"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Bad key "text.kerning_factor" on line 4 in
/home/luke/.anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


In [2]:
# use allennlp's bert wrappers, which handle the wordpiece tokenization nastiness 
# and let us get a single vector per input token. don't worry about the rest of this cell
from transformers import BertTokenizer
from allennlp.data import Vocabulary, Token
from allennlp.data.token_indexers import PretrainedTransformerMismatchedIndexer
from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)
indexer = PretrainedTransformerMismatchedIndexer(model_name=BERT_MODEL, namespace="tokens")
vocab = Vocabulary()
for word in tokenizer.vocab.keys():
    vocab.add_token_to_namespace("tokens")
embedder = PretrainedTransformerMismatchedEmbedder(model_name=BERT_MODEL).eval().to(DEVICE)

In [3]:
def encode_sentence(sentence):
    """Given a list of tokens or a string of length m, return a tensor of embeddings 
    with shape (m, N), where N is 768 for bert base"""
    with torch.no_grad():
        if isinstance(sentence, str):
            sentence = sentence.split()
        inputs = indexer.tokens_to_indices([Token(text=token) for token in sentence], vocab)
        for key, value in inputs.items():
            inputs[key] = torch.tensor([value], device=DEVICE)
        tensor = embedder(**inputs)
        return tensor[0]
    
tensor = encode_sentence(['Hello', ',', 'world', '!'])
print(tensor)
print(tensor.shape)

tensor([[ 0.4953, -0.2757,  0.8327,  ..., -0.3281,  0.4333, -0.1131],
        [ 0.4680,  0.3816,  0.5320,  ..., -0.0647, -0.1558, -0.1061],
        [ 0.4916,  0.2297,  0.3994,  ..., -0.3129, -0.3210, -0.3038],
        [ 0.5205,  0.2332,  0.8453,  ...,  0.2023,  0.3757,  0.1231]],
       device='cuda:0')
torch.Size([4, 768])


In [8]:
from nltk.corpus import framenet as fn
from tqdm.notebook import tqdm
import nltk
print(nltk.__file__)
frames = fn.frames()

# see: https://github.com/nltk/nltk/issues/2616
annotations = []
progress = tqdm(total=400000)
try:
    for annotation in fn.annotations():
        annotations.append(annotation)
        progress.update(1)
except RuntimeError as e:
    print("Exception!", e)
finally:
    print(len(annotations))
progress.close()

/home/luke/.anaconda3/lib/python3.7/site-packages/nltk/__init__.py


HBox(children=(IntProgress(value=0, max=400000), HTML(value='')))

229692



In [9]:
#print(frame)
#lexical_unit = frame.lexUnit['abandon.v']
#print(lexical_unit)
#print([x for x in dir(fn) if not x.startswith("_")])

def slice2index(text, slice_):
    begin, end = slice_
    target_word = text[begin:end]
    tokens = text.split()
    token_index = 0
    i = 0
    while i < len(text):
        token = tokens[token_index]
        if i == begin and token.lower() == target_word.lower():
            if end - begin != len(token):
                print(f"ERROR: word '{token}' is {len(token)} chars long, but FrameNet said it was {j-1} chars long")
            return token_index
        else:
            i += len(token) + 1
            token_index += 1
    raise Exception(f"Word not found!\nLooked for {target_word}, but couldn't find it in {tokens}")
        
        
def target_vector(annotation):
    text, targets = annotation.text, annotation.Target
    tokens = text.split()
    target_indexes = [slice2index(text, target_slice) for target_slice in targets]
    
    tensor = encode_sentence(tokens)
    # make sure allennlp bert kept its promise
    assert len(tokens) == tensor.shape[0]
    
    # for multi-word targets, take their average
    target_vector = torch.mean(torch.stack([tensor[i] for i in target_indexes][:1]), dim=0)
    return target_vector
    

In [10]:
import os, pickle
def get_vectors(annotations):
    vectors = []
    
    path = 'vecs.pkl'
    if os.path.isfile(path):
        print("Loading from pickle")
        with open(path, 'rb') as f:
            vectors = pickle.load(f)
        
    def handle_annotation(annotation):
        try:
            if "Target" not in annotation:
                print(f"{i}: Skipping targetless ann")
                return
            if "frame" not in annotation:
                print(f"{i}: Skipping frameless ann")
                return
            vectors.append(target_vector(annotation))
        except Exception as e:
            print(type(annotation))
            print(annotation.keys())
            print("Uh oh!")
            print(e)
            vectors.append(None)
            
    i = len(vectors)
    any_new = False
    progress = tqdm(total=len(annotations))
    progress.update(len(vectors))
    while i < len(annotations):
        handle_annotation(annotations[i])
        progress.update(1)
        any_new = True
        i += 1
    progress.close()
    
    if any_new:
        with open(path, 'wb') as f:
            print("Writing pickle")
            pickle.dump(vectors, f)
    
    return vectors

vectors = get_vectors(annotations)
len(vectors), len(annotations)

HBox(children=(IntProgress(value=0, max=229692), HTML(value='')))

2560: Skipping targetless ann
<class 'nltk.corpus.reader.framenet.AttrDict'>
dict_keys(['cDate', 'status', 'ID', '_type', 'layer', '_ascii', 'FE', 'GF', 'PT', 'Sent', 'Other', 'Target', 'Verb', 'sent', 'text', 'LU', 'frame'])
Uh oh!
list index out of range
<class 'nltk.corpus.reader.framenet.AttrDict'>
dict_keys(['cDate', 'status', 'ID', '_type', 'layer', '_ascii', 'Target', 'FE', 'GF', 'PT', 'Other', 'Sent', 'Noun', 'sent', 'text', 'LU', 'frame'])
Uh oh!
Word not found!
Looked for Mr, but couldn't find it in ['Mr.', 'Gonzalez', 'also', 'has', 'split', 'with', 'the', 'left', 'in', 'reaffirming', 'Spain', "'s", 'NATO', 'commitment', 'and', 'in', 'renewing', 'a', 'defense', 'treaty', 'with', 'the', 'U.S', '.']
<class 'nltk.corpus.reader.framenet.AttrDict'>
dict_keys(['cDate', 'status', 'ID', '_type', 'layer', '_ascii', 'Target', 'FE', 'GF', 'PT', 'Other', 'Sent', 'Noun', 'sent', 'text', 'LU', 'frame'])
Uh oh!
Word not found!
Looked for Mr, but couldn't find it in ['Mr.', 'Gonzalez', 'is'

(229691, 229692)

In [11]:
from collections import defaultdict
frame2pairs = defaultdict(list)
parent2children = defaultdict(list)

skipped = 0
for i, (vector, annotation) in tqdm(enumerate(zip(vectors, annotations)), total=len(annotations)):
    if vector is None:
        print(f"Skipping none vector at {i}", end='\r')
        skipped += 1
        continue
        
    frame = annotation.frame
    frame_name = frame.name
    frame2pairs[frame_name].append((vector.to('cpu'), annotation))
    for relation in frame.frameRelations:
        if relation.type.name == 'Inheritance':
            child_name = relation.Child.name
            parent_name = relation.Parent.name
            if child_name not in parent2children[parent_name]:
                parent2children[parent_name].append(child_name)
print(f"Skipped {skipped}")

HBox(children=(IntProgress(value=0, max=229692), HTML(value='')))

Skipping none vector at 227746
Skipped 206


In [58]:
import re
import numpy as np
from scipy.stats import normaltest
import pingouin as pg

roots = [k for k in parent2children.keys() 
         if not any(k in children for children in parent2children.values())]
def tree_print(node, degree=0):
    print(("   " * degree) + node)
    if node in parent2children:
        for child in parent2children[node]:
            tree_print(child, degree + 1)
            
def filter_by_pattern(pairs, pattern):
    return [pair for pair in pairs if re.match(pattern, pair[1].LU.name)]
            
    
#for root in roots:
#    tree_print(root)
#    print()
#tree_print('Relation')

def f(frames):
    hue = []
    labels = []
    vectors = []
    for frame in frames:
        fvecs, anns = zip(*frame2pairs[frame])
        print(frame, len(fvecs))
        vectors += fvecs
        hue += [frame for i in range(len(fvecs))]
        if len(fvecs) > 0:
            labels += [frame]
            labels += [None for i in range(len(fvecs)-1)]
                
    #hue = [x.LU.name for x in annotations]
    vectors = torch.stack(vectors).numpy()
    print(vectors.shape)
    print(len(labels))
    print(len(hue))
    print(dir(hypertools.plot(np.array(vectors), '.', reduce='TSNE', hue=hue, labels=labels, ndims=3)))

f(['Telling', 'Suasion', 'Warning', 'Reveal_secret', 'Recording', 'Predicting', 'Complaining', 'Affirm_or_deny', 'Statement'])

#Statement
#   Affirm_or_deny
#   Complaining
#   Predicting
#   Reading_aloud
#   Recording
#   Reveal_secret
#   Telling
#      Suasion
#      Warning


Telling 632
Suasion 79
Reveal_secret 700
Recording 55
Predicting 154
Complaining 220
Affirm_or_deny 172
Statement 3696
(5822, 768)
5822
5822


IndexError: index 4741 is out of bounds for axis 0 with size 4741

In [25]:
rs = [r for r in fn.frame_relations() if r.type.name == "Using"]
#for r in rs:
#    print(r)
#for k ,v in parent2children.items():
#    if 'Volubility' in v:
#        print(k)
