# Story Ending Generation (Visualizations)

Erik McGuire

CSC594-810-ADL

Winter 19-20

## I. Imports

In [None]:
"""
from google.colab import drive
drive.mount('/content/drive')
"""

In [None]:
# Required over !pip install transformers for subclassing model/overriding methods.
!git clone https://github.com/huggingface/transformers.git
%cd transformers
!pip install .
%cd ..

In [None]:
%tensorflow_version 2.x
from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel, GPT2LMHeadModel

In [None]:
import logging, torch
import torch.nn.functional as F # For softmax viz.
import matplotlib.pyplot as plt # For alt. attention viz.
from typing import Dict, List, Tuple
from ipywidgets import * # Interface.
import seaborn as sns # Attention viz.
import pandas as pd # Data.
import random # Choose random story for attentions.

> We must navigate to the main project folder in mounted My Drive:

In [None]:
%cd drive/My Drive/csc594-ADL

> Assumes the following structure:
<pre>.
├── content
│   ├──drive                         # Mounted drive folder.
│   │   └── My Drive                 # Mounted drive folder.
│   │       └── CSC-594-ADL          # Main project folder.
│   │           ├── datasets         # ConceptNet and ROCStories.
│   │           ├── endings          # Correct and generated endings per model.
│   │           ├── evals            # Evaluation results for stories and endings per model.
│   │           ├── models           # Pretrained models, tokenizers, vocabulary, etc.
│   │           ├── scripts          # Scripts for training and generation.
│   │           └── stories          # Combined story bodies and generated endings per model.
│   ├── sample_data                  # Default Colab folder.
│   └── transformers                 # Installed from HuggingFace.
└── ...
</pre>

---
## II. Functions
---

#### Attention
* Inspired by Krishan Subudhi's [code](https://krishansubudhi.github.io/deeplearning/2019/09/26/BertAttention.html).

In [None]:
def get_attns(prompt: str, ending: str, 
              model, tokenizer, 
              mname: str) -> Tuple[List[str], int, torch.Tensor]:
    """Get attention weights for story."""
    prompt_tokens = tokenizer.tokenize(prompt)
    if not mname == 'gpt2':
        end_tokens = tokenizer.tokenize("_delimiter_" + ending)
    else:
        end_tokens = tokenizer.tokenize(ending)
    pos_token = len(prompt_tokens)
    in_tokens = prompt_tokens + end_tokens
    in_tokens = list(map(lambda w: w.replace("Ġ", ""), in_tokens))
    ids = torch.tensor(tokenizer.convert_tokens_to_ids(in_tokens)).unsqueeze(0).to('cuda')
    with torch.no_grad():
        output = model(ids)
    attentions = torch.cat(output[-1]).cpu()
    attentions = attentions.permute(2,1,0,3)
    return in_tokens, pos_token, attentions

def display_attns(in_tokens: list, attns: torch.Tensor, pos_token: int) -> None: 
    "Displays multi-head attention weights for a token."
    heads = len(attns[0])
    cols = 2
    rows = int(heads/cols)
    fig, axes = plt.subplots(rows, cols, figsize = (27, 27))
    fig.tight_layout(pad=5.0)
    axes = axes.flat
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    print(f'\nMultihead attention weights:')
    for i, att in enumerate(attentions_pos):
        sns.heatmap(att, vmin = 0, vmax = 1, ax = axes[i], xticklabels = in_tokens)
        axes[i].set_title(f'Head #{i+1} ' )
        axes[i].set_ylabel('Layers')
        for tick in axes[i].get_xticklabels():
                tick.set_rotation(45)
                tick.set_fontsize(8) 

def display_per_head_attn(in_tokens: list, 
                                     attns: torch.Tensor, 
                                     pos_token: int, 
                                     head: int) -> None:
    "Given head number and position index, displays attention weights for a token."
    head = head -1
    fig, axes = plt.subplots(figsize = (20, 8))
    fig.tight_layout(pad=5.0)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    print(f'\nHeadwise attention weights:')
    sns.heatmap(attentions_pos[head], 
                vmin = 0, vmax = 1, 
                xticklabels = in_tokens)
    axes.set_title(f'Head #{head + 1} ' )
    axes.set_ylabel('Layers')
    for ix, tick in enumerate(axes.get_xticklabels()):
            tick.set_rotation(55)
            tick.set_fontsize(9) 
            if tick.get_text() == in_tokens[pos_token]:
                try:
                    if axes.get_xticklabels()[ix+1].get_text() == ca:
                        if axes.get_xticklabels()[ix-1].get_text() == cp:
                            tick.set_color("magenta")
                            tick.set_fontsize(10)
                except:
                    tick.set_color("magenta")
                    tick.set_fontsize(10)

def display_per_layer_attn(in_tokens: list, 
                           attns: torch.Tensor, 
                           pos_token: int, 
                           layer: int) -> None:
    "Given layer number and position index, displays attention weights for a token."
    layer -= 1
    fig, axes = plt.subplots(figsize = (20, 8))
    fig.tight_layout(pad=5.0)
    attns = attns.permute(0, 2, 1, 3)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    print(f'\nLayerwise attention weights:')
    sns.heatmap(attentions_pos[layer], 
                vmin = 0, vmax = 1, 
                xticklabels = in_tokens)
    axes.set_title(f'Layer #{layer + 1} ' )
    axes.set_ylabel('Heads')
    for ix, tick in enumerate(axes.get_xticklabels()):
            tick.set_rotation(55)
            tick.set_fontsize(9) 
            if tick.get_text() == in_tokens[pos_token]:
                try:
                    if axes.get_xticklabels()[ix+1].get_text() == ca:
                        if axes.get_xticklabels()[ix-1].get_text() == cp:
                            tick.set_color("magenta")
                            tick.set_fontsize(10)
                except:
                    tick.set_color("magenta")
                    tick.set_fontsize(10)

def display_per_layer_per_head_attn(in_tokens: list, 
                                    attns: torch.Tensor, 
                                    pos_token: int, 
                                    layer: int, 
                                    head: int) -> None:
    "Given head, layer number and position index, displays attention weights for a token."
    layer -= 1
    head -= 1
    fig, axes = plt.subplots(figsize = (20, 1))
    attns = attns.permute(0, 2, 1, 3)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    print(f'\nLayerwise attention weights for given head:')
    sns.heatmap(attentions_pos[layer][head].reshape(1, -1), 
                vmin = 0, vmax = 1, 
                xticklabels = in_tokens)
    axes.set_title(f'Layer #{layer + 1} ' )
    axes.set_ylabel(f'Head #{head + 1}')
    for ix, tick in enumerate(axes.get_xticklabels()):
            tick.set_rotation(55)
            tick.set_fontsize(9) 
            if tick.get_text() == in_tokens[pos_token]:
                try:
                    if axes.get_xticklabels()[ix+1].get_text() == ca:
                        if axes.get_xticklabels()[ix-1].get_text() == cp:
                            tick.set_color("magenta")
                            tick.set_fontsize(10)
                except:
                    tick.set_color("magenta")
                    tick.set_fontsize(10)


def display_per_head_attn_alt(in_tokens: list, 
                              attns: torch.Tensor, 
                              pos_token: int, 
                              head: int) -> None:
    "Given head number and position index, displays attention weights for a token."
    head -= 1
    fig, axes = plt.subplots(figsize = (20, 8))
    fig.tight_layout(pad=5.0)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    print(f'\nHeadwise attention weights:')
    plt.imshow(attentions_pos[head])
    plt.xticks(range(len(in_tokens)), in_tokens, rotation=45)
    plt.title(f'Head #{head + 1} ' )
    plt.ylabel('Layers')
    plt.show()

def display_avg_attn(in_tokens: list, attns: torch.Tensor, pos_token: int) -> None:
    """Display mean attention weights for token."""
    fig, axes = plt.subplots(figsize = (20, 8))
    fig.tight_layout(pad=5.0)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    avg_attn = attentions_pos.mean(dim=0)
    plt.title(f'\nAverage attention weights')
    sns.heatmap(avg_attn, vmin = 0, vmax = 1, xticklabels = in_tokens)
    axes.set_ylabel('Layers')
    plt.xticks(rotation=55, fontsize=8)
    for ix, tick in enumerate(axes.get_xticklabels()):
        if tick.get_text() == in_tokens[pos_token]:
            try:
                if axes.get_xticklabels()[ix+1].get_text() == ca:
                    if axes.get_xticklabels()[ix-1].get_text() == cp:
                        tick.set_color("magenta")
                        tick.set_fontsize(10)
            except:
                tick.set_color("magenta")
                tick.set_fontsize(10)
    plt.show()

def display_avg_attn_l(in_tokens: list, attns: torch.Tensor, pos_token: int) -> None:
    """Display mean attention weights for token."""
    fig, axes = plt.subplots(figsize = (20, 8))
    fig.tight_layout(pad=5.0)
    attns = attns.permute(0, 2, 1, 3)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    avg_attn = attentions_pos.mean(dim=0)
    plt.title(f'\nAverage attention weights (layers)')
    sns.heatmap(avg_attn, vmin = 0, vmax = 1, xticklabels = in_tokens)
    axes.set_ylabel('Heads')
    plt.xticks(rotation=55, fontsize=8)
    for ix, tick in enumerate(axes.get_xticklabels()):
        if tick.get_text() == in_tokens[pos_token]:
            try:
                if axes.get_xticklabels()[ix+1].get_text() == ca:
                    if axes.get_xticklabels()[ix-1].get_text() == cp:
                        tick.set_color("magenta")
                        tick.set_fontsize(10)
            except:
                tick.set_color("magenta")
                tick.set_fontsize(10)
    plt.show()
    

def display_avg_attn_per_layer(in_tokens: list, 
                               attns: torch.Tensor, 
                               pos_token: int, 
                               layer: int) -> None:
    """Display mean attention weights for token per layer."""
    layer -= 1
    fig, axes = plt.subplots(figsize = (20, 1))
    attns = attns.permute(0, 2, 1, 3)
    attentions_pos = attns[pos_token]
    cp = " ".join(in_tokens[pos_token-1:pos_token])
    ca = " ".join(in_tokens[pos_token+1:pos_token+2])
    context = cp + f" {in_tokens[pos_token]} " + ca
    avg_attn = attentions_pos[layer].mean(dim=0)
    print(f'\nPer-layer average attention weights')
    sns.heatmap(avg_attn.reshape(1, -1), vmin = 0, vmax = 1, xticklabels = in_tokens)
    axes.set_title(f'Layer #{layer + 1} ' )
    axes.set_ylabel('Heads (avg)')
    plt.xticks(rotation=55, fontsize=8)
    for ix, tick in enumerate(axes.get_xticklabels()):
        if tick.get_text() == in_tokens[pos_token]:
            try:
                if axes.get_xticklabels()[ix+1].get_text() == ca:
                    if axes.get_xticklabels()[ix-1].get_text() == cp:
                        tick.set_color("magenta")
                        tick.set_fontsize(10)
            except:
                tick.set_color("magenta")
                tick.set_fontsize(10)
    plt.show()

#### Logits

In [None]:
def get_model_gen(mpath: str='', typ: str='2', dev: str='cuda') -> Tuple:
    """Get subclassed model, tokenizer from pretrained."""
    logging.basicConfig(level=logging.INFO)
    tokenizer = GPT2Tokenizer.from_pretrained(mpath)
    if typ == "2":
        model = GPT2DoubleHeadsModel.from_pretrained(mpath, 
                                                     output_attentions=True)
    else:
        model = GPT2LMHeadModel.from_pretrained(mpath)
    model.eval() # deactivate dropout for reproducibility
    model.to(dev)
    return tokenizer, model

# Define function for loading and processing ROCStories data file for generation prompts.
def load_rocstories_dataset(dataset_path: str) -> Tuple[List[str], pd.DataFrame, pd.DataFrame]:
    """Output story, ending w/ special tokens."""
    sop = "_start_"
    eop = "_delimiter_"
    roc_df = pd.read_csv(dataset_path, sep=',', usecols=[2, 3, 4, 5, 6])
    df_endings = roc_df.sentence5
    df_stories = roc_df.loc[:, :'sentence4'].copy()
    story_bodies = df_stories.sentence1 + " " + df_stories.sentence2 + " " + df_stories.sentence3 + " " + df_stories.sentence4
    df_stories.sentence1 = sop + df_stories.sentence1
    df_stories.sentence4 = df_stories.sentence4 + eop
    return story_bodies, df_stories, df_endings
    
def model_c(model: str) -> str:
    return model

def get_endings(mname: str) -> list:
    """Use generation function to collect stories+endings."""
    gen_endings = pd.read_csv(f"endings/{mname}_gen_ends.txt", sep='\t').values.tolist()
    return gen_endings

def load_res(mname: str) -> pd.DataFrame:
    """Load results as DataFrame."""
    df = pd.read_csv(f"stories/seg_results_{mname}.txt", sep='\t', error_bad_lines=False)
    return df

def get_logits(gen: str, cref: str, model, tokenizer, mname: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get logits for generated endings."""
    input_ids_gen = torch.tensor(tokenizer.encode(gen, add_special_tokens=False)).unsqueeze(0).to("cuda")  # Batch size 1
    input_ids_cref = torch.tensor(tokenizer.encode(cref, add_special_tokens=False)).unsqueeze(0).to("cuda")

    with torch.no_grad():
        outputs_gen = model(input_ids_gen, lm_labels=input_ids_gen)
        _, lm_logits, mc_logits, _, attns = outputs_gen # lm_loss, lm_logits, mc_logits, presents, attns
        outputs_cref = model(input_ids_cref, lm_labels=input_ids_cref)
    return lm_logits, attns

def collect_logits(gen_stories: list, model, tokenizer, mname: str) -> list:
    """Collect each story's perplexity results for generated, correct endings."""
    df_endings_list = df_endings.values.tolist()
    results = []
    for ix, story in enumerate(gen_stories):
        story = story.split("\t")
        try:
            gen_end = story[1]
        except:
            gen_end = "_none_"
        if not mname == 'gpt2':
            story_body = "_start_" + story[0]
            gen = story_body + "_delimiter_" + gen_end
            cref = story_body + "_delimiter_" + df_endings_list[ix]
        else:
            story_body = story[0]
            gen = story_body + gen_end
            cref = story_body + df_endings_list[ix]
        
        lm_logits, attns = get_logits(gen, cref, model, tokenizer, mname)
        results.append([lm_logits, story_body, (df_endings_list[ix], gen_end)])
    # Return small batch of logits for softmax viz, the rest for attn viz.
    return results

#### Softmax

In [None]:
# From HuggingFace
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size, vocabulary size)
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            Make sure we keep at least min_tokens_to_keep per batch example in the output
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits
#
    
def softmax(temperature):
    """Softmax w/ temp for logits."""
    global old_logits
    v = len(old_logits)
    plt.figure(figsize=(8,5))
    plt.title("Softmax distribution w/ temperature")
    plt.xlabel("Vocabulary")
    plt.ylabel("Probability")
    probs = F.softmax(old_logits/temperature, dim=-1)
    plt.bar(range(v),
            probs, 
            facecolor='cyan', 
            edgecolor='blue')
    sample = torch.multinomial(probs, replacement=True, num_samples=v)
    next_token = tokenizer.decode(sample)
    plt.xticks(sample, next_token)
    
def softmaxp(temperature):
    """Softmax w/ temp for truncated logits."""
    global p_logits
    v = len(p_logits)
    plt.figure(figsize=(8,5))
    plt.title(f"Truncated distribution w/ temperature; k: {top_k}, p: {top_p}")
    plt.xlabel("Vocabulary")
    plt.ylabel("Probability")
    probs = F.softmax(p_logits/temperature, dim=-1)
    plt.bar(range(v), 
            probs, 
            facecolor='magenta', 
            edgecolor='purple')
    sample = torch.multinomial(probs, replacement=True, num_samples=v)
    next_token = tokenizer.decode(sample)
    plt.xticks(sample, next_token)

def get_probs(plogits: torch.Tensor, prompt: str, ending: str, typ: str, tau: float = 1.0) -> list:
    """WIP: Get probabilities models assigns to each token."""
    prompt_tokens = tokenizer.tokenize(prompt)
    ending = ending
    end_tokens = tokenizer.tokenize(ending)
    if prompt != "_start_":
        story_tokens = prompt_tokens + end_tokens
    else:
        story_tokens = end_tokens
        plogits = plogits[:, 1:, 1:]
    story_tokens = list(map(lambda w: w.replace("Ġ", ""), story_tokens))
    ids = torch.tensor(tokenizer.convert_tokens_to_ids(story_tokens)).unsqueeze(0)
    probs = F.softmax(plogits[0]/tau, dim=-1)
    unmask = torch.zeros(plogits[0].shape)
    for i, v in enumerate(ids[0]):
        unmask[i, v] = probs[i, v]
    unmask = unmask[unmask != 0]
    n = plogits[0].shape[0]
    plt.figure(figsize=(10, 5))
    plt.title(f"Original probabilities assigned by model to {typ} ending tokens.")
    plt.xlabel("Vocabulary")
    plt.ylabel("Probability")
    #new_probs = []
    #for i in range(n):
    #    new_probs.append(probs[i][ids[0][i]])
    c = 'orange' if typ == 'generated' else 'blue'
    plt.plot(unmask, color=c)
    plt.xticks(range(len(story_tokens)), story_tokens, rotation=45)
    plt.show()

## III. Visualization

#### Load model, tokenizer and run on toy examples for visualization.

> Run the following cell to choose model from dropdown:

In [None]:
model_chooser = interactive(model_c, model=[('Base', 'gpt2'),
                                            ('Base to Sentiment to SCT', 'b_sentiment_SCT'),
                                            ('Base to ConceptNet', 'conceptnet'),
                                            ('Base to ROC', 'roc1617'),
                                            ('Base to SCT', 'b_SCT'),
                                            ('ConceptNet to SCT', 'cn_SCT'),
                                            ('ROC to SCT', 'roc1617_SCT'),
                                            ('CN to Sentiment to SCT', 'cn_sentiment_SCT'),
                                            ('ROC to Sentiment to SCT', 'roc1617_sentiment_SCT')])
display(model_chooser)

> Once the model is chosen, run the next cell get the model and tokenizer, set parameters and load the stories and endings:

In [None]:
if model_chooser and model_chooser.result:
    mname = model_chooser.result
    
model_path = f'models/{mname}'

In [None]:
tokenizer, model = get_model_gen(model_path)

### Visualize losses vs. epochs

In [None]:
e1_mc_losses, e2_mc_losses, e3_mc_losses = [], [], []
e1_lm_losses, e2_lm_losses, e3_lm_losses = [], [], []
models = ['b_SCT', 'cn_SCT', 'roc1617_SCT']
models2 = ['b_sentiment_SCT', 'cn_sentiment_SCT', 'roc1617_sentiment_SCT']
for mname in models:
    with open(f"models/{mname}/train_results.txt") as trainres:
        lines = trainres.read().splitlines()

        e1_mc_losses.append(eval(lines[0][19:]))
        e1_lm_losses.append(eval(lines[1][19:]))

        e2_mc_losses.append(eval(lines[2][19:]))
        e2_lm_losses.append(eval(lines[3][19:]))

        e3_mc_losses.append(eval(lines[4][19:]))
        e3_lm_losses.append(eval(lines[5][19:]))

In [None]:
#Reverse labels until MTL script is rerun due to foolish tuple indexing error (corrected).

lm_losses = {'e1': e1_mc_losses, 'e2': e2_mc_losses, 'e3': e3_mc_losses}
mc_losses = {'e1': e1_lm_losses, 'e2': e2_lm_losses, 'e3': e3_lm_losses}

In [None]:
fig, axs = plt.subplots(2, 1, sharex=True, figsize=(8, 8))
fig.tight_layout(pad=4)
axs[0].set_title(f"Multiple Choice (MC) + Language Modeling (LM) fine-tuning losses\n")
axs[0].grid(True)
axs[1].grid(True)

for e, c in zip(['e1', 'e2', 'e3'], ['cyan', 'coral', 'lime']):
    axs[0].plot([1, 2, 3], mc_losses[e], color=c)
    axs[0].set_ylabel("MC Loss")
    axs[0].set_ylim(top=1, bottom=0)
    axs[0].legend(models)

    axs[1].plot([1, 2, 3], lm_losses[e])
    axs[1].legend(models)
    axs[1].set_xticks([1, 2, 3])
    axs[1].set_xlabel("Epochs")
    axs[1].set_ylabel("LM Loss")
    axs[1].set_ylim(top=8, bottom=2)

plt.savefig(f'devoirs/survey/figures/models_losses.png', transparent=True)
plt.show()

### Load data

> Get first story results:

In [None]:
dpath = "datasets/roc_1617_test.csv"
story_bodies, df_stories, df_endings = load_rocstories_dataset(dpath)
# stories = df_stories.join(df_endings).values.tolist()
# gen_endings = get_endings(mname)

gen_stories = load_res(mname)
stories = gen_stories.Story + "\t" + " " + gen_stories.GenEnding

story = stories[0]

num_logits = 50 # We just want enough to demonstrate temperature. More than 100 or so and the notebook explodes.
top_p = 0.9
top_k = 20

results = collect_logits(gen_stories = [story], 
                         model = model, 
                         tokenizer = tokenizer, 
                         mname = mname)
full_logits, prompt, endings = results[0]
logits = full_logits[0, 0, :num_logits].cpu()
ending = endings[0] # endings: (Correct Ending, Generated Ending)

# Filter logits for truncated visualization
old_logits = logits.clone() # Avoid zany side-effects from visualization functions.
p_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

# Get attention weights, set necessary visualization variables
in_tokens, pos_token, attns = get_attns(prompt, 
                                        ending, 
                                        model, 
                                        tokenizer, 
                                        mname)

Ending probabilities according to model:

In [None]:
def run_plot(ix: int):
    gen_ending = gen_stories.GenEnding[ix]

    results = collect_logits(gen_stories = ["\t" + gen_ending], 
                            model = model, 
                            tokenizer = tokenizer, 
                            mname = mname)
    full_logits, prompt, endings = results[0]
    get_probs(full_logits.cpu(), prompt, gen_ending, typ="generated")

    cstory = gen_stories.CorrectEnding[ix]

    results = collect_logits(gen_stories = ["\t" + cstory], 
                            model = model, 
                            tokenizer = tokenizer, 
                            mname = mname)
    full_logits, prompt, endings = results[0]

    in_tokens, pos_token, attns = get_attns(prompt, 
                                            ending, 
                                            model, 
                                            tokenizer, 
                                            mname)
    
    get_probs(full_logits.cpu(), prompt, cstory, typ="correct")

Plot ending probabilities:

In [None]:
interactive(run_plot, ix=range(0, len(gen_stories.Story)), continuous=False)

### Main visualization

#### Run the next cell for softmax visualization on toy example:

In [None]:
w = interactive(softmax, 
                temperature=FloatSlider(min=0.1, 
                                        max=10.01, 
                                        step=0.1, 
                                        description='Temperature:', 
                                        value=1.0, 
                                        continuous_update=False))
y = interactive(softmaxp, 
                temperature=FloatSlider(min=0.1, 
                                        max=10.01, 
                                        step=0.1, 
                                        description='Temperature:', 
                                        value=1.0, 
                                        continuous_update=False))

w.layout.height = '450px'
y.layout.height = '450px'
display(w, y)

#### Attention visualizations

##### Run any below to map attention by token position:

> All heads (slow):

In [None]:
dph_all = interactive(display_attns, 
                      {'manual': True,
                       'manual_name': 'Display attns (all)'},
                        in_tokens=fixed(in_tokens), 
                        attns=fixed(attns), 
                        pos_token=IntSlider(min=0, 
                                            max=len(in_tokens)-1, 
                                            step=1, 
                                            description='position:', 
                                            value=1, 
                                            continuous_update=False))
                                      
display(dph_all)

> Per head:

In [None]:
dph = interactive(display_per_head_attn, 
                  in_tokens=fixed(in_tokens), 
                  attns=fixed(attns), 
                  pos_token=IntSlider(min=0, 
                                      max=len(in_tokens)-1, 
                                      step=1, 
                                      description='position:', 
                                      value=1, 
                                      continuous_update=False), 
                  head=IntSlider(min=1, 
                                 max=12, 
                                 step=1, 
                                 description='head:', 
                                 value=12, 
                                 continuous_update=False))
dph.layout.height = '650px'
display(dph)

> Per layer:

In [None]:
dpl = interactive(display_per_layer_attn, 
                  in_tokens=fixed(in_tokens), 
                  attns=fixed(attns), 
                  pos_token=IntSlider(min=0, 
                                      max=len(in_tokens)-1, 
                                      step=1, 
                                      description='position:', 
                                      value=1, 
                                      continuous_update=False), 
                  layer=IntSlider(min=1, 
                                 max=12, 
                                 step=1, 
                                 description='layer:', 
                                 value=12, 
                                 continuous_update=False))
dpl.layout.height = '650px'
display(dpl)

> Per head, per layer:

In [None]:
dplh = interactive(display_per_layer_per_head_attn, 
                  in_tokens=fixed(in_tokens), 
                  attns=fixed(attns), 
                  pos_token=IntSlider(min=0, 
                                      max=len(in_tokens)-1, 
                                      step=1, 
                                      description='position:', 
                                      value=1, 
                                      continuous_update=False), 
                  layer=IntSlider(min=1, 
                                 max=12, 
                                 step=1, 
                                 description='layer:', 
                                 value=12, 
                                 continuous_update=False),
                  head=IntSlider(min=1, 
                                 max=12, 
                                 step=1, 
                                 description='head:', 
                                 value=12, 
                                 continuous_update=False))
dplh.layout.height = '300px'
display(dplh)

> Averaged over heads:

In [None]:
dph_avg = interactive(display_avg_attn, 
                  in_tokens=fixed(in_tokens), 
                  attns=fixed(attns), 
                  pos_token=IntSlider(min=0, 
                                      max=len(in_tokens)-1, 
                                      step=1, 
                                      description='position:', 
                                      value=1, 
                                      continuous_update=False))
dph_avg.layout.height = '650px'
display(dph_avg)

> Averaged over layers:

In [None]:
dph_avg = interactive(display_avg_attn_l, 
                  in_tokens=fixed(in_tokens), 
                  attns=fixed(attns), 
                  pos_token=IntSlider(min=0, 
                                      max=len(in_tokens)-1, 
                                      step=1, 
                                      description='position:', 
                                      value=1, 
                                      continuous_update=False))
dph_avg.layout.height = '650px'
display(dph_avg)

> Averaged over heads, per layer:

In [None]:
dph_avg_l = interactive(display_avg_attn_per_layer, 
                  in_tokens=fixed(in_tokens), 
                  attns=fixed(attns), 
                  pos_token=IntSlider(min=0, 
                                      max=len(in_tokens)-1, 
                                      step=1, 
                                      description='position:', 
                                      value=1, 
                                      continuous_update=False), 
                  layer=IntSlider(min=1, 
                                 max=12, 
                                 step=1, 
                                 description='layer:', 
                                 value=12, 
                                 continuous_update=False))
dph_avg_l.layout.height = '650px'
display(dph_avg_l)