1. LLaMA2-13B-Chat-GPTQ model should be downloaded.

2. Install packages to the conda environment by running the code:
```
conda create --name token-vis python=3.12
conda activate token-vis
pip install torch numpy transformers ipykernel spacy
python -m spacy download en_core_web_sm
python -m ipykernel install --user --name=token-vis
```

3. Select the ipykernel `token-vis` as the kernel to run this ipynb file
4. Set MODEL_DIR in the first cell and RESULTS_DIR in the second cell

In [2]:
import html 
import base64 
import re 
import numpy as np
import random
import torch

# set tokenizer and spacy
from transformers import AutoTokenizer 
import spacy 

MODEL_DIR = "TheBloke/Llama-2-13B-chat-GPTQ"  # should change to the directory of downloaded LLama2-13B-chat-gptq model
nlp = spacy.load("en_core_web_sm")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

In [None]:
# get data
import json
with open(RESULTS_DIR) as f: results = json.load(f)

In [3]:
def tokenize_with_spacy(text):
    ids = []
    wordwise_ids = []
    for i, word in enumerate(nlp(text)):
        word = word.text 
        ids += (tokenizer(word)['input_ids'][1:])
        wordwise_ids.append(tokenizer(word)['input_ids'][1:])
    return ids, wordwise_ids

In [70]:
def get_tokens_html_code(token_ids, scores=None):
    text_html_code = f"<div class='tokens-container'>"
    nobr_closed = True
    previous_space_flag = False

    for i, token_id in enumerate(token_ids):
        class_name = "token"
        token_decoded = tokenizer.convert_ids_to_tokens([token_id])[0]
        # print(i, token_decoded)
        
        if token_decoded=="<0x0A>": 
            class_name += " line-break-token"
            if nobr_closed : text_html_code += f"<div class='{class_name}' id='token-{i}'></div><br>"
            else: text_html_code += f"<div class='{class_name}' id='token-{i}'></div></nobr><br>"
            nobr_closed = True
            continue

        if "<" in token_decoded: token_decoded = token_decoded.replace("<", "&lt;")
        if ">" in token_decoded: token_decoded = token_decoded.replace(">", "&gt;")

        if "▁" == token_decoded:
            if i == 0: continue
            class_name += " space-token"
            token_decoded = "&nbsp;"
            html_code = f"<div class='{class_name}' id='token-{i}'>{token_decoded}</div>"
            previous_space_flag = True
            if not nobr_closed: html_code = html_code + "</nobr>"
        elif token_decoded in ["▁.", "▁,", "▁'", "▁\""]: 
            if not nobr_closed: html_code = html_code + "</nobr>"
            token_decoded = token_decoded[1:]
            html_code = f"<nobr><div class='{class_name}' id='token-{i}'>{token_decoded}</div>"
            nobr_closed = False 
            previous_space_flag = False
        elif "▁" == token_decoded[0]: 
            if not nobr_closed: html_code = html_code + "</nobr>"
            if i>0: class_name += " left-space-token"
            token_decoded = token_decoded[1:]
            html_code = f"<nobr><div class='{class_name}' id='token-{i}'>{token_decoded}</div>"
            nobr_closed = False 
            previous_space_flag = False
        else:
            if previous_space_flag:
                html_code = f"<nobr><div class='{class_name}' id='token-{i}'>{token_decoded}</div>"
                nobr_closed = False
            else: html_code = f"<div class='{class_name}' id='token-{i}'>{token_decoded}</div>"
            previous_space_flag = False
        
        text_html_code += html_code
    
    if not nobr_closed: text_html_code += "</nobr>"
    text_html_code += "</div>"

    return text_html_code

In [None]:
# index of data to visualize
i = 0
generated_text = results[i]['generated']
label = results[i]['label']  # correct text: aligned, mistake: misaligned, fabrication: fabricated
mkt_scores = results[i]['kld']
at_scores = results[i]['delta_p']
subject_token_pos = results[i]['generated_subject_token_pos']
subject_token_pos = sum(sum(subject_token_pos.values(), []), [])
ids, wordwise_ids = tokenize_with_spacy(generated_text)

In [77]:
# # test data point (don't run this cell)
# generated_text = "Western Rat Snakes reproduce by laying eggs. Females will typically lay between 6-12 eggs per clutch, and the eggs will hatch after an incubation period of approximately 60-70 days. The hatchlings will then go through a series of shedding and growth stages before reaching maturity."
# ids, wordwise_ids = tokenize_with_spacy(generated_text)
# label = "correct text"
# mkt_scores = np.random.random(len(ids))*5 
# at_scores = np.random.random(len(ids))*2-1
# subject_token_pos = {'Western Rat Snake': [[0,1,2,3]]}
# subject_token_pos = sum(sum(subject_token_pos.values(), []), [])

In [108]:
score_type = "mkt"  # change to at to check Alignment Score

mkt_scale = 5  # increase this value to make color lighter
at_scale = 1  # Alignment score is always in the range of [-1,1] so may not need to be changed

if score_type=="mkt": scores = np.array(mkt_scores); score_scale=mkt_scale
elif score_type=="at": scores = np.array(at_scores) ; score_scale=at_scale

scores[subject_token_pos] = 0.0
scores = scores.tolist()

In [109]:
import os 
from IPython.display import display_html

text_html_code = get_tokens_html_code(ids)
html_code_filename = "./vis/vis.html"

html_code = open(html_code_filename, "r").read()
css_code = f"<style>{open('./vis/styles.css', 'r').read()}</style>"
js_code = open("./vis/vis.js", "r").read()
js_b = bytes(js_code, encoding="utf-8")
js_base64 = base64.b64encode(js_b).decode("utf-8")
message_js = f"""
        (function() {{
            const event = new Event('scores');
            event.scores = {scores};
            event.score_type = "{score_type}";
            event.score_scale = {score_scale};
            document.dispatchEvent(event);
        }}())
        """
message_js = message_js.encode()
messenger_js_base64 = base64.b64encode(message_js).decode("utf-8")
message_js = f"""<script src='data:text/javascript;base64,{messenger_js_base64}'></script>"""

html_code = html_code.replace("<!--tokens-slot-->", text_html_code)
html_code = html_code.replace("<!--style-slot-->", css_code)
html_code = html_code.replace("<!--js-slot-->", f"""<script data-notebookMode="true" data-package="{__name__}" src='data:text/javascript;base64,{js_base64}'></script>""")
html_code = html_code.replace("<!--message-slot-->", message_js)


iframe = f"""
        <iframe 
            srcdoc="{html.escape(html_code)}" 
            frameBorder="0" 
            height="300px"
            width="100%">
        """
display_html(iframe, raw=True)