In [1]:
from sentiment_analysis.model import Model
from sentiment_analysis.common.checkpointer import Checkpointer
from sentiment_analysis.experiment import load_settings
from flax import nnx
from jax import numpy as jnp
import jax
from pathlib import Path

In [25]:
path = Path("results/small_mixed_single_2024-07-18_13-20-12")
settings = load_settings(path / "settings.json")

original = Model(settings.model, nnx.Rngs(0))

checkpoints = Checkpointer(path / "checkpoints")
model = checkpoints.restore(original, 49999)

In [6]:
import tokenmonster
vocab = tokenmonster.load(settings.model.vocab.path)

In [47]:
@nnx.jit
def eval_model(model, tokens):
    pred = model(tokens, True, nnx.Rngs(0))
    return jnp.argmax(pred)

def eval_text(text):
    tokens = list(vocab.tokenize(text))
    print('tokens: ' + str(len(tokens)))
    padded_tokens = jnp.array(tokens + [-1] * (115 - len(tokens)), jnp.int16)
    if len(tokens) > 115:
        padded_tokens = padded_tokens[:115]

    return eval_model(model, padded_tokens)

In [27]:
import ipywidgets as widgets
from IPython.display import display

In [46]:
text_area = widgets.Textarea(
    value='',
    placeholder='Type something',
    description='Input:',
    layout=widgets.Layout(width="auto"),
)

submit_button = widgets.Button(
    description='Submit',
    disabled=False,
    button_style='success',
    tooltip='Submit text for classification',
    icon='check'
)

# Create an Output widget to display the classification result
output = widgets.Output()

# Define the function to call when the text changes
def on_text_change(change):
    with output:
        output.clear_output()  # Clear previous output
        result = eval_text(text_area.value)  # Get the classification result
        
        #print("*" * (result.item() + 1))  # Display the result

# Attach the function to the text area's 'value' trait
submit_button.on_click(on_text_change)

# Display the widgets
display(text_area, submit_button, output)

Textarea(value='', description='Input:', layout=Layout(width='auto'), placeholder='Type something')

Button(button_style='success', description='Submit', icon='check', style=ButtonStyle(), tooltip='Submit text f…

Output()