In [1]:
document = {
    'classification': 'Finance',
    'lines': [
        {
            'score': 0.85,
            'tokens': [
                {'text': 'This', 'score': 0.20},
                {'text': 'is', 'score': 0.30},
                {'text': 'a', 'score': 0.40},
                {'text': 'test', 'score': 0.60},
                {'text': 'example', 'score': 0.70}
            ]
        },
        {
            'score': 0.25,
            'tokens': [
                {'text': 'Can', 'score': 0.90},
                {'text': 'it', 'score': 0.60},
                {'text': 'display', 'score': 0.30},
                {'text': 'well?', 'score': 0.20}
            ]
        }
    ]
}

In [2]:
from IPython.display import Javascript, display
import json

def output_doc(doc) -> None:
    display(Javascript("""
    require.config({
        paths: {
        d3: 'https://d3js.org/d3.v6.min'
    }});

    function outputHAN(element, doc) {
         require(['d3'], function(d3) {
             d3.select(element.get(0)).append('h3').text(`Prediction: ${doc['classification']}`)
             d3.select(element.get(0)).append('br')
             for (const line of doc['lines']){
                 let thisdiv = d3.select(element.get(0)).append('div')
                thisdiv.append('div')
                 .style('height', '20px')
                 .style('width', '20px')
                 .style('background-color', d3.interpolateReds(0.6 * line['score']))
                 .style('display', 'inline-block')
                 .style('margin-right', '4px')
                 .style('vertical-align', 'middle')
                 for (const token of line['tokens']){
                     thisdiv
                         .append('span')
                         .text(token['text'] + ' ')
                         .style('background-color', d3.interpolateBlues(0.6*token['score']))
                         .style('margin', '2px')
                         .style('padding', '2px')
                }
             }
         })
     };
     outputHAN(element, %s);
     """ % json.dumps(doc)))
    
output_doc(document)

<IPython.core.display.Javascript object>

In [3]:
from utils import HANDataset
import pytorch_lightning as pl
from model import HierarchicalAttentionNetwork, Preprocessor
from transformers import AutoTokenizer, AutoModel
from nltk.tokenize import PunktSentenceTokenizer
from tqdm import tqdm
import torch

pretrained_embedding_model = 'distilroberta-base'
embedding_layer = AutoModel.from_pretrained(pretrained_embedding_model).get_input_embeddings()
pre = Preprocessor(PunktSentenceTokenizer(), AutoTokenizer.from_pretrained(pretrained_embedding_model, use_fast=True))

model = HierarchicalAttentionNetwork(n_classes = 10, 
                                    embedding_layer = embedding_layer,
                                    embedding_size = 768,
                                    fine_tune_embeddings = False, 
                                    word_rnn_size = 50, 
                                    sentence_rnn_size = 50, 
                                    word_rnn_layers = 1,
                                    sentence_rnn_layers = 1, 
                                    word_att_size = 100, # size of the word-level attention layer (also the size of the word context vector)
                                    sentence_att_size = 100, # size of the sentence-level attention layer (also the size of the sentence context vector)
                                    dropout = 0.3)
model.load_state_dict(torch.load('model.pth'), strict = False)



_IncompatibleKeys(missing_keys=['sentence_attention.word_attention.embeddings.weight'], unexpected_keys=[])