## CAPTUM Attributions

Words' importance using CAPTUM.

In [None]:
import torch 
import pandas as pd
import numpy as np
from transformers import BertForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer

DATASETS = [
    # 'V-DS1',
    # 'V-DS2',
    # 'V-DS3',
    # 'SI-DS1',
    # 'SI-DS2',
    'H-DS1',
]

OUTPUTS = [
    # 'bertimbau-base',
    # 'bert-base-multilingual-uncased',
    # 'bert-base-multilingual-uncased',
    'bert-base-uncased',
    # 'distilbert-base-uncased'
]

CHECKPOINTS = [
    # 'neuralmind/bert-base-portuguese-cased',
    # 'bert-base-multilingual-uncased',
    # 'bert-base-multilingual-uncased',
    'bert-base-uncased',
    # 'distilbert-base-uncased'
]

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
for i, dataset in enumerate(DATASETS):
    print(f'Dataset {dataset}...')
    for j, output in enumerate(OUTPUTS):
        print(f' Model {output}...')
        model = BertForSequenceClassification.from_pretrained(
            f'./outputs/{dataset}/{output}/model/',
            config=f'./outputs/{dataset}/{output}/model/config.json'
        )
        model.to(device)
        tokenizer = AutoTokenizer.from_pretrained(CHECKPOINTS[j])
        print(f'  Model is designed for {model.num_labels} labels.')

        df = pd.read_csv(f'./outputs/{dataset}/{output}/test_results_complete.csv', sep=';')
        tweets = df.tweet.values
        gotten_classes = df.got.values

        multiclass_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)

        captum_scores = []
        for k, tweet in enumerate(tweets):
            words_scores = multiclass_explainer(text=tweet)

            if gotten_classes[k] != multiclass_explainer.predicted_class_index:
                print(f'  Error for tweet ({k}): {tweet}')
                raise

            captum_scores.append(
                np.array([score[1] for score in words_scores])
            )

            if ((k + 1) % 200) == 0:
                print(f'  {k+1}/{len(tweets)}...')

        torch.save(captum_scores, f'./outputs/{dataset}/{output}/captum.pt')

        print('  Done.')

    print('\n')