In [1]:
import re
import torch
import os
import random
import json
from tqdm import tqdm
from PIL import Image
import numpy as np
import time
import wandb
from queue import Queue
from utils import parse_args, prompt_element, preprocess_language, extract_group, create_prompt, extend_prompts, create_question
import argparse
import pathlib
from bertviz import head_view, model_view

from data import VisualReasoningDataset, custom_collate, load_dataset
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from loguru import logger
import gc
import sys
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
# sys.path.append('/mnt/lustre/share/lychen/code/sm/interpret-lm')
# from lm_saliency import *

logger.remove()
logger.add(sys.stdout, colorize=True, format="<green>{time:YYYY-MM-DD:HH:mm:ss}</green> | <cyan>{name}</cyan><cyan>:line {line}</cyan>: <level>{message}</level>", level="INFO")


2

In [2]:
sys.argv = ['test.py',  
            "--data-dir", "../datasets/",
            "--dataset-name", "aokvqa", 
            "--split", "val", 
            "--bs", "1", 
            "--max-length", "250",
            # "--flan", "google/flan-t5-small",
            # "--flan", "google/flan-t5-large",
            "--flan", "../pretrained_models/fl_pc2a_aok_1/google/flan-t5-large_language_profile_bs64_epoch0",
            # "--flan", "../pretrained_models/fx_pc2a_eg0_f_1/google/flan-t5-xxl_language_profile_bs32_epoch2",
            "--prediction-out", "../predictions/viz_atten-da.json",
            "--prediction-output-dir", "../predictions/",
            "--include-choices",
            "--include-profile",
            "--include-caption",
            ]
args = parse_args()

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
train_set = VisualReasoningDataset(
    dataset_dir=args.data_dir, dataset_name=args.dataset_name, split="train", include_image=False
)
dataset = VisualReasoningDataset(
    dataset_dir=args.data_dir, dataset_name=args.dataset_name, split=args.split, include_image=False
)
dataloader = DataLoader(
    dataset,
    batch_size=args.bs,
    shuffle=False,
    num_workers=16,
    pin_memory=True,
    drop_last=False,
    collate_fn=custom_collate,
)
train_context = {}
context = {}
if args.context_file is not None:
    with open(args.context_file, "r") as f:
        context = json.load(args.context_file)
    with open(args.train_context_file, "r") as f:
        train_context = json.load(args.train_context_file)

# Initiate FLAN model and tokenizer
flan_device = 'cuda:{}'.format(str(0))
flan_tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-small', return_token_type_ids=True)
# flan_model = T5ForConditionalGeneration.from_pretrained(args.flan, output_attentions=True)
if args.adapter_name is not None:
    flan_model.load_adapter(args.adapter_name, set_active=True)
# flan_model.parallelize()

with open(f'{args.prediction_output_dir}/{args.dataset_name}_blip_vqa_train-da.json', 'r') as profile: # hard code
    blip_answers_train = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_ofa_vqa_train-da.json', 'r') as profile: # hard code
    ofa_answers_train = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_blip_caption_train-da.json', 'r') as profile: # hard code
    blip_captions_train = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_ofa_caption_train-da.json', 'r') as profile: # hard code
    ofa_captions_train = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_blip_vqa_{args.split}-da.json', 'r') as profile: # hard code
    blip_answers_val = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_ofa_vqa_{args.split}-da.json', 'r') as profile: # hard code
    ofa_answers_val = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_blip_caption_{args.split}-da.json', 'r') as profile: # hard code
    blip_captions_val = json.load(profile)
with open(f'{args.prediction_output_dir}/{args.dataset_name}_ofa_caption_{args.split}-da.json', 'r') as profile: # hard code
    ofa_captions_val = json.load(profile)

[32m2023-01-19:13:50:33[0m | [36mdata[0m[36m:line 82[0m: [1mLoading aokvqa train dataset[0m
[32m2023-01-19:13:50:34[0m | [36mdata[0m[36m:line 89[0m: [1mLoading dataset took 0.80 seconds[0m
[32m2023-01-19:13:50:34[0m | [36mdata[0m[36m:line 82[0m: [1mLoading aokvqa val dataset[0m
[32m2023-01-19:13:50:34[0m | [36mdata[0m[36m:line 89[0m: [1mLoading dataset took 0.02 seconds[0m


In [4]:
total_predictions = {}
counter = 0
iterator = iter(dataloader)

In [5]:
group = next(iterator)

In [6]:
group = [dataset[349]]

In [7]:
prompts = create_prompt(group, "")
if args.incontext:
    for e in torch.utils.data.Subset(train_set, torch.randperm(len(train_set))[:args.num_examples]):
        prompts = extend_prompts(
            prompts,
            prompt_element(
                e,
                include_choices=args.include_choices,
                include_answer=True,
                include_profile=args.include_profile,
                include_rationale=args.include_rationale,
                include_caption=args.include_caption,
                blip_answers=blip_answers_train,
                ofa_answers=ofa_answers_train,
                ofa_captions=ofa_captions_train,
                blip_captions=blip_captions_train,
                cot=False,
                # rationalization=rationalization_train,
            ),
        )
        prompts = extend_prompts(prompts, "\n\n")
    prompts = extend_prompts(prompts, "\n\n")
    
prompts = extend_prompts(
    prompts,
    create_question(
        group,
        include_choices=args.include_choices,
        include_answer=False,
        include_profile=args.include_profile,
        blip_answers=blip_answers_val,
        ofa_answers=ofa_answers_val,
        include_caption=args.include_caption,
        include_rationale=False,
        ofa_captions=ofa_captions_val,
        blip_captions=blip_captions_val,
        cot=args.cot,
        # rationalization=rationalization_train, # placeholder
    ),
)
if args.cot:
    prompts = extend_prompts(prompts, "\n\n")
    cot = "Answer in the following format: \n[YOUR RATIONALE]. The answer is [YOUR CHOICE]."
    prompts = extend_prompts(prompts, cot)
input_ids, attention_mask = preprocess_language(
    flan_tokenizer, prompts, device=flan_device
)
targets = extract_group(group, 'mc_answer')

In [8]:
prompt = prompts[0]
target = targets[0]

input_ = flan_tokenizer.encode_plus(prompt, return_tensors='pt', add_special_tokens=True, return_token_type_ids=True).to(flan_device)
target_ = flan_tokenizer.encode_plus(target, return_tensors='pt', add_special_tokens=True, return_token_type_ids=True).to(flan_device)

input_ids = input_['input_ids']
target_ids = target_['input_ids']

print(input_ids.shape)
# outputs = flan_model(input_ids, labels=target_ids)

encoder_text = flan_tokenizer.convert_ids_to_tokens(input_ids[0])
decoder_text = flan_tokenizer.convert_ids_to_tokens(target_ids[0])

torch.Size([1, 146])


In [9]:
target_ids.shape

torch.Size([1, 3])

# ECCO

In [10]:
import ecco

In [11]:
# save tokenizer for ecco to load autotokenizer
flan_tokenizer.save_pretrained(args.flan)

('../pretrained_models/fl_pc2a_aok_1/google/flan-t5-large_language_profile_bs64_epoch0/tokenizer_config.json',
 '../pretrained_models/fl_pc2a_aok_1/google/flan-t5-large_language_profile_bs64_epoch0/special_tokens_map.json',
 '../pretrained_models/fl_pc2a_aok_1/google/flan-t5-large_language_profile_bs64_epoch0/spiece.model',
 '../pretrained_models/fl_pc2a_aok_1/google/flan-t5-large_language_profile_bs64_epoch0/added_tokens.json')

In [12]:
model_config = {
    'embedding': 'shared.weight',
    'type':'enc-dec',
    'activations': 'wo', #Note that this will be both encoder and decoder layers
    'token_prefix': '▁',
    'partial_token_prefix': '',
}

lm_cola = ecco.from_pretrained(args.flan, model_config=model_config)
lm_cola.parallelize()

lm_colazero = ecco.from_pretrained('google/flan-t5-large', model_config=model_config)
lm_colazero.parallelize()


In [13]:
output = lm_cola.generate(prompt, generate=5, do_sample=False, attribution=['ig', 'grad_x_input'])
output.primary_attributions(attr_method='ig')

<IPython.core.display.Javascript object>

In [14]:
output.primary_attributions(attr_method='grad_x_input')

In [15]:
output = lm_colazero.generate(prompts[0], generate=5, do_sample=False, attribution=['ig', 'grad_x_input'])
output.primary_attributions(attr_method='ig')

<IPython.core.display.Javascript object>

In [16]:
output.primary_attributions(attr_method='grad_x_input')

In [14]:
# put outputs.encoder_attentions, outputs.decoder_attentions, outputs.cross_attentions on the same cuda device
viz_device = 'cuda:0'

def transfer_device(attention_tuple, device=viz_device):
    atten_list = []
    for i in range(len(attention_tuple)):
        atten_list.append(attention_tuple[i].to(device))
    return tuple(atten_list)
        
outputs.encoder_attentions = transfer_device(outputs.encoder_attentions)
outputs.decoder_attentions = transfer_device(outputs.decoder_attentions)
outputs.cross_attentions = transfer_device(outputs.cross_attentions)

In [15]:
# head_view(
#     encoder_attention=outputs.encoder_attentions,
#     decoder_attention=outputs.decoder_attentions,
#     cross_attention=outputs.cross_attentions,
#     encoder_tokens= encoder_text,
#     decoder_tokens = decoder_text
# )

In [16]:
# page = model_view(
#     encoder_attention=outputs.encoder_attentions,
#     decoder_attention=outputs.decoder_attentions,
#     cross_attention=outputs.cross_attentions,
#     encoder_tokens= encoder_text,
#     decoder_tokens = decoder_text,
#     include_layers=[0,1],
#     include_heads=[0,1],
#     display_mode='light',
#     html_action='return'
# )

In [17]:
# page

# Attention Map

In [18]:
def heatplot_attention(i_layer, attentions, encoder_tokens, decoder_tokens):
  encoder_tokens = [t.strip('▁') for t in encoder_tokens]
  decoder_tokens = [t.strip('▁') for t in decoder_tokens]
  sns.color_palette("flare", as_cmap=True)
  fig = plt.figure(figsize=(300, 50))
  axes = []
  for i_head in range(1):
  # for i_head in range(10):
    axes.append(fig.add_subplot(5, 2, 1 + i_head))
    # axes.append(fig.add_subplot(1, 1, 1 + i_head))
    axes[-1].set_title(f"Head {1 + i_layer}:{1 + i_head}")
    # sns.heatmap(attentions[i_layer][0][i_head].detach().cpu().numpy().T, 
    sns.heatmap(attentions[i_layer][0][i_head].detach().cpu().numpy(), 
                vmin=0, vmax=1, 
                # yticklabels=encoder_tokens, xticklabels=decoder_tokens, 
                yticklabels=decoder_tokens, xticklabels=encoder_tokens, 
                annot=True, ax=axes[-1], 
                cbar=False, 
                cmap='Reds',
                linewidths=1,
                fmt='04.2f')
     

In [19]:
def heatplot_attention(i_layer, i_head, attentions, encoder_tokens, decoder_tokens):
  encoder_tokens = [t.strip('▁') for t in encoder_tokens]
  decoder_tokens = [t.strip('▁') for t in decoder_tokens]
  sns.color_palette("flare", as_cmap=True)
  fig = plt.figure(figsize=(15, 10))
  plt.subplots_adjust(hspace = 0.8)
  axes = []
  batches = 4
  len_trunk = len(encoder_tokens) // batches
  for token_batch in range(batches):
    axes.append(fig.add_subplot(5, 1, 1+token_batch))
    axes[0].set_title(f"Head {1 + i_layer}:{1 + i_head}")
    sns.heatmap(attentions[i_layer][0][i_head][:,token_batch*len_trunk:(token_batch+1)*len_trunk].detach().cpu().numpy(), 
                vmin=0, vmax=1, 
                yticklabels=decoder_tokens, xticklabels=encoder_tokens[token_batch*len_trunk:(token_batch+1)*len_trunk], 
                annot=True, ax=axes[-1], 
                cbar=False, 
                cmap='Reds',
                linewidths=1,
                fmt='04.2f')
    axes[-1].tick_params(axis='x', rotation=45)

In [20]:
# multi-layer multi-head average
def heatplot_attention(i_layers, i_heads, attentions, encoder_tokens, decoder_tokens):
  encoder_tokens = [t.strip('▁') for t in encoder_tokens]
  decoder_tokens = [t.strip('▁') for t in decoder_tokens]
  sns.color_palette("flare", as_cmap=True)
  fig = plt.figure(figsize=(15, 10))
  plt.subplots_adjust(hspace = 0.8)
  axes = []
  batches = 4
  len_trunk = len(encoder_tokens) // batches
  for token_batch in range(batches):
    axes.append(fig.add_subplot(5, 1, 1+token_batch))
    # axes[0].set_title(f"Head {[1+x for x in i_layers]}:{[1+x for x in i_heads]}")
    # print([x for x in itertools.product(i_layers, i_heads)])
    attentions_selected = [attentions[i_layer][0][i_head][:,token_batch*len_trunk:(token_batch+1)*len_trunk] for i_layer, i_head in itertools.product(i_layers, i_heads)]
    stacked_attentions = torch.stack(attentions_selected)
    # mean_attention = torch.mean(stacked_attentions, dim=0)
    mean_attention = torch.max(stacked_attentions, dim=0)[0]

    sns.heatmap(mean_attention.detach().cpu().numpy(), 
                vmin=0, vmax=1, 
                yticklabels=decoder_tokens, xticklabels=encoder_tokens[token_batch*len_trunk:(token_batch+1)*len_trunk], 
                annot=True, ax=axes[-1], 
                cbar=False, 
                cmap='Reds',
                linewidths=1,
                fmt='04.2f')
    axes[-1].tick_params(axis='x', rotation=45)

In [21]:
# layers = range(24)
layers = [22]
heads = range(64)
print('layers:', layers)
print('heads:', heads)
heatplot_attention(layers, heads, outputs.cross_attentions, encoder_text, decoder_text)

layers: [22]
heads: range(0, 64)


### Contrastive Explanation

In [22]:
sys.path.append('/mnt/lustre/share/lychen/code/sm/interpret-lm')
# from lm_saliency_t5_8 import saliency, input_x_gradient, l1_grad_norm, erasure_scores, visualize
from lm_saliency_t5_10 import saliency, input_x_gradient, l1_grad_norm, erasure_scores, visualize

In [23]:
def saliency(model, input_ids, attention_mask, decoder_input_ids, batch=0, correct=None, foil=None):
    torch.enable_grad()
    model.eval()
    embeddings_list = []
    handle = register_embedding_list_hook(model, embeddings_list)
    embeddings_gradients = []
    hook = register_embedding_gradient_hooks(model, embeddings_gradients)
    
    if correct is None:
        correct = input_ids[0][-1]
    input_ids = torch.tensor(input_ids, dtype=torch.long).to(model.device)
    decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long).to(model.device)

    model.zero_grad()
    A = model(input_ids=input_ids, 
              attention_mask=attention_mask,
              decoder_input_ids=decoder_input_ids)

    if foil is not None:
        if correct == foil:
            (A.logits[0][-1][correct]).backward()
        else:
            (A.logits[0][-1][correct]-A.logits[0][-1][foil]).backward()
    else:
        (A.logits[0][-1][correct]).backward()
    handle.remove()
    hook.remove()

    dec_saliency, enc_saliency = embeddings_gradients
    enc_embed, dec_embed = embeddings_list
    return enc_saliency.squeeze(), enc_embed, dec_saliency.squeeze(), dec_embed

def register_embedding_list_hook(model, embeddings_list):
    def forward_hook(module, inputs, output):
        embeddings_list.append(output.squeeze(0).clone().cpu().detach().numpy())
    # embedding_layer = model.get_input_embeddings()
    embedding_layer = model.encoder.embed_tokens
    handle = embedding_layer.register_forward_hook(forward_hook)
    return handle

def register_embedding_gradient_hooks(model, embeddings_gradients):
    def hook_layers(module, grad_in, grad_out):
        embeddings_gradients.append(grad_out[0].detach().cpu().numpy())
    embedding_layer = model.encoder.embed_tokens
    hook = embedding_layer.register_backward_hook(hook_layers)
    return hook

In [24]:
import matplotlib as mpl

def visualize(attention, tokenizer, input_ids, gold=None, normalize=False, print_text=True, save_file=None, title=None, figsize=(60,60), fontsize=36):
    tokens = [tokenizer.decode(i) for i in input_ids[0][:len(attention) + 1]]
    if gold is not None:
        for i, g in enumerate(gold):
            if g == 1:
                tokens[i] = "**" + tokens[i] + "**"

    # Normalize to [-1, 1]
    if normalize:
        a,b = min(attention), max(attention)
        x = 2/(b-a)
        y = 1-b*x
        attention = [g*x + y for g in attention]
    attention = np.array([list(map(float, attention))])

    fig, ax = plt.subplots(figsize=figsize)
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    im = ax.imshow(attention, cmap='seismic', norm=norm)

    if print_text:
        ax.set_xticks(np.arange(len(tokens)))
        ax.set_xticklabels(tokens, fontsize=fontsize)
    else:
        ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)


    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
    for (i, j), z in np.ndenumerate(attention):
        ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', fontsize=fontsize)


    ax.set_title("")
    fig.tight_layout()
    if title is not None:
        plt.title(title, fontsize=36)
    
    if save_file is not None:
        plt.savefig(save_file, bbox_inches = 'tight',
        pad_inches = 0)
        plt.close()
    else:
        plt.show()

In [25]:
prompt

"Answer the following multiple choice question by OFA and BLIP's description and their answers to the visual question. OFA and BLIP are two different vision-language models to provide clues. \n\nOFA's description:  actor riding a motorcycle on the set of crime fiction film\nBLIP's description: a man riding a motorcycle with a helmet on\nQ: What is in the motorcyclist's mouth?\nOFA's answer:  a cigarette.\nBLIP's answer: helmet.\n\nChoices: toothpick, food, popsicle stick, cigarette.\nA:"

In [26]:
# Liangyu's code, lm

tokenizer = flan_tokenizer
model = flan_model

# input_text = "Can you stop the dog from "
# input_text = "The dog is "
input_text = prompt[:-2]

# input_tokens = tokenizer(input)['input_ids'].to(model.device)
input = flan_tokenizer.encode_plus(input_text, return_tensors='pt', 
                                   add_special_tokens=True, 
                                   ).to(flan_device)
input_ids, attention_ids = input['input_ids'], input['attention_mask']
# attention_ids = tokenizer(input)['attention_mask']

# decoder_input = ' The answer is'
# decoder_input_text = input_text
# decoder_input_text = f'<pad>'
decoder_input_text = f'A: '
# decoder_input_text = target

decoder_input = flan_tokenizer.encode_plus(decoder_input_text, return_tensors='pt', 
                                                  add_special_tokens=False, 
                                                  ).to(flan_device)
decoder_input_ids, _ = decoder_input['input_ids'], decoder_input['attention_mask']

print(input_ids.shape)
print('test')
print(decoder_input)
print(decoder_input_ids)
print('test 2')

# foil = "drying hair"
# correct = "brushing teeth"
correct = "drying hair"
foil = "brushing teeth"
CORRECT_ID = tokenizer(" "+ correct)['input_ids'][0]
FOIL_ID = tokenizer(" "+ foil)['input_ids'][0]

decoder_start_token_id = 0
print(decoder_start_token_id)
decoder_input_ids = torch.ones((input_ids.shape[0], 2), dtype=torch.long, device=model.device) * decoder_start_token_id

base_enc_saliency, base_enc_embed, base_dec_saliency, base_dec_embed = saliency(model, input_ids, attention_ids, 
                                                                                decoder_input_ids=decoder_input_ids)
enc_saliency, enc_embed, dec_saliency, dec_embed = saliency(model, input_ids, attention_ids, foil=FOIL_ID,
                                                            decoder_input_ids=decoder_input_ids)

In [27]:
def erasure_scores(model, input_ids, decoder_input_ids, correct=None, foil=None, normalize=normalize):
    model.eval()
    if correct is None:
        correct = input_ids[0][-1]
    input_ids = torch.tensor(input_ids, dtype=torch.long).to(model.device)
    decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long).to(model.device)
    
    A = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
    softmax = torch.nn.Softmax(dim=0)
    logits = A.logits[0][-1]
    probs = softmax(logits)
    if foil is not None and correct != foil:
        base_score = (probs[correct]-probs[foil]).detach().cpu().numpy()
    else:
        base_score = (probs[correct]).detach().cpu().numpy()

    enc_scores = np.zeros(len(input_ids[0]))
    for i in range(len(input_ids[0])):
        input_ids_i = torch.cat((input_ids[0][:i], input_ids[0][i+1:])).unsqueeze(0)
        
        A = model(input_ids=input_ids_i, decoder_input_ids=decoder_input_ids)
        logits = A.logits[0][-1]
        probs = softmax(logits)
        if foil is not None and correct != foil:
            erased_score = (probs[correct]-probs[foil]).detach().cpu().numpy()
        else:
            erased_score = (probs[correct]).detach().cpu().numpy()
                    
        enc_scores[i] = base_score - erased_score # higher score = lower confidence in correct = more influential input

    dec_scores = np.zeros(len(decoder_input_ids[0]))
    
    print('force decoder_input_ids to be 0,0', decoder_input_ids)
    
    for i in range(len(decoder_input_ids[0])):
        
        print('debug decoder_input_ids', decoder_input_ids)
        
        decoder_input_ids_i = torch.cat((decoder_input_ids[0][:i], decoder_input_ids[0][i+1:])).unsqueeze(0)
        
        print('debug i', i)
        print('debug decoder_input_ids_i', decoder_input_ids_i)
        
        A = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids_i)
        logits = A.logits[0][-1]
        probs = softmax(logits)
        if foil is not None and correct != foil:
            erased_score = (probs[correct]-probs[foil]).detach().cpu().numpy()
        else:
            erased_score = (probs[correct]).detach().cpu().numpy()
                    
        dec_scores[i] = base_score - erased_score # higher score = lower confidence in correct = more influential input
    
    
    if normalize:
        norm = np.linalg.norm(enc_scores, ord=1)
        enc_scores /= norm
        norm = np.linalg.norm(dec_scores, ord=1)
        dec_scores /= norm
        
    return enc_scores, dec_scores

In [28]:
normalize = False

# Input x gradient
base_enc_explanation = input_x_gradient(base_enc_saliency, base_enc_embed, normalize=normalize)
base_dec_explanation = input_x_gradient(base_dec_saliency, base_dec_embed, normalize=normalize)
enc_explanation = input_x_gradient(enc_saliency, enc_embed, normalize=normalize)
dec_explanation = input_x_gradient(dec_saliency, dec_embed, normalize=normalize)

# Gradient norm
base_enc_explanation = l1_grad_norm(base_enc_saliency, normalize=normalize)
base_dec_explanation = l1_grad_norm(base_dec_saliency, normalize=normalize)
enc_explanation = l1_grad_norm(enc_saliency, normalize=normalize)
dec_explanation = l1_grad_norm(dec_saliency, normalize=normalize)  

# Erasure
base_enc_explanation, base_dec_explanation = erasure_scores(model, input_ids, decoder_input_ids, correct=CORRECT_ID, normalize=normalize)
enc_explanation, dec_explanation = erasure_scores(model, input_ids, decoder_input_ids, correct=CORRECT_ID, foil=FOIL_ID, normalize=normalize)

# Normalize
base_norm = np.linalg.norm(np.concatenate((base_enc_explanation, base_dec_explanation)), ord=1)
base_enc_explanation /= base_norm
base_dec_explanation /= base_norm
norm = np.linalg.norm(np.concatenate((enc_explanation, dec_explanation)), ord=1)
enc_explanation /= norm
dec_explanation /= norm

# Visualize
visualize(base_enc_explanation, tokenizer, 
          input_ids, print_text=True, 
          title=f"Why did the model predict {correct}? (encoder input)",
        #   normalize=True,
          figsize=(120, 10))
visualize(base_dec_explanation, tokenizer, 
          decoder_input_ids, print_text=True, title=f"Why did the model predict {correct}? (decoder input)")
visualize(enc_explanation, tokenizer, 
          input_ids, print_text=True, title=f"Why did the model predict {correct} instead of {foil}? (encoder input)",
        #   normalize=True,
          figsize=(200, 10))

In [29]:
decoder_input_ids.shape