<a href="https://colab.research.google.com/github/bassel-94/Bert-App/blob/main/dash_app.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dash Flask test

In [None]:
# mount google drive storage
from google.colab import drive
drive.mount('/content/drive')

# create temporary file for results
!mkdir results

# unzip from google drive to the created temperary file
!unzip /content/drive/MyDrive/results_V8_light.zip -d results

# install libraries
!pip install fast_bert hydra-core omegaconf jupyter-dash
!pip install -q jupyter-dash==0.3.0rc1 dash-bootstrap-components

In [None]:
import warnings
warnings.filterwarnings("ignore")

import time
import dash
import dash_html_components as html
import dash_core_components as dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
from jupyter_dash import JupyterDash
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd

# import prediction class from fast bert
from fast_bert.prediction import BertClassificationPredictor

# import tokenizer and pipeline for sentiment analysis
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, pipeline

# define predictor class
predictor = BertClassificationPredictor(
                model_path='results/finetuned_model/model_out',
                label_path='results/labels/',
                multi_label=True,
                model_type='camembert-base',
                do_lower_case=False)

# load sentiment analyzer
tokenizer = AutoTokenizer.from_pretrained("tblard/tf-allocine")
model = TFAutoModelForSequenceClassification.from_pretrained("tblard/tf-allocine")
predict_sentiment = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer)

In [None]:
# Define app
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server

# define go figure here and get rid of axis
fig1 = go.Figure()
fig2 = go.Figure()
fig1.update_layout(paper_bgcolor='white', plot_bgcolor='white', 
                   yaxis={'visible': False, 'showticklabels': False}, 
                   xaxis={'visible': False, 'showticklabels': False})
fig2.update_layout(paper_bgcolor='white', plot_bgcolor='white', 
                   yaxis={'visible': False, 'showticklabels': False}, 
                   xaxis={'visible': False, 'showticklabels': False})

# Define Layout
app.layout = dbc.Container(
    fluid=True,
    children=[
              html.H1("BERT App test for multilabel classification!", style={'textAlign': 'center'}),
              html.Hr(),
              dbc.Row([dbc.Col(
                  width=4,
                  children=[dbc.Card(
                      body=True,
                      children=[dbc.FormGroup([
                                               dbc.Label("Qu'en pensez-vous de Groupama?"),
                                               dcc.Textarea(id="original-text", style={"width": "100%", "height": "20vh"})]),
                                dbc.FormGroup([
                                               dbc.Spinner([
                                                            dbc.Button("Get labels", id="button-run"),
                                                            html.Div(id="time-taken")])
                                               ])]
                                     )]
                    ),
                    dbc.Col(dbc.Card(
                        children=[dcc.Graph(id="label-graph", figure = fig1, style={"height": "400px"}),
                                  dcc.Graph(id = "sentiment-graph", figure=fig2)]),
                            md=7)
            ]),
    ],
)

@app.callback(
    [Output("label-graph", "figure"),
     Output("sentiment-graph", "figure"),
     Output("time-taken", "children")],
     [Input("button-run", "n_clicks")],
    [State("original-text", "value")])

def get_labels(n_clicks, original_text):
  
#  if original_text is None or original_text == "":
#    return fig1, fig1, "Did not run."
  
  # get starting runtime
  t0 = time.time()
  
  # get predictions and save to lists
  pred = predictor.predict(original_text)
  l, p = [], []
  
  for label, proba in pred:
    if proba >= 0.5:
      l.append(label)
      p.append(round(proba,2))
  
  # put results in a data frame
  df = pd.DataFrame({"labels":l, "probs":p})
  fig1 = go.Figure()
  fig2 = go.Figure()

  # plot the graph
  fig1.add_trace(go.Bar(x = df.loc[:, "probs"],
                        y = df.loc[:, "labels"],
                        text = df.loc[:, "probs"],
                        textposition='auto',
                        orientation='h'))

  # customize the graph
  fig1.update_layout(autosize=True,
                     paper_bgcolor='white', 
                     plot_bgcolor='white',
                     xaxis_title="Probabilité",
                     yaxis_title="Label(s) détecté(s)",
                     title={'text': "Label(s) prédits en fonction de la probabilité", 'x':0.5, 'xanchor': 'center', 'yanchor': 'top'})
  
  fig1.update_traces(marker_color='steelblue', marker_line_color='rgb(8,48,107)', marker_line_width=1.5, opacity=0.7, width = 0.5)
  fig1.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey', zeroline=True, zerolinewidth=2, zerolinecolor='LightGrey')
  fig1.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
  
  # get sentiment prediction
  sent = predict_sentiment(original_text)[0]
  pred_label = sent["label"]           # extracts the value of the key "label"
  pred_proba = sent["score"]           # extract the value of the key "score"

  # get sentiment gauge
  if pred_label == "POSITIVE":
    color = "limegreen"
  else:
    color = "red"
    
  fig2.add_trace(go.Indicator(
        mode = "gauge+number",
        value = round(pred_proba*100, 2),
        number = {'suffix': "%", 'font': {'size': 30}},
        domain = {'x': [0, 1], 'y': [0, 1]},
        title = {'text': '<b>'+pred_label+'</b>', 'font': {'size': 14}},
        gauge = {'axis': {'range': [0, 100]}, 'bar': {'color': color}}))

  # get finish time and compute difference
  t1 = time.time()
  time_taken = f"It took {t1-t0:.2f} seconds to compute labels and sentiment"
  
  return fig1, fig2, time_taken

In [None]:
app.run_server(mode='external')

Dash app running on:


<IPython.core.display.Javascript object>