In [3]:
import torch; torch.set_grad_enabled(False)
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import load_orig_ds_txt, tokenize, get_logits, get_correct_probs
from IPython.display import HTML

model_name = "roneneldan/TinyStories-1M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
ds_txt = load_orig_ds_txt("validation[:100]")
ds_tok = [tokenize(tokenizer, txt) for txt in ds_txt]
sample_tok = ds_tok[0]



  0%|          | 0/100 [00:00<?, ?it/s]

## visualize the tokenized prompt
each token is going to be an HTML div with a border around it

generate a `<style>` tag with CSS for the token divs

In [4]:
token_style = {
    "border": "1px solid #888",
    "display": "inline-block",
    # each character of the same width, so we can easily spot a space
    "font-family": "monospace",
    "font-size": "14px",
    "color": "black",
    "background-color": "white",
    "margin": "1px 0px 1px 1px",
    "padding": "0px 1px 1px 1px",
}
style_str = " ".join([f"{k}: {v};" for k, v in token_style.items()])
# every element of class "token" will have this style applied
style_tag = f"<style>.token {{ {style_str} }}</style>"
style_tag

'<style>.token { border: 1px solid #888; display: inline-block; font-family: monospace; font-size: 14px; color: black; background-color: white; margin: 1px 0px 1px 1px; padding: 0px 1px 1px 1px; }</style>'

convert a token to it's HTML representation

In [5]:
def token_to_html_v0(token):
    # non-breakable space, w/o it leading spaces wouldn't be displayed
    str_token = tokenizer.decode(token).replace(" ", "&nbsp;")

    # line break or not
    br = ""
    if str_token == "\n":
        # replace new line character with two characters: \ and n
        str_token = r"\n"
        # add line break in html
        br = "<br>"

    return f"<div class='token'>{str_token}</div>{br}"
print(token_to_html_v0(10435)) # " authorized"
print(token_to_html_v0(198))   # "\n"

<div class='token'>&nbsp;authorized</div>
<div class='token'>\n</div><br>


combine the style tag with divs for each token

In [6]:
def vis_sample_v0(sample_tok):
    token_htmls = []
    for i in range(sample_tok.shape[0]):
        tok = sample_tok[i]
        token_htmls.append(token_to_html_v0(tok))
    html_str = style_tag + "".join(token_htmls)
    display(HTML(html_str))
vis_sample_v0(sample_tok)

There is one small problem. When we select and copy from the visualization, we get "\\" and "n" in addition to line breaks. So let's make the text inside new line token divs unselectable.

### unselectable `\n`

In [7]:
def token_to_html_v1(token):
    # non-breakable space, w/o it leading spaces wouldn't be displayed
    str_token = tokenizer.decode(token).replace(" ", "&nbsp;")

    br = style_str = ""
    if str_token == "\n":
        # replace new line character with two characters: \ and n
        str_token = r"\n"
        # add line break in html
        br = "<br>"
        # NEW THING: make \n unselectable
        style_str = " style='user-select: none'"

    return f"<div class='token'{style_str}>{str_token}</div>{br}"
print(token_to_html_v1(10435)) # " authorized"
print(token_to_html_v1(198))   # "\n"

<div class='token'>&nbsp;authorized</div>
<div class='token' style='user-select: none'>\n</div><br>


In [8]:
def vis_sample_v1(sample_tok):
    token_htmls = []
    for i in range(sample_tok.shape[0]):
        tok = sample_tok[i]
        token_htmls.append(token_to_html_v1(tok))
    html_str = style_tag + "".join(token_htmls)
    display(HTML(html_str))
vis_sample_v1(sample_tok)

## visualize correct next token probabilities

Now we will make the background color represent what probability the model assigned to the actual next token.

### collect probabilities of correct next tokens

In [9]:
correct_probs = get_correct_probs(model, sample_tok)

convert correct probabilities to colors

In [10]:
def probs_to_colors(probs):
    # for the endoftext token
    # no prediction, no color
    colors = ["white"]
    for p in probs.tolist():
        red_gap = 150  # the higher it is, the less red the tokens will be
        green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
        colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
    return colors

In [11]:
def token_to_html_v2(token, bg_color=None):
    # non-breakable space, w/o it leading spaces wouldn't be displayed
    str_token = tokenizer.decode(token).replace(" ", "&nbsp;")

    # NEW THING:
    #  now background color and user select will be defined in
    #  this dict if needed, it will be easier to extend later
    specific_styles = {}
    # for now just adds line break or doesn't
    br = ""

    # NEW THING
    if bg_color:
        specific_styles["background-color"] = bg_color
    if str_token == "\n":
        # replace new line character with two characters: \ and n
        str_token = r"\n"
        # add line break in html
        br += "<br>"
        # this is so we can copy the prompt without "\n"s
        specific_styles["user-select"] = "none"

    # NEW THING: converting the dict into the style attribute
    style_str = ""
    if specific_styles:
        inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
        style_str = f" style='{inside_style_str}'"
    return f"<div class='token'{style_str}>{str_token}</div>{br}"
print(token_to_html_v2(10435, "green")) # " authorized"
print(token_to_html_v2(198, "red"))     # "\n"

<div class='token' style='background-color: green'>&nbsp;authorized</div>
<div class='token' style='background-color: red; user-select: none'>\n</div><br>


In [12]:
def vis_sample_v2(sample_tok, probs=None):
    colors = color = None
    if probs is not None:
        colors = probs_to_colors(probs) 
    token_htmls = []
    for i in range(sample_tok.shape[0]):
        tok = sample_tok[i]
        if colors:
            color = colors[i]
        token_htmls.append(token_to_html_v2(tok, bg_color=color))
    html_str = style_tag + "".join(token_htmls)
    display(HTML(html_str))
vis_sample_v2(sample_tok, correct_probs)

## showing more info on hover

It would be good to see the actual % probability or what were the other predictions

include additional information in `data-*` attributes of token divs

In [13]:
def token_to_html_v3(token, bg_color=None, data=None):
    # NEW THING: we can define arbitrary data as a dict
    data = data or {}  # equivalent to if not data: data = {}
    # non-breakable space, w/o it leading spaces wouldn't be displayed
    str_token = tokenizer.decode(token).replace(" ", "&nbsp;")

    # background or user-select (for \n) goes here
    specific_styles = {}
    # for now just adds line break or doesn't
    br = ""

    if bg_color:
        specific_styles["background-color"] = bg_color
    if str_token == "\n":
        # replace new line character with two characters: \ and n
        str_token = r"\n"
        # add line break in html
        br += "<br>"
        # this is so we can copy the prompt without "\n"s
        specific_styles["user-select"] = "none"

    style_str = data_str = ""
    # converting style dict into the style attribute
    if specific_styles:
        inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
        style_str = f" style='{inside_style_str}'"
    # NEW THING: converting data dict into data attributes
    if data:
        data_str = "".join(f" data-{k}='{v.replace(' ', '&nbsp;')}'" for k, v in data.items())
    return f"<div class='token'{style_str}{data_str}>{str_token}</div>{br}"
print(token_to_html_v3(10435, "green", data=dict(prob="10%")))       # " authorized"
print(token_to_html_v3(198,   "red",   data=dict(top_pred="Once")))  # "\n"

<div class='token' style='background-color: green' data-prob='10%'>&nbsp;authorized</div>
<div class='token' style='background-color: red; user-select: none' data-top_pred='Once'>\n</div><br>


### collect top k predictions

In [14]:
def get_probs(sample_tok, top_k=3):
    """Get probabilities for the actual next token and for top k predictions"""
    # shape: (pos, d_vocab)
    logits = get_logits(model, sample_tok)
    # pos, d_vocab
    probs = torch.softmax(logits, dim=-1)
    # drop the value for the last position, as we don't know
    # what is the correct next token there
    probs = probs[:-1]
    # out of d_vocab values, take the one that corresponds to the correct next token
    correct_probs = probs[range(len(probs)), sample_tok[1:]]
    top_k_probs = torch.topk(probs, top_k, dim=-1)
    return correct_probs, top_k_probs

In [15]:
correct_probs, top_k_probs = get_probs(sample_tok)

In [16]:
style_tag_with_hover = f"<style>.token {{ {style_str} }} #hover_info {{ height: 100px; font-family: monospace }}</style>"

In [24]:
def to_tok_prob_str(tok, prob):
    tok_str = tokenizer.decode(tok).replace(" ", "&nbsp;").replace("\n", r"\n")
    prob_str = f"{prob:.2%}"
    return f"{prob_str:>6} |{tok_str}|"

def vis_sample_v3(sample_tok, correct_probs, top_k_probs):
    colors = probs_to_colors(correct_probs)
    token_htmls = []
    for i in range(sample_tok.shape[0]):
        tok = sample_tok[i]
        data = {}
        if i > 0:
            correct_prob = correct_probs[i-1]
            data["next"] = to_tok_prob_str(tok, correct_prob)
            top_k_probs_tokens = top_k_probs.indices[i-1]
            top_k_probs_values = top_k_probs.values[i-1]
            for j in range(top_k_probs_tokens.shape[0]):
                top_tok = top_k_probs_tokens[j]
                top_prob = top_k_probs_values[j]
                data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob)
        token_htmls.append(token_to_html_v3(tok, bg_color=colors[i], data=data))
    html_str = style_tag_with_hover + "".join(token_htmls) + "<div id='hover_info'>run the cell below</div>"
    display(HTML(html_str))
vis_sample_v3(sample_tok, correct_probs, top_k_probs)

This cell defined JavsScript logic that will format data from the attributes on hover

In [25]:
%%js
var token_divs = document.querySelectorAll('.token');
var hover_info = document.getElementById('hover_info');

token_divs.forEach(function(token_div) {
    token_div.addEventListener('mousemove', function(e) {
        hover_info.innerHTML = ""
        for( var d in this.dataset) {
            hover_info.innerHTML += "<b>" + d + "</b> ";
            hover_info.innerHTML += this.dataset[d] + "<br>";
        }
    });

    token_div.addEventListener('mouseout', function(e) {
        hover_info.innerHTML = ""
    });
});

<IPython.core.display.Javascript object>