In [1]:
#!pip install ngrok -q
#!pip install dash -q
#!pip install "dash[diskcache]" -q

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, AutoConfig, StoppingCriteriaList, StoppingCriteria
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import itertools
import torch

  _torch_pytree._register_pytree_node(


In [3]:
# https://github.com/oobabooga/text-generation-webui/blob/2cf711f35ec8453d8af818be631cb60447e759e2/modules/callbacks.py#L12
class _SentinelTokenStoppingCriteria(StoppingCriteria):
    def __init__(self, sentinel_token_ids: list, starting_idx: int):
        StoppingCriteria.__init__(self)
        self.sentinel_token_ids = sentinel_token_ids
        self.starting_idx = starting_idx
        self.shortest = min([x.shape[-1] for x in sentinel_token_ids])

    def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
        for sample in input_ids:
            trimmed_sample = sample[self.starting_idx:]
            trimmed_len = trimmed_sample.shape[-1]
            if trimmed_len < self.shortest:
                continue

            for sentinel in self.sentinel_token_ids:
                sentinel_len = sentinel.shape[-1]
                if trimmed_len < sentinel_len:
                    continue

                window = trimmed_sample[-sentinel_len:]
                if torch.all(torch.eq(sentinel, window)):
                    return True

        return False
####

def generate_stopping_criteria(stopgen_tokens, input_len=0):
    return StoppingCriteriaList([
        _SentinelTokenStoppingCriteria(
            sentinel_token_ids = stopgen_tokens,
            starting_idx=input_len
        )
    ])


# CODE

In [4]:
#model_id = "microsoft/phi-1_5"
model_id = "meta-llama/Llama-2-7b-hf"
torch.set_default_device("cpu")

In [5]:
hf_key = ""
if model_id in ["meta-llama/Llama-2-7b-hf"]:
    hf_key = input("Hugging Face Key: ")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_key)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, token=hf_key)
model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True, token=hf_key)
del hf_key

  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
if model_id in ["microsoft/phi-1_5"]:
    stopgen_tokens = [
        torch.tensor([198, 198]),  # \n\n
        torch.tensor([628])        # \n\n
    ]
    prompt_structure = "Question: {prompt}\n\nAnswer:"
    exclude_token_offset = 3
    fix_characters = [("Ġ", "␣"), ("Ċ", "\n")]
elif model_id in ["meta-llama/Llama-2-7b-hf"]:
    stopgen_tokens = [
        torch.tensor([1]),  # <s>
        torch.tensor([2])   # </s>
    ]
    prompt_structure = "{prompt}\n"
    exclude_token_offset = 0
    fix_characters = []

In [7]:
### ALTERATION ### Divided computation for attentions
### ALTERATION ###  Added function to compute attentions also for prompt

def pad_masked_attentions(attentions, max_len):
    """
    Attention in generative models are masked, we want to plot a heatmap so we must pad all attentions to the same size with 0.0 values
    """
    array_attentions = [np.array(att) for att in attentions]
    new_attentions = [np.concatenate([att, np.zeros([max_len - len(att)])]) for att in array_attentions]
    return np.array(new_attentions)

def compute_complete_padded_attentions(generated_output, layer, head):
    single_layer_attentions = []
    # Prompt tokens
    for single_layer_single_head in torch.squeeze(torch.select(generated_output.attentions[0][layer], 1, head)):
        single_layer_attentions.append(single_layer_single_head)
    # Response tokens
    for attentions_per_token in generated_output.attentions[1:]:
        # Take single layer
        single_layer = attentions_per_token[layer]
        # Take only one head
        single_layer_single_head = torch.select(single_layer, 1, head)
        single_layer_attentions.append(single_layer_single_head)
    # Squeeze dimensions to one a one-dimensional tensor
    pure_attentions = [s.squeeze() for s in single_layer_attentions]
    max_seq_len  = len(pure_attentions[-1])
    # Print last attention heatmap
    padded_attentions = pad_masked_attentions(pure_attentions, max_seq_len)
    return padded_attentions

def compute_batch_complete_padded_attentions(generated_output, heads):
    multi_layer_head_attentions = []
    for head in heads:
        multi_layer_attentions = []
        for layer in range(0, len(generated_output.attentions[0])):
            # Prompt tokens
            prompt_att = [
                torch.squeeze(single_head)
                for single_head in torch.squeeze(torch.select(generated_output.attentions[0][layer], 1, head))
            ]
            # Response tokens
            response_att = [
                torch.squeeze(torch.select(single_layer[layer], 1, head))
                for single_layer in generated_output.attentions[1:]
            ]
            # Pad and merge attentions
            multi_layer_attentions.append(pad_masked_attentions( 
                [att_token for att_token in prompt_att + response_att],
                len(response_att[-1])
            ))
        multi_layer_head_attentions.append(multi_layer_attentions)
    return multi_layer_head_attentions

def plot_attentions(generated_output, layer, head, generated_tokens, past_tokens):
    # Plot 
    data = compute_padded_attentions(generated_output, layer, head)
    fig, ax = plt.subplots(figsize = (12,5))
    im = ax.imshow(data)
    # Show all ticks and label them with the respective list entries
    ax.set_yticks(np.arange(len(generated_tokens)), labels=generated_tokens)
    ax.set_xticks(np.arange(len(past_tokens)), labels=past_tokens, fontsize=8)
    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax)

    ax.set_title(f"Heatmap of attention layers: layer {layer} head {head}")
    fig.tight_layout()
    plt.show()

In [8]:
### ALTERATION ### Adapted functions to work outside model
### ALTERATION ### Added option to include hidden states for prompt and ending token in embed_hidden_states
def _apply_lm_head(model, hidden_states):
    """
    Function which takes as input the hidden states of the model and returns the prediction of the next token.
    Uses the language modeling head of output
    """
    pred_ids = []
    per_token_logits = []
    for i in range(len(hidden_states)):        
        logits = model.lm_head(hidden_states[i])
        logits = logits.float()
        pred_id = torch.argmax(logits)
        pred_ids.append(pred_id)
        per_token_logits.append(logits)
    return pred_ids, per_token_logits

def _apply_input_lm_head(model, hidden_states):
    """
    Function which takes as input the hidden states of the model and returns the prediction of the next token.
    Uses the language modeling head of input
    """
    pred_ids = []
    per_token_logits = []
    for layer in hidden_states:
        output = torch.matmul(layer.to(model.model.embed_tokens.weight.device), model.model.embed_tokens.weight.T)
        token_id = output.argmax(dim=-1)
        pred_ids.append(token_id)
        per_token_logits.append(output)
    return pred_ids, per_token_logits
    
def embed_hidden_states(model, hidden_states, embedding="output", include_prompt=False, include_end=True):
    if embedding not in ['input', 'output']:
        raise ValueError("Embedding not valid")

    end_idx = len(hidden_states) if include_end else len(hidden_states) - 1

    predictions = []
    # Prompt tokens
    if include_prompt:
        for token_states in torch.stack(hidden_states[0]).swapaxes(0, 2):
            if embedding == 'output':
                pred_ids, per_token_logits = _apply_lm_head(model, token_states.swapaxes(0, 1))
            else:
                pred_ids, per_token_logits = _apply_input_lm_head(model, token_states.swapaxes(0, 1))
            predictions.append([int(id) for id in pred_ids])
    # Response tokens
    for n_token in range(1, end_idx):
        if embedding == 'output':
            pred_ids, per_token_logits = _apply_lm_head(model, hidden_states[n_token])
        else:
            pred_ids, per_token_logits = _apply_input_lm_head(model, hidden_states[n_token])
        predictions.append([int(id) for id in pred_ids])
    return predictions

In [9]:
def fix_dataframe_characters(df, replacements, columns=False):
    for old, new in replacements:
        df = df.applymap(lambda x: x.replace(old, new))
    if columns:
        for old, new in replacements:
            df.columns = df.columns.str.replace(old, new)
    return df

In [10]:
def gen_edges(
    edges, line_hover_traces, permanent_edges, permanent_line_hover_traces, attentions,
    generated_output, 
    nodexs, nodeys, 
    head, 
    exclude, 
    threshold, permanent_threshold,
):
    # Cycle through every layer of the model, gathering the aggregated coordinates for each node in a layer
    for idx, coords in enumerate(zip(nodexs, nodeys)):
        xs, ys = coords

        # Do not plot attention traces for the starting layer
        if idx != 0:
            # Compute the attention weights for the current layer and head
            attentions_lh = attentions[idx - 1]

            # Cycle through every node in the layer, gathering its coordinates 
            for i, c in enumerate(zip(xs, ys)):
                x, y = c

                # Cycle through every node in the PREVIOUS layer w.r.t. the current one, gathering its coordinates 
                for ii, cc in enumerate(zip(nodexs[idx - 1], nodeys[idx - 1])):
                    xx, yy = cc
                    weight = attentions_lh[ii][i]
                    
                    if x not in exclude and weight >= threshold:
                        if weight >= permanent_threshold:
                            # Create single edge representing the attention weigth
                            permanent_edges[head].append(go.Scattergl(
                                x=[x, xx],
                                y=[y, yy],
                                name="pedge",
                                mode="lines",
                                hoverinfo="none",
                                line=dict(color="rgba(125,125,125,0.8)", width=2 * weight + 0.25),
                                customdata=[{"type": "edge"}],
                                showlegend=False,
                            ))

                            # Create all hidden hoverable nodes displaying the attention weight values
                            permanent_line_hover_traces[head].append(go.Scattergl(
                                x=[(x+xx)/2],
                                y=[(y+yy)/2],
                                name="ptrace",
                                text=f"{str(weight)}",
                                mode='markers',
                                hoverinfo='text',
                                opacity=0,
                                customdata=[{"type": "hover_trace", "P1": {"x": x, "y": y}, "P2": {"x": xx, "y":yy}}],
                                showlegend=False,
                            ))
                        else:
                            # Create single edge representing the attention weigth
                            edges[head].append(go.Scattergl(
                                x=[x, xx],
                                y=[y, yy],
                                name="edge",
                                mode="lines",
                                hoverinfo="none",
                                line=dict(color="rgba(125,125,125,0.8)", width=2 * weight + 0.25),
                                customdata=[{"type": "edge"}],
                                showlegend=False,
                            ))

                            # Create all hidden hoverable nodes displaying the attention weight values
                            line_hover_traces[head].append(go.Scattergl(
                                x=[(x+xx)/2],
                                y=[(y+yy)/2],
                                name="trace",
                                text=f"{str(weight)}",
                                mode='markers',
                                hoverinfo='text',
                                opacity=0,
                                customdata=[{"type": "hover_trace", "P1": {"x": x, "y": y}, "P2": {"x": xx, "y":yy}}],
                                showlegend=False,
                            ))

def create_transformer_plot(dfs, generated_output, exclude=[], heads=range(0, 1), max_heads=32, threshold=0.002, permanent_threshold=0.4):
    nodes = {key: [] for key in dfs.keys()}
    edges = {head: [] for head in heads}
    permanent_edges = {head: [] for head in heads}
    line_hover_traces = {head: [] for head in heads}
    permanent_line_hover_traces = {head: [] for head in heads}

    attentions = compute_batch_complete_padded_attentions(generated_output, range(0, max_heads))
    
    for key, df in dfs.items():
        
        nodexs = []
        nodeys = []

        # Cycle through every layer of the model, gathering all blocks as nodes
        for idx, row in df.iterrows():

                # Generate coordinates for nodes
                xs = [i for i in range(len(row))]
                ys = [idx] * len(row)

                nodexs.append(xs)
                nodeys.append(ys)

                for x, y in zip(xs, ys):

                    color = "lightblue"
                    if x in exclude:
                        color = "red"

                    # Create nodes
                    nodes[key].append(go.Scattergl(
                        x=[x],
                        y=[y],
                        name="node",
                        mode="markers+text",
                        marker=dict(size=20, color=color),
                        marker_line_width=2,
                        marker_symbol=1,
                        text=row[x],
                        textposition="bottom center",
                        hoverinfo="none",
                        customdata=[{"type":"node"}],
                        showlegend=False
                    ))

    for head in heads:
        att = attentions[head]
        if head == -1:
            att = np.mean(attentions, axis=0)
        gen_edges(
            edges, line_hover_traces, permanent_edges, permanent_line_hover_traces, attentions[head],
            generated_output, 
            nodexs, nodeys, 
            head, 
            exclude, 
            threshold, permanent_threshold
        )

    # Create figure
    fig = go.Figure(data=[])

    # Customize layout
    fig.update_layout(
        title="Transformer Weights Visualization",
        showlegend=True,
        xaxis=dict(showticklabels=False, zeroline=False),
        yaxis=dict(showticklabels=False, zeroline=False),
        plot_bgcolor='white',
        width=1600, height=1500,
        uirevision="const"
    )
    
    return go.FigureWidget(fig), nodes, (edges, permanent_edges), (line_hover_traces, permanent_line_hover_traces)

In [11]:
def compute_edge_cache(nodes, edges, hover_traces, heads=range(0,1)):
    return {
        f"{x}.{y}":  { 
            head: {
                "edges": [i for i, edge in enumerate(edges[head]) if (edge.x[1] == x and edge.y[1] == y) or (edge.x[0] == x and edge.y[0] == y)],
                "hovers": [i for i, trace in enumerate(hover_traces[head]) if (trace.customdata[0]["P1"]["x"] == x and trace.customdata[0]["P1"]["y"] == y) or (trace.customdata[0]["P2"]["x"] == x and trace.customdata[0]["P2"]["y"] == y)]
            } for head in heads
        }
        for x, y in [(x,y) for node_coords in [zip(node_batch.x, node_batch.y) for node_batch in nodes] for x,y in node_coords]
    }

def compute_add_traces(edges, hover_traces, x, y, head):
    return [
        edge for edge in edges[head] if (edge["x"][1] == x and edge["y"][1] == y) or (edge["x"][0] == x and edge["y"][0] == y)
    ] + [
        trace for trace in hover_traces[head] if (
            trace.customdata[0]["P1"]["x"] == x and trace.customdata[0]["P1"]["y"] == y
        ) or (
            trace.customdata[0]["P2"]["x"] == x and trace.customdata[0]["P2"]["y"] == y
        )
    ]

def edge_match_vis(el, el_name, x, y, vis):
    el.visible = not vis
    if el_name == "edge":
        if (el.x[1] == x and el.y[1] == y) or (el.x[0] == x and el.y[0] == y):
            el.visible = vis
    elif el_name == "trace":
        if (el.customdata[0]["P1"]["x"] == x and el.customdata[0]["P1"]["y"] == y) or (el.customdata[0]["P2"]["x"] == x and el.customdata[0]["P2"]["y"] == y):
            el.visible = vis
    return None

In [12]:
def model_generate(model, tokenizer, prompt, max_extra_length, config, min_stop_length, stopping_tokens):
    
    inputs = tokenizer(prompt, return_tensors="pt")
    input_len = len(inputs.input_ids.squeeze().tolist())
    max_len = input_len + max_extra_length
    
    gen_config = config
    stopping_criteria = generate_stopping_criteria(stopping_tokens, input_len + min_stop_length)
    
    generated_output = model.generate(inputs.input_ids, generation_config=gen_config, max_length=max_len, stopping_criteria=stopping_criteria)
    
    text_output = tokenizer.decode(generated_output.sequences.squeeze()[input_len:])
    
    all_tokens = tokenizer.convert_ids_to_tokens(generated_output.sequences[0])
    input_tokens = all_tokens[0:input_len]
    generated_tokens = all_tokens[input_len:]
    
    return text_output, generated_output, {"in": input_tokens, "gen": generated_tokens}

def create_hidden_states_df(model, tokenizer, generated_output, gen_tokens, embedding, include_prompt, fix_characters):
    predictions = embed_hidden_states(model, generated_output.hidden_states, embedding, include_prompt=include_prompt)
    rows = [tokenizer.convert_ids_to_tokens(pred) for pred in predictions]
    df = pd.DataFrame(rows).T.sort_index(ascending=False).rename(columns={n: col for n, col in enumerate(gen_tokens["in"] + gen_tokens["gen"])})
    df = fix_dataframe_characters(df, fix_characters, columns=True)
    return df

def create_attention_visualization(dfs, generated_output, exclude, heads, max_heads, compute_cache=True):
    figure, nodes, edges, hover_traces = create_transformer_plot(dfs, generated_output, exclude, heads=heads, max_heads=max_heads)
    edges, permanent_edges = edges
    hover_traces, permanent_hover_traces = hover_traces
    permanent_traces = {head: {"edges": permanent_edges[head], "hovers": permanent_hover_traces[head]} for head in heads}
    edge_cache = compute_edge_cache(list(nodes.values())[0], edges, hover_traces, heads) if compute_cache else None
    return figure, {"nodes": nodes, "edges": edges, "hovers": hover_traces, "perm": permanent_traces}, edge_cache

In [13]:
#import ngrok
import asyncio
import dash
import diskcache
import uuid

#from kaggle_secrets import UserSecretsClient

from dash import dcc, html, ctx, Patch, DiskcacheManager
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate

import dash_daq as daq

import plotly.graph_objects as go

In [14]:
async def deploy_ngrok(address="http://127.0.0.1:8050"):
    listener = ngrok.forward(addr=address, authtoken=UserSecretsClient().get_secret("ngrok_key"))
    await asyncio.wait_for(listener, timeout=10)
    public_url = listener.result().url()
    print(f"Deploy URL: {public_url}")
    return listener.result()

In [15]:
current_els = []
current_head = -1
current_emb = "output"

cache = diskcache.Cache("./cache")
background_callback_manager = DiskcacheManager(cache)
app = dash.Dash("Test")

# Define the layout of the app
app.layout = html.Div([
    html.Div([
        html.Div([
            dcc.Textarea(id='model_input', placeholder ='Insert prompt...', style={'width': '100%', 'height': "50%"}),
            dcc.Loading(id="model_loading", type="dot", color="#873ba1", children =
                dcc.Textarea(id='model_output', readOnly=True, style={'width': '100%', 'height': "50%"})
            )
        ], style={"float": "left", "width": "70%", "height": "100%", "padding": 2}),
        html.Div([
            html.P(children="# of attention heads to load", style={"margin": "0"}),
            dcc.Dropdown([{"value": i, "label": i+1 if i >= 0 else "Average"} for i in range(-1, model_config.num_attention_heads)], id='model_generate_heads', value=-1, clearable=False),
            daq.NumericInput(id="min_stop_tokens", label="Min # tokens for stopping criteria", value=1, min=0, max=1024, labelPosition="right"),
            daq.NumericInput(id="max_new_tokens", label="Max # of generated tokens", value=10, min=0, max=1024, labelPosition="right"),
            html.Button('Generate', id='model_generate', style={"width": "100%", "height": "20px"}),
        ], style={"float": "right", "height": "100%", "width": "20%", "padding": 2}),
    ], style={"height": "140px"}),
    html.Div([
        html.P("Attention Head Selector"),
        dcc.Slider(marks={}, step=1, value=-1, id='attention_heads'),
        html.P("Embeddings Selector"),
        dcc.RadioItems(['input', 'output'], value='output', inline=True, id='embeddings'),
        dcc.Graph(id='scatterplot'),
    ]),
    dcc.Store(id="run_config"),
    dcc.Store(id="notify"),
    dcc.Store(id="graph_id")
])

# Aggregate run parameters data
@app.callback(
    Output('run_config', 'data'),
    [
        Input('model_generate_heads', 'value'),
        Input('min_stop_tokens', 'value'),
        Input('max_new_tokens', 'value'),
    ]
)
def update_run_config(gen_heads, min_stop_tokens, max_new_tok):
    return {"gen_heads": gen_heads, "min_stop_tokens": min_stop_tokens, "max_new_tok": max_new_tok}


@cache.memoize()
def model_output(prompt, session, run_config):
    prompt = prompt_structure.format(prompt=prompt)
    gen_config = GenerationConfig(output_attentions=True, output_hidden_states=True, return_dict_in_generate=True)
    text_output, generated_output, gen_tokens = model_generate(
            model, tokenizer, prompt, 
            max_extra_length=run_config["max_new_tok"], 
            config=gen_config, 
            min_stop_length=run_config["min_stop_tokens"], stopping_tokens=stopgen_tokens
    )   
    dfs = {}
    for emb in ["input", "output"]:
        dfs[emb] = create_hidden_states_df(
            model, tokenizer, generated_output, gen_tokens, emb, 
            include_prompt=True, fix_characters=fix_characters
        )
    exclude_prompt_returns = len(gen_tokens["in"]) - exclude_token_offset - 1
    figure, fig_els, edge_cache = create_attention_visualization(
        dfs, generated_output, 
        exclude=[0, exclude_prompt_returns], 
        heads=range(-1, run_config["gen_heads"]+1), max_heads=model_config.num_attention_heads, 
        compute_cache=False,
    )
    return text_output, fig_els, figure

# Define callback to generate output
@app.callback(
    [
        Output('model_output', 'value'),
        Output('scatterplot', 'figure', allow_duplicate=True),
        Output('attention_heads', 'marks'),
        Output('attention_heads', 'value'),
        Output('graph_id', 'data'),
        Output("notify", "data"),
    ],
    Input('model_generate', 'n_clicks'),
    [
        State('model_input', 'value'),
        State('run_config', 'data'),
    ],
    running=[(Output("model_generate", "disabled"), True, False)],
    prevent_initial_call=True,
    background=True,
    manager=background_callback_manager
)
def update_model_generation(click_data, prompt, run_config):
    if ctx.triggered_prop_ids:
        graph_id = str(uuid.uuid4())
        slider_marks = {i: f"Head {i}" for i in range(0, run_config["gen_heads"] + 1)}
        slider_marks.update({-1: "AVG"})
        text_output, fig_els, figure = model_output(prompt, graph_id, run_config)
        return text_output, figure, slider_marks, -1, graph_id, True
    raise PreventUpdate

#Define callback to update scatter plot
@app.callback(
    Output('scatterplot', 'figure'),
    [
        Input('scatterplot', 'hoverData'),
        Input('attention_heads', 'value'),
        Input('embeddings', 'value'),
        Input('notify', 'data'),
    ],[
        State('model_input', 'value'),
        State('graph_id', 'data'),
        State('run_config', 'data'),
    ]
)
def update_scatter_plot(hover_data, attention_head, embeddings, notify, prompt, graph_id, run_config):
    global current_els
    global current_head
    global current_emb
    if ctx.triggered_prop_ids:
        _, fig_els, fig = model_output(prompt, graph_id, run_config)
        p = Patch()
        # Update for new graph available (suppresses other updates)
        if "notify.data" in ctx.triggered_prop_ids:
            default_emb = "output"
            add_traces = (fig_els["nodes"][default_emb] + fig_els["perm"][attention_head]["edges"] + fig_els["perm"][attention_head]["hovers"])
        # Update for hovering over a plot node
        elif "scatterplot.hoverData" in ctx.triggered_prop_ids and hover_data and "customdata" in hover_data["points"][0] and hover_data["points"][0]["customdata"]["type"] == "node":
            x = hover_data['points'][0]['x']
            y = hover_data['points'][0]['y']
            _ = [p["data"].remove(el) for el in current_els]
            add_traces = compute_add_traces(fig_els["edges"], fig_els["hovers"], x, y, attention_head)
            current_els = add_traces 
        # Update for changing attention head visualization
        elif "attention_heads.value" in ctx.triggered_prop_ids:
            _ = [p["data"].remove(el) for el in fig_els["perm"][current_head]["edges"] + fig_els["perm"][current_head]["hovers"]]
            _ = [p["data"].remove(el) for el in current_els]
            add_traces = fig_els["perm"][attention_head]["edges"] + fig_els["perm"][attention_head]["hovers"]
            current_head = attention_head
            current_els = []
        # Update for changing embeddings visualization
        elif "embeddings.value" in ctx.triggered_prop_ids:
            _ = [p["data"].remove(el) for el in fig_els["nodes"][current_emb]]
            add_traces = fig_els["nodes"][embeddings]
            current_emb = embeddings
            current_els = []
        else:
            raise PreventUpdate
        p["data"].extend(add_traces)
        return p
    raise PreventUpdate

In [16]:
# Run the app
if __name__ == '__main__':
    app.run(debug=True, jupyter_mode="_none", port=8050)



In [17]:
#listener = await deploy_ngrok()

In [18]:
# await listener.close()