# Tests
In the following notebook there will be several tests about the source code and the usabilty of the Knowledge Graph. In first place it's required to import all necessary libraries and connect to the database.

In [None]:
# importing necessary libraries
import os
import timeit
import datetime as dt
import pandas as pd
from termcolor import colored
from helpers.helper_functions import init_connection, excel_import, export_to_excel, test_query, reset_db

# init connection to the neo4j database
graph = init_connection()

## Functional Tests

### Test Creation of Knowledge Graph
This test case checks if the Knowledge Graph got created successfully by validating the number of nodes and relationships.

In [None]:
# query to get the number of nodes in the database
query_node_count = """
MATCH (n)
RETURN count(n) as node_count
"""
# run the query and get the data of the result
node_result =  graph.run(query_node_count).data()
# extract the number of nodes from the result
number_of_nodes = node_result[0]['node_count']
# assert the number of nodes is at least 1
assert(number_of_nodes > 0)
# print the number of nodes
print(f"Number of nodes in the database: {number_of_nodes}")


# query to get the number of relationships in the database
query_node_count = """
MATCH ()-[r]->()
RETURN count(r) as relationship_count
"""
# run the query and get the data of the result
node_result = graph.run(query_node_count).data()
# extract the number of relationships from the result
number_of_rels = node_result[0]['relationship_count']
# assert the number of relationships is at least 1
assert(number_of_rels > 0)
# print the number of relationships
print(f"Number of relationships in the database: {node_result[0]['relationship_count']}")

### Get all Doctors
This query returns all doctors and tests if all nodes have a specialization attribute to be sure there are only doctor nodes returned.

In [None]:
# query to get all Doctor nodes
query_node_count = """
MATCH (d:Doctor)
RETURN d.name as name, d.specialization as specialization, d.yearsOfExperience as years_of_experience, d.contactEmail as contact_email
"""
# run the query and get the data of the result
node_result = graph.run(query_node_count).data()
# extract the number of doctors from the result
number_of_doctors = len(node_result)
# assert that all doctors have a specialization to be sure that there are only doctors returned
assert all(doctor['specialization'] for doctor in node_result), node_result
# print the number of doctors
print(f"Number of doctors in the database: {number_of_doctors}")
# print the details of the doctors
pd.DataFrame(node_result)

### Test Export Functionality
This test case checks if the export functionality works correctly. This is done by getting the number of files in the export directory. Afterwards run the export script and check subsequently if the number of files have increased.

In [None]:
export_path = "../data/export"

# number of files in export directory
number_of_files = len(os.listdir(export_path))

# define the current time for the filename
current_time = dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

export_to_excel(current_time=current_time, graph=graph, export_path=export_path)

# check if the number of files in the export directory has increased
number_of_files_after = len(os.listdir(export_path))

# assert that the number of files has increased by 1
assert number_of_files_after - number_of_files == 1, "Export failed: no new file created."

### Test Import Functionality
This test case checks if the import functionality works correctly. This is done by getting the number of nodes and relationships. Afterwards run the import script and check subsequently if the number of nodes and relationships have increased. 

In [None]:
# query to get the number of current nodes
query_node_count = """
MATCH (n)
RETURN count(n) as node_count
"""
# query to get the number of current relationships
query_rel_count = """
MATCH ()-[r]->()
RETURN count(r) as relationship_count
"""

# run the query and get the data of the result
node_result = graph.run(query_node_count).data()
# run the query and get the data of the result
rel_result = graph.run(query_rel_count).data()

# extract the number of nodes from the result
number_of_nodes = node_result[0]['node_count']
# extract the number of relationships from the result
number_of_rels = rel_result[0]['relationship_count']

# define the excel file path
import_file = pd.ExcelFile("../data/import/import_data.xlsx")

# import the data from the excel file into the database
excel_import(excel_file=import_file, graph=graph)

# query to get the number of nodes after importing the data
after_result = graph.run(query_node_count).data()
# query to get the number of relationships after importing the data
after_rel_result = graph.run(query_rel_count).data()

# extract the number of nodes from the result
after_number_of_nodes = after_result[0]['node_count']
# extract the number of relationships from the result
after_number_of_rels = after_rel_result[0]['relationship_count']

# assert that the number of nodes has increased
assert after_number_of_nodes > number_of_nodes, f"Number of nodes before import: {number_of_nodes}, after import: {after_number_of_nodes}"
# assert that the number of relationships has increased
assert after_number_of_rels > number_of_rels, f"Number of relationships before import: {number_of_rels}, after import: {after_number_of_rels}"

## Non-functional Tests - Usability Tests

### Get all illnesses

In [None]:
# define query to get all illness nodes
query_illness = """
MATCH (i:Illness)
RETURN i.name as name, i.ICDCode as ICD_Code, i.description as description
"""

# time the query execution
execution_time = timeit.timeit("graph.run(query_illness).data()", number=100, globals=globals())
# print the average execution time
print(f"Average execution time: {execution_time / 100} seconds")

# run the query and get the data of the result
result = graph.run(query_illness).data()

# print the details of the illness nodes
pd.DataFrame(result)

### Find all symptoms of a specific illness

In [None]:
# define query to get all symptoms of a specific illness
query = """ 
MATCH (s:Symptom)-[r:SYMPTOM_OF]->(i:Illness) 
WHERE i.name = 'Migraine'
RETURN s.name as symptom
"""

print(graph)

test_query(query=query, graph=graph)

### Find all doctors who treated patients with a specific illness

In [None]:
# define query to get all symptoms of a specific illness
query = """ 
MATCH (d:Doctor)-[:TREATS]->(p:Patient)-[:HAS]->(i:Illness)
WHERE i.name = 'Breast Cancer'
RETURN DISTINCT d.name
"""

test_query(query=query, graph=graph)

### List illnesses that share at least one symptom

In [None]:
# define query to get all symptoms of a specific illness
query = """ 
MATCH (i1:Illness)<-[:SYMPTOM_OF]-(s:Symptom)-[:SYMPTOM_OF]->(i2:Illness)
WHERE i1.name <> i2.name
RETURN DISTINCT i1.name AS Illness1, i2.name AS Illness2, s.name AS SharedSymptom
"""

timeit.timeit("graph.run(query).data()", number=100, globals=globals())
print(f"Average execution time: {execution_time / 100} seconds")

result = graph.run(query).data()
pd.DataFrame(result)

### Find patients allergic to drugs they were prescribed

In [None]:
# define query to get all symptoms of a specific illness
query = """ 
MATCH (p:Patient)-[:TAKES]->(d:Drug),
      (p)-[:HAS]->(a:Allergy)
WHERE d.name = a.name
RETURN p.name AS Patient, d.name AS ConflictMedicament
"""

timeit.timeit("graph.run(query).data()", number=100, globals=globals())
print(f"Average execution time: {execution_time / 100} seconds")

result = graph.run(query).data()
df = pd.DataFrame(result)

if len(df) == 0:
    print("Luckily, no patients have a conflict with their medication.")
else:
    display(df)

### Find the most common symptom accross all illnesses

In [None]:
# define query to get all symptoms of a specific illness
query = """ 
MATCH (s:Symptom)-[:SYMPTOM_OF]->(:Illness)
RETURN s.name, COUNT(*) AS Occurrence
ORDER BY Occurrence DESC
LIMIT 1
"""

timeit.timeit("graph.run(query).data()", number=100, globals=globals())
print(f"Average execution time: {execution_time / 100} seconds")

result = graph.run(query).data()
pd.DataFrame(result)

## Restore database to original state

In [None]:
reset_db(graph=graph)

file = pd.ExcelFile(f"../data/export/export_{current_time}.xlsx")

excel_import(excel_file=file, graph=graph)

# delete the export file
os.remove(f"../data/export/export_{current_time}.xlsx")
print("Export file deleted.")
print(colored("--- Tests completed successfully. ---", "green"))