In [1]:
import os
from collections import defaultdict
import pickle as pkl
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    DataCollatorWithPadding,
    set_seed,
)

In [2]:
# configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "bert-base-uncased"
ckpt_dir = "./bert-base-uncased/"
mask_dir = "./masks/bert-base-uncased/squad/mac/0.5/seed_0/"
# data_file = "./dataset/cardinal.pkl"
data_file = "./dataset/ptb_pos.txt"
sample_file = "./dataset/sample_seed_0.pkl"
tag_file = "./dataset/relevant_pos.txt"
output_dir = "./features"
RUN_MASKED = False
IS_SQUAD = True

# for reproducibility
set_seed(0) # handles torch, np, random, tf in theory as well

In [3]:
# load the finetuned model and the corresponding tokenizer
config = AutoConfig.from_pretrained(ckpt_dir, output_hidden_states=True)
model_generator = AutoModelForQuestionAnswering if IS_SQUAD else AutoModelForSequenceClassification
model = model_generator.from_pretrained(ckpt_dir, config=config)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=True,
    use_auth_token=None,
)

# Load masks
head_mask = torch.load(mask_dir + "head_mask.pt")
neuron_mask = torch.load(mask_dir + "neuron_mask.pt")

Some weights of the model checkpoint at ./bert-base-uncased/ were not used when initializing BertForQuestionAnswering: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at ./bert-base-unca

In [4]:
# helpers for applying neuron mask
def get_layers(model):
    model_type = model.base_model_prefix
    backbone = getattr(model, model_type)
    encoder = backbone.encoder
    layers = encoder.layer
    return layers

def get_ffn2(model, index):
    layer = get_layers(model)[index]
    ffn2 = layer.output
    return ffn2

def register_mask(module, mask):
    hook = lambda _, inputs: (inputs[0] * mask, inputs[1])
    handle = module.register_forward_pre_hook(hook)
    return handle

def apply_neuron_mask(model, neuron_mask):
    num_hidden_layers = neuron_mask.shape[0]
    handles = []
    for layer_idx in range(num_hidden_layers):
        ffn2 = get_ffn2(model, layer_idx)
        handle = register_mask(ffn2, neuron_mask[layer_idx])
        handles.append(handle)
    return handles

def remove_neuron_mask(handles):
    for handle in handles:
        handle.remove()


In [5]:
# Prepare model and apply neuron mask
model = model.to(device)
model.eval()
handles = apply_neuron_mask(model, neuron_mask)

In [6]:
# Remove neuron mask
# remove_neuron_mask(handles)

In [7]:
manifold_vectors = defaultdict(dict)
with open(tag_file) as f:
    for tag in f:
        tag = tag.strip().lower()
        for layer in range(1,config.num_hidden_layers+1):
            manifold_vectors[layer][tag] = None

In [8]:
line_word_tag_map = pkl.load(open(sample_file, 'rb+'))

In [9]:
with open(data_file, encoding='utf-8') as dfile:
    for line_idx,line in enumerate(dfile):
        if line_idx in line_word_tag_map:
            words, tags = line.strip().split('\t')
            word_list = list(words.split())
            for word_idx in line_word_tag_map[line_idx]:
                tag = line_word_tag_map[line_idx][word_idx].lower()
                if RUN_MASKED:
                    # replace the word_idx location with mask token
                    word_list[word_idx] = tokenizer.mask_token

                if model_name == 'openai-gpt':
                    split_word_idx = []
                else:
                    split_word_idx = [-1]

                # tokenization - assign the same id for all sub words of a same word
                word_tokens = []
                for split_id, split_word in enumerate(word_list):
                    tokens = tokenizer.tokenize(split_word)
                    word_tokens.extend(tokens)
                    split_word_idx.extend([split_id] * len(tokens))

                if model_name != 'openai-gpt':
                    split_word_idx.append(len(word_list))
                # print(word_tokens)

                input_ids = torch.Tensor([tokenizer.encode(text=word_tokens, is_split_into_words=True, add_special_tokens=True)]).long()
                input_ids = input_ids.to(device)
                with torch.no_grad():
                    model_output = model(input_ids, head_mask=head_mask)[-1]
                for layer in range(1,config.num_hidden_layers+1):
                    layer_output = model_output[layer][0]
                    vector_idcs = np.argwhere(np.array(split_word_idx) == word_idx).reshape(-1)
                    token_vector = layer_output[vector_idcs].mean(0).cpu().reshape(-1,1).numpy()
                    if manifold_vectors[layer][tag] is None:
                        manifold_vectors[layer][tag] = token_vector
                    else:
                        manifold_vectors[layer][tag] = np.hstack((manifold_vectors[layer][tag], token_vector))


In [10]:
# save embedding vectors
for layer in range(1,config.num_hidden_layers+1):
    pkl.dump(list(manifold_vectors[layer].values()), open(os.path.join(output_dir, str(layer)+'.pkl'), 'wb+'))