In [None]:
import json
import requests

import matplotlib.pyplot as plt
import pandas as pd

from fiber.cohort import Cohort
from fiber.condition import Procedure, Diagnosis
from fiber.database import read_with_progress
from fiber.database.hana import engine as hana_engine
from fiber.database.mysql import engine as mysql_engine
from fiber.utils import Timer

In [None]:
def slack_notification(text):
    webhook_url = 'https://hooks.slack.com/services/xxxx/yyyy'
    slack_data = {'text': text}

    response = requests.post(
        webhook_url, data=json.dumps(slack_data),
        headers={'Content-Type': 'application/json'}
    )
    if response.status_code != 200:
        raise ValueError(
            'Request to slack returned an error %s, the response is:\n%s'
            % (response.status_code, response.text)
        )

# `sample_cohort.values_for(Diagnosis('584.9', 'ICD-9'))`

This notebook executes the benchmark for value fetching of a specific diagnosis.
The values are fetched for a cohort of heart surgery patients `sample_cohort`.

The queries in `build_query` emulate FIBER's translation process.
However, they have a possibility to limit the number of included MRNs in the result, which should control the result size.

The benchmark is run for up to 15,000 medical record numbers and reports the execution and fetching time of the queries on HANA and MySQL as well as the number of rows fetched per iteration.

In [None]:
sample_cohort = Cohort(Procedure('35.%', 'ICD-9') | Procedure('36.1%', 'ICD-9'))
hs_mrns = sample_cohort.mrns()

In [None]:
def build_query(mrns, limit):
    mrn_query = '('
    for p in list(mrns)[0:limit]:
        mrn_query += "'" + p + "',"
    mrn_query = mrn_query[:-1] + ')'

    hana_query = """
        SELECT DISTINCT D_PERSON.MEDICAL_RECORD_NUMBER, FACT.AGE_IN_DAYS, FD_DIAGNOSIS.CONTEXT_NAME, FD_DIAGNOSIS.CONTEXT_DIAGNOSIS_CODE 
        FROM "MSDW_2018"."FACT" 
            JOIN "MSDW_2018"."D_PERSON" ON "MSDW_2018"."FACT"."PERSON_KEY" = "MSDW_2018"."D_PERSON"."PERSON_KEY" 
            JOIN "MSDW_2018"."B_DIAGNOSIS" ON "MSDW_2018"."FACT"."DIAGNOSIS_GROUP_KEY" = "MSDW_2018"."B_DIAGNOSIS"."DIAGNOSIS_GROUP_KEY" 
            JOIN "MSDW_2018"."FD_DIAGNOSIS" ON "MSDW_2018"."FD_DIAGNOSIS"."DIAGNOSIS_KEY" = "MSDW_2018"."B_DIAGNOSIS"."DIAGNOSIS_KEY" 
        WHERE "MSDW_2018"."FD_DIAGNOSIS"."CONTEXT_NAME" LIKE 'ICD-9' 
            AND upper("MSDW_2018"."FD_DIAGNOSIS"."CONTEXT_DIAGNOSIS_CODE") LIKE '584.9' 
            AND "MSDW_2018"."D_PERSON"."MEDICAL_RECORD_NUMBER" IN 
        """ + mrn_query

    mysql_query = """
        SELECT DISTINCT `D_PERSON`.`MEDICAL_RECORD_NUMBER`, `FACT`.`AGE_IN_DAYS`, `FD_DIAGNOSIS`.`CONTEXT_NAME`, `FD_DIAGNOSIS`.`CONTEXT_DIAGNOSIS_CODE` 
        FROM `FACT` 
            INNER JOIN `D_PERSON` ON `FACT`.`PERSON_KEY` = `D_PERSON`.`PERSON_KEY` 
            INNER JOIN `B_DIAGNOSIS` ON `FACT`.`DIAGNOSIS_GROUP_KEY` = `B_DIAGNOSIS`.`DIAGNOSIS_GROUP_KEY` 
            INNER JOIN `FD_DIAGNOSIS` ON `FD_DIAGNOSIS`.`DIAGNOSIS_KEY` = `B_DIAGNOSIS`.`DIAGNOSIS_KEY`
        WHERE `FD_DIAGNOSIS`.`CONTEXT_NAME` LIKE 'ICD-9' 
            AND upper(`FD_DIAGNOSIS`.`CONTEXT_DIAGNOSIS_CODE`) LIKE '584.9' 
            AND `D_PERSON`.`MEDICAL_RECORD_NUMBER` IN 
        """ + mrn_query
    
    return hana_query, mysql_query

In [None]:
def execute_benchmark(mrns, limits, query_builder):
    hana_benchmark_results = []
    mysql_benchmark_results = []
    number_of_rows = []
    for limit in limits:
        queries = query_builder(mrns, limit)
        with Timer() as t:
            df = read_with_progress(queries[0], hana_engine, silent=True)
        number_of_rows.append((limit, len(df)))
        hana_benchmark_results.append([limit, t.elapsed])
        with Timer() as t:
            read_with_progress(queries[1], mysql_engine, silent=True)
        mysql_benchmark_results.append([limit, t.elapsed])

        slack_notification(f'Done value fetching for {str(limit)} MRNs')

    return (
        pd.DataFrame(hana_benchmark_results, columns=['# Patients', 'Runtime in s']),
        pd.DataFrame(mysql_benchmark_results, columns=['# Patients', 'Runtime in s']),
        pd.DataFrame(number_of_rows, columns=['# Patients', '# Rows'])
    )

In [None]:
limits = [10, 100, 500, 1000, 5000, 10000, 15000]
hana_results, mysql_results, number_of_rows = execute_benchmark(hs_mrns, limits, build_query)

### Result Persisting

In [None]:
hana_results.to_csv('../results/value_fetching/hana.csv', index=False)
mysql_results.to_csv('../results/value_fetching/mysql.csv', index=False)
number_of_rows.to_csv('../results/value_fetching/number_of_rows.csv', index=False)

### Visualization

In [None]:
number_of_rows.plot.line(x='# Patients', y='# Rows')

In [None]:
mysql_results.plot.line(x='# Patients', y='Runtime in s')

In [None]:
hana_results.plot.line(x='# Patients', y='Runtime in s')

In [None]:
results = pd.merge(hana_results, mysql_results, on='# Patients')
results.rename(columns={
    'Runtime in s_x': 'IMDB Runtime in s', 
    'Runtime in s_y': 'MySQL Runtime in s'
}, inplace=True)

plt.figure()

results.plot(logy=1, logx=0, x='# Patients')
plt.ylabel('Runtime in s')

plt.savefig('../figures/value_fetching/runtime.png', dpi=600, bbox_inches="tight")