<a href="https://colab.research.google.com/github/danielhou13/cogs402longformer/blob/main/src/CaptumLongformerSequenceClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook adapts the [Captum tutorial for question answering](https://captum.ai/tutorials/Bert_SQUAD_Interpret) and refactors it into the longformer sequence classification task. Specifically, this notebook focuses on using the model's embeddings to get token attributions for the examples of your choice, or the entire dataset if needed. By doing so, we can visualize which tokens have the most influence in the model's prediction, and find out the k tokens with the most influence at helping the model predict correctly as well as incorrectly.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Import dependencies

In [None]:
pip install transformers --quiet

[K     |████████████████████████████████| 4.4 MB 15.6 MB/s 
[K     |████████████████████████████████| 101 kB 9.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 53.9 MB/s 
[K     |████████████████████████████████| 596 kB 81.7 MB/s 
[?25h

In [None]:
pip install captum --quiet

[?25l[K     |▎                               | 10 kB 32.7 MB/s eta 0:00:01[K     |▌                               | 20 kB 39.6 MB/s eta 0:00:01[K     |▊                               | 30 kB 37.2 MB/s eta 0:00:01[K     |█                               | 40 kB 15.8 MB/s eta 0:00:01[K     |█▏                              | 51 kB 14.4 MB/s eta 0:00:01[K     |█▍                              | 61 kB 16.5 MB/s eta 0:00:01[K     |█▋                              | 71 kB 15.8 MB/s eta 0:00:01[K     |█▉                              | 81 kB 15.4 MB/s eta 0:00:01[K     |██                              | 92 kB 16.8 MB/s eta 0:00:01[K     |██▎                             | 102 kB 15.1 MB/s eta 0:00:01[K     |██▌                             | 112 kB 15.1 MB/s eta 0:00:01[K     |██▊                             | 122 kB 15.1 MB/s eta 0:00:01[K     |███                             | 133 kB 15.1 MB/s eta 0:00:01[K     |███▏                            | 143 kB 15.1 MB/s eta 0:

In [None]:
pip install datasets --quiet

[K     |████████████████████████████████| 362 kB 14.3 MB/s 
[K     |████████████████████████████████| 140 kB 81.1 MB/s 
[K     |████████████████████████████████| 212 kB 84.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 71.2 MB/s 
[K     |████████████████████████████████| 127 kB 85.6 MB/s 
[K     |████████████████████████████████| 271 kB 53.8 MB/s 
[K     |████████████████████████████████| 94 kB 3.9 MB/s 
[K     |████████████████████████████████| 144 kB 50.8 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m
[?25h

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch
import pandas as pd

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Import model

Here we are importing the model and tokenizer and letting the model use our GPU to run. Please change model path, and tokenizer to whichever one you wish to use.

In [None]:
from transformers import LongformerForSequenceClassification, LongformerTokenizer, LongformerConfig
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = 'danielhou13/longformer-finetuned_papers_v2'
#model_path = 'danielhou13/longformer-finetuned-new-cogs402'

# load model
model = LongformerForSequenceClassification.from_pretrained(model_path, num_labels = 2)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

Downloading:   0%|          | 0.00/0.99k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/567M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/694 [00:00<?, ?B/s]

Create functions that give us the input ids and the position ids for the text we want to examine along with the baselines for integrated gradients.

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

Please adjust the max_length accordingly for your project. The length should be the length you desire subtracted by 2 (as we are adding the CLS token at the beginning and the seperator token at the end.

In [None]:
max_length = 2046
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, truncation = True, add_special_tokens=False, max_length = max_length)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)

    #taken from the longformer implementation
    mask = input_ids.ne(ref_token_id).int()
    incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
    position_ids = incremental_indices.long().squeeze() + ref_token_id

    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    position_ids = position_ids[:, :seq_length]
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

### Import Dataset

Here we import the papers dataset

In [None]:
from datasets import load_dataset
import numpy as np
cogs402_ds = load_dataset("danielhou13/cogs402dataset")["test"]

Downloading:   0%|          | 0.00/739 [00:00<?, ?B/s]

Using custom data configuration danielhou13--cogs402dataset-144b958ac1a53abb


Downloading and preparing dataset None/None (download: 157.87 MiB, generated: 311.56 MiB, post-processed: Unknown size, total: 469.43 MiB) to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-144b958ac1a53abb/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/132M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-144b958ac1a53abb/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Here we import the news dataset

In [None]:
# cogs402_ds = load_dataset("danielhou13/cogs402dataset2")["validation"]

## Getting the Attributions

A custom forward function that returns the softmaxed logits, which are the class probabilities that the model uses for prediction.

In [None]:
def predict(inputs, position_ids=None, attention_mask=None):
    output = model(inputs,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    return output.logits

In [None]:
#set 1 if we are dealing with a positive class, and 0 if dealing with negative class
def custom_forward(inputs, position_ids=None, attention_mask=None):
    preds = predict(inputs,
                   position_ids=position_ids,
                   attention_mask=attention_mask
                   )
    return torch.softmax(preds, dim = 1)

A helper function to summarize attributions for each word token in the sequence. The attribution output has a shape of (seq_len, model_embedding_size) so this function summarizes the output to an array of shape (seq_len).

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.linalg.norm(attributions)
    return attributions

Perform Layer Integrated Gradients using the longformer's embeddings.

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.longformer.embeddings)

This function will let us get the example and the baseline inputs in order to perform integrated gradients, and add the attributions to our visualization tool. Additionally, we will add the attributions and tokens for each example into an array so we can use them when we want to further examine the attributions scores for each example. More information about the integrated gradients function can be found [here](https://captum.ai/api/layer.html#layer-integrated-gradients).

In [None]:
vis_data_records = []
all_attributions = {}
all_tokens = {}
all_deltas = {}

In [None]:
# Takes in dataset and example number
def get_token_attributions(dataset, example):
  text = dataset['text'][example]
  label = dataset['labels'][example]

  # get the inputs, position ids, attention mask, and the baselines
  input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
  position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
  attention_mask = construct_attention_mask(input_ids)

  #get the tokens
  indices = input_ids[0].detach().tolist()
  all_tokens_curr = tokenizer.convert_ids_to_tokens(indices)
  all_tokens[str(example)] = all_tokens_curr

  #perform integrated gradients
  attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True,
                                    additional_forward_args=(position_ids, attention_mask),
                                    target=1,
                                    n_steps=250,
                                    internal_batch_size = 2)

  # We want one value for every token.
  attributions_sum = summarize_attributions(attributions)

  # store the values in our dictionary
  all_attributions[str(example)] = attributions_sum
  all_deltas[str(example)] = attributions_sum

  # get the score for our visualization
  score = predict(input_ids, position_ids, attention_mask)

  # storing couple samples in an array for visualization purposes
  # requires array of attributions, prediction score, predicted class, true class 
  # the label you want your attributions to associate positive with, the attribution score
  # the tokens, and the delta if you have it.
  vis_data_records.append(viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1).max(),
                        torch.argmax(torch.softmax(score, dim = 1)),
                        label,
                        str(1),
                        attributions_sum.sum(),       
                        all_tokens_curr,
                        delta)
  )

Here we are taking some examples from the Papers datasets.

In [None]:
get_token_attributions(cogs402_ds, 976)
# get_token_attributions(cogs402_ds, 891)
# get_token_attributions(cogs402_ds, 589)
# get_token_attributions(cogs402_ds, 605)
# get_token_attributions(cogs402_ds, 148)

Here we are taking some examples from the Papers datasets.

In [None]:
# get_token_attributions(cogs402_ds, 102)
# get_token_attributions(cogs402_ds, 1168)
# # get_token_attributions(cogs402_ds, 2307)
# # get_token_attributions(cogs402_ds, 2359)

This function allows us to display our attributions in a manner that is easy to read. We can see the attributions of the word overlayed on top of their respective token. The green colour represents positive attributions (i.e. the model is attributing this token to influential for predicting the positive class) while the red colour represents negative attributions. 

In [None]:
# # storing couple samples in an array for visualization purposes
# score_vis = viz.VisualizationDataRecord(
#                         attributions_sum,
#                         torch.softmax(score, dim = 1).max(),
#                         torch.argmax(torch.softmax(score, dim = 1)),
#                         label,
#                         str(1),
#                         attributions_sum.sum(),       
#                         all_tokens,
#                         delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
_ = viz.visualize_text(vis_data_records)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (1.00),1.0,7.9,"#s Published Ġas Ġa Ġconference Ġpaper Ġin ĠInternational ĠConference Ġof ĠComputer ĠVision Ġ( IC CV ) Ġ2017 Ġ ĠSpeaking Ġthe ĠSame ĠLanguage : ĠMatch ing ĠMachine Ġto ĠHuman ĠCapt ions Ġby ĠAd vers arial ĠTraining ĠRak sh ith ĠShe tty 1 Ġ ĠMarcus ĠRoh r bach 2 , 3 Ġ Ġar X iv : 17 03 . 10 476 v 2 Ġ[ cs . CV ] Ġ6 ĠNov Ġ2017 Ġ ĠMario ĠFritz 1 Ġ1 Ġ ĠLisa ĠAnne ĠHendricks 2 Ġ ĠBer nt ĠS chie le 1 Ġ ĠMax ĠPlan ck ĠInstitute Ġfor ĠIn format ics , ĠSa ar land ĠIn format ics ĠCampus , ĠSa arb ru Ì Ī ck en , ĠGermany Ġ2 Ġ3 ĠUC ĠBerkeley ĠE EC S , ĠCA , ĠUnited ĠStates ĠFacebook ĠAI ĠResearch Ġ ĠAbstract ĠWhile Ġstrong Ġprogress Ġhas Ġbeen Ġmade Ġin Ġimage Ġcaption ing Ġrecently , Ġmachine Ġand Ġhuman Ġcapt ions Ġare Ġstill Ġquite Ġdistinct . ĠThis Ġis Ġprimarily Ġdue Ġto Ġthe Ġdeficiencies Ġin Ġthe Ġgenerated Ġword Ġdistribution , Ġvocabulary Ġsize , Ġand Ġstrong Ġbias Ġin Ġthe Ġgenerators Ġtowards Ġfrequent Ġcapt ions . ĠFurthermore , Ġhumans ĠâĢĵ Ġrightfully Ġso ĠâĢĵ Ġgenerate Ġmultiple , Ġdiverse Ġcapt ions , Ġdue Ġto Ġthe Ġinherent Ġambiguity Ġin Ġthe Ġcaption ing Ġtask Ġwhich Ġis Ġnot Ġexplicitly Ġconsidered Ġin Ġtoday âĢ Ļ s Ġsystems . ĠTo Ġaddress Ġthese Ġchallenges , Ġwe Ġchange Ġthe Ġtraining Ġobjective Ġof Ġthe Ġcaption Ġgenerator Ġfrom Ġreprodu cing Ġground truth Ġcapt ions Ġto Ġgenerating Ġa Ġset Ġof Ġcapt ions Ġthat Ġis Ġindistinguishable Ġfrom Ġhuman Ġwritten Ġcapt ions . ĠInstead Ġof Ġhand craft ing Ġsuch Ġa Ġlearning Ġtarget , Ġwe Ġemploy Ġadvers arial Ġtraining Ġin Ġcombination Ġwith Ġan Ġapproximate ĠG umb el Ġsam pler Ġto Ġimplicitly Ġmatch Ġthe Ġgenerated Ġdistribution Ġto Ġthe Ġhuman Ġone . ĠWhile Ġour Ġmethod Ġachieves Ġcomparable Ġperformance Ġto Ġthe Ġstate - of - the - art Ġin Ġterms Ġof Ġthe Ġcorrectness Ġof Ġthe Ġcapt ions , Ġwe Ġgenerate Ġa Ġset Ġof Ġdiverse Ġcapt ions Ġthat Ġare Ġsignificantly Ġless Ġbiased Ġand Ġbetter Ġmatch Ġthe Ġglobal Ġun i -, Ġbi - Ġand Ġtri - gram Ġdistributions Ġof Ġthe Ġhuman Ġcapt ions . Ġ ĠO urs : Ġa Ġperson Ġon Ġsk is Ġjumping Ġover Ġa Ġramp Ġ ĠO urs : Ġa Ġsk ier Ġis Ġmaking Ġa Ġturn Ġon Ġa Ġcourse Ġ ĠO urs : Ġa Ġcross Ġcountry Ġsk ier Ġmakes Ġhis Ġway Ġthrough Ġthe Ġsnow Ġ ĠO urs : Ġa Ġsk ier Ġis Ġheaded Ġdown Ġa Ġsteep Ġslope Ġ ĠBas eline : Ġa Ġman Ġriding Ġsk is Ġdown Ġa Ġsnow Ġcovered Ġslope Ġ ĠFigure Ġ1 : ĠFour Ġimages Ġfrom Ġthe Ġtest Ġset , Ġall Ġrelated Ġto Ġskiing , Ġshown Ġwith Ġcapt ions Ġfrom Ġour Ġadvers arial Ġmodel Ġand Ġa Ġbaseline . ĠBas eline Ġmodel Ġdescribes Ġall Ġfour Ġimages Ġwith Ġone Ġgeneric Ġcaption , Ġwhereas Ġour Ġmodel Ġproduces Ġdiverse Ġand Ġmore Ġimage Ġspecific Ġcapt ions . ĠAs Ġwe Ġanalyze Ġin Ġthis Ġpaper , Ġthis Ġis Ġlikely Ġdue Ġto Ġartifacts Ġand Ġdeficiencies Ġin Ġthe Ġstatistics Ġof Ġthe Ġgenerated Ġcapt ions , Ġwhich Ġis Ġmore Ġapparent Ġwhen Ġobserving Ġmultiple Ġsamples . ĠSpecifically , Ġwe Ġobserve Ġthat Ġstate - of - the - art Ġsystems Ġfrequently ĠâĢ ľ reve al Ġthemselves âĢ Ŀ Ġby Ġgenerating Ġa Ġdifferent Ġword Ġdistribution Ġand Ġusing Ġsmaller Ġvocabulary . ĠFurther Ġscrutiny Ġreveals Ġthat Ġgeneral ization Ġfrom Ġthe Ġtraining Ġset Ġis Ġstill Ġchallenging Ġand Ġgeneration Ġis Ġbiased Ġto Ġfrequent Ġfragments Ġand Ġcapt ions . ĠAlso , Ġtoday âĢ Ļ s Ġsystems Ġare Ġevaluated Ġto Ġproduce Ġa Ġsingle Ġcaption . ĠYet , Ġmultiple Ġpotentially Ġdistinct Ġcapt ions Ġare Ġtypically Ġcorrect Ġfor Ġa Ġsingle Ġimage ĠâĢĵ Ġa Ġproperty Ġthat Ġis Ġreflected Ġin Ġhuman Ġground - truth . ĠThis Ġdiversity Ġis Ġnot Ġequally Ġreproduced Ġby Ġstate - of - the - art Ġcaption Ġgenerators Ġ[ 40 , Ġ23 ]. ĠTherefore , Ġour Ġgoal Ġis Ġto Ġmake Ġimage Ġcapt ions Ġless Ġdistinguish able Ġfrom Ġhuman Ġones ĠâĢĵ Ġsimilar Ġin Ġthe Ġspirit Ġto Ġa ĠTuring Ġ Ġ1 . ĠIntroduction ĠImage Ġcaption ing Ġsystems Ġhave Ġa Ġvariety Ġof Ġapplications Ġranging Ġfrom Ġmedia Ġretrieval Ġand Ġtagging Ġto Ġassistance Ġfor Ġthe Ġvisually Ġimpaired . ĠIn Ġparticular , Ġmodels Ġwhich Ġcombine Ġstate - of - the - art Ġimage Ġrepresentations Ġbased Ġon Ġdeep Ġconv olution al Ġnetworks Ġand Ġdeep Ġrecurrent Ġlanguage Ġmodels Ġhave Ġled Ġto Ġever Ġincreasing Ġperformance Ġon Ġevaluation Ġmetrics Ġsuch Ġas ĠC ID Er Ġ[ 39 ] Ġand ĠMET E OR Ġ[ 8 ] Ġas Ġcan Ġbe Ġseen Ġe . g . Ġon Ġthe ĠC OC O Ġimage ĠCaption Ġchallenge Ġleader board Ġ[ 6 ]. ĠDespite Ġthese Ġadvances , Ġit Ġis Ġoften Ġeasy Ġfor Ġhumans Ġto Ġdifferentiate Ġbetween Ġmachine Ġand Ġhuman Ġcapt ions ĠâĢĵ Ġparticularly Ġwhen Ġobserving Ġmultiple Ġcapt ions Ġfor Ġa Ġsingle Ġimage . Ġ1 Ġ Ġ Č 2 . ĠRelated ĠWork Ġ Ġa Ġbus Ġthat Ġhas Ġpulled Ġinto Ġthe Ġside Ġof Ġthe Ġstreet Ġa Ġbus Ġis Ġparked Ġat Ġthe Ġside Ġof Ġthe Ġroad Ġa Ġwhite Ġbus Ġis Ġparked Ġnear Ġa Ġcurb Ġwith Ġpeople Ġwalking Ġby Ġ Ġa Ġgroup Ġof Ġpeople Ġstanding Ġoutside Ġin Ġa Ġold Ġmuseum Ġan Ġairplane Ġshow Ġwhere Ġpeople Ġstand Ġaround Ġa Ġline Ġof Ġplanes Ġparked Ġat Ġan Ġairport Ġshow Ġ ĠBase ĠâĢ¢ Ġa Ġbus Ġis Ġparked Ġon Ġthe Ġside Ġof Ġline Ġthe Ġroad ĠâĢ¢ Ġa Ġbus Ġthat Ġis Ġparked Ġin Ġthe Ġstreet Ġa Ġbus Ġis Ġparked Ġin Ġthe Ġstreet Ġnext Ġto Ġa Ġbus Ġ Ġa Ġgroup Ġof Ġpeople Ġstanding Ġaround Ġa Ġplane Ġa Ġgroup Ġof Ġpeople Ġstanding Ġaround Ġa Ġplane Ġa Ġgroup Ġof Ġpeople Ġstanding Ġaround Ġa Ġplane Ġ ĠO urs Ġ ĠFigure Ġ2 : ĠTwo Ġexamples Ġcomparing Ġmultiple Ġcapt ions Ġgenerated Ġby Ġour Ġadvers arial Ġmodel Ġand Ġthe Ġbaseline . ĠBi - gram s Ġwhich Ġare Ġtop - 20 Ġfrequent Ġbi - gram s Ġin Ġthe Ġtraining Ġset Ġare Ġmarked Ġin Ġred Ġ( e . g ., ĠâĢ ľ a Ġgroup âĢ Ŀ Ġand ĠâĢ ľ group Ġof âĢ Ŀ ). ĠCapt ions Ġwhich Ġare Ġrepl icas Ġfrom Ġtraining Ġset Ġare Ġmarked Ġwith ĠâĢ¢ Ġ. ĠTest . ĠWe Ġalso Ġembrace Ġthe Ġambiguity Ġof Ġthe Ġtask Ġand Ġextend Ġour Ġinvestigation Ġto Ġpredicting Ġsets Ġof Ġcapt ions Ġfor Ġa Ġsingle Ġimage Ġand Ġevaluating Ġtheir Ġquality , Ġparticularly Ġin Ġterms Ġof Ġthe Ġdiversity Ġin Ġthe Ġgenerated Ġset . ĠIn Ġcontrast , Ġpopular Ġapproaches Ġto Ġimage Ġcaption ing Ġare Ġtrained Ġwith Ġan Ġobjective Ġto Ġreproduce Ġthe Ġcapt ions Ġas Ġprovided Ġby Ġthe Ġground - truth . ĠInstead Ġof Ġrelying Ġon Ġhand craft ing Ġloss - fun ctions Ġto Ġachieve Ġour Ġgoal , Ġwe Ġpropose Ġan Ġadvers arial Ġtraining Ġmechanism Ġfor Ġimage Ġcaption ing . ĠFor Ġthis Ġwe Ġbuild Ġon ĠGener ative ĠAd vers arial ĠNetworks Ġ( GAN s ) Ġ[ 14 ], Ġwhich Ġhave Ġbeen Ġsuccessfully Ġused Ġto Ġgenerate Ġmainly Ġcontinuous Ġdata Ġdistributions Ġsuch Ġas Ġimages Ġ[ 9 , Ġ30 ], Ġalthough Ġexceptions Ġexist Ġ[ 27 ]. ĠIn Ġcontrast Ġto Ġimages , Ġcapt ions Ġare Ġdiscrete , Ġwhich Ġposes Ġa Ġchallenge Ġwhen Ġtrying Ġto Ġback prop agate Ġthrough Ġthe Ġgeneration Ġstep . ĠTo Ġovercome Ġthis Ġobstacle , Ġwe Ġuse Ġa ĠG umb el Ġsam pler Ġ[ 20 , Ġ28 ] Ġthat Ġallows Ġfor Ġend - to - end Ġtraining . ĠWe Ġaddress Ġthe Ġproblem Ġof Ġcaption Ġset Ġgeneration Ġfor Ġimages Ġand Ġdiscuss Ġmetrics Ġto Ġmeasure Ġthe Ġcaption Ġdiversity Ġand Ġcompare Ġit Ġto Ġhuman Ġground - truth . ĠWe Ġcontribute Ġa Ġnovel Ġsolution Ġto Ġthis Ġproblem Ġusing Ġan Ġadvers arial Ġformulation . ĠThe Ġevaluation Ġof Ġour Ġmodel Ġshows Ġthat Ġaccuracy Ġof Ġgenerated Ġcapt ions Ġis Ġon Ġpar Ġto Ġthe Ġstate - of - the - art , Ġbut Ġwe Ġgreatly Ġincrease Ġthe Ġdiversity Ġof Ġthe Ġcaption Ġsets Ġand Ġbetter Ġmatch Ġthe Ġground - truth Ġstatistics Ġin Ġseveral Ġmeasures . ĠQual itatively , Ġour Ġmodel Ġproduces Ġmore Ġdiverse Ġcapt ions Ġacross Ġimages Ġcontaining Ġsimilar Ġcontent Ġ( Figure Ġ1 ) Ġand Ġwhen Ġsampling Ġmultiple Ġcapt ions Ġfor Ġan Ġimage Ġ( see Ġsupplementary ) 1 Ġ. Ġ1 Ġhttps :// goo . gl / 3 y R V n q Ġ ĠImage ĠDescription . ĠEarly Ġcaption ing Ġmodels Ġrely Ġon Ġfirst Ġrecognizing Ġvisual Ġelements , Ġsuch Ġas Ġobjects , Ġattributes , Ġand Ġactivities , Ġand Ġthen Ġgenerating Ġa Ġsentence Ġusing Ġlanguage Ġmodels Ġsuch Ġas Ġa Ġtemplate Ġmodel Ġ[ 13 ], Ġn - gram Ġmodel Ġ[ 22 ], Ġor Ġstatistical Ġmachine Ġtranslation Ġ[ 34 ]. ĠAdv ances Ġin Ġdeep Ġlearning Ġhave Ġled Ġto Ġend - to - end Ġtrain able Ġmodels Ġthat Ġcombine Ġdeep Ġconv olution al Ġnetworks Ġto Ġextract Ġvisual Ġfeatures Ġand Ġrecurrent Ġnetworks Ġto Ġgenerate Ġsentences Ġ[ 11 , Ġ41 , Ġ21 ]. ĠThough Ġmodern Ġdescription Ġmodels Ġare Ġcapable Ġof Ġproducing Ġcoherent Ġsentences Ġwhich Ġaccurately Ġdescribe Ġan Ġimage , Ġthey Ġtend Ġto Ġproduce Ġgeneric Ġsentences Ġwhich Ġare Ġreplicated Ġfrom Ġthe Ġtrain Ġset Ġ[ 10 ]. ĠFurthermore , Ġan Ġimage Ġcan Ġcorrespond Ġto Ġmany Ġvalid Ġdescriptions . ĠHowever , Ġat Ġtest Ġtime , Ġsentences Ġgenerated Ġwith Ġmethods Ġsuch Ġas Ġbeam Ġsearch Ġare Ġgenerally Ġvery Ġsimilar . Ġ[ 40 , Ġ23 ] Ġfocus Ġon Ġincreasing Ġsentence Ġdiversity Ġby Ġintegrating Ġa Ġdiversity Ġpromoting Ġhe uristic Ġinto Ġbeam Ġsearch . Ġ[ 42 ] Ġattempts Ġto Ġincrease Ġthe Ġdiversity Ġin Ġcaption Ġgeneration Ġby Ġtraining Ġan Ġensemble Ġof Ġcaption Ġgenerators Ġeach Ġspecializing Ġin Ġdifferent Ġportions Ġof Ġthe Ġtraining Ġset . ĠIn Ġcontrast , Ġwe Ġfocus Ġon Ġimproving Ġdiversity Ġof Ġgenerated Ġcapt ions Ġusing Ġa Ġsingle Ġmodel . ĠOur Ġmethod Ġachieves Ġthis Ġby Ġlearning Ġa Ġcorresponding Ġmodel Ġusing Ġa Ġdifferent Ġtraining Ġloss Ġas Ġopposed Ġto Ġafter Ġtraining Ġhas Ġcompleted . ĠWe Ġnote Ġthat Ġgenerating Ġdiverse Ġsentences Ġis Ġalso Ġa Ġchallenge Ġin Ġvisual Ġquestion Ġgeneration , Ġsee Ġconcurrent Ġwork Ġ[ 19 ], Ġand Ġin Ġlanguage - only Ġdialogue Ġgeneration Ġstudied Ġin Ġthe Ġlinguistic Ġcommunity , Ġsee Ġe . g . Ġ[ 23 , Ġ24 ]. ĠWhen Ġtraining Ġrecurrent Ġdescription Ġmodels , Ġthe Ġmost Ġcommon Ġmethod Ġis Ġto Ġpredict Ġa Ġword Ġw t Ġconditioned Ġon Ġan Ġimage Ġand Ġall Ġprevious Ġground Ġtruth Ġwords . ĠAt Ġtest Ġtime , Ġeach Ġword Ġis Ġpredicted Ġconditioned Ġon Ġan Ġimage Ġand Ġpreviously Ġpredicted Ġwords . ĠConsequently , Ġat Ġtest Ġtime Ġpredicted Ġwords Ġmay Ġbe Ġconditioned Ġon Ġwords Ġthat Ġwere Ġincorrectly Ġpredicted Ġby Ġthe Ġmodel . ĠBy Ġonly Ġtraining Ġon Ġground Ġtruth Ġwords , Ġthe Ġmodel Ġsuffers Ġfrom Ġexposure Ġbias Ġ[ 31 ] Ġand Ġcannot Ġeffectively Ġlearn Ġto Ġrecover Ġwhen Ġit Ġpredicts Ġan Ġincorrect Ġword Ġduring Ġtraining . ĠTo Ġavoid Ġthis , Ġ[ 4 ] Ġproposes Ġa Ġscheduled Ġsampling Ġtraining Ġscheme Ġwhich Ġbegins Ġby Ġtraining Ġwith Ġground Ġtruth Ġwords , Ġbut Ġthen Ġslowly Ġconditions Ġgenerated Ġwords Ġon Ġwords Ġpreviously Ġproduced Ġby Ġthe Ġmodel . ĠHowever , Ġ[ 17 ] Ġshows Ġthat Ġthe Ġscheduled Ġsampling Ġalgorithm Ġis Ġinconsistent Ġand Ġthe Ġoptimal Ġsolution Ġunder Ġthis Ġobjective Ġdoes Ġnot Ġconverge Ġto Ġthe Ġtrue Ġdata Ġdistribution . ĠTaking Ġa Ġdifferent Ġdirection , Ġ[ 31 ] Ġproposes Ġto Ġaddress Ġthe Ġexposure Ġbias Ġby Ġgradually Ġmixing Ġa Ġsequence Ġlevel Ġloss Ġ( BLE U Ġscore ) Ġusing ĠRE IN FOR CE Ġrule Ġwith Ġthe Ġstandard Ġmaximum Ġlikelihood Ġtraining . ĠSeveral Ġother Ġworks Ġhave Ġfollowed Ġthis Ġup Ġwith Ġusing Ġreinforcement Ġlearning Ġbased Ġapproaches Ġto Ġdirectly Ġoptimize Ġthe Ġevaluation Ġmetrics Ġlike ĠB LE U , ĠMET E OR Ġand ĠC IDER Ġ[ 33 , Ġ25 ]. ĠHowever , Ġoptimizing Ġthe Ġevaluation Ġmetrics Ġdoes Ġnot Ġdirectly Ġaddress Ġthe Ġdiversity Ġof Ġthe Ġ Ġ Č generated Ġcapt ions . ĠSince Ġall Ġcurrent Ġevaluation Ġmetrics Ġuse Ġn - gram Ġmatching Ġto Ġscore Ġthe Ġcapt ions , Ġcapt ions Ġusing Ġmore Ġfrequent Ġn - gram s Ġare Ġlikely Ġto Ġachieve Ġbetter Ġscores Ġthan Ġones Ġusing Ġrare r Ġand Ġmore Ġdiverse Ġn - gram s . ĠIn Ġthis Ġwork , Ġwe Ġformulate Ġour Ġcaption Ġgenerator Ġas Ġa Ġgener ative Ġadvers arial Ġnetwork . ĠWe Ġdesign Ġa Ġdiscrim inator Ġthat Ġexplicitly Ġencourages Ġgenerated Ġcapt ions Ġto Ġbe Ġdiverse Ġand Ġindistinguishable Ġfrom Ġhuman Ġcapt ions . ĠThe Ġgenerator Ġis Ġtrained Ġwith Ġan Ġadvers arial Ġloss Ġwith Ġthis Ġdiscrim inator . ĠConsequently , Ġour Ġmodel Ġgenerates Ġcapt ions Ġthat Ġbetter Ġreflect Ġthe Ġway Ġhumans Ġdescribe Ġimages Ġwhile Ġmaintaining Ġsimilar Ġcorrectness Ġas Ġdetermined Ġby Ġa Ġhuman Ġevaluation . ĠGener ative ĠAd vers arial ĠNetworks . ĠThe ĠGener ative ĠAd vers arial ĠNetworks Ġ( GAN s ) Ġ[ 14 ] Ġframework Ġlearns Ġgener ative Ġmodels Ġwithout Ġexplicitly Ġdefining Ġa Ġloss Ġfrom Ġa Ġtarget Ġdistribution . ĠInstead , ĠG AN s Ġlearn Ġa Ġgenerator Ġusing Ġa Ġloss Ġfrom Ġa Ġdiscrim inator Ġwhich Ġtries Ġto Ġdifferentiate Ġreal Ġand Ġgenerated Ġsamples , Ġwhere Ġthe Ġgenerated Ġsamples Ġcome Ġfrom Ġthe Ġgenerator . ĠWhen Ġtraining Ġto Ġgenerate Ġreal Ġimages , #/s"
,,,,


## Further Examination of the Attributions

Next we might want to look in-depth about the attribution scores for each token of an example. We saved the attributions for the examples we looked at above, so we can easily retrieve the attributions. We also grab the examples because we want to know what tokens the attributions are associated with.

Both lists are of shape: (seq_len)

In [None]:
example = 976
attributions_sum = all_attributions[f"{example}"]
all_tokens2 = all_tokens[f"{example}"]

These functions return which words had the strongest (most positive and most negative) attributions. Change the number of tokens you wish to visualize for your needs. It takes in the attributions and the tokens we grabbed in the previous cell and returns 3 lists: the topk (or bottomk) attributions, their respective token and their position.

Note: Remember that the attributions are with respect to the positive class, so the most impact tokens that helped the model predict the negative class will be in the botk attributed tokens.

In [None]:
def get_topk_attributed_tokens(attrs, all_tokens, k=20):
    values, indices = torch.topk(attrs, k)
    top_tokens = [all_tokens[idx] for idx in indices]
    return top_tokens, values, indices

In [None]:
def get_botk_attributed_tokens(attrs, all_tokens, k=20):
    values, indices = torch.topk(attrs, k, largest=False)
    top_tokens = [all_tokens[idx] for idx in indices]
    return top_tokens, values, indices

Convert the values, index of the values, and the token into a pandas Dataframe for visualization. It will be sorted by highest value for attributions to lowest. Alternatively, if youre looking for the most negative attributions, it goes from lowest to highest.

In [None]:
top_words_start, top_words_val_start, top_word_ind_start = get_topk_attributed_tokens(attributions_sum, all_tokens2)
bot_words_start, bot_words_val_start, bot_word_ind_start = get_botk_attributed_tokens(attributions_sum, all_tokens2)

df_high = pd.DataFrame({'Word(Index), Attribution': ["{} ({}), {}".format(word, pos, round(val.item(),2)) for word, pos, val in zip(top_words_start, top_word_ind_start, top_words_val_start)]})

df_low = pd.DataFrame({'Word(Index), Attribution': ["{} ({}), {}".format(word, pos, round(val.item(),2)) for word, pos, val in zip(bot_words_start, bot_word_ind_start, bot_words_val_start)]})
# df_start.style.apply(['cell_ids: False'])

# ['{}({})'.format(token, str(i)) for i, token in enumerate(all_tokens)]

Here we display our top k positively and negatively attributed tokens for our example.

In [None]:
df_high

Unnamed: 0,"Word(Index), Attribution"
0,"Ġtraining (1544), 0.47"
1,"Ġtraining (1593), 0.37"
2,"Ġtraining (1700), 0.35"
3,"Ġtraining (1687), 0.32"
4,"Ġtraining (1791), 0.26"
5,". (1459), 0.23"
6,"Ġtraining (1538), 0.2"
7,"Ġtraining (1659), 0.19"
8,"Ġtraining (1506), 0.16"
9,"Ġtraining (1705), 0.12"


In [None]:
df_low

Unnamed: 0,"Word(Index), Attribution"
0,"Ġlanguage (1334), -0.14"
1,"Ġlanguage (1571), -0.09"
2,". (1440), -0.08"
3,"Ġlinguistic (1579), -0.07"
4,". (1688), -0.03"
5,"Ġmodels (1335), -0.03"
6,"Ġa (1536), -0.03"
7,"Ġsentences (1553), -0.03"
8,"- (1572), -0.02"
9,". (1656), -0.02"


In [None]:
d = {"tokens":all_tokens2, "attribution":attributions_sum[:len(all_tokens2)].cpu()}

We notice that there are many repeating tokens in each example that have different positions. While we might want to know how the position plays into the attributions, if we want to know strictly based on the tokens itself, we can add all the duplicate tokens together to get the aggregate attribution for each token. Therefore, we aggregate the attributions strictly based on token type.

In [None]:
df_attrib = pd.DataFrame(d)
aggregation_functions = {'attribution': 'sum'}
df_new = df_attrib.groupby(df_attrib['tokens']).aggregate(aggregation_functions)

In [None]:
highest_attrib_tokens = df_new.sort_values(by=['attribution'], ascending=False)
highest_attrib_tokens[:10]

Unnamed: 0_level_0,attribution
tokens,Unnamed: 1_level_1
Ġtraining,2.592207
.,0.439756
Ġhuman,0.254601
",",0.216136
Ġto,0.159274
Ġlearning,0.152409
Ġmodel,0.149081
Ġbias,0.119391
Ġthe,0.112879
Ġdiversity,0.099239


In [None]:
lowest_attrib_tokens = df_new.sort_values(by=['attribution'])
lowest_attrib_tokens[:10]

Unnamed: 0_level_0,attribution
tokens,Unnamed: 1_level_1
Ġlanguage,-0.230792
Ġlinguistic,-0.066125
Ġsentences,-0.056785
Ġwords,-0.05628
Ġword,-0.026663
Ġtranslation,-0.018235
Ġ,-0.017787
art,-0.014868
Ġdialogue,-0.014522
Ġdescriptions,-0.013466


Using this [notebook](https://colab.research.google.com/drive/1lktilbL1IY4nBanlzCdP8TLsBNfUsl_U?usp=sharing), we can get the files to view the aggregated attributions for the entire dataset for both the positive and negative classes. This means we summed up and averaged the attributions for every instance of any given token throughout the entire dataset (whether or not they have positive or negative attributions).

In [None]:
df_word = pd.read_csv("/content/drive/MyDrive/cogs402longformer/results/papers/papers_attributions/longformer_emb_papers.csv")

Here we see the highest attributions for the positive class, meaning that these tokens have the most influence when the model tries to predict positive. All of these words do have relevence to A.I. related topics.

In [None]:
df_word[:15]

Unnamed: 0,tokens,attribution
0,Ġlearning,0.163092
1,.,0.145281
2,Ġneural,0.110611
3,Ġdata,0.097347
4,",",0.077573
5,Ġthe,0.072926
6,Ġtraining,0.052609
7,Ġdataset,0.050907
8,Ġalgorithms,0.048352
9,ĠAI,0.045684


Here we see the largest attributions for the negative class, meaning that these tokens have the most influence when the model predicts negative.

In [None]:
df_word[:-15:-1]

Unnamed: 0,tokens,attribution
30061,Ġprogramming,-0.121651
30060,Ġprogram,-0.085085
30059,Ġprograms,-0.078384
30058,Ġlanguages,-0.070023
30057,Ġlanguage,-0.054024
30056,Ġ.,-0.053213
30055,Ġcode,-0.049736
30054,Ġsoftware,-0.037241
30053,Ġcompiler,-0.030792
30052,ĠProgramming,-0.029799
