In [1]:
from sentence_transformers import SentenceTransformer, util
import csv
import pandas as pd

## Load Model

In [2]:
# choose model

base_model = 'bert-base-nli-mean-tokens-eps1-batch16-lr2e-05'
# base_model = 'bert-base-nli-mean-tokens-eps1-batch32-lr2e-05'
# base_model = 'bert-base-nli-mean-tokens-eps2-batch32-lr2e-05'
# base_model = 'bert-base-nli-mean-tokens-eps3-batch16-lr2e-05'
# base_model = 'bert-base-nli-mean-tokens-eps3-batch16-lr2e-07'
# base_model = 'bert-base-nli-mean-tokens-eps3-batch32-lr2e-05'
model = SentenceTransformer('./model/' + base_model)

## Load WHO Covid19 Data

In [3]:
test_data = csv.reader(open('data/WHO/covid19_QA_data.csv'))

WHO_covid19_QA = {}
for row in test_data:
    question = row[1]
    WHO_covid19_QA[question] = row[2]
    
# Extract All WHO Covid19 Questions
sentences1 = pd.read_csv('data/WHO/covid19_QA_data.csv')['question'].tolist()

In [4]:
def chatbot(user_query):
    
    # Chatbot get User Query
    sentences2 = [user_query] * len(sentences1)

    # Compute Embeddings for Sentences
    embeddings1 = model.encode(sentences1, convert_to_tensor=True)
    embeddings2 = model.encode(sentences2, convert_to_tensor=True)

    # Compute Cosine-Similarits
    cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)

    # Output Similarity Score
    df = pd.DataFrame(columns=['WHO Question','Similarity Score'])
    for i in range(len(sentences1)):
        df.loc[i] = [sentences1[i], cosine_scores[i][i].item()]
    df = df.sort_values('Similarity Score', ascending=False)
    
    # Get Answer by mapping WHO_covid19_QA Dict
    q = df.iloc[0]['WHO Question']
    a = WHO_covid19_QA[q]
    top5 = df[:5]
    return q, a, top5


In [5]:
test_df = pd.read_csv('data/test_data.csv')
test_dict = {}
for index, row in test_df.iterrows():
    value = test_dict.get(row['user_query'], [])
    value.append(row['who_question'])
    test_dict[row['user_query']] = value

In [6]:
test_queries = test_df["user_query"].unique()
n = len(test_queries)

match = 0
for test_query in test_queries:
    print("User Query: {}".format(test_query))
    expected_who_questions = test_dict[test_query]
    question_by_model, answer_by_model, top5_qa_by_model = chatbot(test_query)
    print("WHO Question Selected by Chatbot: {}".format(question_by_model))
    if question_by_model in expected_who_questions:
        match += 1
        print("No. {} / {} : MATCH".format(test_queries.tolist().index(test_query), n))
    else:
        print("No. {} / {} : FAIL".format(test_queries.tolist().index(test_query), n))


print("Model Matches: {} / {}".format(match, n))
print("Model Accuracy: {}".format(match/n))

User Query: Is coronavirus the flu?
WHO Question Selected by Chatbot: What is a coronavirus?
No. 0 / 21 : FAIL
User Query: Are there medicines to treat the coronavirus infection?
WHO Question Selected by Chatbot: Can antiretrovirals be used to prevent COVID-19 infection?
No. 1 / 21 : MATCH
User Query: What is coronavirus?
WHO Question Selected by Chatbot: What is a coronavirus?
No. 2 / 21 : MATCH
User Query: Which foods boost immunity to viruses such as COVID-19?
WHO Question Selected by Chatbot: Can antiretrovirals be used to prevent COVID-19 infection?
No. 3 / 21 : MATCH
User Query: Is it true that COVID-19 does not spread through the air?
WHO Question Selected by Chatbot: Is there anything I should not do?
No. 4 / 21 : FAIL
User Query: Why is no one able to find medicine for Covid 19?
WHO Question Selected by Chatbot: Are there any medicines or therapies that can prevent or cure COVID-19?
No. 5 / 21 : MATCH
User Query: Is social distancing helpful during the COVID-19 infection world