In [30]:
from bssp.common.reading import read_dataset_cached, indexer_for_embedder, embedder_for_embedding
from bssp.common.embedder_model import EmbedderModel, EmbedderModelPredictor, EmbedderDatasetReader
from allennlp.data import Token

def activate_bert_layers(embedder, layers):
    """
    The Embedder has params deep inside that produce a scalar mix of BERT layers via a softmax
    followed by a dot product. Activate the ones specified in `layers` and deactivate the rest
    """
    # whew!
    scalar_mix = embedder.token_embedder_tokens._matched_embedder._scalar_mix.scalar_parameters

    for i, param in enumerate(scalar_mix):
        param.requires_grad = False
        param.fill_(1 if i in layers else 0)
        
indexer = indexer_for_embedder('bert-base-cased')
vocab, embedder = embedder_for_embedding('bert-base-cased')
reader = EmbedderDatasetReader({"tokens": indexer})
model = EmbedderModel(vocab, embedder).eval()
predictor = EmbedderModelPredictor(model, reader)

In [42]:
import numpy as np
import torch
import torch.nn.functional as F

embs = []
for i in range(0, 12):
    activate_bert_layers(embedder, [i])
    with torch.no_grad():
        res = predictor.predict("Luke 's the one who wrote this sentence !".split())
    #print([vocab.get_token_from_index(t, "tokens") for t in res['token_ids']])
    embeddings = torch.tensor(res['embeddings'])
    embs.append(embeddings[-3])

first_emb = embs[0]
first_emb.unsqueeze_(0)
for i, emb in enumerate(embs):
    if i != 0:
        emb.unsqueeze_(0)
    print(f"Layer 1 and layer {i+1}:", F.cosine_similarity(first_emb, emb).item())

Layer 1 and layer 1: 1.0
Layer 1 and layer 2: 0.9985597729682922
Layer 1 and layer 3: 0.9972352981567383
Layer 1 and layer 4: 0.9960039258003235
Layer 1 and layer 5: 0.9952501058578491
Layer 1 and layer 6: 0.9940927028656006
Layer 1 and layer 7: 0.99290931224823
Layer 1 and layer 8: 0.9916004538536072
Layer 1 and layer 9: 0.9903111457824707
Layer 1 and layer 10: 0.9888986945152283
Layer 1 and layer 11: 0.9882490634918213
Layer 1 and layer 12: 0.992591142654419
