The purpose of this is to make it easier to see how the model makes predictions. This is also to double-check that optimal post-processing is done and no bugs are present.

In [1]:
from datasets import Dataset
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer
import pickle

d = Path("/drive2/kaggle/pii-dd/piidd/inference/outputs/d3l")
model_path = "/drive2/kaggle/pii-dd/piidd/training/basic/outputs/d3l_1e-5_f_ld0.1_msd/checkpoint-1750"

np_preds = np.load(str(d / "preds.npy"))
eval_ds = Dataset.from_parquet(str(d / "ds.pq"))
tokenized_ds = Dataset.from_parquet(str(d / "tds.pq"))
char_preds = pickle.load(open(d / "char_preds.pkl", "rb"))

eidx2tidxs = {}

for i, x in enumerate(tokenized_ds["idx"]):
    if x not in eidx2tidxs:
        eidx2tidxs[x] = []
    
    eidx2tidxs[x].append(i)

didx2eidx = {d: i for i, d in enumerate(eval_ds["document"])}

In [2]:
import json

id2label = json.load(open(str(Path(model_path) / "config.json")))["id2label"]

def make_table_header(id2label):
    html = ""
    for i in range(len(id2label)):
        html += f"<th>{id2label[str(i)]}</th>"

    return html


def generate_html_table(tokens, scores):    
    # Start of the HTML string
    html = """
<!DOCTYPE html>
<html>
<head>
    <style>
        .highest {
            background-color: #66ff33;
        }
        .medium {
            background-color: #ccffff;
        }
        .low {
            background-color: #ff99ff;
        }
        .tiny {
            background-color: #ffb3b3;
        }
        .zero {
            background-color: rgba(0,0,0,0.1);
        }
        table {
            border-collapse: collapse;
            width: 100%;
        }
        th, td {
            text-align: center;
            padding: 0px;
            border: 1px solid black;
        }
        th {
            font-size: 12px;
        }
    </style>
</head>
<body>

<table>
    <tr>
    <th>Token</th>
    """ + make_table_header(id2label) + """
    </tr>"""
    
    for token, score in zip(tokens, scores):
        row = f"\n    <tr>\n        <td>{token}</td>"
        for s in score:
            # Apply high-score class based on the score value
            class_name = "zero"
            if s > 0.9:
                class_name = "highest"
            elif s > 0.7:
                class_name = "medium"
            elif s > 0.5:
                class_name = "low"
            elif s > 0.3:
                class_name = "tiny"
            row += f"\n        <td class='{class_name}'>{s:.2f}</td>"
        row += "\n    </tr>"
        html += row
    
    # End of the HTML string
    html += """
</table>

</body>
</html>
"""
    return html

In [3]:
import gradio as gr
import random
from scipy.special import softmax


def load_preds(doc_id, do_softmax, char=False):
    idx = didx2eidx[doc_id]

    if char:
        preds = char_preds[idx]
    else:
        preds = [np_preds[x] for x in eidx2tidxs[idx]]

    if do_softmax:
        preds = [softmax(x, -1) for x in preds]

    return preds


def get_random_id():

    doc_id = random.choice(list(didx2eidx.keys()))

    return doc_id


def load_token_preds(doc_id, chunk, do_softmax):
    idx = didx2eidx[doc_id]

    preds = load_preds(doc_id, do_softmax)

    offset_mapping = [tokenized_ds[x]["offset_mapping"] for x in eidx2tidxs[idx]]
    text = eval_ds[idx]["full_text"]

    return show_scores(preds[chunk], offset_mapping[chunk], text)


def load_char_preds(doc_id, chunk, do_softmax):
    idx = didx2eidx[doc_id]

    preds = load_preds(doc_id, do_softmax, char=True)

    text = eval_ds[idx]["full_text"]

    return show_scores(preds, text, text, chars=True)



def show_scores(preds, offset_mapping, text, chars=False):
    if chars:
        tokens = list(text)
    else:
        tokens = [text[m[0]:m[1]] for m in offset_mapping]

    try:
        html = generate_html_table(tokens, preds)
    except Exception as e:
        print(e)
        

    return html


with gr.Blocks(css=""".gradio-container {margin: 0 !important; max-width: 4000px};""") as demo:

    rnd = gr.Button("Random")
    do_softmax = gr.Checkbox(False, label="Softmax")
    doc_id = gr.Slider(0, 1000000, 0, label="Doc ID")
    chunk = gr.Slider(0, 10, step=1)


    with gr.Tab("Model token preds"):
        refresh_token = gr.Button("Refresh")
        html_token = gr.HTML()
    with gr.Tab("Char preds"):
        refresh_char = gr.Button("Refresh")
        html_char = gr.HTML()


    rnd.click(fn=get_random_id, outputs=doc_id)

    refresh_token.click(fn=load_token_preds, inputs=[doc_id, chunk, do_softmax], outputs=html_token)
    refresh_char.click(fn=load_char_preds, inputs=[doc_id, chunk, do_softmax], outputs=html_char)


demo.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Traceback (most recent call last):
  File "/home/nicholas/miniconda3/lib/python3.10/site-packages/gradio/queueing.py", line 495, in call_prediction
    output = await route_utils.call_process_api(
  File "/home/nicholas/miniconda3/lib/python3.10/site-packages/gradio/route_utils.py", line 231, in call_process_api
    output = await app.get_blocks().process_api(
  File "/home/nicholas/miniconda3/lib/python3.10/site-packages/gradio/blocks.py", line 1591, in process_api
    result = await self.call_function(
  File "/home/nicholas/miniconda3/lib/python3.10/site-packages/gradio/blocks.py", line 1176, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/nicholas/miniconda3/lib/python3.10/site-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/home/nicholas/miniconda3/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "/ho