In [None]:
!pip install top2vec
!pip install lime
!apt install subversion

In [None]:
!svn export https://github.com/esti4444/XAI_topic_modeling/trunk/Preload_movies /content/preload

In [3]:
#@title Load classifiers (shift+enter to run)
# load classifier
import os
from joblib import dump, load
import pandas as pd
import numpy as np
from top2vec import Top2Vec
if os.path.exists("/content/preload/topic_classifier.joblib"):
  topic_classifier = load("/content/preload/topic_classifier.joblib") 
else:
  print("Error - Missing classifier!")

if os.path.exists("/content/preload/text_classifier.joblib"):
  text_classifier = load("/content/preload/text_classifier.joblib") 
else:
  print("Error - Missing classifier!")

if os.path.exists("/content/preload/X_train_movies_tabular.csv"):
  df_train = pd.read_csv("/content/preload/X_train_movies_tabular.csv")
  df_train.drop(columns=["Unnamed: 0"], inplace = True) 
else:
  print("Error - Missing training data for tabular lime!")

if os.path.exists("/content/preload/top2vec"):
  model = load("/content/preload/top2vec") 
else:
  print("Error - Missing topic model!")

if os.path.exists("/content/preload/movies.csv"):
  df_movies = pd.read_csv("/content/preload/movies.csv") 
else:
  print("Error - Missing movie examples!")

# load sentence transformer (Document embedding space)

#lime explainer
from lime.lime_text import LimeTextExplainer
from lime.lime_tabular import LimeTabularExplainer
if not os.path.exists("/content/HTML"):
  !mkdir HTML

class_names = [False,True]
def interpret_data(explainer,func, class_names, txt_list, exp_type, num_features):
    scores = [] 
    for (idx, txt) in txt_list:
      exp = explainer.explain_instance(txt, func, num_features=num_features, top_labels=1)
      scores.append(exp.score)
      exp = exp.as_html()
      output_filename = "HTML/{}-{}-explanation.html".format(idx, exp_type)
      # print(output_filename)
      Html_file = open(output_filename, "w", encoding="utf-8")
      Html_file.write(exp)
      Html_file.close()
        
    return scores

def display_html(id, exp_type):
    html = 'HTML/{}-{}-explanation.html'.format(id, exp_type)  
    display(HTML(html))

def get_doc(name):
    text = df_movies[df_movies['name']==name][:1]['noun_plot'].item()
    # print(text)
    # print(len(model.model.docvecs.vectors_docs))
    model.add_documents([text])
    # print(len(model.model.docvecs.vectors_docs))
    id = len(model.model.docvecs.vectors_docs)-1
    # print(id)
    # the distance between a document and a topic is the inner product of their embeddings (cosine similariity)
    res = np.inner(model.model.docvecs.vectors_docs[id], model.topic_vectors)
    # print(res)
    df_row = pd.DataFrame(columns=df_train.columns.to_list())
    df_row.loc[0] = res[:28]
    return df_row
    

text_explainer = LimeTextExplainer(class_names=class_names)
tab_explainer = LimeTabularExplainer(training_data=df_train.values,feature_names = df_train.columns,class_names=class_names,kernel_width=5)

In [12]:
#@title Demo (shift+enter to run)

from IPython.display import IFrame, display, HTML, clear_output

import ipywidgets as widgets

movie_list = df_movies['name'][:20].to_list()
rm_movies = [2,12,18]
l= []
for i, x in enumerate(movie_list):
  if i not in rm_movies:
    l.append(x) 
examples = widgets.Dropdown(
       options=l,
       value=df_movies['name'][:1].item(),
       description='Select Movie:')
textarea = widgets.Textarea(
        value = df_movies['plot'][:1].item(),
       description='Plot:', layout=widgets.Layout(width="auto"))
button1 = widgets.Button(description="Text Classifier")
button2 = widgets.Button(description="Topic Classifier")
output1 = widgets.Output()
output2 = widgets.Output()

def get_bigger(args):        
    textarea.rows = textarea.value.count('\n') + 1
def on_button1_clicked(b):
  with output1:
    clear_output()
    if textarea.value =="":
      print("Error - input is missing")
    else:
      name = examples.value
      df_row = df_movies[df_movies['name']==name][:1]
      txt = df_row['plot'].item()
      print("Prediction:", text_classifier.predict([txt]))
      text_explain = [('1',txt)]
      interpret_data(text_explainer, text_classifier.predict_proba, class_names, text_explain, "text", 20)
      # display_html('1', "text")

def on_button2_clicked(b):
  with output2:
    clear_output()
    if textarea.value =="":
      print("Error - input is missing")
    else:
      name = examples.value
      df = get_doc(name)
      row = df.loc[0]
      # row = df_train.loc[0]
      print("Prediction:", topic_classifier.predict([row]))
      text_explain = [('1',row)]
      interpret_data(tab_explainer, topic_classifier.predict_proba, class_names, text_explain, "tabular", 10)
      # display_html('1', "tabular")


def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
      # copy example to text area
      name = examples.value
      df_row = df_movies[df_movies['name']==name][:1]
      textarea.value = df_row['plot'].item()
      with output2:
        clear_output()
      with output1:
        clear_output()


examples.observe(on_change)
textarea.observe(get_bigger, 'value')

button1.on_click(on_button1_clicked)
button2.on_click(on_button2_clicked)

print("* Select a Movie from the list to predict if Thriller or not (family movie)")
print("* Text explainer button predicts using text input")
print("* Topic explainer button predicts using content found topics\n")
display(examples)
display(textarea)
display(button1, output1)
display(button2, output2)


* Select a Movie from the list to predict if Thriller or not (family movie)
* Text explainer button predicts using text input
* Topic explainer button predicts using content found topics



Dropdown(description='Select Movie:', options=('End Game', 'Dark Water', "Charlie Chan's Secret", 'Ashes to As…

Textarea(value="The president is on his way to give a speech. While he is traveling there a man shows up with …

Button(description='Text Classifier', style=ButtonStyle())

Output()

Button(description='Topic Classifier', style=ButtonStyle())

Output()

In [13]:
#@title Show Explanations (shift+enter to run)
if  os.path.exists("/content/HTML/1-text-explanation.html"): 
  display_html('1', "text")
if  os.path.exists("/content/HTML/1-tabular-explanation.html"):
  display_html('1', "tabular")

Output hidden; open in https://colab.research.google.com to view.