# LIME to Inspect Text Classification 

This tutorial focuses on showing how to use Captum's implementation of Local Interpretable Model-agnostic Explanations (LIME) to understand neural models. The following content is divided into an image classification section to present our high-level interface `Lime` class and a text classification section for the more customizable low-level interface `LimeBase`. 

## 2. Text Classification

In this section, we will take use of a news subject classification example to demonstrate more customizable functions in Lime. We will train a simple embedding-bag classifier on AG_NEWS dataset and analyze its understanding of words.

In [None]:
!conda install -c pytorch torchtext==0.8

In [4]:
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab

from collections import Counter

from IPython.core.display import HTML, display

ImportError: DLL load failed while importing _torchtext: The specified procedure could not be found.

### 2.1 Load the data and define the model

`torchtext` has included the AG_NEWS dataset but since it is only split into train & test, we need to further cut a validation set from the original train split. Then we build the vocabulary of the frequent words based on our train split.

In [4]:
ag_ds = list(AG_NEWS(split='train'))

ag_train, ag_val = ag_ds[:100000], ag_ds[100000:]

tokenizer = get_tokenizer('basic_english')
word_counter = Counter()
for (label, line) in ag_train:
    word_counter.update(tokenizer(line))
voc = Vocab(word_counter)

print('Vocabulary size:', len(voc))

num_class = len(set(label for label, _ in ag_train))
print('Num of classes:', num_class)

Vocabulary size: 86716
Num of classes: 4


The model we use is composed of an embedding-bag, which averages the word embeddings as the latent text representation, and a final linear layer, which maps the latent vector to the logits. Unconventially, `pytorch`'s embedding-bag does not assume the first dimension is batch. Instead, it requires a flattened vector of indices with an additional offset tensor to mark the starting position of each example. You can refer to its [documentation](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#embeddingbag) for details.

In [13]:
class EmbeddingBagModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim)
        self.linear = nn.Linear(embed_dim, num_class)

    def forward(self, inputs, offsets):
        embedded = self.embedding(inputs, offsets)
        return self.linear(embedded)

### 2.2 Training and Baseline Classification

In order to train our classifier, we need to define a collate function to batch the samples into the tensor fomat required by the embedding-bag and create the interable dataloaders.

In [6]:
BATCH_SIZE = 64

def collate_batch(batch):
    labels = torch.tensor([label - 1 for label, _ in batch]) 
    text_list = [tokenizer(line) for _, line in batch]
    
    # flatten tokens across the whole batch
    text = torch.tensor([voc[t] for tokens in text_list for t in tokens])
    # the offset of each example
    offsets = torch.tensor(
        [0] + [len(tokens) for tokens in text_list][:-1]
    ).cumsum(dim=0)

    return labels, text, offsets

train_loader = DataLoader(ag_train, batch_size=BATCH_SIZE,
                          shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(ag_val, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=collate_batch)

We will then train our embedding-bag model with the common cross-entropy loss and Adam optimizer. Due to the simplicity of this task, 5 epochs should be enough to give us a stable 90% validation accuracy. 

In [8]:
EPOCHS = 2
EMB_SIZE = 64
CHECKPOINT = './embedding_bag_ag_news.pt'
USE_PRETRAINED = True  # change to False if you want to retrain your own model

def train_model(train_loader, val_loader):
    model = EmbeddingBagModel(len(voc), EMB_SIZE, num_class)
    
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(1, EPOCHS + 1):      
        # training
        model.train()
        total_acc, total_count = 0, 0
        
        for idx, (label, text, offsets) in enumerate(train_loader):
            optimizer.zero_grad()
            predited_label = model(text, offsets)
            loss(predited_label, label).backward()
            optimizer.step()
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)

            if (idx + 1) % 500 == 0:
                print('epoch {:3d} | {:5d}/{:5d} batches | accuracy {:8.3f}'.format(
                    epoch, idx + 1, len(train_loader), total_acc / total_count
                ))
                total_acc, total_count = 0, 0       
        
        # evaluation
        model.eval()
        total_acc, total_count = 0, 0

        with torch.no_grad():
            for label, text, offsets in val_loader:
                predited_label = model(text, offsets)
                total_acc += (predited_label.argmax(1) == label).sum().item()
                total_count += label.size(0)

        print('-' * 59)
        print('end of epoch {:3d} | valid accuracy {:8.3f} '.format(epoch, total_acc / total_count))
        print('-' * 59)
    
    torch.save(model, CHECKPOINT)
    return model
        
eb_model = torch.load(CHECKPOINT) if USE_PRETRAINED else train_model(train_loader, val_loader)

Now, let us take the following sports news and test how our model performs.

In [12]:
test_label = 2  # {1: World, 2: Sports, 3: Business, 4: Sci/Tec}
test_line = ('US Men Have Right Touch in Relay Duel Against Australia THENS, Aug. 17 '
            '- So Michael Phelps is not going to match the seven gold medals won by Mark Spitz. '
            'And it is too early to tell if he will match Aleksandr Dityatin, '
            'the Soviet gymnast who won eight total medals in 1980.')

test_labels, test_text, test_offsets = collate_batch([(test_label, test_line)])

probs = F.softmax(eb_model(test_text, test_offsets), dim=1).squeeze(0)
print('Prediction probability:', round(probs[test_labels[0]].item(), 4))

AttributeError: 'EmbeddingBag' object has no attribute 'padding_idx'

Our embedding-bag does successfully identify the above news as sports with pretty high confidence.

### 2.3 Inspect the model prediction with Lime

Finally, it is time to bring back Lime to inspect how the model makes the prediction. However, we will use the more customizable `LimeBase` class this time which is also the low-level implementation powering the `Lime` class we used before. The `Lime` class is opinionated when creating features from perturbed binary interpretable representations. It can only set the "absense" features to some baseline values while keeping other "presense" features. This is not what we want in this case. For text, the interpretable representation is a binary vector indicating if the word of each position is present or not. The corresponding text input should literally remove the absent words so our embedding-bag can calculate the average embeddings of the left words. Setting them to any baselines will pollute the calculation and moreover, our embedding-bag does not have common baseline tokens like `<padding>` at all. Therefore, we have to use `LimeBase` to customize the conversion logic through the `from_interp_rep_transform` argument.

`LimeBase` is not opinionated at all so we have to define every piece manually. Let us talk about them in order:
- `forward_func`, the forward function of the model. Notice we cannot pass our model directly since Captum always assumes the first dimension is batch while our embedding-bag requires flattened indices. So we will add the dummy dimension later when calling `attribute` and make a wrapper here to remove the dummy dimension before giving to our model.
- `interpretable_model`, the surrogate model. This works the same as we demonstrated in the above image classification example. We also use sklearn linear lasso here.
- `similarity_func`, the function calculating the weights for training samples. The most common distance used for texts is the cosine similarity in their latent embedding space. The text inputs are just sequences of token indices, so we have to leverage the trained embedding layer from the model to encode them to their latent vectors. Due to this extra encoding step, we cannot use the util `get_exp_kernel_similarity_function('cosine')` like in the image classification example, which directly calculate the cosine similarity of the given inputs.
- `perturb_func`, the function to sample interpretable representations. We present another way to define this argument other than using generator as shown in the above image classification example. Here we directly define a function returning a randomized sample every call. It outputs a binary vector where each token is selected independently and uniformly at random.
- `perturb_interpretable_space`, whether perturbed samples are in interpretable space. `LimeBase` also supports sampling in the original input space, but we do not need it in our case.
- `from_interp_rep_transform`, the function transforming the perturbed interpretable samples back to the original input space. As explained above, this argument is the main reason for us to use `LimeBase`. We pick the subset of the present tokens from the original text input according to the interpretable representation.
- `to_interp_rep_transform`, the opposite of `from_interp_rep_transform`. It is needed only when `perturb_interpretable_space` is set to false.

In [27]:
# remove the batch dimension for the embedding-bag model
def forward_func(text, offsets):
    return eb_model(text.squeeze(0), offsets)

# encode text indices into latent representations & calculate cosine similarity
def exp_embedding_cosine_distance(original_inp, perturbed_inp, _, **kwargs):
    original_emb = eb_model.embedding(original_inp, None)
    perturbed_emb = eb_model.embedding(perturbed_inp, None)
    distance = 1 - F.cosine_similarity(original_emb, perturbed_emb, dim=1)
    return torch.exp(-1 * (distance ** 2) / 2)

# binary vector where each word is selected independently and uniformly at random
def bernoulli_perturb(text, **kwargs):
    probs = torch.ones_like(text) * 0.5
    return torch.bernoulli(probs).long()

# remove absenst token based on the intepretable representation sample
def interp_to_input(interp_sample, original_input, **kwargs):
    return original_input[interp_sample.bool()].view(original_input.size(0), -1)

lasso_lime_base = LimeBase(
    forward_func, 
    interpretable_model=SkLearnLasso(alpha=0.08),
    similarity_func=exp_embedding_cosine_distance,
    perturb_func=bernoulli_perturb,
    perturb_interpretable_space=True,
    from_interp_rep_transform=interp_to_input,
    to_interp_rep_transform=None
)

The attribution call is the same as the `Lime` class. Just remember to add the dummy batch dimension to the text input and put the offsets in the `additional_forward_args` because it is not a feature for the classification but a metadata for the text input.

In [28]:
attrs = lasso_lime_base.attribute(
    test_text.unsqueeze(0), # add batch dimension for Captum
    target=test_labels,
    additional_forward_args=(test_offsets,),
    n_samples=32000,
    show_progress=True
).squeeze(0)

print('Attribution range:', attrs.min().item(), 'to', attrs.max().item())

Lime Base attribution: 100%|██████████| 32000/32000 [00:22<00:00, 1432.67it/s]


Attribution range: -0.4232264757156372 to 0.9536108374595642


At last, let us create a simple visualization to highlight the influential words where green stands for positive correlation and red for negative.

In [29]:
def show_text_attr(attrs):
    rgb = lambda x: '255,0,0' if x < 0 else '0,255,0'
    alpha = lambda x: abs(x) ** 0.5
    token_marks = [
        f'<mark style="background-color:rgba({rgb(attr)},{alpha(attr)})">{token}</mark>'
        for token, attr in zip(tokenizer(test_line), attrs.tolist())
    ]
    
    display(HTML('<p>' + ' '.join(token_marks) + '</p>'))
    
show_text_attr(attrs)

The above visulization should render something like the image below where the model links the "Sports" subject to many reasonable words, like "match" and "medals".

![Lime Text](img/lime_text_viz.png)