In [1]:
import torch
import json
import token
import transformers
import numpy
import seaborn
import copy
import matplotlib.pyplot as plt
from transformers import BertTokenizer, RobertaTokenizer
from IPython.core.display import display, HTML

### Init tokenizer

In [2]:
st = {
    'additional_special_tokens': ['[CIT]', '[URL]', '[SYSTEM]', '[TITLE]', '[SECTION]', '[TEXT]',  '[RANK:0]', '[RANK:1]', '[RANK:2]', '[RANK:3]', '[RANK:4]', '[RANK:5]', '[RANK:6]', '[RANK:7]', '[RANK:8]', '[RANK:9]', '[RANK:10]', '[RANK:11]', '[RANK:12]', '[RANK:13]', '[RANK:14]', '[RANK:15]', '[RANK:16]', '[RANK:17]', '[RANK:18]', '[RANK:19]', '[RANK:20]', '[RANK:21]', '[RANK:22]', '[RANK:23]', '[RANK:24]', '[RANK:25]', '[RANK:26]', '[RANK:27]', '[RANK:28]', '[RANK:29]', '[RANK:30]', '[RANK:31]', '[RANK:32]', '[RANK:33]', '[RANK:34]', '[RANK:35]', '[RANK:36]', '[RANK:37]', '[RANK:38]', '[RANK:39]', '[RANK:40]', '[RANK:41]', '[RANK:42]', '[RANK:43]', '[RANK:44]', '[RANK:45]', '[RANK:46]', '[RANK:47]', '[RANK:48]', '[RANK:49]', '[RANK:50]', '[RANK:51]', '[RANK:52]', '[RANK:53]', '[RANK:54]', '[RANK:55]', '[RANK:56]', '[RANK:57]', '[RANK:58]', '[RANK:59]', '[RANK:60]', '[RANK:61]', '[RANK:62]', '[RANK:63]', '[RANK:64]', '[RANK:65]', '[RANK:66]', '[RANK:67]', '[RANK:68]', '[RANK:69]', '[RANK:70]', '[RANK:71]', '[RANK:72]', '[RANK:73]', '[RANK:74]', '[RANK:75]', '[RANK:76]', '[RANK:77]', '[RANK:78]', '[RANK:79]', '[RANK:80]', '[RANK:81]', '[RANK:82]', '[RANK:83]', '[RANK:84]', '[RANK:85]', '[RANK:86]', '[RANK:87]', '[RANK:88]', '[RANK:89]', '[RANK:90]', '[RANK:91]', '[RANK:92]', '[RANK:93]', '[RANK:94]', '[RANK:95]', '[RANK:96]', '[RANK:97]', '[RANK:98]', '[RANK:99]']
}

tokenizer = RobertaTokenizer.from_pretrained("roberta-large", do_lower_case="True")
tokenizer.add_special_tokens(st)

106

### Load interactions file

You can also load other shards from the prediction, i.e. set the shard_id in the range from 0-15

In [3]:
shard_id = 15

In [4]:
colbert_mode_prediction_path = "/checkpoint/fabiopetroni/WAI/Samuel/misc/colbert_interactions_example/"
# other shard files are 0-16, 1-16, 2-16, ...
colbert_interactions_file = colbert_mode_prediction_path + "wafer-dev-kiltweb.jsonl.15-16"

colbert_interactions = list()
with open(colbert_interactions_file) as f:
    for line in f:
        colbert_interactions.append(json.loads(line))

In [5]:
# Uncomment to see raw text from instance to check
#colbert_interactions[instance_id]["input"]

In [6]:
# Uncomment to see raw text from instance to check
#colbert_interactions[instance_id]["output"][0]["provenance"][0]["text"]

### Define helper functions for visualization

In [318]:
def get_lower_bound(_interactions):
    #return interactions.min()
    interactions = _interactions.view(-1)
    size = interactions.size(0)
    return interactions.topk(int((size*0.99))).values[-1]

def cl(tok_str, t):
    if isinstance(tokenizer, BertTokenizer):
        if t.startswith("##"):
            if len(tok_str) > 0:
                tok_str.pop(-1)
            t = t[2:]

    if isinstance(tokenizer, RobertaTokenizer):
        if t.startswith("Ġ"):
            if len(tok_str) > 0:
                tok_str.append(" ")
            t = t[1:]
        if t == "<s>":
            t = "[S]"
        if t == "</s>":
            t = "[/S]"
        if t == "<pad>":
            t = ""
        if t == "âĢ":
            t = ""
    return tok_str, t
    
def print_highlighted(interactions, toks, other_toks, other_ind):
    _interactions = copy.deepcopy(interactions)
    if _interactions.min() < 0:
        _interactions += _interactions.min()
    _interactions = interactions - get_lower_bound(_interactions)
    _interactions[_interactions < 0] = 0
    _interactions = ((1-_interactions / _interactions.max()) * 255).int().tolist() 
    tok_str = list()
    for i,t in enumerate(toks):
        tok_str, t = cl(tok_str, t)
        
        attended = ' '.join([
            cl([], other_toks[j])[1] if j != other_ind[i] else "[" + cl([], other_toks[j])[1] + "]"
            for j in range(
                max(0, other_ind[i]-3),
                min(other_ind[i]+3, len(other_toks)),
                1
            )])
        
        tok_str.append(
            f'<span style="background:rgb({_interactions[i]}, 255, {_interactions[i]})" title="{attended}">{t}</span>'
        )
        if isinstance(tokenizer, BertTokenizer):
            tok_str.append(" ")
        
    display(HTML("".join(tok_str) + "<p>"))

def print_instance(instance_id):
    q_ids = colbert_interactions[instance_id]["output"][0]["provenance"][0]["question_ids"]
    c_ids = colbert_interactions[instance_id]["output"][0]["provenance"][0]["context_ids"]

    q_toks = tokenizer.convert_ids_to_tokens(q_ids[instance_id%2])
    c_toks = tokenizer.convert_ids_to_tokens(c_ids[instance_id%2])

    interactions = colbert_interactions[instance_id]["output"][0]["provenance"][0]["interactions"]
    interactions = torch.tensor(interactions)
    
    print("Wikipedia")
    highlights = interactions[instance_id%2][instance_id%2].clone()
    max_ind = highlights.max(-2).indices
    ind = torch.zeros_like(highlights)
    ind[max_ind, torch.arange(0,ind.size(0))] = 1
    highlights = (highlights*ind).max(-1).values.view(-1)
    aligned = (highlights*ind).max(-1).indices.view(-1)
    
    print_highlighted(
        interactions=highlights,
        toks=q_toks,
        other_toks=c_toks,
        other_ind=aligned,
    )

    print("Passage")
    highlights = interactions[instance_id%2][instance_id%2].clone()
    max_ind = highlights.max(-1).indices
    ind = torch.zeros_like(highlights)
    ind[torch.arange(0,ind.size(0)), max_ind] = 1
    highlights = (highlights*ind).max(-2).values.view(-1)
    aligned = (highlights*ind).max(-2).indices.view(-1)

    print_highlighted(
        interactions=highlights,
        toks=c_toks,
        other_toks=q_toks,
        other_ind=aligned,
    )


In [319]:
#fig, ax = plt.subplots()
#fig.set_size_inches(50., 50.)
#seaborn.heatmap((interactions[0][0][:192][:192]).max(-1).values.view(-1, 1), yticklabels=q_toks)

In [320]:
#fig, ax = plt.subplots()
#fig.set_size_inches(50., 50.)
#seaborn.heatmap((interactions[0][0][:192][:192]).max(-2).values.view(1, -1).t(), yticklabels=c_toks)

### Playground

Put in other numbers for the instance_id.

In [321]:
print(f"You can set instance_id maximal to {len(colbert_interactions)-1}.")

You can set instance_id maximal to 269.


In [351]:
instance_id = 17

In [352]:
print("\n", instance_id, "\n")
print_instance(instance_id=instance_id)
instance_id += 1
# 17, 41


 17 

Wikipedia


Passage


### Show heatmap

In [335]:
def show_hm(instance_id):
    q_ids = colbert_interactions[instance_id]["output"][0]["provenance"][0]["question_ids"]
    c_ids = colbert_interactions[instance_id]["output"][0]["provenance"][0]["context_ids"]

    q_toks = tokenizer.convert_ids_to_tokens(q_ids[instance_id%2])
    c_toks = tokenizer.convert_ids_to_tokens(c_ids[instance_id%2])

    interactions = colbert_interactions[instance_id]["output"][0]["provenance"][0]["interactions"]
    interactions = torch.tensor(interactions)
    fig, ax = plt.subplots()
    fig.set_size_inches(50., 50.)
    inter = interactions[instance_id%2][instance_id%2]
    seaborn.heatmap(inter, xticklabels=c_toks, yticklabels=q_toks)

In [334]:
#show_hm(instance_id=1)