In [1]:

import argparse
import json
import logging
import math
import os
import random
from pathlib import Path
import sys
import evaluate
metric = evaluate.load("rouge")
from scipy import stats
from datasets import concatenate_datasets

import datasets
import evaluate
import nltk
import numpy as np
import pandas as pd
from torch import nn
import torch
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from accelerate.utils import DummyOptim, DummyScheduler
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from filelock import FileLock
from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from copy import deepcopy
from torch.nn import CrossEntropyLoss
import gc
from accelerate import FullyShardedDataParallelPlugin
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_int8_training,  TaskType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
from peft.utils.other import fsdp_auto_wrap_policy
from datasets import Dataset
from datasets import concatenate_datasets
import seaborn as sns

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    SchedulerType,
    get_scheduler,
    AutoModelForCausalLM,
)
pd.options.display.float_format = '{:,.2f}'.format
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from datasets import Dataset, load_from_disk
pd.set_option('max_colwidth', 800)
pd.set_option('display.max_columns', 100)

def plot_heatmap(df, figsize=(6, 5), fmt='.2f',):
    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=figsize)

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(240, 10, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(df, cmap=cmap, vmax=1.0, vmin=-1.0, cbar_kws={"shrink": .8}, center=0,
                square=True, linewidths=.5, annot=True, fmt=fmt)
    plt.title("Column Correlation Heatmap")
    plt.show()
    

def plot_correlation_heatmap(df, threshold=0, figsize=(6, 5), fmt='.2f', spearman=False):
    import seaborn as sns
    corr = df.corr()
    if spearman:
        from scipy import stats
        res = stats.spearmanr(df.values)
        corr = pd.DataFrame(res.statistic, index=corr.index, columns=corr.columns)

    # corr = corr.where(np.abs(corr) > threshold, 0)

    # Generate a mask for the upper triangle
    mask = np.zeros_like(corr, dtype=bool)
    mask[np.triu_indices_from(mask)] = True

    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=figsize)

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(240, 10, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(corr, mask=mask, cmap=cmap, vmax=1.0, vmin=-1.0, cbar_kws={"shrink": .8}, center=0,
                square=True, linewidths=.5, annot=True, fmt=fmt)
    plt.title("Column Correlation Heatmap")
    plt.show()
    
    fig, ax = plt.subplots(figsize=(8, 6))
    plt.title("Columns as heatmap plot")
    sns.heatmap(df, cmap='coolwarm', ax=ax)
    plt.show()
    
    return corr

def calculate_jsd(x, y):
    if not torch.is_tensor(x):
        x = torch.tensor(x)
    if not torch.is_tensor(y):
        y = torch.tensor(y)
    jsd_m = 0.5 * (x + y)
    jsd = 0.5 * nn.KLDivLoss(reduction='none', log_target=False)(torch.log(x), jsd_m) + 0.5 * nn.KLDivLoss(reduction='none', log_target=False)(torch.log(y), jsd_m)
    jsd = jsd.sum(-1).tolist()
    return jsd
    

  from .autonotebook import tqdm as notebook_tqdm


In [159]:
gender_words = {
    'he': 'she',
 'him': 'her',
 'his': 'hers',
 'himself': 'herself',
 'son': 'daughter',
 'father': 'mother',
 'brother': 'sister',
 'uncle': 'aunt',
 'nephew': 'niece',
 'grandfather': 'grandmother',
 'husband': 'wife',
 'boyfriend': 'girlfriend',
 'male': 'female',
 'king': 'queen',
 'sir': 'madam',
 'actor': 'actress',
 'host': 'hostess',
 'waiter': 'waitress',
 'steward': 'stewardess',
 'policeman': 'policewoman',
 'fireman': 'firewoman',
 'chairman': 'chairwoman',
 'businessman': 'businesswoman',
 'salesman': 'saleswoman',
 'doctor': 'doctora',
 'mr.': 'ms.',
 'groom': 'bride',
 'duke': 'duchess',
 'hero': 'heroine',
 'landlord': 'landlady',
 'manager': 'manageress',
 'monk': 'nun',
 'postman': 'postwoman',
 'prince': 'princess',
 'prophet': 'prophetess',
 'singer': 'songstress',
 'sorcerer': 'sorceress',
 'waitperson': 'server',
 'widower': 'widow',
 'she': 'he',
 'her': 'him',
 'hers': 'his',
 'herself': 'himself',
 'daughter': 'son',
 'mother': 'father',
 'sister': 'brother',
 'aunt': 'uncle',
 'niece': 'nephew',
 'grandmother': 'grandfather',
 'wife': 'husband',
 'girlfriend': 'boyfriend',
 'female': 'male',
 'queen': 'king',
 'madam': 'sir',
 'actress': 'actor',
 'hostess': 'host',
 'waitress': 'waiter',
 'stewardess': 'steward',
 'policewoman': 'policeman',
 'firewoman': 'fireman',
 'chairwoman': 'chairman',
 'businesswoman': 'businessman',
 'saleswoman': 'salesman',
 'doctora': 'doctor',
 'ms.': 'mr.',
 'bride': 'groom',
 'duchess': 'duke',
 'heroine': 'hero',
 'landlady': 'landlord',
 'manageress': 'manager',
 'nun': 'monk',
 'postwoman': 'postman',
 'princess': 'prince',
 'prophetess': 'prophet',
 'songstress': 'singer',
 'sorceress': 'sorcerer',
 'server': 'waitperson',
 'widow': 'widower',
 'He': 'She',
 'Him': 'Her',
 'His': 'Hers',
 'Himself': 'Herself',
 'Son': 'Daughter',
 'Father': 'Mother',
 'Brother': 'Sister',
 'Uncle': 'Aunt',
 'Nephew': 'Niece',
 'Grandfather': 'Grandmother',
 'Husband': 'Wife',
 'Boyfriend': 'Girlfriend',
 'Male': 'Female',
 'King': 'Queen',
 'Sir': 'Madam',
 'Actor': 'Actress',
 'Host': 'Hostess',
 'Waiter': 'Waitress',
 'Steward': 'Stewardess',
 'Policeman': 'Policewoman',
 'Fireman': 'Firewoman',
 'Chairman': 'Chairwoman',
 'Businessman': 'Businesswoman',
 'Salesman': 'Saleswoman',
 'Doctor': 'Doctora',
 'Mr.': 'Ms.',
 'Groom': 'Bride',
 'Duke': 'Duchess',
 'Hero': 'Heroine',
 'Landlord': 'Landlady',
 'Manager': 'Manageress',
 'Monk': 'Nun',
 'Postman': 'Postwoman',
 'Prince': 'Princess',
 'Prophet': 'Prophetess',
 'Singer': 'Songstress',
 'Sorcerer': 'Sorceress',
 'Waitperson': 'Server',
 'Widower': 'Widow',
 'She': 'He',
 'Her': 'Him',
 'Hers': 'His',
 'Herself': 'Himself',
 'Daughter': 'Son',
 'Mother': 'Father',
 'Sister': 'Brother',
 'Aunt': 'Uncle',
 'Niece': 'Nephew',
 'Grandmother': 'Grandfather',
 'Wife': 'Husband',
 'Girlfriend': 'Boyfriend',
 'Female': 'Male',
 'Queen': 'King',
 'Madam': 'Sir',
 'Actress': 'Actor',
 'Hostess': 'Host',
 'Waitress': 'Waiter',
 'Stewardess': 'Steward',
 'Policewoman': 'Policeman',
 'Firewoman': 'Fireman',
 'Chairwoman': 'Chairman',
 'Businesswoman': 'Businessman',
 'Saleswoman': 'Salesman',
 'Doctora': 'Doctor',
 'Ms.': 'Mr.',
 'Bride': 'Groom',
 'Duchess': 'Duke',
 'Heroine': 'Hero',
 'Landlady': 'Landlord',
 'Manageress': 'Manager',
 'Nun': 'Monk',
 'Postwoman': 'Postman',
 'Princess': 'Prince',
 'Prophetess': 'Prophet',
 'Songstress': 'Singer',
 'Sorceress': 'Sorcerer',
 'Server': 'Waitperson',
 'Widow': 'Widower'}

In [97]:
model_name="t5-large"
dataset_name="samsum"
N_FOLD=2

max_length = 512
max_target_length=128
padding="max_length"
proba_columns = [
                # "baseline_lora", 
                 # "inverted_jsd", 
                 "proba_v10_cumulative_windowed", 
                 "proba_v11_cumulative_windowed", 
                 "proba_v12_cumulative_windowed", 
                 # "proba_v10_cumulative_windowed_logsig_w10", 
                 "proba_v10_cumulative_windowed_w10", 
                 # "proba_v12_cumulative_windowed_logsig", 
                 # "proba_v10_cumulative_windowed_logsig", 
                 # "proba_v12_cumulative_windowed_logsig_w8",
                 # "proba_v10_cumulative_logsig_windowed_w8", 
                 # "proba_v12_cumulative_windowed_logsig_w5", 
                 # "proba_v10_cumulative_windowed_logsig_w5", 
                 "proba_v12_cumulative_windowed_w5"]



In [98]:
models = dict()
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset(dataset_name)
# models["pretrained"] = model
generator=pipeline(task='text2text-generation',model=model,tokenizer=tokenizer, max_length=max_target_length)
# generator._forward_params
# generator("summarize: A: Hi Tom, are you busy tomorrow’s afternoon?\r\nB: I’m pretty sure I am. What’s up?\r\nA: Can you go with me to the animal shelter?.\r\nB: What do you want to do?\r\nA: I want to get a puppy for my son.\r\nB: That will make him so happy.\r\nA: Yeah, we’ve discussed it many times. I think he’s ready now.\r\nB: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \r\nA: I'll get him one of those little dogs.\r\nB: One that won't grow up too big;-)\r\nA: And eat too much;-))\r\nB: Do you know which one he would like?\r\nA: Oh, yes, I took him there last Monday. He showed me one that he really liked.\r\nB: I bet you had to drag him away.\r\nA: He wanted to take it home right away ;-).\r\nB: I wonder what he'll name it.\r\nA: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))",
#           max_length=10, num_return_sequences=4, num_beams=4, do_sample=True)


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
100%|██████████| 3/3 [00:00<00:00, 742.57it/s]


In [4]:
from collections import Counter
ctr = Counter([w for sent in map(lambda x: x.strip().split(), list(dataset["train"]["summary"])) for w in sent])
len(ctr)

26730

In [99]:
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
for proba_col in proba_columns: 
    
    md = deepcopy(model)
    if "_lora" not in proba_col:
        state_dict = torch.load(f"outputs/{model_name}/{dataset_name}/folds_{N_FOLD}_{proba_col}_combined/model.pt", map_location='cpu')
        md.load_state_dict(state_dict)
    else:
        
        peft_config = PeftConfig.from_pretrained(f"outputs/{model_name}/{dataset_name}/{proba_col}")
        md = PeftModel.from_pretrained(md, f"outputs/{model_name}/{dataset_name}/{proba_col}")
    md = md.eval()
    models[proba_col] = md
    
    

<All keys matched successfully>

<All keys matched successfully>

<All keys matched successfully>

<All keys matched successfully>

<All keys matched successfully>

In [93]:
# https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/pipelines/text2text_generation.py


In [100]:
dsets = []

for FOLD in range(N_FOLD):
    ds = Dataset.load_from_disk(f"outputs/{model_name}/{dataset_name}/fold_{N_FOLD}_{FOLD}")
    ds = ds.rename_column("proba", f"proba{FOLD}")
    md = deepcopy(model)
    md.load_state_dict(torch.load(f"outputs/{model_name}/{dataset_name}/fold_{N_FOLD}_{FOLD}/model.pt", map_location = 'cpu'))
    md = md.eval()
    models[FOLD] = md
    
    dsets.append(ds)
    
combined_ds = Dataset.load_from_disk(f"outputs/{model_name}/{dataset_name}/folds_{N_FOLD}_combined")
combined_ds_jsd = Dataset.load_from_disk(f"outputs/{model_name}/{dataset_name}/folds_{N_FOLD}_jsd")
combined_ds_jsd = combined_ds_jsd.map(lambda x:{k: v + [v[-1]]*2 for k, v in x.items()})
combined_ds_jsd = combined_ds_jsd.rename_column("proba0", "proba0_jsd").rename_column("proba1", "proba1_jsd")
combined_ds = concatenate_datasets([combined_ds, combined_ds_jsd], axis=1)

md = deepcopy(model)
md.load_state_dict(torch.load(f"outputs/{model_name}/{dataset_name}/baseline/model.pt", map_location = 'cpu'))
models["baseline"] = md


<All keys matched successfully>

<All keys matched successfully>



<All keys matched successfully>

In [8]:
# combined_ds.save_to_disk(f"outputs/{model_name}/{dataset_name}/folds_{N_FOLD}_combined")

In [81]:
!pip install line_profiler

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting line_profiler
  Downloading line_profiler-4.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.9/661.9 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: line_profiler
Successfully installed line_profiler-4.0.3


In [82]:
%load_ext line_profiler

In [125]:

        
        
        
        
def get_one_step_proba(model, labels, input_ids, attention_mask, encoder_outputs=None):
    original_lables = deepcopy(labels)
    # labels = model._shift_right(labels)
    probas = []
    log_proba = 0
    logits = []
    with torch.no_grad():
        outputs = model(input_ids, attention_mask, labels=deepcopy(labels), encoder_outputs=encoder_outputs)
        lm_logits = outputs.logits.softmax(dim=-1).squeeze(0)
        loss = outputs.loss
        
        for i in range(labels.shape[-1]-1):
            actual = original_lables[..., :(i+1)]
            proba = lm_logits[i, actual[0, -1]].item()
            if actual[0, -1].item()!=0:
                log_proba += np.log(proba)
            probas.append(proba)
        return {"probas": probas, "loss": loss.item(), "log_proba": log_proba}
        

In [182]:
def greedy_prefix_decoding(model, tokenizer, prefix_text, input_ids, attention_mask, max_length, encoder_outputs=None):
    eos_token_id, pad_token_id = tokenizer.eos_token_id, tokenizer.pad_token_id
    prefixs = tokenizer(text_target=[prefix_text], max_length=max_length, padding="do_not_pad", 
                    truncation=True, return_tensors="pt", 
                    add_special_tokens=False)
    prefix_ids = prefixs["input_ids"].type(torch.int).to(input_ids.device)
    out_token_ids = deepcopy(prefix_ids)
    
    # prefix_ids = model._shift_right(prefix_ids)
    start_pad = torch.tensor([pad_token_id], device=prefix_ids.device).unsqueeze(0)
    prefix_ids = torch.cat([start_pad, prefix_ids], dim=1)
    assert prefix_ids.shape[0] == 1
    with torch.no_grad():
        if encoder_outputs is None:
            encoder_outputs = model.encoder(input_ids, attention_mask)
        encoder_hidden_states = encoder_outputs[0]
        
        past_key_values = None
        next_token = None
        gen_len = 0
        while next_token is None or next_token[0].item() != eos_token_id and out_token_ids.shape[1] < max_length:
            # print(prefix_ids, tokenizer.batch_decode(prefix_ids))
            with torch.autocast("cuda"):
                decoder_outputs = model.decoder(input_ids=prefix_ids, past_key_values=past_key_values, 
                                                encoder_hidden_states=encoder_hidden_states, 
                                                encoder_attention_mask=attention_mask, use_cache=True)
            past_key_values = decoder_outputs.past_key_values
            gen_len += 1
            if gen_len % 32 == 0:
                past_key_values = None
            sequence_output = decoder_outputs[0]
            if model.config.tie_word_embeddings:
                sequence_output = sequence_output * (model.model_dim**-0.5)
            lm_logits = model.lm_head(sequence_output).softmax(dim=-1) # B, S, D
            next_token_logits = lm_logits[:, -1] # B, D
            next_token = torch.argmax(next_token_logits, dim=1).unsqueeze(1) # B, 1
            out_token_ids = torch.cat([out_token_ids, next_token], dim=1)
            prefix_ids = torch.cat([prefix_ids, next_token], dim=1)
        return out_token_ids

In [154]:
# Noising (He to She), Double stopwords mistake
# Invert gender, 
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to /home/ahemf/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [194]:
def run_for_one_model(model, model_key, idx, input_text, label_text, tokenizer, padding, device, **gen_kwargs):
    _ = gc.collect()
    torch.cuda.empty_cache()
    label_ids = tokenizer(text_target=[label_text], max_length=gen_kwargs["max_length"], padding=padding, truncation=True)["input_ids"][0]
    words = label_text.split()
    batch = tokenizer(input_text, max_length=max_length, padding=padding, truncation=True, return_tensors="pt")
    model_texts = dict()
    model_texts["idx"] = idx
    model_texts["label"] = label_text
    model_texts["model_key"] = model_key
    batch["input_ids"] = batch["input_ids"].to(device)
    batch["attention_mask"] = batch["attention_mask"].to(device)
    with torch.no_grad():
        encoder_outputs = model.encoder(batch["input_ids"], batch["attention_mask"])
        generated_ids = model.generate(
                            input_ids=batch["input_ids"],
                            encoder_outputs=encoder_outputs,
                            attention_mask=batch["attention_mask"], 
                            use_cache=True,
                            **gen_kwargs,
                        )
        one_steps_probas = get_one_step_proba(model, 
                                              tokenizer(text_target=[label_text], max_length=gen_kwargs["max_length"], padding="max_length", truncation=True, return_tensors="pt")["input_ids"].to(device), 
                                              batch["input_ids"], batch["attention_mask"], encoder_outputs=encoder_outputs,)
        log_proba_of_label = one_steps_probas["log_proba"]
        model_texts["log_proba_of_label"] = log_proba_of_label
        generated_ids = generated_ids.squeeze().tolist()
        if generated_ids[0] == tokenizer.pad_token_id:
            generated_ids = generated_ids[1:]
        predictions = tokenizer.decode(generated_ids, skip_special_tokens=True)
        model_texts["predictions"] = predictions
        rouge_score = metric.compute(predictions=[predictions], references=[label_text], use_stemmer=True)
        model_texts["rouge_score"] = rouge_score
        
    # Get proba of label
    # inputs_embeds
    
    model_texts["word_index_list"] = list(range(1, len(words) - 1))
    for i in range(1, len(words) - 1):
        _ = gc.collect()
        torch.cuda.empty_cache()
        prefix_text = " ".join(words[:i])
        model_texts[i] = dict()
        model_texts[i]["prefix_text"] = prefix_text
        greedy_prefix_decode = greedy_prefix_decoding(model, 
                                                      tokenizer, 
                                                      prefix_text, 
                                                      batch["input_ids"].type(torch.int), 
                                                      batch["attention_mask"].type(torch.int), 
                                                      max_target_length, encoder_outputs=encoder_outputs,).squeeze().tolist()
        greedy_prefix_predictions = tokenizer.decode(greedy_prefix_decode, skip_special_tokens=True)
        model_texts[i]["predictions"] = greedy_prefix_predictions
        rouge_score = metric.compute(predictions=[greedy_prefix_predictions], references=[label_text], use_stemmer=True)
        model_texts[i]["rouge_score"] = rouge_score
        model_texts[i]["prefix_text_match_label"] = greedy_prefix_predictions.strip().lower() == label_text.strip().lower()
        generation_length_verbatim = 0
        generation_suffix_verbatim_text = ""
        gen_words = greedy_prefix_predictions.lower().split()
        for aw, gw in list(zip(label_text.lower().split(), gen_words))[i:]:
            if aw != gw:
                break
            generation_length_verbatim += 1
            generation_suffix_verbatim_text += (gw+" ")
        generation_suffix_verbatim_text = generation_suffix_verbatim_text.strip()
        model_texts[i]["generation_length_verbatim"] = generation_length_verbatim
        model_texts[i]["generation_suffix_verbatim_text"] = generation_suffix_verbatim_text
        model_texts[i]["generation_length_verbatim_by_prefix_length"] = generation_length_verbatim / i
        model_texts[i]["generation_length_verbatim_by_generation_length"] = generation_length_verbatim / len(gen_words)
        model_texts[i]["generation_length_verbatim_by_label_length"] = generation_length_verbatim / len(words)
        
        if words[:i][-1].lower() in stop_words and False:
            model_texts[i]["stopword_repeat"] = dict()
            model_texts[i]["stopword_repeat"]["stopword"] = words[:i][-1]
            prefix_text = " ".join(words[:i] + [words[:i][-1]])
            greedy_prefix_decode = greedy_prefix_decoding(model, 
                                                      tokenizer, 
                                                      prefix_text, 
                                                      batch["input_ids"].type(torch.int), 
                                                      batch["attention_mask"].type(torch.int), 
                                                      max_target_length, encoder_outputs=encoder_outputs,).squeeze().tolist()
            greedy_prefix_predictions = tokenizer.decode(greedy_prefix_decode, skip_special_tokens=True)
            model_texts[i]["stopword_repeat"]["prefix_text"] = prefix_text
            model_texts[i]["stopword_repeat"]["predictions"] = greedy_prefix_predictions
            generation_length_verbatim = 0
            generation_suffix_verbatim_text = ""
            gen_words = greedy_prefix_predictions.lower().split()
            for aw, gw in list(zip(label_text.lower().split()[i:], gen_words[i+1:])):

                if aw != gw:
                    break
                generation_length_verbatim += 1
                generation_suffix_verbatim_text += (gw+" ")
            generation_suffix_verbatim_text = generation_suffix_verbatim_text.strip()
            model_texts[i]["stopword_repeat"]["generation_length_verbatim"] = generation_length_verbatim
            model_texts[i]["stopword_repeat"]["generation_suffix_verbatim_text"] = generation_suffix_verbatim_text
            model_texts[i]["stopword_repeat"]["generation_length_verbatim_by_prefix_length"] = generation_length_verbatim / i
            model_texts[i]["stopword_repeat"]["generation_length_verbatim_by_generation_length"] = generation_length_verbatim / len(gen_words)
            model_texts[i]["stopword_repeat"]["generation_length_verbatim_by_label_length"] = generation_length_verbatim / len(words)
        if words[:i][-1].lower() in gender_words:
            model_texts[i]["gender_swap"] = dict()
            model_texts[i]["gender_swap"]["stopword"] = gender_words[words[:i][-1]]
            prefix_text = " ".join(words[:i-1] + [gender_words[words[:i][-1]]])
            greedy_prefix_decode = greedy_prefix_decoding(model, 
                                                      tokenizer, 
                                                      prefix_text, 
                                                      batch["input_ids"].type(torch.int), 
                                                      batch["attention_mask"].type(torch.int), 
                                                      max_target_length, encoder_outputs=encoder_outputs,).squeeze().tolist()
            greedy_prefix_predictions = tokenizer.decode(greedy_prefix_decode, skip_special_tokens=True)
            model_texts[i]["gender_swap"]["prefix_text"] = prefix_text
            model_texts[i]["gender_swap"]["predictions"] = greedy_prefix_predictions
            generation_length_verbatim = 0
            generation_suffix_verbatim_text = ""
            gen_words = greedy_prefix_predictions.lower().split()
            for aw, gw in list(zip(label_text.lower().split(), gen_words))[i:]:

                if aw != gw:
                    break
                generation_length_verbatim += 1
                generation_suffix_verbatim_text += (gw+" ")
            generation_suffix_verbatim_text = generation_suffix_verbatim_text.strip()
            model_texts[i]["gender_swap"]["generation_length_verbatim"] = generation_length_verbatim
            model_texts[i]["gender_swap"]["generation_suffix_verbatim_text"] = generation_suffix_verbatim_text
            model_texts[i]["gender_swap"]["generation_length_verbatim_by_prefix_length"] = generation_length_verbatim / i
            model_texts[i]["gender_swap"]["generation_length_verbatim_by_generation_length"] = generation_length_verbatim / len(gen_words)
            model_texts[i]["gender_swap"]["generation_length_verbatim_by_label_length"] = generation_length_verbatim / len(words)
    return model_texts



In [192]:
models["proba_v10_cumulative_windowed_w10"] = models["proba_v10_cumulative_windowed_w10"].to("cpu")

In [193]:
device1 = torch.device("cuda:1")
models[1]=models[1].to("cpu")
models[0]=models[0].to("cpu")
models["baseline"] = models["baseline"].to(device1)
device2 = torch.device("cuda:2")
models["proba_v12_cumulative_windowed"] = models["proba_v12_cumulative_windowed"].to(device2)

In [196]:
from tqdm import trange
results = []
for idx in trange(0, 30):
    label_text = dataset["train"][idx]["summary"]
    input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
    r1 = run_for_one_model(models["baseline"], "baseline", idx, input_text, label_text, tokenizer, padding, device1, **gen_kwargs)
    r2 = run_for_one_model(models["proba_v12_cumulative_windowed"], "proba_v12_cumulative_windowed", idx, input_text, label_text, tokenizer, padding, device2, **gen_kwargs)
    results.append([r1, r2])
    

100%|██████████| 30/30 [28:02<00:00, 56.10s/it]


In [206]:
for idx in trange(91, 101):
    label_text = dataset["train"][idx]["summary"]
    input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
    r1 = run_for_one_model(models["baseline"], "baseline", idx, input_text, label_text, tokenizer, padding, device1, **gen_kwargs)
    r2 = run_for_one_model(models["proba_v12_cumulative_windowed"], "proba_v12_cumulative_windowed", idx, input_text, label_text, tokenizer, padding, device2, **gen_kwargs)
    results.append([r1, r2])
    

 50%|█████     | 10/20 [09:08<09:08, 54.86s/it]

KeyboardInterrupt



In [210]:
normal_verbatims = 0
stopword_verbatims = 0
gender_verbatims = 0
log_proba_diff = 0
for i, (r1, r2) in enumerate(results):
    wl = r1["word_index_list"]
    log_proba_diff += ((r2["log_proba_of_label"] - r1["log_proba_of_label"])/len(wl))
    for ix in wl:
        normal_verbatims += (r2[ix]["generation_length_verbatim"] - r1[ix]["generation_length_verbatim"])
        
        # if "stopword_repeat" in r1[ix]:
        #     stopword_verbatims+= (r2[ix]["stopword_repeat"]["generation_length_verbatim"] - r1[ix]["stopword_repeat"]["generation_length_verbatim"])
        if "gender_swap" in r1[ix]:
            gender_verbatims += (r2[ix]["gender_swap"]["generation_length_verbatim"] - r1[ix]["gender_swap"]["generation_length_verbatim"])
print(normal_verbatims, stopword_verbatims, gender_verbatims, log_proba_diff/len(results))


-264 0 -20 -1.4828793558960458


In [212]:
normal_verbatims = 0
stopword_verbatims = 0
gender_verbatims = 0
log_proba_diff = 0
for i, (r1, r2) in enumerate(results):
    wl = r1["word_index_list"]
    log_proba_diff += ((r2["log_proba_of_label"] - r1["log_proba_of_label"])/len(wl))
    for ix in wl:
        if r2[ix]["generation_length_verbatim"] < r1[ix]["generation_length_verbatim"]:
            print(r2["label"], 
                  "\n", r2[ix]["predictions"],
                  "\n", r2[ix]["prefix_text"],"||", 
                  r2[ix]["generation_suffix_verbatim_text"], "||", r1[ix]["generation_suffix_verbatim_text"])
            print("="*40)
        # r2[ix]["generation_length_verbatim"], r1[ix]["generation_length_verbatim"]
        normal_verbatims += (r2[ix]["generation_length_verbatim"] - r1[ix]["generation_length_verbatim"])
        
        # if "stopword_repeat" in r1[ix]:
        #     stopword_verbatims+= (r2[ix]["stopword_repeat"]["generation_length_verbatim"] - r1[ix]["stopword_repeat"]["generation_length_verbatim"])
        if "gender_swap" in r1[ix]:
            gender_verbatims += (r2[ix]["gender_swap"]["generation_length_verbatim"] - r1[ix]["gender_swap"]["generation_length_verbatim"])
print(normal_verbatims, stopword_verbatims, gender_verbatims, log_proba_diff/len(results))


Amanda baked cookies and will bring Jerry some tomorrow. 
 Amanda baked cookies and will bring them to Jerry tomorrow. 
 Amanda baked cookies and || will bring || will bring jerry some tomorrow.
Amanda baked cookies and will bring Jerry some tomorrow. 
 Amanda baked cookies and will bring them to Jerry tomorrow. 
 Amanda baked cookies and will || bring || bring jerry some tomorrow.
Amanda baked cookies and will bring Jerry some tomorrow. 
 Amanda baked cookies and will bring them to Jerry tomorrow. 
 Amanda baked cookies and will bring ||  || jerry some tomorrow.
Kim may try the pomodoro technique recommended by Tim to get more stuff done. 
 Kim may move her ass tomorrow. 
 Kim may ||  || try the pomodoro technique
Kim may try the pomodoro technique recommended by Tim to get more stuff done. 
 Kim may try to do everything tomorrow. 
 Kim may try ||  || the pomodoro technique
Sam is confused, because he overheard Rick complaining about him as a roommate. Naomi thinks Sam should talk to 

In [None]:
# idx = random.randint(0, len(dataset["train"]))
label_text = dataset["train"][idx]["summary"]
input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'


# Label is a prefix of generation
# Can I tell it the number of words
run_for_one_model(models[0], "0", idx, input_text, label_text, tokenizer, padding, device, **gen_kwargs)


In [None]:
run_for_one_model(models[1], "1", input_text, label_text, tokenizer, padding, device, **gen_kwargs)

In [107]:
gen_kwargs = {
                "max_length": max_target_length,
                "num_beams": 1,
    }


def investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, fold, percent_frac, **more_gen_kwargs):
    gen_kwargs = deepcopy(gen_kwargs)
    gen_kwargs.update(more_gen_kwargs)
    labels = tokenizer(text_target=[label_text], max_length=max_target_length, padding=padding, truncation=True)
    label_ids = labels["input_ids"][0]
    labels = [tokenizer.decode(i) for i in label_ids]
    model_texts = dict()
    prefix_length = int(len(label_text.split()) * percent_frac)
    prefix_text = " ".join(label_text.split()[:prefix_length])
    model_texts["input"] = (input_text,prefix_text, 0, 0)
    model_texts["label"] = (label_text,prefix_text, 0, 0)
    
    
    # pipeline_generation = generator(input_text, **gen_kwargs)
    # pipeline_token_ids = tokenizer(text_target=[pipeline_generation[0]['generated_text']], max_length=gen_kwargs["max_length"], padding=padding, truncation=True)
    # pipeline_token_ids = pipeline_token_ids["input_ids"][0]
    # pipeline_decoded = [tokenizer.decode(i) for i in pipeline_token_ids]
    # model_texts["pipeline"] = (pipeline_generation[0]['generated_text'],prefix_text, 0, 0)

    
    batch = tokenizer(input_text, max_length=max_length, padding=padding, truncation=True, return_tensors="pt")
    model_predictions = dict()

    for FOLD, used_model in models.items():
        with torch.no_grad():
            generated_ids = used_model.generate(
                            input_ids=batch["input_ids"],
                            attention_mask=batch["attention_mask"],
                            use_cache=True,
                            **gen_kwargs,
                        )
            
            one_steps_probas = get_one_step_proba(used_model, 
                                                  tokenizer(text_target=[label_text], max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt")["input_ids"], 
                                                  batch["input_ids"], batch["attention_mask"])
            probas = one_steps_probas["probas"]
            if len(probas) < max_target_length:
                probas = probas + ([0] * (max_target_length - len(probas)))
            
            log_proba_of_label = one_steps_probas["log_proba"]
        generated_ids = generated_ids.squeeze().tolist()
        if generated_ids[0] == tokenizer.pad_token_id:
            generated_ids = generated_ids[1:]
        if len(generated_ids) < max_target_length:
            generated_ids = generated_ids + [tokenizer.pad_token_id]*(max_target_length - len(generated_ids))
        generated_ids_decoded = [tokenizer.decode(i) for i in generated_ids]
        model_predictions[f"generated_token_ids_{FOLD}"] = generated_ids
        model_predictions[f"generated_tokens_{FOLD}"] = generated_ids_decoded
        model_predictions[f"generated_probas_{FOLD}"] = probas
        predictions = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        greedy_prefix_decode = greedy_prefix_decoding(used_model, tokenizer, prefix_text, batch["input_ids"], batch["attention_mask"], max_target_length).squeeze().tolist()
        greedy_prefix_predictions = tokenizer.decode(greedy_prefix_decode, skip_special_tokens=True)
        
        one_steps_probas = get_one_step_proba(used_model, 
                                                  tokenizer(text_target=[predictions], max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt")["input_ids"], 
                                                  batch["input_ids"], batch["attention_mask"])
        log_proba_of_generated = one_steps_probas["log_proba"]
        model_texts[FOLD] = (predictions, greedy_prefix_predictions, log_proba_of_label, log_proba_of_generated)
        
    rouge_scores = dict()
    for FOLD, (predictions, greedy_prefix_predictions, log_proba_of_label, log_proba_of_generated) in model_texts.items():
        rouge_score = metric.compute(predictions=[predictions], references=[label_text], use_stemmer=True)
        rouge_score["text"] = predictions
        rouge_score["prefix_text"] = greedy_prefix_predictions
        rouge_score["prefix_text_match_label"] = greedy_prefix_predictions.strip().lower() == label_text.strip().lower()
        rouge_score["rouge_of_prefix_text"] = metric.compute(predictions=[greedy_prefix_predictions], references=[label_text], use_stemmer=True)
        rouge_score["log_proba_of_label"] = log_proba_of_label
        rouge_score["log_proba_of_generated"] = log_proba_of_generated
        rouge_scores[FOLD] = rouge_score
    rouge_df = pd.DataFrame(rouge_scores.values(), index=rouge_scores.keys())
    rouge_df
    if fold == "train":
        probas_dict = combined_ds[idx]
        pdict = dict()
        # [x.replace("_windowed", "") for x in models.keys() if "_windowed" in str(x)] +
        for k in ["proba0", "proba1"] + [x for x in models.keys()]:
            if k in probas_dict:
                pdict[k] = probas_dict[k]
        probas_dict = pdict
    else:
        probas_dict = dict()
    our_dict = {
                "labels": labels, 
                        # "label_ids": label_ids, 
                        # "pipeline_token_ids": pipeline_token_ids, "pipeline_tokens": pipeline_decoded, 
                **model_predictions}

    probas_dict.update(our_dict)
    # print({k: len(v) for k, v in probas_dict.items()})
    df = pd.DataFrame(probas_dict)
    df = df.loc[((df==tokenizer.pad_token).sum(axis=1) < 6)|(df["labels"]!=tokenizer.pad_token)]
    return rouge_df, df


In [87]:
# idx = random.randint(0, len(dataset["train"]))
# label_text = dataset["train"][idx]["summary"]
# input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
# rouge_df, df = investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, "train", 0.5, do_sample=False, temperature=0.7)
%lprun -f investigate investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, "train", 0.5, do_sample=False, temperature=0.7)
# rouge_df
# print("=" * 80)

Timer unit: 1e-09 s

Total time: 73.6908 s
File: /tmp/ipykernel_84155/3118769970.py
Function: investigate at line 7

Line #      Hits         Time  Per Hit   % Time  Line Contents
     7                                           def investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, fold, percent_frac, **more_gen_kwargs):
     8         1      28108.0  28108.0      0.0      gen_kwargs = deepcopy(gen_kwargs)
     9         1       1852.0   1852.0      0.0      gen_kwargs.update(more_gen_kwargs)
    10         1     574587.0 574587.0      0.0      labels = tokenizer(text_target=[label_text], max_length=max_target_length, padding=padding, truncation=True)
    11         1       4771.0   4771.0      0.0      label_ids = labels["input_ids"][0]
    12         1    1585692.0 1585692.0      0.0      labels = [tokenizer.decode(i) for i in label_ids]
    13         1        611.0    611.0      0.0      model_texts = dict()
    14         1       3940.0   

In [108]:
idx = random.randint(0, len(dataset["train"]))
label_text = dataset["train"][idx]["summary"]
input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
rouge_df, df = investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, "train", 0.9, do_sample=False, temperature=0.7)
rouge_df
print("=" * 80)

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum,text,prefix_text,prefix_text_match_label,rouge_of_prefix_text,log_proba_of_label,log_proba_of_generated
input,0.23,0.08,0.21,0.18,"summarize: Anna: where are you lost?\r\nEma: here only where will i go?\r\nAnna: why aint you joining us for girly parties?\r\nEma: just been busy with school and job\r\nAnna: job? your working?\r\nEma: yes during weekends\r\nAnna: wow! super girl how do you manage\r\nEma: by missing girly parties\r\nAnna: awww why are you working.. its time to enjoy\r\nEma: i need money that why\r\nAnna: aww.. is everything ok?\r\nEma: oh yes absolutely!! its just that i want to start my own business after graduating from university.\r\nAnna: wow.. you amaze me every time i talk to you, you are so ambitious\r\nEma: thank you, i have very long plans and as for enjoyment and parties are concerned i would better have them with my business associates.\r\nAnna: i am so inspired! wish you all the happiness...",Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in,False,"{'rouge1': 0.9473684210526316, 'rouge2': 0.945945945945946, 'rougeL': 0.9473684210526316, 'rougeLsum': 0.9473684210526316}",0.0,0.0
label,1.0,1.0,1.0,1.0,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in thinking about her future.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in,False,"{'rouge1': 0.9473684210526316, 'rouge2': 0.945945945945946, 'rougeL': 0.9473684210526316, 'rougeLsum': 0.9473684210526316}",0.0,0.0
proba_v10_cumulative_windowed,0.52,0.31,0.42,0.42,Ema has been busy with school and job. She wants to start her own business after graduating from university. She will join Anna for girly parties.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established,False,"{'rouge1': 0.9487179487179489, 'rouge2': 0.9210526315789475, 'rougeL': 0.9487179487179489, 'rougeLsum': 0.9487179487179489}",-88.63,-7.74
proba_v11_cumulative_windowed,0.58,0.34,0.5,0.5,Ema hasn't been to girly parties because she's been busy with school and job. She wants to start her own business after graduating from university. Anna is inspired by Ema.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established,False,"{'rouge1': 0.9487179487179489, 'rouge2': 0.9210526315789475, 'rougeL': 0.9487179487179489, 'rougeLsum': 0.9487179487179489}",-108.35,-6.96
proba_v12_cumulative_windowed,0.55,0.38,0.46,0.46,Ema is busy with school and job. She wants to start her own business after graduating from university. She will join Anna for girly parties.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established,False,"{'rouge1': 0.9487179487179489, 'rouge2': 0.9210526315789475, 'rougeL': 0.9487179487179489, 'rougeLsum': 0.9487179487179489}",-112.46,-5.24
proba_v10_cumulative_windowed_w10,0.63,0.49,0.63,0.63,Ema is busy with school and work. She wants to start her own business after graduating from university. Anna is inspired by Ema.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established business.,False,"{'rouge1': 0.9367088607594937, 'rouge2': 0.9090909090909091, 'rougeL': 0.9367088607594937, 'rougeLsum': 0.9367088607594937}",-94.45,-4.85
proba_v12_cumulative_windowed_w5,0.47,0.3,0.38,0.38,Ema hasn't been to the girly parties because she's been busy with school and job. She wants to start her own business after graduating from university.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established,False,"{'rouge1': 0.9487179487179489, 'rouge2': 0.9210526315789475, 'rougeL': 0.9487179487179489, 'rougeLsum': 0.9487179487179489}",-114.12,-4.35
0,0.34,0.0,0.25,0.25,Ema has been working during weekends to earn money for her business. She would rather have parties with her business associates. Anna wishes her good luck and wishes her good grades.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established,False,"{'rouge1': 0.9487179487179489, 'rouge2': 0.9210526315789475, 'rougeL': 0.9487179487179489, 'rougeLsum': 0.9487179487179489}",-98.3,-17.82
1,0.48,0.27,0.43,0.43,Ema has been busy with school and work. She works during weekends to earn money for starting her own business after graduating. She prefers to spend time with her business associates. Anna is impressed with her ambition and wants to be like Ema.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in thinking about,False,"{'rouge1': 0.9743589743589743, 'rouge2': 0.9736842105263158, 'rougeL': 0.9743589743589743, 'rougeLsum': 0.9743589743589743}",-25.22,-21.52
baseline,0.44,0.1,0.31,0.31,Ema is working during weekends to earn money for her business. She is not going to the girly parties with Anna and her friends.,Ema is busy with school and work. She is missing out on girly parties. She is saving up to start her own business after graduating. Anna is amazed and inspired by this. Ema finds motivation in her established,False,"{'rouge1': 0.9487179487179489, 'rouge2': 0.9210526315789475, 'rougeL': 0.9487179487179489, 'rougeLsum': 0.9487179487179489}",-48.89,-18.75




In [None]:
for k in range(20):
    idx = random.randint(0, len(dataset["train"]))
    label_text = dataset["train"][idx]["summary"]
    input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
    rouge_df, df = investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, "train", do_sample=False)
    rouge_df
    print("=" * 80)

In [None]:
for k in range(5):
    idx = random.randint(0, len(dataset["train"]))
    label_text = dataset["train"][idx]["summary"]
    input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
    rouge_df, df = investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, "train", do_sample=False, num_beams=10, temperature=0.6)
    rouge_df
    print("=" * 80)

In [None]:
idx = 3
label_text = dataset["train"][idx]["summary"]
input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
rouge_df, df = investigate(idx, input_text, label_text, models, generator, tokenizer, gen_kwargs, padding, "train")

rouge_df
# df

# do_sample=False, num_beams=1, temperature=0.7

In [117]:
dataset["train"][idx]["dialogue"]
dataset["train"][idx]["summary"]

"Edward: Rachel, I think I'm in ove with Bella..\r\nrachel: Dont say anything else..\r\nEdward: What do you mean??\r\nrachel: Open your fu**ing door.. I'm outside"

'Edward thinks he is in love with Bella. Rachel wants Edward to open his door. Rachel is outside. '

In [12]:
idx = 3
label_text = dataset["train"][idx]["summary"]
altered_label_text = 'Rajeev thinks he is in love with Rashmi. Sima wants Edward to open his door. Sima is outside. '
input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'


altered_input_text = "Rajeev: Sima, I think I'm in ove with Rashmi..\r\nsima: Dont say anything else..\r\nRajeev: What do you mean??\r\nsima: Open your fu**ing door.. I'm outside"
altered_input_text = f'summarize: {altered_input_text}'

rouge_df, df = investigate(idx, altered_input_text, altered_label_text, models, generator, tokenizer, gen_kwargs, padding, "train")

rouge_df
df

# do_sample=False, num_beams=1, temperature=0.7

In [None]:
check_df = df.select_dtypes(include=['float64'])

probas_df_only = check_df
probas_df_only.index = df["labels"]
plot_heatmap(probas_df_only, figsize=(14,14))


In [None]:
plot_correlation_heatmap(probas_df_only, figsize=(10,10))

In [13]:
input_ids = tokenizer(text_target=[input_text], max_length=512, padding="max_length", truncation=True, return_tensors="pt")
labels = tokenizer(text_target=[label_text], max_length=128, padding="max_length", truncation=True, return_tensors="pt")
label_ids = labels["input_ids"]
label_ids.shape
label_ids

model._shift_right(label_ids)

torch.Size([1, 128])

tensor([[ 8200,   317,     7,     3,    88,    19,    16,   333,    28,  5377,
             9,     5, 15868,  2746,  8200,    12,   539,   112,  1365,     5,
         15868,    19,  1067,     5,     1,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

tensor([[    0,  8200,   317,     7,     3,    88,    19,    16,   333,    28,
          5377,     9,     5, 15868,  2746,  8200,    12,   539,   112,  1365,
             5, 15868,    19,  1067,     5,     1,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [96]:
get_one_step_proba(models[0], label_ids, input_ids["input_ids"], input_ids["attention_mask"])["log_proba"]
get_one_step_proba(models[1], label_ids, input_ids["input_ids"], input_ids["attention_mask"])["log_proba"]

-22.270734582341067

-11.429230845551135

In [19]:
tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.pad_token_id

(1, None, 0)

In [66]:
idx = 3
label_text = dataset["train"][idx]["summary"]
input_text = f'summarize: {dataset["train"][idx]["dialogue"]}'
input_ids = tokenizer(text_target=[input_text], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
labels = tokenizer(text_target=[label_text], max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt")
label_ids = labels["input_ids"]

prefix_text = " ".join(label_text.split()[:4]) + " "
label_text
prefix_text
print("="*40)




greedy_decode = greedy_prefix_decoding(models[1], tokenizer, prefix_text, input_ids["input_ids"], input_ids["attention_mask"], max_target_length, )
tokenizer.batch_decode(greedy_decode)

greedy_decode = greedy_prefix_decoding(models[0], tokenizer, prefix_text, input_ids["input_ids"], input_ids["attention_mask"], max_target_length,)
tokenizer.batch_decode(greedy_decode)
tokenizer.decode(greedy_decode.squeeze().tolist(), skip_special_tokens=True)

'Edward thinks he is in love with Bella. Rachel wants Edward to open his door. Rachel is outside. '

'Edward thinks he is '



['Edward thinks he is in love with Bella. Rachel is outside.</s>']

['Edward thinks he is in love with Bella. Rachel is outside, but Edward is outside.</s>']

'Edward thinks he is in love with Bella. Rachel is outside, but Edward is outside.'