In [1]:
import sys
sys.executable

'/Users/chantal/Desktop/systematic_review/abstract_env/bin/python'

In [14]:
import torch
from transformers import *
import logging
import matplotlib.pyplot as plt
import os
import re
from pprint import pprint
import pandas as pd
%matplotlib inline

### In this notebook, we're going to be extracting embeddings from a trained BERT model and visualizing them in different spaces

#### 1. Load text to be visualized

Here, we're going to use some keywords we automatically pulled out earlier for each topic, rather than visualizing random abstracts in full. 

NOTE: we can look at two different spaces: 
- where all systematic review topics are included --> this will show us whether we can build a generalizable model
- where one topic is visualized with a focus on the lable --> this will show us how separable and informative the embeddings are per label

Going to focus on the second bullet point. Also, as Delvin mentioned "[the second label would show] whether there are similarities in the representations learned by the model for each topic, and if it’s generalizable then those representations would be similar regardless of topic"

In [194]:
# read in each dataset into a dictionary
reviews = {}

# assuming naming follows 'type' + '_keywords.csv' structure 
for f in os.listdir('../data/keywords/'):
    if not f.startswith('.'):
        key = re.split(r'_', f)
        reviews[key[0]+'_'+key[1]] = f
        
PATH = os.path.abspath('../data/keywords/')

for key, dataset in reviews.items():
    reviews[key] = pd.read_csv(os.path.join(PATH, dataset), names='0', encoding='latin1')
    reviews[key] = sorted(set(str(reviews[key]['0']).split())) # removes duplicates 
    reviews[key] = [x for x in reviews[key] if x not in ['0,', 'Name:']]
    reviews[key] = ' '.join(reviews[key])

In [195]:
pprint(reviews)

{'ADIPP_0': 'acid adolescents adults aid alcohol also analysis anemia assess '
            'associate association attitude attitudes awareness base behavior '
            'birth blood bmi body boys breast breastfeed cancer care case '
            'cause change child children ci clinical collect common community '
            'compare concentrations conclusion conduct consumption control '
            'countries cross day days deficiency demographic determine develop '
            'development diabetes diagnosis diet dietary disease drink drug '
            'dtype: early eat education effect evidence experience exposure '
            'family feed female first focus follow food gestational girls '
            'growth high higher hiv hospital household human identify impact '
            'improve include income index infant infants infection information '
            'intake intervention interventions interview iodine iron kg '
            'knowledge life low mass maternal may mean measur

#### 2. Tokenize using BERT, obtain mappings of words to ids 

In [229]:
tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = BertForSequenceClassification.from_pretrained('allenai/scibert_scivocab_uncased', 
                                                      output_hidden_states=True)

In [343]:
# encode keywords 
indexed_tokens = {}
segment_ids = {}

for key in reviews:
    indexed_tokens[key] = torch.Tensor(tokenizer.encode(reviews[key], add_special_tokens=True)).long().unsqueeze(0)
    segment_ids[key] = torch.Tensor([1]*len(indexed_tokens[key])).long().unsqueeze(0)

ADIPP_1
ADIPP_0
VitaminD_0
VitaminD_1
NCDS_1
Rehab_1
Rehab_0
NCDS_0
Washing_1
Scaling_0
Washing_0
Scaling_1


In [231]:
indexed_tokens['Scaling_0'].shape

torch.Size([1, 164])

In [344]:
# map ids to tokens 
id2token = {}
for key in indexed_tokens:
    for token in indexed_tokens[key]:
        id2token[key] = tokenizer.convert_ids_to_tokens(token)

ADIPP_1
ADIPP_0
VitaminD_0
VitaminD_1
NCDS_1
Rehab_1
Rehab_0
NCDS_0
Washing_1
Scaling_0
Washing_0
Scaling_1


#### 3. Run through BERT

In [239]:
# original saved file with DataParallel
state_dict = torch.load('/Users/chantal/Desktop/systematic_review/abstract_tool/_model/scibert_Scaling_data_3.pt',
                       map_location=torch.device('cpu'))

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict

new_state_dict = OrderedDict()

for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    
# load params
model.load_state_dict(new_state_dict) 
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [240]:
# testing with Scaling TODO: come back and do the rest
with torch.no_grad():
    _, encoded_layers_0 = model(indexed_tokens['Scaling_0'])
    _, encoded_layers_1 = model(indexed_tokens['Scaling_1'])

#### 4. Obtain hidden layers, stack them together

In [241]:
print ("Number of layers:", len(encoded_layers_0))
layer_i = 0

print ("Number of batches:", len(encoded_layers_0[layer_i]))
batch_i = 0

print ("Number of tokens:", len(encoded_layers_0[layer_i][batch_i]))
token_i = 0

print ("Number of hidden units:", len(encoded_layers_0[layer_i][batch_i][token_i]))

Number of layers: 13
Number of batches: 1
Number of tokens: 164
Number of hidden units: 768


In [359]:
token_embeddings = torch.stack(encoded_layers_0, dim=0)

# remove batch dimension
token_embeddings = torch.squeeze(token_embeddings, dim=1)

# [# tokens, # layers, # features]
token_embeddings = token_embeddings.permute(1,0,2)

# 13 --> includes the FCN layer I think? 
token_embeddings.size()

torch.Size([164, 13, 768])

#### 5. Get the word embeddings by concatenating the last four layers but you can also combine them in different ways (e.g., summing)

In [366]:
token_vecs_cat = []

for token in token_embeddings: 
    cat_vec = torch.cat((token[-1], token[-2], token[-3], token[-4]), dim=0)
    token_vecs_cat.append(cat_vec)

print (len(token_vecs_cat), len(token_vecs_cat[0]))

164 3072


In [365]:
token_vecs_sum = []

for token in token_embeddings:
    sum_vec = torch.sum(token[-4:], dim=0)
    token_vecs_sum.append(sum_vec)

print (len(token_vecs_sum), len(token_vecs_sum[0]))

164 768


#### 6. Project embeddings into space (e.g., using UMAP) 
For now, I'm working with https://projector.tensorflow.org

In [364]:
import numpy as np
tab_sep_sum = []

for i in range(len(token_vecs_sum)):
    string = [''.join(str(x)) for x in np.array(token_vecs_sum[0])]
    tab_sep_sum.append('\t'.join(string))

In [363]:
# vectors
vecs = '\n'.join(tab_sep_sum)
# corresponding words
meta = '\n'.join(id2token['Scaling_0'])

with open('vecs_.tsv', 'w+') as f:
    f.write(vecs)

with open('meta_.tsv', 'w+') as g:
    g.write(meta)

0.07127169	-1.4114288	-6.2630434	-2.7323105	2.0709026	-5.410299	2.3760433	-3.1147933	6.0796514	-2.6846974	0.753406	-3.2065973	1.8979284	-2.8907254	2.5706213	-0.95273787	-8.461882	-2.8048067	1.0391753	-3.7541094	0.62245136	6.8339267	-5.2886868	-6.3811646	2.4212966	-1.2406908	3.9803438	-4.6702905	-3.9610868	2.0533187	-5.493805	-0.9876606	0.10600209	-6.4870696	-2.7739103	4.996336	-1.8896937	-1.3148447	-5.781551	-0.6073555	0.9049854	-0.79099214	-0.021936297	3.3724434	1.4919679	3.929586	2.039454	-4.3041716	4.0270743	-1.8233944	-1.3912396	-2.5102892	-1.7724342	-2.6087055	-1.7403337	-4.5562897	-0.50869775	-3.3349574	7.8990345	-2.0818908	-2.6080108	-2.7706096	1.8942502	-4.7067347	-2.657241	-1.8728621	5.131479	3.8363566	-0.907114	1.6440752	3.7685847	1.2676383	-2.7673101	1.8149548	4.015999	6.7372513	0.19405621	2.894606	-2.802044	-2.598877	-1.9816799	-1.724443	-6.720811	-3.382965	-4.8891873	-0.35526398	2.279666	3.33086	2.671976	2.3068824	1.661581	-1.1860679	-0.48759222	1.4646453	3.6101782	-2.1543

In [286]:
print(f'{len(tab_sep_sum)}, {len(tab_sep_sum[0])}')

164, 7993
