In [None]:
negativeModel = AutoModel.from_pretrained('bert-base-uncased', output_attentions = True, cache_dir = cache_dir)
positiveModel = AutoModel.from_pretrained('bert-base-uncased', output_attentions = True, cache_dir = cache_dir)

data = {}
tokenizedData = {}
for i, modelName in tqdm(enumerate(['negativeModel', 'positiveModel']), total = 2):
    data[modelName] = [sent for sent, targ in zip(sentences, target) if targ == i]
    tokenizedData[modelName] = bertTokenizerFast.batch_encode_plus(data[modelName], pad_to_max_length = True, 
                                                               max_length=512, return_tensors='pt', truncation=True)

from torch.utils.data import Dataset

class DatasetFromTokenized(Dataset):
    def __init__(self, examples):
        self.examples = examples
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        # We’ll pad at the batch level.
        return torch.tensor(self.examples[i].ids)

from transformers import BertForMaskedLM
negativeModel = BertForMaskedLM.from_pretrained('bert-base-uncased', cache_dir = cache_dir)
positiveModel = BertForMaskedLM.from_pretrained('bert-base-uncased', cache_dir = cache_dir)

negativeModel.train()
positiveModel.train()

from transformers import AdamW
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

optimizer_grouped_parameters = {}
optimizer = {}
no_decay = ['bias', 'LayerNorm.weight']
trainer = {}

for modelName, model in (('positiveModel', positiveModel), ('negativeModel', negativeModel)):
    
    optimizer_grouped_parameters[modelName] =  [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer[modelName] = AdamW(optimizer_grouped_parameters[modelName], lr=1e-5)
    
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer = bertTokenizerFast, mlm=True, mlm_probability=0.15
    )

    training_args = TrainingArguments(
        output_dir="/vol/scratch/guy/" + modelName,
        overwrite_output_dir=True,
        num_train_epochs=20,
        per_gpu_train_batch_size=4
    )

    trainer[modelName] = Trainer(
        model = model,
        args = training_args,
        data_collator = data_collator,
        train_dataset = DatasetFromTokenized(tokenizedData[modelName]),
        prediction_loss_only=True,
    )

    

for _, t in trainer.items():
    t.train()

sequence = "she is very [MASK]."
input = bertTokenizerFast.encode(sequence, return_tensors="pt").cuda()
mask_token_index = torch.where(input == bertTokenizerFast.mask_token_id)[1]

token_logits = negativeModel(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(sequence.replace(bertTokenizerFast.mask_token, bertTokenizerFast.decode([token])))

input = bertTokenizerFast.encode(sequence, return_tensors="pt").cuda()
mask_token_index = torch.where(input == bertTokenizerFast.mask_token_id)[1]

token_logits = positiveModel(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(sequence.replace(bertTokenizerFast.mask_token, bertTokenizerFast.decode([token])))

# Dissecting BERT

In [None]:
## Dissect BERT

from transformers import BertForMaskedLM

mlmModel = BertForMaskedLM.from_pretrained('bert-base-uncased', output_attentions = False, cache_dir = cache_dir)

### Try to make BERT angry

embeddings = dict(mlmModel.base_model.named_children())['embeddings']

_bad_token = 2919
_bad_emb = embeddings(torch.LongTensor([[_bad_token]])).squeeze()

def _add_bad_emb(module, inp, outp):
    return f(module, inp ,outp)
hookHandle = embeddings.register_forward_hook(_add_bad_emb)

def f1(module, inp, outp):
    old_norm = torch.norm(outp)
    new_outp = outp +  _bad_emb
    new_norm = torch.norm(new_outp)
    norm_ratio = new_norm / old_norm
    return new_outp / norm_ratio


def f2(module, inp, outp):
    noise = (_bad_emb - outp)
    rate = 1
    new_outp = (1-rate) * outp +  rate * noise
    return new_outp

def f(module, inp, outp):
    pass

old_mask_emb = embeddings.word_embeddings.weight[bertTokenizer.mask_token_id]

bertTokenizer("she")
_she_token = 2016

embeddings.word_embeddings.weight[:] +=   .5* (embeddings.word_embeddings
                                                                        .weight[_bad_token])

### Play with Layers

#### Old

bertLayers = list(mlmModel.base_model.encoder.layer.children())

hookHandle = {}

def fLayer(module, inp, outp):
    if module.isFirst:
        module.isFirst = False
#        return (module(module(outp[0])[0])[0],)

layerNum = 4 

hookHandle[layerNum].remove()

bertLayers[layerNum].isFirst = True
hookHandle[layerNum] = bertLayers[layerNum].register_forward_hook(fLayer)

bertLayers = list(mlmModel.base_model.encoder.layer.children())

hookHandle = {}

def fLayer(module, inp, outp):
    if module.isFirst:
        module.isFirst = False
#        return (module(module(outp[0])[0])[0],)

layerNum = 4 

hookHandle[layerNum].remove()

bertLayers[layerNum].isFirst = True
hookHandle[layerNum] = bertLayers[layerNum].register_forward_hook(fLayer)

#### New

myLayers = mlmModel.base_model.encoder.layer

myLayers.insert(0, myLayers[4])

#### Fill Mask

sequence = "She is like that sometimes. Don't let it bother you. She is [MASK]."
input = bertTokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == bertTokenizer.mask_token_id)[1]

token_logits = mlmModel(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(sequence.replace(bertTokenizer.mask_token, bertTokenizer.decode([token])))
