In [1]:
import pandas as pd 
import json
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from tokenizers import BertWordPieceTokenizer

MAX_SEQ_LENGTH = 384

#bert layer
bert_layer = hub.KerasLayer("bert_en_uncased_L-12_H-768_A-12_2", trainable=True)

# build tokenizer
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy().decode("utf-8")
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = BertWordPieceTokenizer(vocab=vocab_file, lowercase=True)

# load the model
model= tf.keras.models.load_model('bert_model')

def tokenizer_output(tokenized_context, tokenized_question, seq_maxlen):
  '''
  prepare input arrays for bert 

  '''
  input_ids = tokenized_context.ids + tokenized_question.ids[1:] # ignore [CLS] to the start of the question
  input_type_ids = [0] * len(tokenized_context.ids) + [1] * len(
                            tokenized_question.ids[1:]  )
  input_mask = [1] * len(input_ids)
  # add padding if sentence less than seq_max length 
  padding_length = seq_maxlen - len(input_ids)
  if padding_length >= 0:

    input_ids      = input_ids      + ([0] * padding_length)
    input_mask     = input_mask     + ([0] * padding_length)
    input_type_ids = input_type_ids + ([0] * padding_length)

    return {'input_word_ids' : input_ids, 'input_mask':input_mask, 'input_type_ids':input_type_ids}


def predict_answer(context,question):
  # toknize the context
  tokenized_context = tokenizer.encode(context)
  # tokenize the question
  tokenized_question = tokenizer.encode(question) 
  inp_dict  = tokenizer_output(tokenized_context, tokenized_question, MAX_SEQ_LENGTH)

  pred_start, pred_end = model.predict([np.array([inp_dict['input_word_ids']]),np.array([inp_dict['input_mask']]),np.array([inp_dict['input_type_ids']])])
  
  return tokenizer.decode(np.array(inp_dict['input_word_ids'])[pred_start.argmax():pred_end.argmax()+1])


In [10]:

import dash
from dash import html
from dash import dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
import dash_daq as daq
from jupyter_dash import JupyterDash
import dash_loading_spinners as dls


In [13]:
# Define app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP],title = 'NLP QA',prevent_initial_callbacks = True)



# Define Layout
app.layout = dbc.Container(
    fluid=True,
    children=[
        html.H1("Answer Questions about your Text"),
        html.Hr(),
        dbc.Row(
            [
                dbc.Col(
                    width=5,
                    children=[
                        
                        dbc.Card(
                            body=True,
                            children=[
                                dbc.FormGroup(dbc.Button("Get Answer", id="button-run")),
                                dbc.FormGroup([dbc.Label("Question"), dcc.Textarea(id="question_label", style={"height": "100px", "width": "100%"})]),
                                dbc.FormGroup(
                                    [dls.Roller(children =[
                                            dbc.Label("Answer"),
                                            html.Div(
                                                id="answer_label",
                                                style={
                                                    "width": "100%",
                                                    "height": "calc(75vh - 300px)",
                                                },
                                            )],
                                        color="black",
                                        )
                                    ]
                                )
                            ],
                        ),
                    ],
                ),
                dbc.Col(
                    width=7,
                    children=[
                        dbc.Card(
                            body=True,
                            children=[
                            
                                dbc.Card(
                                body=True,
                                children=[
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Context (Paste here)"),
                                        dcc.Textarea(
                                            id="context_label",
                                            style={"width": "100%", "height": "calc(75vh - 1px)"},
                                        ),
                                    ]
                                )
                              ],
                            )
                          ],
                        )
                    ],
                ),
            ]
        ),
    ],
)

In [14]:
@app.callback(
    Output("answer_label", "children"),
    [Input("button-run", "n_clicks")],
    State("question_label", "value"), State("context_label", "value")
    
)

def run(n_clicks, question, context):
    # print(n_clicks)
    if question :
        if context :
            ans = predict_answer(context=context, question=question)
            return ans 
        else : return "no context enter"
    else : return "no question enter "


In [15]:
if __name__ == "__main__":
    
    app.run_server(debug=False,port=8070)


Dash is running on http://127.0.0.1:8070/

Dash is running on http://127.0.0.1:8070/

Dash is running on http://127.0.0.1:8070/

Dash is running on http://127.0.0.1:8070/

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8070 (Press CTRL+C to quit)
127.0.0.1 - - [01/Sep/2022 11:17:24] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [01/Sep/2022 11:17:24] "GET /_dash-layout HTTP/1.1" 200 -
127.0.0.1 - - [01/Sep/2022 11:17:24] "GET /_dash-dependencies HTTP/1.1" 200 -
127.0.0.1 - - [01/Sep/2022 11:17:24] "GET /_favicon.ico?v=2.6.1 HTTP/1.1" 200 -
127.0.0.1 - - [01/Sep/2022 11:17:37] "POST /_dash-update-component HTTP/1.1" 200 -
