In [1]:
from bssp.common.reading import read_dataset_cached, indexer_for_embedder, embedder_for_embedding
from bssp.common.embedder_model import EmbedderModel, EmbedderModelPredictor, EmbedderDatasetReader
from allennlp.data import Token

def activate_bert_layers(embedder, layers):
    """
    The Embedder has params deep inside that produce a scalar mix of BERT layers via a softmax
    followed by a dot product. Activate the ones specified in `layers` and deactivate the rest
    """
    # whew!
    scalar_mix = embedder.token_embedder_tokens._matched_embedder._scalar_mix.scalar_parameters

    with torch.no_grad():
        for i, param in enumerate(scalar_mix):
            param.requires_grad = False
            param.fill_(1e9 if i in layers else -1e9)
        
indexer = indexer_for_embedder('bert-base-cased')
vocab, embedder = embedder_for_embedding('bert-base-cased')
reader = EmbedderDatasetReader({"tokens": indexer})
model = EmbedderModel(vocab, embedder).to('cuda').eval()
predictor = EmbedderModelPredictor(model, reader)



In [2]:
import numpy as np
import torch
import torch.nn.functional as F

def process_sent_allennlp(text, pos, ref_layer):
    embs = []
    for i in range(0, 12):
        activate_bert_layers(embedder, [i])
        with torch.no_grad():
            res = predictor.predict(text.split())
        if i == 0:
            print([vocab.get_token_from_index(t, "tokens") for t in res['token_ids']])
        embeddings = torch.tensor(res['embeddings'])
        embs.append(embeddings[pos])

    first_emb = embs[ref_layer]
    first_emb.unsqueeze_(0)
    for i, emb in enumerate(embs):
        if i != ref_layer:
            emb.unsqueeze_(0)
        print(f"Layer {ref_layer+1} and layer {i+1}:", F.cosine_similarity(first_emb, emb).item())

process_sent_allennlp("Luke 's the one who wrote this sentence !", 3, 0)

['[CLS]', 'Luke', "'", 's', 'the', 'one', 'who', 'wrote', 'this', 'sentence', '!', '[SEP]']
Layer 1 and layer 1: 1.0
Layer 1 and layer 2: 0.8992964625358582
Layer 1 and layer 3: 0.794964611530304
Layer 1 and layer 4: 0.6988664865493774
Layer 1 and layer 5: 0.5934978723526001
Layer 1 and layer 6: 0.5190457701683044
Layer 1 and layer 7: 0.44594624638557434
Layer 1 and layer 8: 0.40702757239341736
Layer 1 and layer 9: 0.35980919003486633
Layer 1 and layer 10: 0.34618818759918213
Layer 1 and layer 11: 0.3302433490753174
Layer 1 and layer 12: 0.27428197860717773


In [3]:
# Do it with transformers now
from transformers import BertTokenizer, BertModel, BertConfig

# GPU available?
t_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
t_config = BertConfig.from_pretrained("bert-base-cased", output_hidden_states=True)
t_model = BertModel.from_pretrained('bert-base-cased', config=t_config).to('cuda:0').eval()

In [4]:
def process_sent_transformers(text, pos, ref_layer):
    marked_text = '[CLS] ' + text + ' [SEP]'
    tokenized_text = t_tokenizer.tokenize(marked_text)
    indexed_tokens = t_tokenizer.convert_tokens_to_ids(tokenized_text) 
    print(t_tokenizer.convert_ids_to_tokens(indexed_tokens))
    segments_ids = [1] * len(tokenized_text) 
    
    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
    segments_tensor = torch.tensor([segments_ids]).to('cuda')
    
    with torch.no_grad():
        outputs = t_model(tokens_tensor, segments_tensor, return_dict=True)
    encoded_layers = outputs['hidden_states']
        
    token_embeddings = torch.stack(encoded_layers, dim=0)
    #Remove dimension 1, the "batches".
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    
    first_emb = token_embeddings[ref_layer, pos, :]
    first_emb.unsqueeze_(0)
    for i, emb in enumerate(token_embeddings[:12, pos, :]):
        emb.unsqueeze_(0)
        print(f"Layer {ref_layer+1} and layer {i+1}:", F.cosine_similarity(first_emb, emb).item())

process_sent_allennlp("Luke is the one who wrote this !", pos=0, ref_layer=0)
print()
process_sent_transformers("Luke is the one who wrote this !", pos=0, ref_layer=0)
print()
print()
process_sent_allennlp("Bah !", pos=1, ref_layer=0)
print()
process_sent_transformers("Bah !", pos=1, ref_layer=0)

# scalar mix matters?! and we still don't get the kinds of differences we get with raw transformers

['[CLS]', 'Luke', 'is', 'the', 'one', 'who', 'wrote', 'this', '!', '[SEP]']
Layer 1 and layer 1: 1.0
Layer 1 and layer 2: 0.9288332462310791
Layer 1 and layer 3: 0.8593798279762268
Layer 1 and layer 4: 0.7701111435890198
Layer 1 and layer 5: 0.7190108299255371
Layer 1 and layer 6: 0.6400652527809143
Layer 1 and layer 7: 0.6084674596786499
Layer 1 and layer 8: 0.5701310634613037
Layer 1 and layer 9: 0.4878999888896942
Layer 1 and layer 10: 0.4057501256465912
Layer 1 and layer 11: 0.3372558355331421
Layer 1 and layer 12: 0.3738574981689453

['[CLS]', 'Luke', 'is', 'the', 'one', 'who', 'wrote', 'this', '!', '[SEP]']
Layer 1 and layer 1: 1.0
Layer 1 and layer 2: 0.8729580044746399
Layer 1 and layer 3: 0.7460393905639648
Layer 1 and layer 4: 0.6788817048072815
Layer 1 and layer 5: 0.6552401185035706
Layer 1 and layer 6: 0.6503338813781738
Layer 1 and layer 7: 0.6055536866188049
Layer 1 and layer 8: 0.5832138061523438
Layer 1 and layer 9: 0.5874451398849487
Layer 1 and layer 10: 0.5963452458