# A simpler version

In [1]:
import torch
from spacyfuncs import get_docs
from transformers import logging, AutoTokenizer
logging.set_verbosity_error()

import pandas as pd
from tqdm import tqdm
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)



docs = get_docs([],
                '../data/bookcorpus_0_5000.spacy',
                id_text_tuples=False,
               )

embeds = torch.load('../data/bookcorpus_embeddings_0_5000.pt')
bert_embeds = torch.load('../data/bert_lookup_embeddings.pt')

assert len(docs) == len(embeds)

In [2]:
def nearest_neighbor_lookup(hidden_state, lookup_embeds, topk=3):
    dist = torch.norm(lookup_embeds - hidden_state, dim=1, p=None)
    return dist.topk(topk, largest=False)

In [5]:
def recover_token(hidden_state, lookup_embeds, orig_tok_id):
    knn = nearest_neighbor_lookup(hidden_state, lookup_embeds, topk=1)
    return orig_tok_id == knn.indices

In [3]:
cos_similarity = torch.nn.CosineSimilarity(dim=0, eps=1e-6)

def distance_to_self(hidden_state, lookup_embed):
    return cos_similarity(hidden_state, lookup_embed)

In [4]:
def identify_tokens(hidden_state, lookup_embeds, topk=1):
    '''Calculates the nearest topk neighbors for a token hidden_state.
    hidden_states are the token embedding with dimensionality [layers, 768].
    Returns true_label, distances, nearest_neighbors.
    true_label: the text form of the original token
    distances: list of distances for each nearest_neighbor
    nearest_neighbors: list of the closest token lookup_embeddings in text form
    '''
    true_label = [tokenizer.decode([orig_tok_id])]

    cosine_distances, tok_ids = nearest_neighbor_lookup(hidden_state, lookup_embeds, topk=topk)
    nearest_neighbors = [tokenizer.decode([tok_id]) for tok_id in tok_ids]
    
    return true_label, cosine_distances.numpy(), nearest_neighbors

In [None]:
word_types = {
    'content': ['ADV',  'VERB', 'ADJ', 'NOUN', 'PROPN'],
    'function': ['ADP', 'AUX', 'CCONJ', 'DET', 'INTJ', 'PART', 'PRON', 'SCONJ'],
    'other': ['NUM', 'PUNCT', 'SYM', 'X'],
}

pos_dict = {tag: word_type for word_type, tags in word_types.items() for tag in tags}

In [130]:
def calculate_score(hidden_state, bert_embeds, orig_tok_id):
    cos = distance_to_self(hidden_state, bert_embeds[orig_tok_id]).item()
    nn = recover_token(hidden_state, bert_embeds, orig_tok_id).item()
    return cos, nn

def get_results(docs, embeds, bert_embeds, layer=12, total=5000):
    simple_counts = []
    for doc, example in tqdm(zip(docs[:total], embeds[:total]), total=total):
        spacy_offsets = {t.idx: t.i for t in doc}
        for (orig_tok_id, bert_offset, hidden_states) in example:
            token_ind = spacy_offsets.get(bert_offset.item())
            if token_ind:
                pos = doc[token_ind].pos_
                cos, nn = calculate_score(hidden_states[layer],
                                          bert_embeds,
                                          orig_tok_id,
                                      )
                simple_counts.append((pos, cos, nn))
    
    df = pd.DataFrame(simple_counts, columns=['pos', 'mean_similarity', 'identification_acc']).assign(word_type=lambda x: x.pos.map(pos_dict))

    by_pos = (df.groupby(['word_type', 'pos'])
              .agg(mean_similarity=('mean_similarity', 'mean'),
                   identification_acc=('identification_acc', 'mean'),
                   support=('pos', 'count'))
             )    
    
    by_word_type = (df.groupby(['word_type'])
                    .agg(mean_similarity=('mean_similarity', 'mean'),
                         identification_acc=('identification_acc', 'mean'),
                         support=('pos', 'count'))
                    .unstack()
                    .to_frame('total')
                    .unstack(0)
                    .stack(0)
                    .pipe(lambda x: x.reindex(x.index.rename('pos', level=1)))
                   )
    
    full_total = (df
                  .agg({'mean_similarity': 'mean',
                        'identification_acc': 'mean',
                        'pos': 'count'})
                  .to_frame('total')
                  .rename({'pos': 'support'})
                  .stack(0)
                  .to_frame('total')
                  .unstack(0)
                  .stack(0)
           )
    
    agg_view = pd.concat([by_pos, by_word_type, full_total]).sort_index()
    
    agg_view.to_csv(f'../results/layer_{layer:02}_{total}_examples.csv')
    
    return agg_view

In [None]:
for layer in range(2,11):
    get_results(docs, embeds, bert_embeds, layer=layer, total=5000)

100%|██████████| 5000/5000 [10:17<00:00,  8.09it/s]
100%|██████████| 5000/5000 [10:26<00:00,  7.98it/s]
100%|██████████| 5000/5000 [10:19<00:00,  8.07it/s]
100%|██████████| 5000/5000 [10:23<00:00,  8.02it/s]
100%|██████████| 5000/5000 [10:23<00:00,  8.02it/s]
 40%|███▉      | 1976/5000 [04:07<06:40,  7.56it/s]

In [2]:
import pandas as pd
import glob

In [59]:
df[df.pos == 'total']

Unnamed: 0,block,word_type,pos,mean_similarity,identification_acc,support
5,1,content,total,0.708811,1.0,22585.0
14,1,function,total,0.717982,1.0,24693.0
18,1,other,total,0.704366,1.0,12058.0
19,1,total,total,0.711724,1.0,59336.0
25,2,content,total,0.614527,1.0,22585.0
34,2,function,total,0.619174,0.999514,24693.0
38,2,other,total,0.617165,0.740338,12058.0
39,2,total,total,0.616997,0.94703,59336.0
45,3,content,total,0.537896,0.999734,22585.0
54,3,function,total,0.539596,0.967521,24693.0
