# Interpretable Machine Learning
## Exercise Sheet 9: Ante-hoc Interpretability Methods
## This exercise sheet covers lecture 9 on Ante-hoc Interpretability Metods
Sophie Langbein (langbein@leibniz-bips.de)<br>
Pegah Golchian (golchian@leibniz-bips.de)
<hr style="border:1.5px solid gray"> </hr>


# Instance-wise Feature Selection with Select and Predict Models

Select and predict style models are composed of a selector, that predicts a binary mask over the input features of each instance, and a predictor, that consumes the masked input to make a final prediction. Unfortunately, binary masking is a non-differentiable operation. This makes it hard to train such models end-to-end. A workaround is the so-called pipeline setup, where selector and predictor are trained independently. In addition to a label for each instance to train the predictor, this requires groundtruth explanations.

`text: A gorgeous musical I watched at the palace cinema` <br>
`label: 1 (positive)` <br>
`rationale: [0,1,0,0,0,0,0,0,0]`

`text: What a bad drama` <br>
`label: 0 (negative)` <br>
`rationale: [0,0,1,0,0]`

In this exercise, the goal is to train and run such a pipeline model on a movie review sentiment classification dataset. The dataset also includes rationale annotations, i.e. highlights that represent important tokens in form of a binary mask. We want to use these as groundtruth explanations to train the selector model. Here are two examples:

Note that the 1s in the rationales highlight the tokens that are important for sentiment classification.


**Selector Model**

The goal of the selector model is to predict these rationale masks. In our case, the selector is a token classifier, that predicts either 0 or 1 for each token in the sequence. For this exercise, we choose `DistilBERT` for both the selector and the predictor model. Running the selector consists of the following steps:

1. First, the input text has to be tokenized. That means, the input text tokens are mapped to a sequence of integer input ids. For BERT-style models, a few special tokens are added: Each sequence starts with a `[CLS]` token and ends with a `[SEP]` token. Since all sequences in one batch must be of the same length, a `[PAD]` token is used to pad shorter sequences to the same length as the longest one in the batch. For a batch containing our two examples, the tokenized input could look like this:

    `text: [CLS] A gorgeous musical I watched at the palace cinema [SEP]` <br>
    `input ids: [101, 2, 876, 1098, 5, 66, 78, 134, 867, 555, 102]` <br>
    `attention mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]` <br>
    `rationale: [−100,0,1,0,0,0,0,0,0,0,−100]` <br>

    `text: [CLS] What a bad drama . [SEP] [PAD] [PAD] [PAD] [PAD]` <br>
    `input ids: [101, 44, 2, 11, 43, 3, 102, 0, 0, 0, 0]` <br>
    `attention mask: [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]` <br>
    `rationale: [−100,0,0,1,0,0,−100,−100,−100,−100,−100]` <br>

The special tokens `[CLS]`, `[SEP]`, `[PAD]` are represented by the special input ids `101`, `102`, and `0`, respectively. Also, the rationale is aligned to the tokenized input ids and is filled with `-100` to indicate that the label and the token at that index should be ignored when computing the loss or performance metrics.

2. The model takes as input the input ids and attention masks. The output obtained by running a forward pass are logits of shape (batch size, 2, sequence length). By normalization, one can obtain the probabilities for each of the two classes for each token in the input from these logits.

3. After obtaining the predictions, the input tokens for which the predicted label is `0` can be masked by replacing the corresponding token ids with the token id of the `[MASK]` token.

4. Finally, the input ids can be converted back to string text using a decode function, that inverts the tokenization process. Any `[CLS]`, `[SEP]`, `[PAD]` tokens can be dropped. An example result could be:

    `masked text: [MASK] gorgeous [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]`

**Predictor Model**

The predictor model is a sequence classification model, as its goal is to predict a single label for every input sequence. Given a movie review text that was masked by the selector model, it predicts either 1 (`positive`) or 0 (`negative`). The inputs must be tokenized in the same way as for the selector. The model output is now of shape (batch size, 2). Again, these are the logits for each class.


<hr style="border:1.5px solid gray"> </hr>

## Training the Selector Model

In this exercise, we build on pretrained models. Yet, the selector needs some finetuning to the new task of extracting explanation masks.

**a)** As a first step, make sure the required packages are installed in your python environment. Then import the required packages and functions as described below and familiarize yourself with the imports from `utils` using the documentation from below.

In [None]:
# Required packages
numpy
pandas
pytest
torch
transformers
pytreebank
tqdm

In [2]:
import sys
import os  # noqa
sys.path.insert(0, "")  # noqa

import torch
torch.manual_seed(0)

from utils.dataset import (
        SentimentRationaleDataset,
        tokenize,
        decode,
        pad_token_id,
        mask_token_id,
        cls_token_id,
        sep_token_id,
        _custom_collate
    )
from classifiers.distilbert import Selector, Predictor
from tqdm.auto import tqdm

In [None]:
"""
Some documentation for imports:


tokenize (function):
    Use this function to tokenize the input text, and optionally align the corresponding rationales.

    Parameters:
        text (List[List[str]]):
            A batch of text as returned by the dataloaders.
        Optional: rationale (List[List[int]]):
            A batch of rationale masks as returned by the dataloaders.
            Required when using the rationales as labels, as they have to remain aligned with the text tokens.
    Returns:
        tokenized_inputs (dict):
            A dict containing the tokenized text (key='input_ids'), an attention_mask (key='attention_mask') and aligned rationales (key='rationales) if passed.
            Rationale labels that belong to tokens that not belong to the text are labeld with -100.


decode (function):
    Use this function to turn tokenized input_ids back to text.

    Parameters:
        input_ids (torch.tensor): A batch of input ids.
    Returns:
        text (str): decoded text.


{pad, mask, cls, sep}_token_id (int):
    The token_id representing the [PAD], [MASK], [CLS], [SEP] token, respectively.

"""

**b)** The goal of this exercise is to complete the `train_selector_model` function that trains the token classification head of the `DistilBERT` model for one epoch and then validates the model. The function should take the selector model, the training instances and the validation instances as inputs and return average training loss, average training accuracy over batches as well as the average validation loss and accuracy over all batches. For more detailed descriptions, read the documentation in the function below. 

**i)** In a first step complete the `train_one_epoch` function. The function trains the selector model for one epoch. The training can be summarized in the following steps:

- first, set the selector model to training mode using `selector_model.train()`
- then, iterate over all batches in the training dataloader `dl_train`
- tokenize each batch and rationales using the `tokenize` function from `utils`
- set gradients to zero
- perform a forward pass through then selector model using the tokenized inputs, and obtain model output using `selector_model(input_ids='input_ids', attention_mask='attention_mask')` with the input_ids and attention mask obtained from the `tokenize` function
- compute the cross-entropy loss using `loss = torch.nn.functional.cross_entropy()`
- perform a backward pass & optimization step using `loss.backward()` and `optimizer.step()`
- finally compute the training accuracy for each batch
- for the training accuracy obtain the predicted labels (binary) from the selector model using `.argmax(1)` (important: exclude tokens with label -100 (used for padding))
- obtain the true labels of all tokens != -100, from `tokenized_batch['rationales'] 
- then, compute the accuracy by comparing predicted and true labels             

**ii)** In a second step complete the `validate` function. The function validates the selector model. The validation can be summarized in the following steps:

- first, set the selector model to evluation mode using `selector_model.eval()`
- then, iterate over all batches in the training dataloader `dl_val`
- tokenize each batch and rationales using the `tokenize` function from `utils`
- switch off gradient tracking using `with torch.no_grad()`
- compute the model output from the selector model without gradient tracking, using `selector_model(input_ids='input_ids', attention_mask='attention_mask')` with the input_ids and attention mask obtained from the `tokenize` function
- compute the cross-entropy loss using `loss = torch.nn.functional.cross_entropy()`
- for the validation accuracy obtain the predicted labels (binary) from the selector model using `.argmax(1)` (important: exclude tokens with label -100 (used for padding))
- obtain the true labels of all tokens != -100, from `tokenized_batch['rationales'] 
- then, compute the accuracy by comparing predicted and true labels             

**iii)** In a final step call the `train_one_epoch` and the `validate` functions to obtain both training and validation losses and accuracies for one epoch. Compute their respective means and return them as  `epoch_train_loss`, `epoch_train_acc`, `epoch_val_loss`,`epoch_val_acc`. 

**Hint:** To retain the compatibility with other PyTorch operations, convert the Python lists `train_losses`, `train_accs`, `val_losses`, `val_accs` to PyTorch tensors using `torch.tensor` and then retrieve the mean as a Python float.

**Solution:**

In [None]:
def train_selector_model(selector_model, dl_train, dl_val):
    """
    Trains the given selector model for one epoch, then validates the model.
    Essentially, the goal of the selector model is to predict a mask such that only the important tokens are revealed.
    For example, for the positive movie review
        `A georgeous movie .`
    the prediction could be `0, 1, 0, 0`.
    For each token in the input sequence, the model predicts 0 if the token should be masked and 1 if the token should be revealed.
    In this exercise, we train the selector in a supervised manner, using annotated rationale data (in the form of binary masks).

    The dataloaders for the selector return batches in the form of dicts, with the following structure:
        'text': List[List[str]]:
            A batch of movie review text.
            Each review is a List of tokens.
        'rationale': List[List[int]]:
            A batch of rationale masks.
            Each rationale is a List representing a binary mask over tokens (length = num of tokens in text).
        'label': List[int]: 
            A batch of labels, either 0 (negative) or 1 (positive).
            Not relevant for training the selector, as here the rationale masks are used as groundtruth.

    Parameters:
        selector_model (Selector): 
            A token classification model based on DistilBERT.
            For each token in the input sequence, the model predicts whether it should be masked (0) or not (1).
            The selector_model is also a torch.nn.Module, so you can call its forward method as
                selector_model(input_ids, attention_mask)
            Both of these inputs can be created by applying the `tokenize` function on a batch returned by the dataloaders.

        dl_train (torch.utils.data.DataLoader): The dataloader containing the training instances in batches.
        dl_val (torch.utils.data.DataLoader): The dataloader containing the validation instances in batches.
    
    Returns:
        epoch_train_loss (float): The average loss over the batches seen during training.
        epoch_train_acc (float): The average accuracy over the batches seen during training.
        epoch_val_loss (float): The average loss over the batches seen during validation.
        epoch_val_acc (float): The average accuracy over the batches seen during validation.
    """
    optimizer = torch.optim.AdamW(selector_model.parameters(), lr=1e-5) #  initialize an AdamW optimizer for the parameters of the selector_model with a learning rate of 1e-5.

    def train_one_epoch():
        """

        Returns:
            train_losses (List[float]): A list containing the loss of each batch seen during training.
            train_accs (List[float]) A list containing the accuracy of each batch seen during training.

        Hints: 
            - Use `.item()` before appending the loss / accuracy of a batch to the corresponding list.
            - torch.nn.functional.cross_entropy already automatically ignores inputs labeled with -100.
            - When computing accuracy, only include the input_ids belonging to the text.
                These are all the tokens for which the tokenized rationale mask is != -100.
            
        """
        # fill in 
        
        return train_losses, train_accs # return the list of training losses and accuracies

    def validate():
        """

        Returns:
            val_losses (List[float]): A list containing the loss of each batch seen during validation.
            val_accs (List[float]) A list containing the accuracy of each batch seen during validation.

        Hint: See train_one_epoch(). Optionally use `with torch.no_grad():` to disable gradient tracking (not needed during eval).
            
        """
        # fill in 
        
        return val_losses, val_accs # return the list of validation losses and accuracies
        
      # fill in 
    
      return epoch_train_loss, epoch_train_acc, epoch_val_loss, epoch_val_acc # return the computed metrics for one epoch

**c)** Complete the `select` function, which uses the selector model to predict the masked text for all instances in the dataloader. 

The function should perform the following steps: 

- iterate over batches in the provided dataloader
- for each batch tokenize the text and the rationale
- disable gradient tracking during inference using `with torch.no_grad()`
- obtain the model's output by performing a forward pass on the tokenized batch using `selector_model(input_ids='input_ids', attention_mask='attention_mask')` as before
- compute the predicted mask by selecting the index with the maximum value along the second dimension of the model's output using `.argmax(1)` 
- then iterate over instances in the batch
- retrieve the input_ids for the current instance
- create a mask of relevant tokens (excluding special tokens, [CLS], [SEP], and [PAD] (all input_ids that are not `cls_token_id`, `sep_token_id`, `pad_token_id` from `utils`)
- apply the predicted mask to the relevant input_ids, by replacing all input_ids for which the mask=0 with `mask_token_id` from `utils`
- decoding the input_ids back to (now masked) text using the `decode` function from `utils`
- append the masked text (split into tokens) and the label for the current instance to a dictionary
- by iterating over all batches create a list containing one dictionary for every instance in the dataloader

**Solution:**

In [5]:
def select(selector_model, dl):
    """

    Parameters:
        selector_model (Selector): 
            The selector model used for prediction.
        dl (torch.utils.data.DataLoader):
            The dataloader containing the instances to predict.

    Returns:
        selections (List[dict]):
            A list containing one dict for every instance in the dataloader.
            Each dict has two keys:
                'text': A list of tokens (as in the dataloader). Some tokens are replaced with the mask token [MASK]
                'label': The label of the instance (as in the dataloader).

    """
    # fill in 
    
    return selector_outputs

**c)** Complete the `predict` function, which predicts the sentiment label for the instances in the dataloader.

The function should perform the following steps: 

- iterate over batches in the provided dataloader
- for each batch tokenize the text and the rationale
- disable gradient tracking during inference using `with torch.no_grad()`
- obtain the output of the predictor model by performing a forward pass on the tokenized batch using `predictor_model(input_ids='input_ids', attention_mask='attention_mask')` as before
- compute the predicted labels by selecting the index with the maximum value along the second dimension of the model output using `.argmax(1)` 
- collect the predictions in a list and return them

**Solution:**

In [6]:
def predict(predictor_model, dl):
    """

    Parameters:
        predictor_model (Predictor):
            A sequence classification model based on DistilBERT.
            For each input sequence, the model predicts whether it is negative (0) or positive (1).
            The predictor model is also a torch.nn.Module, so you can call its forward method as
                predictor_model(input_ids, attention_mask)
            Both of these inputs can be created by applying the `tokenize` function on a batch returned by the dataloaders.

        dl (torch.utils.data.DataLoader):
            The dataloader containing the instances to predict.
            The instances in the dataloader are the results of the `select` function.

    Returns:
        predictions (List[int]): A list containing the predicted labels (0 or 1) for each instance in the dataloader.
    """
    # fill in 
    
    return predictions # returns the list of predicted labels for each instance in the dataloader

**d)** Now we want to use the above functions to perform instance-wise feature selection. For that purpose first load training and validation data and get training and validation dataloaders as described below. Also initialize the selector model. 

**Solution:**

In [8]:
# get training and validation set
ds_train = SentimentRationaleDataset('train', limit=1000)
ds_val = SentimentRationaleDataset('dev', limit=100)

# get training and validation dataloaders
dl_train = ds_train.get_dataloader()
dl_val = ds_val.get_dataloader()

# initialize an instance of the selector class and assign it to the variable selector_model
selector_model = Selector()

Downloading config.json:   0%|          | 0.00/563 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/265M [00:00<?, ?B/s]

**e)** Train the token classification head of the `DistilBERT` model for one epoch and then validate the model, by computing training and validation loss and accuracy using the `train_selector_model` function.

**f)** Predict the masked text for all instances in the dataloader using the `select`function. 

**g)** Predict the sentiment label for the instances in the dataloader using the `predict`function. For this purpose the predictor model needs to be initialized and a DataLoader (dl) for the validation dataset needs to be created. 

**Solution:**

In [None]:
# initialize an instance of the predictor class and assign it to the variable predictor_model
predictor_model = Predictor()

# create a DataLoader for validation dataset
dl = torch.utils.data.DataLoader(ds_val_masked, batch_size=8, collate_fn=_custom_collate)

**h)** Run the following code to see whether the instance-wise feature selection process was successfull. The code should print the first for instances of the movie review dataset, then the selection with the masked text, the prediction of the predictor model of the corresponding instance and its groundtruth. 

**Solution:**

In [None]:
print('=' * 80)
print('Examples:')
for i in range(0, 4):
    print('-' * 80)
    print(f'Instance: {" ".join(ds_val[i]["text"])}')
    print(f'Selection: {" ".join(ds_val_masked[i]["text"])}')
    print(f'Prediction: {"positive" if predictions[i] else "negative"}')
    print(f'Groundtruth: {"positive" if ds_val[i]["label"] else "negative"}')
    print('-' * 80)

<hr style="border:1.5px solid gray"> </hr>