In [None]:
import pyspark
import pandas as pd
import dxpy
import dxdata
import json
import numpy as np
from bokeh.io import show, output_notebook
from bokeh.layouts import gridplot
import random
output_notebook()

In [None]:
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)

In [None]:
import hail as hl
hl.init(sc=sc, default_reference='GRCh38')
db_name = "base_drug_phenos"
db_uri = dxpy.find_one_data_object(name=f"{db_name}", classname="database")['id']
url = f"dnax://{db_uri}/base.ht"
full = hl.read_table(url)

In [None]:
bnf_brand_df = pd.read_csv('../data/bnf_brand.csv')

In [None]:
bnf_df_1 = full.filter(hl.literal(bnf_brand_df['Converted_BNF_Code'].tolist()).contains(full['code']))

In [None]:
brand_names_list = bnf_brand_df['Brand_Name'].tolist()

In [None]:
bnf_df_1 = bnf_df_1.filter(hl.any(lambda substring: bnf_df_1['info'].contains(substring), hl.literal(brand_names_list)))

In [None]:
bnf_df_1.count()

In [None]:
bnf_df_1 = bnf_df_1.key_by(*bnf_df_1.row)

In [None]:
bnf_df_2 = full.filter(hl.literal(bnf_brand_df['BNF_Presentation_Code'].tolist()).contains(full['code']))

In [None]:
bnf_df_2.count()

In [None]:
bnf_df_2 = bnf_df_2.key_by(*bnf_df_2.row)

In [None]:
bnf_df = bnf_df_1.union(bnf_df_2).distinct()

In [None]:
bnf_df = bnf_df.checkpoint()

In [None]:
bnf_df.count()

In [None]:
brand_names_list = list(set(bnf_brand_df['Brand_Name']))

In [None]:
if isinstance(bnf_brand_df, pd.DataFrame):
    bnf_brand_df = hl.Table.from_pandas(bnf_brand_df)

In [None]:
bnf_brand_df = bnf_brand_df.key_by('Brand_Name')
bnf_brand_df = bnf_brand_df.distinct()

In [None]:
def case_insensitive_contains(string, substring):
    return string.lower().contains(substring.lower())
def get_drug_name(info, brand_names):
    return hl.find(lambda substring: case_insensitive_contains(info, substring), hl.literal(brand_names))

In [None]:
bnf_df = bnf_df.annotate(
    drug_name=get_drug_name(bnf_df['info'], brand_names_list)
)

In [None]:
bnf_df = bnf_df.key_by('drug_name').join(bnf_brand_df.key_by('Brand_Name'), how='left')

In [None]:
bnf_cols = ['drug_name', 'eid', 'source', 'code', 'date', 'system', 'info', 'Term']
bnf_df = bnf_df.key_by().select(*bnf_cols)
new_bnf_cols = {'drug_name':'brand_name', 'eid':'eid', 'source':'source', 'code':'code', 'date':'date', 'system':'system', 'info':'info', 'Term':'term'}
bnf_df = bnf_df.rename(new_bnf_cols)

In [None]:
bnf_df.describe()

In [None]:
bnf_df = bnf_df.filter(hl.is_defined(bnf_df['term']))

In [None]:
bnf_df = bnf_df.checkpoint()

In [None]:
read_2_brand_df = pd.read_csv('../data/read_name_drug.csv')

In [None]:
read_2_df = full.filter(hl.literal(read_2_brand_df['read_code'].tolist()).contains(full['code']) & (full['system'] == 'read_2') )

In [None]:
read_2_df.count()

In [None]:
read_3_brand_df = pd.read_csv('../data/ctv3_drug.csv')

In [None]:
read_3_df = full.filter(hl.literal(read_3_brand_df['read_code'].tolist()).contains(full['code']) & (full['system'] == 'read_3') )

In [None]:
read_3_df.count()

In [None]:
if isinstance(read_2_brand_df, pd.DataFrame):
    read_2_brand_df = hl.Table.from_pandas(read_2_brand_df)

In [None]:
if isinstance(read_3_brand_df, pd.DataFrame):
    read_3_brand_df = hl.Table.from_pandas(read_3_brand_df)

In [None]:
read_2_df = read_2_df.key_by('code').join(read_2_brand_df.key_by('read_code'))
read_3_df = read_3_df.key_by('code').join(read_3_brand_df.key_by('read_code'))

In [None]:
read_cols = ['brand_name', 'eid', 'source', 'code', 'date', 'system', 'info', 'term']

In [None]:
read_2_df = read_2_df.key_by().select(*read_cols)
read_3_df = read_3_df.key_by().select(*read_cols)

In [None]:
read_df = read_2_df.union(read_3_df)

In [None]:
read_df = read_df.filter(hl.is_defined(read_df['term']))

In [None]:
read_df = read_df.checkpoint()

In [None]:
read_df.count()

In [None]:
dmd_brand_df = pd.read_csv('../data/dmd_name.csv')

if isinstance(dmd_brand_df, pd.DataFrame):
    dmd_brand_df = hl.Table.from_pandas(dmd_brand_df)

dmd_brand_df = dmd_brand_df.annotate(dmd_code=hl.str(dmd_brand_df['dmd_code']))

In [None]:
dmd_code_list = dmd_brand_df.aggregate(hl.agg.collect(dmd_brand_df.dmd_code))
dmd_code_list_lower = [code.lower() for code in dmd_code_list]
dmd_df = full.filter(hl.literal(dmd_code_list_lower).contains(full['code'].lower()))

In [None]:
dmd_df = dmd_df.key_by('code').join(dmd_brand_df.key_by('dmd_code'))

In [None]:
dmd_cols = ['brand_name', 'eid', 'source', 'code', 'date', 'system', 'info', 'term']
dmd_df = dmd_df.key_by().select(*dmd_cols)

In [None]:
dmd_df = dmd_df.filter(hl.is_defined(dmd_df['term']))

In [None]:
dmd_df.count()

In [None]:
all_df = bnf_df.union(read_df)

In [None]:
all_df = all_df.union(dmd_df)

In [None]:
all_df.count()

In [None]:
all_df.describe()

In [None]:
all_df.group_by('term').aggregate(count=hl.agg.count()).show(100)

In [None]:
all_df.filter(hl.is_missing(all_df['term'])).count()

In [None]:
test_df = all_df.sample(0.0005, seed=42)

In [None]:
test_df.count()

In [None]:
db_name = "mdd_db"
test_tb_name = "test_all_presc.ht"
full_tb_name = "all_presc.ht"

stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)

spark.sql(stmt).show()

In [None]:
db_uri = dxpy.find_one_data_object(name=f"{db_name}", classname="database")['id']
url1 = f"dnax://{db_uri}/{test_tb_name}"
url2 = f"dnax://{db_uri}/{full_tb_name}"

In [None]:
test_df.write(url1, overwrite=True)

In [None]:
all_df.write(url2, overwrite=True)