In [1]:
import torch
from torch import nn

from labml import logger, lab, monit
from labml.logger import Text, Style
from labml_helpers.module import Module
from python_autocomplete.dataset import Tokenizer
from python_autocomplete.dataset.break_words import SourceCodeTokenizer
from python_autocomplete.evaluate.factory import load_experiment
from python_autocomplete.train import StateUpdater
from IPython.core.display import display, HTML

In [2]:
conf = load_experiment()

In [10]:
def anomalies(tokenizer: Tokenizer, text: str, model: Module, state_updater: StateUpdater, is_token_by_token: bool):
    bw = SourceCodeTokenizer()
    words = bw.tokenize(text)
    attrs = 'style="overflow-x: scroll;"'
    html = HTML(f'<pre {attrs}></pre>')
    handle = display(html, display_id=999)

    line_no = 1
    html_code = [f'<span style="color: blue">   1: </span>',
                 f'<span style="font-weight: bold">{words[0]}</span>']

    prompt = torch.tensor(tokenizer.encode(words[0]), dtype=torch.long, device=model.device).unsqueeze(-1)

    state = None
    softmax = nn.Softmax(-1)

    for word in words[1:]:
        tokens = tokenizer.encode(word)
        prob = 1.
        for token in tokens:
            with torch.no_grad():
                prediction, new_state = model(prompt, state)

            state = state_updater(state, new_state)
            prediction = softmax(prediction[-1, 0])

            token = prompt.new_tensor([token]).unsqueeze(-1)
            if is_token_by_token:
                prompt = token
            else:
                prompt = torch.cat((prompt, token), dim=0)
                prompt = prompt[-512:]

            prob *= prediction[token].item()

        for c in word:
            if c == '\n':
                html_code.append('<br />')
                line_no += 1
                html_code.append(f'<span style="color: blue">{line_no :4d}: </span>')
            elif c == '\r':
                continue
            else:
                html_code.append(f'<span style="opacity: {(1-prob) * 0.8 + 0.2}">{c}</span>')

        html = HTML(f"<pre {attrs}>{''.join(html_code)}</pre>")
        handle.update(html)


In [None]:
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
    sample = f.read()

anomalies(conf.text.tokenizer, sample, conf.model, conf.state_updater, conf.is_token_by_token)