In [None]:
import pyspark
import pandas as pd
import dxpy
import dxdata
import numpy as np
import matplotlib.pyplot as plt
from bokeh.io import show, output_notebook
from bokeh.layouts import gridplot
import seaborn as sns
import random
import re
output_notebook()

In [None]:
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)

In [None]:
import hail as hl
hl.init(sc=sc, default_reference='GRCh38')

In [None]:
db_name = "mdd"
db_uri = dxpy.find_one_data_object(name=f"{db_name}", classname="database")['id']
url = f"dnax://{db_uri}/all_presc_v2.ht"
full = hl.read_table(url)

In [None]:
full.describe()

In [None]:
# missing tabs
missing=full.aggregate(hl.agg.count_where((hl.is_missing(full.tablets)) | (full.tablets==-1) | (full.tablets==0)))
all_count=full.count()
print(f'{missing/all_count*100}%')

In [None]:
# missing doses
missing=full.aggregate(hl.agg.count_where((hl.is_missing(full.dose))))
all_count=full.count()
print(f'{missing/all_count*100}%')

In [None]:
stats = full.group_by(full.eid).aggregate(
    count_rows = hl.agg.count(),
    count_missing_rows = hl.agg.count_where(hl.is_missing(full.tablets))
)

stats = stats.annotate(
    percent_of_missing = 100.0 * stats.count_missing_rows / stats.count_rows
)

df=stats.to_pandas()
plt.figure(figsize=(10, 6))
plt.figure(figsize=(10, 6))
counts, bin_edges, patches = plt.hist(df['percent_of_missing'], bins=100, edgecolor='black')
plt.title('Percentage of missing records in the tablets column')
plt.xlabel('Percentage of missing records')
plt.ylabel('Number of eid')
plt.grid(True)
plt.show()

In [None]:
stats = full.group_by(full.eid).aggregate(
    count_rows = hl.agg.count(),
    count_missing_rows = hl.agg.count_where(hl.is_missing(full.dose))
)

stats = stats.annotate(
    percent_of_missing = 100.0 * stats.count_missing_rows / stats.count_rows
)

df=stats.to_pandas()
plt.figure(figsize=(10, 6))
plt.hist(df['percent_of_missing'], bins=100, edgecolor='black')
plt.title('Percentage of missing records in the dose column')
plt.xlabel('Percentage of missing records')
plt.ylabel('Number of eid')
plt.grid(True)
plt.show()

In [None]:
missing=full.filter((hl.is_missing(full.tablets)) | (full.tablets==-1) | (full.tablets==0))

In [None]:
def perc(term):
    term_df=full.filter(full.term==term)
    stats = term_df.group_by(term_df.eid).aggregate(
        count_rows = hl.agg.count(),
        count_missing_rows = hl.agg.count_where(hl.is_missing(term_df.tablets))
    )

    stats = stats.annotate(
        percent_of_missing = 100.0 * stats.count_missing_rows / stats.count_rows
    )

    df=stats.to_pandas()
    plt.figure(figsize=(10, 6))
    plt.hist(df['percent_of_missing'], bins=100, edgecolor='black')
    plt.title(f'Percentage of missing records in the tablets column for {term}')
    plt.xlabel('Percentage of missing records')
    plt.ylabel('Number of eid')
    plt.grid(True)
    plt.show()

In [None]:
for term in terms:
    perc(term)

In [None]:
bins = list(range(0, 101, 1))

In [None]:
missing=full.filter((hl.is_missing(full.tablets)) | (full.tablets==-1) | (full.tablets==0))
missing=missing.group_by(missing.eid).aggregate(
    count_rows = hl.agg.count()
)

df=missing.to_pandas()
plt.figure(figsize=(10, 6))
counts, bin_edges, patches = plt.hist(df['count_rows'], bins=bins, edgecolor='black')
plt.xlabel('Number of prescriptions')
plt.ylabel('Number of unique eids')

total_count = np.sum(counts)
percentages = (counts / total_count) * 100

for count, edge_left, edge_right, percentage in zip(counts, bin_edges[:-1], bin_edges[1:], percentages):
    if count > 0:
        plt.text(edge_left + (edge_right - edge_left) / 2, count, f'{percentage:.1f}%', 
                 ha='center', va='bottom')
plt.grid(True)
plt.show()

In [None]:
df=full.group_by(full.eid).aggregate(
    count_rows = hl.agg.count()
)

df=df.to_pandas()
plt.figure(figsize=(10, 6))
counts, bin_edges, patches = plt.hist(df['count_rows'], bins=bins, edgecolor='black')
plt.xlabel('Number of prescriptions')
plt.ylabel('Number of unique eids')

total_count = np.sum(counts)
percentages = (counts / total_count) * 100

for count, edge_left, edge_right, percentage in zip(counts, bin_edges[:-1], bin_edges[1:], percentages):
    if count > 0:
        plt.text(edge_left + (edge_right - edge_left) / 2, count, f'{percentage:.1f}%', 
                 ha='center', va='bottom')
plt.grid(True)
plt.show()

In [None]:
dose_counts = full.group_by(full.term, full.dose).aggregate(counts=hl.agg.count())
dose_counts_df = dose_counts.to_pandas()
dose_counts_df['dose'] = dose_counts_df['dose'].fillna(-1).astype(int)

In [None]:
#update the missing dose values with the most frequent dose for each term
most_frequent_doses = (
    dose_counts_df.dropna(subset=['dose']) 
    .groupby('term')
    .apply(lambda x: x.loc[x['counts'].idxmax(), 'dose'])
    .reset_index()
    .rename(columns={0: 'most_frequent_dose'})
)
most_frequent_doses_dict = most_frequent_doses.set_index('term')['most_frequent_dose'].to_dict()
most_frequent_doses_hail = hl.literal(most_frequent_doses_dict)

In [None]:
full = full.annotate(most_frequent_dose=most_frequent_doses_hail[full.term])
full = full.annotate(dose=hl.or_else(full.dose, full.most_frequent_dose))
full = full.drop('most_frequent_dose')

In [None]:
#set null in tabs where tabs was -1
full = full.annotate(
    tablets=hl.or_missing(full.tablets != -1, full.tablets)  
)

In [None]:
tabs_counts = full.group_by(full.term, full.tablets).aggregate(counts=hl.agg.count())
tabs_counts_df = tabs_counts.to_pandas()
tabs_counts_df['tablets'] = tabs_counts_df['tablets'].fillna(-1).astype(int)

In [None]:
counts = full.group_by(full.term, full.tablets, full.dose).aggregate(counts=hl.agg.count())
counts_df = counts.to_pandas()
counts_df['tablets'] = counts_df['tablets'].fillna(-1).astype(int)
counts_df['dose'] = counts_df['dose'].fillna(-1).astype(int)

In [None]:
quantity_counts = full.group_by(full.term, full.dose, full.tablets).aggregate(counts=hl.agg.count())
quantity_counts_df = quantity_counts.to_pandas()
quantity_counts_df['tablets'] = quantity_counts_df['tablets'].fillna(-1).astype(int)

In [None]:
#update the missing quantity values with the most frequent quantity for each term and dose
most_frequent_quantities = (
    quantity_counts_df
    .groupby(['term', 'dose'])
    .apply(lambda x: x.loc[x['counts'].idxmax(), 'tablets'])
    .reset_index()
    .rename(columns={0: 'most_frequent_quantity'})
)

In [None]:
most_frequent_quantity_dict = most_frequent_quantities.set_index(['term', 'dose'])['most_frequent_quantity'].to_dict()

In [None]:
most_frequent_quantity_hail = hl.literal(most_frequent_quantity_dict)

In [None]:
full = full.annotate(most_frequent_quantity=most_frequent_quantity_hail[full.term, full.dose])
full = full.annotate(quantity=hl.or_else(full.tablets, full.most_frequent_quantity))
full = full.drop('most_frequent_quantity')
full = full.drop('tablets')

In [None]:
# missing tabs
missing=full.aggregate(hl.agg.count_where((hl.is_missing(full.quantity)) | (full.quantity==-1) | (full.quantity==0)))
all_count=full.count()
print(f'{missing/all_count*100}%')

In [None]:
# missing doses
missing=full.aggregate(hl.agg.count_where((hl.is_missing(full.dose))))
all_count=full.count()
print(f'{missing/all_count*100}%')

In [None]:
# delete record with missing tabs
filtered_full = full.filter(~(hl.is_missing(full.quantity) | (full.quantity == -1) | (full.quantity == 0)))

In [None]:
db_name = "mdd"
full_tb_name = "all_presc_v3.ht"

stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)

spark.sql(stmt).show()

In [None]:
db_uri = dxpy.find_one_data_object(name=f"{db_name}", classname="database")['id']
url = f"dnax://{db_uri}/{full_tb_name}"

In [None]:
filtered_full.write(url, overwrite=True)