<a href="https://colab.research.google.com/github/mohamedhany13/Interpretability_testing/blob/main/probing_exp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install transformers memory_profiler

Collecting transformers
  Downloading transformers-4.35.0-py3-none-any.whl (7.9 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/7.9 MB[0m [31m3.1 MB/s[0m eta [36m0:00:03[0m[2K     [91m━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/7.9 MB[0m [31m22.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━[0m [32m5.8/7.9 MB[0m [31m60.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m7.9/7.9 MB[0m [31m74.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m56.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting memory_profiler
  Downloading memory_profiler-0.61.0-py3-none-any.whl (31 kB)
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloa

In [2]:
from nltk.corpus import ptb
from transformers import BertModel, BertTokenizer, BertTokenizerFast
import time
import torch
import numpy as np
import csv
import pandas as pd
import cProfile, pstats, sys
import tracemalloc
from sys import getsizeof
from memory_profiler import memory_usage, profile
import psutil
import gc
import os
from sklearn.model_selection import train_test_split

In [3]:
def load_POS_dataset(code_test = False):
    '''
    Description:
        This function uses the ptb package from nltk to read the files for the penn treebank dataset (wsj only) and
        process the dataset so that it consists of 1 sentence per example with each word in the sentence having its
        corresponding POS tag. These sentences may then be updated so that each word is associated with it hidden
        representations from a language model.

    Parameters:
        None

    Returns:
        ptb_POS_list : list, a list containing the 1 sentence examples from the penn treebank dataset
    '''
    processing_time_start = time.time()     # used to measure the time taken to process the brown + wsj datasets
    lapped_time_start = time.time()     # used to measure the time taken for processing 1000 sentences
    sent_list = []      # a list of all the sentences (each sentence is a list of words)
    POS_list = []    # a list of the POS tags for the sentences (there's a list of POS tags for each sentence)
    for file_id in ptb.fileids():
        POS_tagged_sents = ptb.tagged_sents(file_id)   # nltk corpus reader containing the sentences with each word and its correspoding POS tag
        for tagged_sent in POS_tagged_sents:
            tagged_sent_arr = np.array(tagged_sent)
            sent = tagged_sent_arr[:, 0].tolist()
            sent = ' '.join(sent)
            POS = tagged_sent_arr[:, 1].tolist()
            sent_list.append(sent)
            POS_list.append(POS)
            x = 0

        lapped_num_sents = 5000
        if len(POS_list) % lapped_num_sents == 0:
            print(f"time elapsed for processing {lapped_num_sents} sentences: {time.time() - lapped_time_start} seconds (current total is {len(POS_list)})")
            lapped_time_start = time.time()

        # for debugging purposes of rest of code
        if code_test == True and len(POS_list) >= 500:
            print(f"total time elapsed for processing: {time.time() - processing_time_start} seconds")
            print(f"size of POS dataset: {len(POS_list)} sentences")

            save_csv(sent_list, "input_sentences")
            save_csv(POS_list, "output_POS")

            return sent_list, POS_list


    print(f"total time elapsed for processing: {time.time() - processing_time_start} seconds")
    print(f"size of POS dataset: {len(POS_list)} sentences")

    save_csv(sent_list, "input_sentences")
    save_csv(POS_list, "output_POS")

    return sent_list, POS_list

In [4]:
def infer_model(tokenizer, model, batch, device):
    '''
    Description:
        Run model on input batch

    Parameters:
        batch : list of str, a list of the input sentences
        model: model used for inference
        tokenizer: tokenizer used to convert sentences into tokens
        device: device on which the inference is made (CPU/GPU)

    Returns:
        output: output of the model when fed with the tokenized input sentences
        tokenized_batch: input batch after tokenization
    '''

    # convert sentences into tokens
    tokenized_batch = tokenizer(batch, padding=True, return_tensors="pt").to(device)
    # run model on tokenized batch
    with torch.no_grad():
        model.eval()
        output = model(**tokenized_batch)

    return output, tokenized_batch

In [5]:
def create_cls_sep_pad_mask(tokenized_batch):
    '''
    Description:
        create a mask to zero out any [cls], [sep], and padding tokens

    Parameters:
        tokenized_batch : batch of sentences after tokenization

    Returns:
        combined_mask: torch tensor, mask that zeros out all unnecessary tokens
    '''
    # use the padding mask to zero out any padded hidden representations
    padding_mask = tokenized_batch.data['attention_mask'].view(-1)
    # for BERT models, zero out the [cls] and [sep] by adding their positions to the padding mask
    cls_mask = (tokenized_batch.data["input_ids"].view(-1) != 101).float()
    sep_mask = (tokenized_batch.data["input_ids"].view(-1) != 102).float()
    # combine the 3 masks
    combined_mask = torch.mul(torch.mul(cls_mask, sep_mask), padding_mask)

    return combined_mask

In [6]:
def process_BERT_hidden_states(model_output, tokenized_batch, device):
    '''
    Description:
        remove all [cls], [sep], and padding vectors from hidden states vector

    Parameters:
        model_output : output of the model for a batch
        tokenized_batch: batch input sentences after tokenization

    Returns:
        processed_hidden_states: processed hidden states vector
    '''

    # get hidden states from all layers of the model for every input sentence in the batch
    hidden_states = model_output.hidden_states  # tuple of tensors. num_tuples = num_layers, each layer is a tensor with shape (batch_size, seq_len, embedding_size)
    # convert hidden states from tuple of tensors to tensor of tensors (place num_layers dimension beside the embedding_size dimension so that they can be later joined into one dimension)
    hidden_states = torch.squeeze(torch.stack(hidden_states, dim=-2))  # shape: (batch_size, seq_len, num_layers, embedding_size)
    # stack the hidden states for the words of all sentences, and stack the number of layers with the embedding size.
    # this means that there is a 1-D vector for every subword
    hidden_states = hidden_states.view(hidden_states.shape[0] * hidden_states.shape[1], -1)  # shape: (batch_size*seq_len, num_layers*embedding_size)
    # create a mask to remove [cls], [sep], and padding tokens
    combined_mask = create_cls_sep_pad_mask(tokenized_batch)  # shape: (batch_size*seq_len)
    # repeat the mask so that it has the same shape as that of the hidden states
    # "repeat" allocates double the memory required for the tensor for some reason, "expand" doesn't allocate any new memory at all (just gets references to the original tensor)
    # adjusted_mask has to be of the same datatype of hidden_states so as not to allocate more memory than necessary in the elementwise multiplication
    adjusted_mask = torch.unsqueeze(combined_mask, -1).expand(hidden_states.shape[0], hidden_states.shape[-1])  # shape: (batch_size*seq_len, num_layers*embedding_size)
    # apply the mask to the hidden representation tensor
    processed_hidden_states = torch.mul(hidden_states, adjusted_mask)  # shape: (batch_size*seq_len, num_layers*embedding_size)
    # remove rows with all zero columns ([cls], [sep], and padded hidden representations)
    processed_hidden_states_no_overhead = processed_hidden_states[processed_hidden_states.sum(dim=-1) != 0]

    # delete all unnecessary tensors
    del hidden_states
    del combined_mask
    del adjusted_mask
    del processed_hidden_states
    gc.collect()
    torch.cuda.empty_cache()

    return processed_hidden_states_no_overhead

In [7]:
def get_last_subword_hidden_state(batch, batch_num, tokenized_batch, processed_hidden_states, tokenizer, device):
    '''
    Description:
        every word in the sentence is split into a subword, each having its own token and hidden state vector.
        For probing, there needs to be only 1 hidden state per output label, so we select the hidden state of the last
        subword of a word as the hidden state for which there is an output label

    Parameters:
        batch : list of str, a list of the input sentences
        batch_num: int, batch number. Used for verification purposes
        tokenized_batch: batch input sentences after tokenization
        processed_hidden_states: hidden states vector after removing [cls], [sep], and padding vectors
        tokenizer: tokenizer used to convert sentences into tokens

    Returns:
        last_subword_hidden_states: hidden states for the last subword of each word in the sentences of the batch
    '''

    # create a mask to remove [cls], [sep], and padding tokens
    combined_mask = create_cls_sep_pad_mask(tokenized_batch)  # shape: (batch_size*seq_len)
    # used for verification purposes, get a list of the subword tokens for all words in all sentences in the batch
    # this is stacked into a 1-D vector, and the [cls], [sep], and padding tokens are removed from it
    processed_subword_tokens = torch.mul(tokenized_batch.data["input_ids"].view(-1), combined_mask)     # shape: (batch_size*seq_len)
    # processed_subword_tokens = processed_subword_tokens[processed_subword_tokens != 0].detach().numpy().tolist()
    processed_subword_tokens = processed_subword_tokens[processed_subword_tokens != 0]

    token_counter = 0
    last_subword_indices = []       # list that stores index of the last subword of the words in every sentence
    for i, input_example in enumerate(batch):
        input_example = input_example.split(" ")
        for word in input_example:
            # get subword tokenization while removing [cls] and [sep] token
            tokenized_word = tokenizer(word).data["input_ids"][1:-1]
            first_subword_index = token_counter  # first subword index in the hidden states tensor
            last_subword_index = token_counter + len(tokenized_word) - 1  # last subword index in the hidden states vector
            # check that first and last indices computed actually correspond to first and last indices in the hidden states vector by checking
            # that the subword tokens are the same
            tokenized_word = torch.tensor(tokenized_word).to(device)
            assert torch.all(torch.eq(processed_subword_tokens[first_subword_index:last_subword_index + 1], tokenized_word)), "extracted indices don't match the indices in the adjusted subword tokens vector"
            # del tokenized_word
            last_subword_indices.append(last_subword_index)
            token_counter = last_subword_index + 1

    # extract the rows that have the last subword hidden states
    last_subword_hidden_states = torch.squeeze(processed_hidden_states[[last_subword_indices], ...])

    # check that the extraction is done correctly (verification is done only on the first batch)
    if batch_num == 0:
        for i, index in enumerate(last_subword_indices):
            assert torch.equal(last_subword_hidden_states[i, ...], processed_hidden_states[index, ...]), "mismatch between extracted rows"

    return last_subword_hidden_states

In [8]:
def extract_batch_hidden_states(tokenizer, model, batch, device):
    model_output, tokenized_batch = infer_model(tokenizer, model, batch, device)

    # get the hidden states vector for each batch example and remove from them all [cls], [sep], and padding vectors
    processed_hidden_states = process_BERT_hidden_states(model_output, tokenized_batch, device)

    # delete all unnecessary tensors
    del model_output
    gc.collect()
    torch.cuda.empty_cache()

    return processed_hidden_states, tokenized_batch

In [9]:
def extract_word_hidden_states(hidden_state_token_assignment, batch, batch_iteration_num, tokenized_batch,
                               processed_hidden_states, tokenizer, device):
    if hidden_state_token_assignment == "last":
        # for each word, assign the hidden state of the last subword as the hidden state for the whole word
        last_subword_hidden_states = get_last_subword_hidden_state(batch, batch_iteration_num, tokenized_batch, processed_hidden_states,
                                                                   tokenizer, device)
    else:
        # for each word, assign the average of the hidden states of the subword as the hidden state for the whole word
        assert False, "hidden states averaging not implemented yet"

    return last_subword_hidden_states

In [10]:
def torch_concat_large(probe_input_dataset, last_subword_hidden_states):
    # with torch.no_grad():
    #     probe_input_dataset_old = probe_input_dataset
    #     probe_input_dataset = torch.concat((probe_input_dataset_old, last_subword_hidden_states), dim=0)
    #
    # del probe_input_dataset_old
    # gc.collect()
    # torch.cuda.empty_cache()

    with torch.no_grad():
        probe_input_dataset = torch.concat((probe_input_dataset, last_subword_hidden_states), dim=0)

    gc.collect()
    torch.cuda.empty_cache()

    return probe_input_dataset

In [11]:
def add_batch_to_probe_dataset(tokenizer, model, batch, batch_iteration_num, device,
                                      hidden_state_token_assignment, probe_input_dataset):
    processed_hidden_states, tokenized_batch = extract_batch_hidden_states(tokenizer, model, batch, device)

    last_subword_hidden_states = extract_word_hidden_states(hidden_state_token_assignment, batch, batch_iteration_num, tokenized_batch,
                               processed_hidden_states, tokenizer, device)

    # probe_input_dataset = torch_concat_large(probe_input_dataset, last_subword_hidden_states)

    start_index = (probe_input_dataset == 0).all(dim=-1).nonzero()[0]
    end_index = start_index + last_subword_hidden_states.shape[0]
    probe_input_dataset[start_index:end_index, ...] = last_subword_hidden_states

    # delete any unnecessary tensor
    del processed_hidden_states
    del last_subword_hidden_states
    gc.collect()
    torch.cuda.empty_cache()


    return probe_input_dataset

In [12]:
def extract_hidden_states(dataset, dataset_num_words, batch_size=32, hidden_state_token_assignment = "last", model_name = 'bert-base-uncased'):
    '''
    Description:
        It gets the hidden representations of the model for each input sentence for the whole datasest

    Parameters:
        dataset : Dict, the dictionary for the input sentence
        model_name: str, name of the model from which the hidden representations are extracted
        hidden_state_token_assignment: str, string that determines whether to assign a word with the representation of the
            last subword in that word or with the average of the hidden representations for all the subwords for that word.
            This variable should either equal "last" or "average"

    Returns:
        updated_dataset: list, list of the examples of the dataset updated with the hidden representations for every word in the example sentences
    '''

    assert hidden_state_token_assignment == "last" or hidden_state_token_assignment == "average", "the variable hidden_state_token_assignment is taking an unknown string value"
    # use GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # initialize the Bert model along with its tokenizer
    model = BertModel.from_pretrained(model_name, output_hidden_states=True).to(device)
    tokenizer = BertTokenizerFast.from_pretrained(model_name)
    # initialize dataloader for use in inference
    test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    lapped_time_start = time.time()  # used to measure the time taken for processing 100 sentences

    probe_input_example_size = 13*768
    probe_input_dataset = torch.zeros([dataset_num_words,probe_input_example_size], requires_grad=False).to(device)
    # probe_input_dataset = torch.empty([0, probe_input_example_size], requires_grad=False).to(device)

    for i, batch in enumerate(test_dataloader):

        probe_input_dataset = add_batch_to_probe_dataset(tokenizer, model, batch, i, device,
                                      hidden_state_token_assignment, probe_input_dataset)

        # delete any unnecessary tensor
        gc.collect()
        torch.cuda.empty_cache()

        print(f"iteration number {i}, time elapsed for processing {batch_size} sentences: {time.time() - lapped_time_start} seconds (current total is {i*batch_size})")
        lapped_time_start = time.time()

        # if i % 4 == 0 and i != 0:
        #     print(f"time elapsed for processing {i*batch_size} sentences: {time.time() - lapped_time_start} seconds (current total is {i*batch_size})")
        #     lapped_time_start = time.time()

    return probe_input_dataset

In [13]:
def load_POS_dataset_csv():

    sent_df_path = os.getcwd()+os.path.sep+"data"+os.path.sep+"input_sentences.csv"
    sent_df = pd.read_csv(sent_df_path, header=None)

    sent_list = sent_df[0].values.tolist()

    POS_df_path = os.getcwd()+os.path.sep+"data"+os.path.sep+"output_POS.csv"
    POS_df = pd.read_csv(POS_df_path, header=None)

    POS_list = POS_df.values.tolist()
    # remove any nan produced from conversion to dataframe
    cleaned_POS_list = []
    for i, sublist in enumerate(POS_list):
        cleaned_POS_list.append([POS_tag for POS_tag in sublist if isinstance(POS_tag, str)])
        assert len(cleaned_POS_list[i]) == len(sent_list[i].split(" ")), "mismatch between POS_list and sent_list"

    return sent_list, cleaned_POS_list

def load_csv(file_name, index_col=None):

    df_path = os.getcwd()+os.path.sep+"data"+os.path.sep+file_name+".csv"
    df = pd.read_csv(df_path, index_col = index_col, header=None)

    return df

def save_csv(input_list, csv_name, save_inex = False):
    df_path = os.getcwd() + os.path.sep + "data" + os.path.sep + csv_name + ".csv"
    df = pd.DataFrame(input_list)
    df.to_csv(df_path, index=save_inex, header=False)

    return

def save_torch_tensor(input_tensor, tensor_name):
    tensor_path = os.getcwd() + os.path.sep + "data" + os.path.sep + tensor_name + ".pt"
    torch.save(input_tensor, tensor_path)

    return

def get_num_words_in_dataset(dataset):
    '''
    Description:
        This function gets the total number of words and punctuation marks in the dataset

    Parameters:
        dataset: list, a list of the example sentences in the dataset

    Returns:
        num_words : int, the number of words and punctuation marks in the dataset
    '''
    num_words = 0
    for example in dataset:
        sentence_len = len(example["split_text"])
        num_words += sentence_len

    return num_words

def flatten_nested_list(nested_list):
    flattened_list = [item for sublist in nested_list for item in sublist]

    return flattened_list

def split_dataset_quarters(sent_list, POS_list):
    dataset_len = len(sent_list)
    quarter_len = int(dataset_len / 4)
    quarter_sent_split = []
    quarter_POS_split = []
    for i in range(4):
        start_index = i*quarter_len
        if i == 3:
            end_index = len(sent_list)
        else:
            end_index = (i+1)*quarter_len
        quarter_sent_split.append(sent_list[start_index:end_index])
        quarter_POS_split.append(POS_list[start_index:end_index])

    return quarter_sent_split, quarter_POS_split

def create_Probe_dataset(sent_list, POS_list, POS_conversion_dict, batch_size, device, set_name):
    flattened_POS_list = flatten_nested_list(POS_list)
    dataset_num_words = len(flattened_POS_list)
    extraction_start_time = time.time()
    probe_input_data = extract_hidden_states(sent_list, dataset_num_words, batch_size)
    print(f"total time for extracting hidden states: {(time.time() - extraction_start_time) / 60} minutes")

    if (set_name == "train"):
        input_data_file_name = set_name + "_probe_input_" + str(i + 1)
        output_data_file_name = set_name + "_probe_output_" + str(i + 1)
    else:
        input_data_file_name = set_name + "_probe_input"
        output_data_file_name = set_name + "_probe_output"

    save_time = time.time()

    # save the input dataset
    save_torch_tensor(probe_input_data, input_data_file_name)

    # convert POS tags to class values
    POS_class_list = convert_POS_tag_to_class(flattened_POS_list, POS_conversion_dict)
    probe_output_data = torch.tensor(POS_class_list).to(device)
    # save the output dataset
    save_torch_tensor(probe_output_data, output_data_file_name)

    print(f"save time: {(time.time() - save_time) / 60} minutes")

    del probe_input_data
    del probe_output_data
    gc.collect()
    torch.cuda.empty_cache()

    return

def create_POS_tags_dictionary(POS_list):

    flattened_POS_list = flatten_nested_list(POS_list)
    flattened_POS_arr = np.array(flattened_POS_list)
    POS_tags = np.unique(flattened_POS_arr)

    POS_dict = {}
    for i in range(len(POS_tags)):
        POS_dict[POS_tags[i]] = i

    POS_conversion_list = list(POS_dict.items())
    save_csv(POS_conversion_list, "POS_conversion_dictionary")

    return POS_dict

def load_POS_conversion_dictionary():

    POS_conversion_df = load_csv("POS_conversion_dictionary")
    POS_conversion_list = POS_conversion_df.values.tolist()
    POS_conversion_dict = {}
    for i, POS_tag in enumerate(POS_conversion_list):
        POS_conversion_dict[POS_tag[0]] = POS_tag[1]

    return POS_conversion_dict

def convert_POS_tag_to_class(flattened_POS_list, POS_conversion_dict):

    # flattened_POS_list = flatten_nested_list(POS_list)
    POS_class_list = []
    for i, POS_tag in enumerate(flattened_POS_list):
        POS_class = POS_conversion_dict[POS_tag]
        POS_class_list.append(POS_class)


    return POS_class_list

In [14]:
    # load the POS dataset and the POS conversion dictionary
    batch_size = 512
    train_split = 0.9
    validation_split = 0.05
    test_split = 0.05
    assert train_split + validation_split + test_split == 1, "train-validation-test split doesn't add to 100%"

    load_POS_conversion_dictionary()


    start_time = time.time()
    # sent_list, POS_list = load_POS_dataset()       # load the penn treebank dataset
    sent_list, POS_list = load_POS_dataset_csv()
    POS_conversion_dict = load_POS_conversion_dictionary() # dictionary that converts each POS tag into a class value

    # POS_dict = create_POS_tags_dictionary(POS_list)


    train_sent_list, temp_sent_list, \
    train_POS_list, temp_POS_list = train_test_split(sent_list, POS_list, train_size= train_split,
                                                     shuffle= True, random_state= 1)
    validation_sent_list, test_sent_list, \
    validation_POS_list, test_POS_list = train_test_split(temp_sent_list, temp_POS_list,
                                                          train_size= validation_split/(validation_split+test_split),
                                                          shuffle= True, random_state= 1)

    train_sent_quarter_splits, train_POS_quarter_splits= split_dataset_quarters(train_sent_list, train_POS_list)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # extract hidden states for training set
    for i in range(len(train_sent_quarter_splits)):
        create_Probe_dataset(train_sent_quarter_splits[i], train_POS_quarter_splits[i], POS_conversion_dict, batch_size,
                             device, set_name = "train")

    # extract hidden states for validation set
    create_Probe_dataset(validation_sent_list, validation_POS_list, POS_conversion_dict, batch_size,
                         device, set_name="validation")

    # extract hidden states for test set
    create_Probe_dataset(test_sent_list, test_POS_list, POS_conversion_dict, batch_size,
                         device, set_name="test")


  POS_df = pd.read_csv(POS_df_path, header=None)


Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

iteration number 0, time elapsed for processing 512 sentences: 11.984705924987793 seconds (current total is 0)
iteration number 1, time elapsed for processing 512 sentences: 3.9679484367370605 seconds (current total is 512)
iteration number 2, time elapsed for processing 512 sentences: 3.788261651992798 seconds (current total is 1024)
iteration number 3, time elapsed for processing 512 sentences: 3.97697114944458 seconds (current total is 1536)
iteration number 4, time elapsed for processing 512 sentences: 3.721247434616089 seconds (current total is 2048)
iteration number 5, time elapsed for processing 512 sentences: 3.701982259750366 seconds (current total is 2560)
iteration number 6, time elapsed for processing 512 sentences: 3.7655158042907715 seconds (current total is 3072)
iteration number 7, time elapsed for processing 512 sentences: 3.86869215965271 seconds (current total is 3584)
iteration number 8, time elapsed for processing 512 sentences: 3.731041669845581 seconds (current t