# Using `antra` with BERT

[BERT](https://arxiv.org/abs/1810.04805) (Devlin et al, 2018) is one of the most popular and successful pretrained transformer architectures on a wide range of NLP tasks. 

`antra` has built-in functionalities that converts BERT into a computation graph, so that you can perform intervention experiments on it to analyze the causal interactions between the hidden values in its layers.

In `antra.compgraphs.bert` we have defined `BertCompGraph`, which is a `antra.CompGraph` object that is an implementation of the `forward()` function of `transformers.BertModel`. It has the following structure, where each hidden layer in BERT is a computation graph node:

![BERT computation graph](figures/bert_compgraph.png)

Here the large thick arrows indicate that a node outputs a tensor of size `(batch_size, sentence_len, hidden_dim)` which is passed into the next node. You can intervene on any node that has this type of output, i.e. `embed` and `bert_layer_n` for `n in range(12)`.

The `input_preparation` node is a special node that prepares the "metainfo" such as attention masks and token type ids, which are required for the computation for each hidden layer, but remain the same between layers. You cannot intervene on this node. 


## Setup

In [5]:
%load_ext autoreload
%autoreload 2

import transformers
import torch
from torch.utils.data import DataLoader
from antra import Intervention
from antra.compgraphs.bert import BertCompGraph, BertGraphInput
from bert_utils import SentimentData

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load the model and toy dataset

Here we use a toy dataset taken from a sentiment classification task.
Each line in the `tsv` file contains a sentence followed by a label 0 (negative) or 1 (positive). We define a dataset `SentimentData` to preprocess this file and tokenize all the sentences using the default `BertTokenizer`. 

We use this dataset purely for the purpose of demonstrating how to prepare the data, input it into `BertCompGraph`, and perform computations/interventions on it. Normally you would like to fine-tune the model on a dataset and then analyze it.

In [2]:
model = transformers.BertModel.from_pretrained("bert-base-uncased")
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
g = BertCompGraph(model)

dataset = SentimentData("sentiment_data.tsv", tokenizer)


--- Loading sentences from sentiment_data.tsv
--- Loaded 100 sentences from sentiment_data.tsv


## Retrieve values of `BertModel`'s hidden layers, without intervention

The following shows an example of using `antra` to retrieve the hidden values in the layers of `transformers.BertModel`.

The input to `BertCompGraph` should be a `BertGraphInput` (defined in `antra.compgraph.bert`), which inherits from the `GraphInput` class. The `BertGraphInput` class automatically computes keys for the inputs based on the `input_ids`. Run `compute()` or `compute_node()` as usual.

It takes in a batch of inputs in the form of a `dict`. The `dict`'s key-value pairs should correspond to the parameter-value pairs of BERT's forward() function. To check out how to package data into this form using `pytorch`'s `Dataset` and `DataLoader`, check out `examples/bert_utils.py`.


In [3]:
dataloader = DataLoader(dataset, shuffle=False, batch_size=10)


with torch.no_grad():
    for batch in dataloader:
        # Load one batch from dataloader.

        # In the following we demonstrate that the output from the forward()
        # is identical to that from antra's computation graph.

        batch = {k: v.to(device) for k, v in batch.items()}

        gi = BertGraphInput(batch)
        compgraph_root_output = g.compute(gi)
        compgraph_embedding_hidden = g.compute_node("embed", gi)
        compgraph_layer5_hidden = g.compute_node("bert_layer_5", gi)

        print("compgraph root output shape", compgraph_root_output.shape)
        print("compgraph embedding layer shape", compgraph_embedding_hidden.shape)
        print("compgraph 5th hidden layer shape", compgraph_layer5_hidden.shape)

        outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    token_type_ids=batch["token_type_ids"],
                    return_dict=True,
                    output_hidden_states=True
                 )
        final_output = outputs.pooler_output
        embedding = outputs.hidden_states[0]
        layer5 = outputs.hidden_states[6]

        # use idx 6 because the 0th item is the embedding layer
        print("final output shape", final_output.shape)
        print("embedding layer shape", embedding.shape)
        print("5th hidden layer shape", layer5.shape)

        assert torch.allclose(final_output, compgraph_root_output) # True
        assert torch.allclose(embedding, compgraph_embedding_hidden) # True
        assert torch.allclose(layer5, compgraph_layer5_hidden) # True
        break

compgraph root output shape torch.Size([10, 768])
compgraph embedding layer shape torch.Size([10, 53, 768])
compgraph 5th hidden layer shape torch.Size([10, 53, 768])
final output shape torch.Size([10, 768])
embedding layer shape torch.Size([10, 53, 768])
5th hidden layer shape torch.Size([10, 53, 768])


## Performing interventions on `BertCompGraph`

Here we perform an intervention by zeroing out the first 50 elements in the left-most hidden vector of each item in the batch, i.e. we are essentially doing
`bert_layer_5[:,0,:50] = torch.zeros(batch_size, 50)`.

Start by constructing a `BertGraphInput`, which will be used as the base input for a batched `Intervention` object. Also prepare the intervention values with an appropriate shape for the intervention object. Perform the intervention using `intervene()` or `intervene_node()` as usual.

In [10]:
with torch.no_grad():
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        gi = BertGraphInput(batch)
        
        batch_size, _ = batch["input_ids"].shape
        interv_values = torch.zeros(batch_size, 50)
        interv_dict = {"bert_layer_5[:,0,:50]": interv_values}
        
        interv = Intervention.batched(gi, interv_dict)
        
        interv_before, interv_after = g.intervene(interv)
        
        assert not torch.allclose(interv_before, interv_after)
        
        # note that the shapes should match
        bert_layer_5 = g.compute_node("bert_layer_5", gi)
        print("shape of bert_layer_5[:,0,:50]", bert_layer_5[:,0,:50].shape)
        print("shape of interv_values        ", interv_values.shape)
        
        break

shape of bert_layer_5[:,0,:50] torch.Size([10, 50])
shape of interv_values         torch.Size([10, 50])
