In [None]:
from dash import dash, html, dcc, Input, Output, State, callback_context
from dash.exceptions import PreventUpdate
import dash_ag_grid as dag
from jupyter_dash import JupyterDash
import dash_bootstrap_components as dbc
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd
import dash, sklearn, jupyter_dash, dash_ag_grid, pandas

for pkg in [dash, sklearn, jupyter_dash, dash_ag_grid, pandas]:
    print(f'{pkg.__name__:-<20}v{pkg.__version__}')


dash----------------v2.9.1
sklearn-------------v1.2.1
jupyter_dash--------v0.4.2
dash_ag_grid--------v2.0.0a5
pandas--------------v1.5.3


In [None]:
def make_dtm_df(text_list):
    vec = CountVectorizer(token_pattern=r"(?u)\b\w+\b")
    X = vec.fit_transform(text_list)
    dtm = pd.DataFrame.sparse.from_spmatrix(
        data=X,
        index=text_list,
        columns=vec.get_feature_names_out())
    dtm['ngrams'] = dtm.apply('sum', axis=1)
    return dtm


def similarity(dtm_df, phrase, min_similarity=1):
    phrasedf = dtm_df.loc[:, dtm_df.loc[phrase, :].gt(0)]
    return (phrasedf
           .assign(
               similarity=phrasedf.iloc[:, :-1].apply('sum', axis=1))
         .sort_values('similarity', ascending=False)
         .query('similarity >= @min_similarity'))

## Simple example with toy data

In [None]:
text_list = [
    'blue green red',
    'blue green yellow',
    'blue black white',
    'white red purple',
    'magenta teal gray',
]


In [None]:
dtm_df = make_dtm_df(text_list)
dtm_df

Unnamed: 0,black,blue,gray,green,magenta,purple,red,teal,white,yellow,ngrams
blue green red,0,1,0,1,0,0,1,0,0,0,3
blue green yellow,0,1,0,1,0,0,0,0,0,1,3
blue black white,1,1,0,0,0,0,0,0,1,0,3
white red purple,0,0,0,0,0,1,1,0,1,0,3
magenta teal gray,0,0,1,0,1,0,0,1,0,0,3


In [None]:
app = JupyterDash(
    __name__,
    external_stylesheets=[dbc.themes.FLATLY],
    suppress_callback_exceptions=True)

app.layout = html.Div([
    dbc.Row([
        dbc.Col(lg=1, xs=1),
        dbc.Col([
            html.Br(),
            html.H1('Word Similarity'), html.Br(),
            dbc.Label('Paste a text list, one phrase per line:'),
            dbc.Textarea(rows=10, cols=30, id='text_list'), html.Br(),
            dbc.Label('Enter a phrase from above to compare:'),
            dbc.Textarea(rows=1, cols=30, id='selected_document'), html.Br(),
            dbc.Button('Submit', id='submit'), html.Br(), html.Br(),
            dcc.Loading(html.Div(id='output'))
        ]),
        dbc.Col(lg=1),
    ])
] + [html.Br() for i in range(15)])

@app.callback(
    Output('output', 'children'),
    Input('submit', 'n_clicks'),
    State('text_list', 'value'),
    State('selected_document', 'value'))
def show_text(n_clicks, text_list, phrase):
    if not n_clicks:
        raise PreventUpdate
    dtm_df = make_dtm_df(text_list.splitlines())
    similarity_df = similarity(dtm_df, phrase).reset_index().rename(columns={'index': 'Text List'})
    table = dag.AgGrid(
        id='similarity_aggrid',
        columnDefs=[{
            "headerName": i,
            "field": i,
            'width': 250 if i == 'Text List' else 130 if i in ['ngrams', 'similarity'] else None,
            "pinned": "left" if i == 'Text List' else 'right' if i in ['ngrams', 'similarity'] else None,
            'filter': 'agNumberColumnFilter' if i != 'Text List' else None,
        }
            for i in similarity_df.columns],
        rowData= similarity_df.to_dict('records'),
        defaultColDef=dict(
            resizable=True,
            sortable=True,
            filter=True,
        ),
        dashGridOptions={"rowSelection":"single"},
        columnSize="autoSizeAll",
        csvExportParams={
            "fileName": f"{phrase.replace(' ', '_')}_similarity.csv",
            },
    )
    return html.Div([
        html.H4([
            'Most similar documents to: ', html.B(f'{phrase}')
        ], style={'align': 'center'}), html.Br(),
        dbc.Button('Export to CSV', id='export_csv_button', n_clicks=0),
        html.Br(),
        table
    ])

@app.callback(
    Output("similarity_aggrid", "exportDataAsCsv"),
    Input("export_csv_button", "n_clicks"))
def export_data_as_csv(n_clicks):
    if n_clicks:
        return True
    return False


app.run_server(debug=True)

Dash is running on http://127.0.0.1:8050/

Dash app running on http://127.0.0.1:8050/
