To start this Jupyter Dash app, please run all the cells below. Then, click on the **temporary** URL at the end of the last cell to open the app.

In [1]:
#!pip install -q jupyter-dash==0.3.0rc1 dash-bootstrap-components transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.3/45.3 KB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.7/219.7 KB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m62.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m34.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [24]:
!pip install "dash-bootstrap-components"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [9]:
import time

import dash
from dash import html
from dash import dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
from jupyter_dash import JupyterDash
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Load Model
pretrained = "sshleifer/distilbart-xsum-12-6"
model = BartForConditionalGeneration.from_pretrained(pretrained)
tokenizer = BartTokenizer.from_pretrained(pretrained)

# Switch to cuda, eval mode, and FP16 for faster inference
if device == "cuda":
    model = model.half()
model.to(device)
model.eval();

Device: cuda


In [28]:
# Define app
#app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])

server = app.server

controls = dbc.Card(
    [
        (
            [
                dbc.Label("Output Length (# Tokens)"),
                dcc.Slider(
                    id="max-length",
                    min=10,
                    max=50,
                    value=30,
                    marks={i: str(i) for i in range(10, 51, 10)},
                ),
            ]
        ),
        (
            [
                dbc.Label("Beam Size"),
                dcc.Slider(
                    id="num-beams",
                    min=2,
                    max=6,
                    value=4,
                    marks={i: str(i) for i in [2, 4, 6]},
                ),
            ]
        ),
        (
            [
                dbc.Spinner(
                    [
                        dbc.Button("Summarize", id="button-run"),
                        html.Div(id="time-taken"),
                    ]
                )
            ]
        ),
    ],
    body=True,
    style={"height": "275px"},
)


# Define Layout
app.layout = dbc.Container(
    fluid=True,
    children=[
        html.H1("Dash Automatic Summarization (with DistilBART)"),
        html.Hr(),
        dbc.Row(
            [
                dbc.Col(
                    width=5,
                    children=[
                        controls,
                        dbc.Card(
                            body=True,
                            children=[
                                (
                                    [
                                        dbc.Label("Summarized Content"),
                                        dcc.Textarea(
                                            id="summarized-content",
                                            style={
                                                "width": "100%",
                                                "height": "calc(75vh - 275px)",
                                            },
                                        ),
                                    ]
                                )
                            ],
                        ),
                    ],
                ),
                dbc.Col(
                    width=7,
                    children=[
                        dbc.Card(
                            body=True,
                            children=[
                                (
                                    [
                                        dbc.Label("Original Text (Paste here)"),
                                        dcc.Textarea(
                                            id="original-text",
                                            style={"width": "100%", "height": "75vh"},
                                        ),
                                    ]
                                )
                            ],
                        )
                    ],
                ),
            ]
        ),
    ],
)

In [29]:
@app.callback(
    [Output("summarized-content", "value"), Output("time-taken", "children")],
    [
        Input("button-run", "n_clicks"),
        Input("max-length", "value"),
        Input("num-beams", "value"),
    ],
    [State("original-text", "value")],
)
def summarize(n_clicks, max_len, num_beams, original_text):
    if original_text is None or original_text == "":
        return "", "Did not run"

    t0 = time.time()

    inputs = tokenizer.batch_encode_plus(
        [original_text], max_length=1024, return_tensors="pt"
    )
    inputs = inputs.to(device)

    # Generate Summary
    summary_ids = model.generate(
        inputs["input_ids"],
        num_beams=num_beams,
        max_length=max_len,
        early_stopping=True,
    )
    out = [
        tokenizer.decode(
            g, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        for g in summary_ids
    ]

    t1 = time.time()
    time_taken = f"Summarized on {device} in {t1-t0:.2f}s"

    return out[0], time_taken

Run the cell below to run your Jupyter Dash app. Click on the **temporary** URL to access the app.

In [30]:
!pip install macrodemos --upgrade
!pip install -q dash==1.19.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [33]:
#app.run_server()
app.dash()

AttributeError: ignored