In [None]:
import pyspark
import pandas as pd
import dxpy
import dxdata
import numpy as np
import matplotlib.pyplot as plt
from bokeh.io import show, output_notebook
from bokeh.layouts import gridplot
import seaborn as sns
import random
from pyspark.sql import SparkSession
from datetime import date
import re
import datetime
from scipy import stats
from scipy.stats import mode

output_notebook()

In [None]:
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)

import hail as hl
hl.init(sc=sc, default_reference='GRCh38')

In [None]:
db_name = "mdd_db"
db_uri = dxpy.find_one_data_object(name=f"{db_name}", classname="database")['id']
url = f"dnax://{db_uri}/all_presc_v3.ht"
full = hl.read_table(url)

In [None]:
full.describe()
with_data = full.filter(hl.is_defined(full.date)).count()
all_data=full.count()

In [None]:
#how many
print(with_data/all_data)

In [None]:
df = full.group_by(full.eid).aggregate(unique_meds=hl.agg.collect_as_set(full.term))
df = df.annotate(num_unique_meds=hl.len(df.unique_meds))

In [None]:
drug_struct = hl.struct(
        year=hl.int(full.date.split('-')[0]), 
        month=hl.int(full.date.split('-')[1]), 
        day=hl.int(full.date.split('-')[2]), 
        drug_name=full.term,
        brand_name=full.brand_name,
        source=full.source,
        code=full.code,
        date=full.date,
        system=full.system,
        info=full.info,
        dose=full.dose,
        tablets = full.tablets,
    )

In [None]:
aggregated_full = full.group_by(full.eid, full.term).aggregate(
    medicines=hl.agg.collect(drug_struct)
)
aggregated_full.describe()

In [None]:
aggregated_full = aggregated_full.annotate(
    medicines=hl.sorted(
        aggregated_full.medicines, 
        key=lambda x: (x.year, x.month, x.day)
    )
)

In [None]:
def days_between_dates(date1, date2):
    def is_leap_year(year):
        return (year % 4 == 0) & ((year % 100 != 0) | (year % 400 == 0))

    def days_in_month(year, month):
        days_in_month = hl.array([31, 28 + is_leap_year(year), 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
        return days_in_month[month - 1]

    def days_since_epoch(date):
        year = date.year
        month = date.month
        day = date.day

        days = day - 1
        days += hl.sum(hl.range(0, month - 1).map(lambda m: days_in_month(year, m + 1)))

        days += hl.sum(hl.range(1970, year).map(lambda y: 365 + is_leap_year(y)))

        return days

    days1 = days_since_epoch(date1)
    days2 = days_since_epoch(date2)

    return days2 - days1

In [None]:
aggregated_full = aggregated_full.annotate(
    date_diffs=hl.enumerate(aggregated_full.medicines).flatmap(lambda i_drug:
        hl.if_else(
            i_drug[0] == 0,
            hl.empty_array(hl.tint32),
            hl.array([days_between_dates(aggregated_full.medicines[i_drug[0] - 1], i_drug[1])])
        )
    )
)

In [None]:
aggregated_full = aggregated_full.annotate(
    date_diffs=hl.enumerate(aggregated_full.medicines).flatmap(lambda i_drug:
        hl.if_else(
            i_drug[0] == 0,
            hl.empty_array(hl.tint32),
            hl.array([days_between_dates(aggregated_full.medicines[i_drug[0] - 1], i_drug[1])])
        )
    ).append(-1)
)

In [None]:
aggregated_first = aggregated_full.annotate(
    medicines=hl.zip(aggregated_full.medicines, aggregated_full.date_diffs)
        .map(lambda meds_diff: hl.struct(
            year=meds_diff[0].year,
            month=meds_diff[0].month,
            day=meds_diff[0].day,
            drug_name=meds_diff[0].drug_name,
            interval=meds_diff[1],
            brand_name=meds_diff[0].brand_name,
            source=meds_diff[0].source,
            code=meds_diff[0].code,
            date=meds_diff[0].date,
            system=meds_diff[0].system,
            info=meds_diff[0].info,
            dose=meds_diff[0].dose,
            tablets=meds_diff[0].tablets
        ))
        [0]
)
aggregated_first=aggregated_first.drop('date_diffs')

In [None]:
intervals = aggregated_first.group_by(aggregated_first.medicines.interval).aggregate(counts=hl.agg.count())
tab_counts=intervals.aggregate(hl.agg.collect_as_set(intervals.counts))

In [None]:
suma=sum(tab_counts)
print(36238/suma)
print(suma)

In [None]:
aggregated_full = aggregated_full.annotate(
    medicines=hl.zip(aggregated_full.medicines, aggregated_full.date_diffs).map(lambda meds_diff:
        hl.struct(
            year=meds_diff[0].year,
            month=meds_diff[0].month,
            day=meds_diff[0].day,
            drug_name=meds_diff[0].drug_name,
            date_diff=meds_diff[1],
            brand_name=meds_diff[0].brand_name,
            source=meds_diff[0].source,
            code=meds_diff[0].code,
            date=meds_diff[0].date,
            system=meds_diff[0].system,
            info=meds_diff[0].info,
            dose=meds_diff[0].dose,
            tablets=meds_diff[0].tablets
        )
    )
)

In [None]:
aggregated_full = aggregated_full.drop('date_diffs')

In [None]:
ht=aggregated_full.add_index('idx')
ht=ht.explode(ht.medicines)
ht = ht.annotate(
    tablets_value=hl.coalesce(ht.medicines.tablets, 0),
)
ht = ht.annotate(
    segment=hl.scan.count_where(
        (ht.medicines.date_diff - ht.tablets_value) >= 90)
)
ht=ht.drop('tablets_value')

In [None]:
ht=ht.group_by('idx','segment','term','eid').aggregate(medicines=hl.agg.collect(ht.medicines))

In [None]:
ht=ht.add_index('therapy_id')
ht=ht.key_by('therapy_id')
ht=ht.drop('idx').drop('segment')

In [None]:
count_by_eid=ht.group_by(ht.eid).aggregate(count_therapy=hl.agg.count())
df=count_by_eid.to_pandas()

In [None]:
ht = ht.annotate(
    medicines=hl.enumerate(ht.medicines)
)

In [None]:
ht = ht.annotate(
    num_prescriptions=hl.len(ht.medicines)
)

In [None]:
ht.describe()

In [None]:
def add_is_last_flag(med_tuple, num_prescriptions):
    idx, struct = med_tuple
    is_last = hl.if_else(idx == num_prescriptions - 1, 1, 0)
    return hl.struct(
        year=struct.year,
        month=struct.month,
        day=struct.day,
        drug_name=struct.drug_name,
        date_diff=struct.date_diff,
        brand_name=struct.brand_name,
        source=struct.source,
        code=struct.code,
        date=struct.date,
        system=struct.system,
        info=struct.info,
        dose=struct.dose,
        tablets=struct.tablets,
        is_last=is_last
    )

ht = ht.annotate(
    medicines=hl.map(lambda med: add_is_last_flag(med, ht.num_prescriptions), ht.medicines)
)

In [None]:
ht.describe()

In [None]:
ht=ht.annotate(
    duration=days_between_dates(ht.medicines[0], ht.medicines[ht.num_prescriptions - 1])
)

In [None]:
pandas=ht.distinct()
pandas=pandas.key_by()
pandas=pandas.select(pandas.therapy_id, pandas.term, pandas.eid, pandas.duration)
pandas_df = pandas.to_pandas()

In [None]:
plt.figure(figsize=(10, 6))
sns.histplot(pandas_df['duration'], kde=False)
plt.xlabel('Therapy duration')
plt.ylabel('Frequency')
plt.title('Histogram of duration of therapies')
plt.xlim(0,1000)
plt.grid(True)
plt.show()

In [None]:
ht=ht.drop('term')
ht=ht.explode(ht.medicines)

In [None]:
def update_date_diff(med_struct):
    date_diff = hl.if_else(med_struct.is_last == 1, -1, med_struct.date_diff)
    return hl.struct(
        year=med_struct.year,
        month=med_struct.month,
        day=med_struct.day,
        drug_name=med_struct.drug_name,
        date_diff=date_diff,
        brand_name=med_struct.brand_name,
        source=med_struct.source,
        code=med_struct.code,
        date=med_struct.date,
        system=med_struct.system,
        info=med_struct.info,
        dose=med_struct.dose,
        tablets=med_struct.tablets
    )

ht = ht.annotate(
    medicines=update_date_diff(ht.medicines)
)

In [None]:
ht = ht.annotate(
    drug_name=ht.medicines.drug_name,
    interval=ht.medicines.date_diff,
    brand_name=ht.medicines.brand_name,
    source=ht.medicines.source,
    code=ht.medicines.code,
    date=ht.medicines.date,
    system=ht.medicines.system,
    info=ht.medicines.info,
    dose=ht.medicines.dose,
    tablets=ht.medicines.tablets
)

In [None]:
ht=ht.drop('medicines')

In [None]:
ht.describe()

In [None]:
df = ht.annotate(
    tablets = hl.case()
    .when(hl.is_missing(ht.tablets), -1)
    .default(ht.tablets)
)

In [None]:
df = df.annotate(
    interval = hl.case()
    .when(((df.interval == -1) | hl.is_missing(df.interval)), df.tablets)
    .default(df.interval)
)

In [None]:
df.describe()

In [None]:
# update prescription with interval zeros
# merge records with the same date and therapy id, update dose and tablets of ones of each and delete others

merged_df = df.group_by(df.date, df.therapy_id).aggregate(
    dose_list=hl.agg.collect(df.dose),
    tablets_list=hl.agg.collect(df.tablets),
    eid_list=hl.agg.collect(df.eid),
    num_prescriptions_list=hl.agg.collect(df.num_prescriptions),
    duration_list=hl.agg.collect(df.duration),
    drug_name_list=hl.agg.collect(df.drug_name),
    interval_list=hl.agg.collect(df.interval),
    brand_name_list=hl.agg.collect(df.brand_name),
    source_list=hl.agg.collect(df.source),
    code_list=hl.agg.collect(df.code),
    system_list=hl.agg.collect(df.system),
    info_list=hl.agg.collect(df.info)
)


def compute_result(dose_list, tablets_list):
    min_dose = hl.min(dose_list)
    dose_tablets_product = hl.map(lambda x: x[0] * x[1], hl.zip(dose_list, tablets_list))
    total_tablets = hl.sum(dose_tablets_product) / min_dose
    return hl.struct(dose=min_dose, tablets=hl.int32(total_tablets))

merged_df = merged_df.annotate(
    result=compute_result(merged_df.dose_list, merged_df.tablets_list)
)

merged_df = merged_df.key_by()

merged_df = merged_df.transmute(
    dose=merged_df.result.dose,
    tablets=merged_df.result.tablets,
    eid=merged_df.eid_list[0],
    num_prescriptions=merged_df.num_prescriptions_list[0],
    duration=merged_df.duration_list[0],
    drug_name=merged_df.drug_name_list[0],
    interval=hl.max(merged_df.interval_list),
    brand_name=merged_df.brand_name_list[0],
    source=merged_df.source_list[0],
    code=merged_df.code_list[0],
    system=merged_df.system_list[0],
    info=merged_df.info_list[0],
)

In [None]:
df = merged_df
df = df.annotate(
    ratio = hl.case()
    .when(df.interval == -1, -1)
    .when(df.interval == 0, 0)
    .default(df.tablets/df.interval)
)
df = df.drop(df.dose_list, df.tablets_list)

In [None]:
selected_df = df.select('eid', 'ratio')
pandas_df = selected_df.to_pandas()

In [None]:
plt.figure(figsize=(10, 6))
sns.histplot(pandas_df['ratio'], bins=[-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5], kde=False)
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram for ratio column')
plt.grid(True)
plt.show()

In [None]:
unique_therapy_per_patient = df.group_by('eid').aggregate(
    unique_therapy_ids=hl.agg.collect_as_set(df.therapy_id)
)
unique_therapy_per_patient = unique_therapy_per_patient.annotate(
    unique_therapy_count=hl.len(unique_therapy_per_patient.unique_therapy_ids)
)

In [None]:
pandas_df=unique_therapy_per_patient.to_pandas()
print('max:')
print(pandas_df['unique_therapy_count'].max())
print('median:')
print(pandas_df['unique_therapy_count'].median())
print('mean:')
print(pandas_df['unique_therapy_count'].mean())
sns.histplot(pandas_df['unique_therapy_count'], bins=[-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5, 17.5, 18.5, 19.5, 20.5], kde=False)
plt.xlabel('Number of unique therapies per patient')
plt.ylabel('Frequency')
plt.title('Histogram showing number of unique therapies per patient')
plt.grid(True)
plt.show()

In [None]:
result = df.group_by(df.therapy_id).aggregate(number_of_days=hl.agg.sum(df.interval))

In [None]:
long_therapies=result.filter(result.number_of_days>1000).to_pandas()['therapy_id']

In [None]:
def intervals(therapy_id, points, median_change_points, mode_change_points, labels=True,):
    color='blue'
    plt.figure(figsize=(16, 6))
    max_interval = None
    points['date'] = pd.to_datetime(points['date'])

    plt.stem(points['date'], points['tablets'], 'b', linefmt='-', markerfmt='D', basefmt=" ")

    if max_interval is None or points['tablets'].max() > max_interval:
        max_interval = points['tablets'].max()
            
    if labels:
        for date, interval, dose, tablets in zip(points['date'], points['interval'], points['dose'], points['tablets']):
            plt.annotate(f'{dose} - {(tablets/interval):.2f}', (date, tablets), textcoords="offset points", xytext=(0,10), ha='left', fontsize=8, rotation=45)

    plt.stem(points['date'], points['interval'], 'b', linefmt='-', markerfmt='s', basefmt=" ")
            
    if max_interval is None or points['interval'].max() > max_interval:
        max_interval = points['interval'].max()
    print("Mean:", np.mean(points['interval']))
    print("Median:", np.median(points['interval']))
    print("Mode:", stats.mode(points['interval']).mode[0])
    
    for date in points['date'].unique():
        tablets_value = points[points['date'] == date]['tablets'].values[0]
        interval_value = points[points['date'] == date]['interval'].values[0]
        plt.vlines(date, min(tablets_value, interval_value), max(tablets_value, interval_value), colors=color)
    
    for x in mode_change_points:
         plt.axvline(x=x, color='magenta', linestyle='--')
            
    for x in median_change_points:
         plt.axvline(x=x, color='red', linestyle='--')
    
    plt.ylim(-5, (max_interval * 1.2))
    plt.grid(axis='y', linestyle='--', linewidth=0.5)
    plt.title(f'therapy_id-{therapy_id}')

In [None]:
def change_points_graphs(tid):
    data = df.filter(df.therapy_id == tid)
    data_pd = data.to_pandas()

    if len(data_pd) <= 7:
        print(f"For {tid} its to short.")
        intervals(tid, data_pd, [], [])
        return

    def detect_changes(data, window_size=3, sd_fraction=1.5):
        median = []
        mode = []
        for x in range(window_size- 1):
            median.append(data[x])
            mode.append(data[x])

        for x in range(len(data) - window_size + 1):
            window_median = np.median(data[x:(x + window_size)])
            window_mode = data[x:(x + window_size)].mode().iloc[0]
            median.append(window_median)
            mode.append(window_mode)

        median = np.diff(median)
        mode = np.diff(mode)

        # Calculate moving average
        #moving_avg = np.convolve(data, np.ones(window_size)/window_size, mode='valid')
        median_moving_avg = np.convolve(median, np.ones(window_size)/window_size, mode='valid')
        mode_moving_avg = np.convolve(mode, np.ones(window_size)/window_size, mode='valid')

        # Calculate standard deviation
        #sd = np.std(data)
        median_sd = np.std(median)
        mode_sd = np.std(mode)

        # Set threshold as a fraction of standard deviation
        median_threshold = sd_fraction * median_sd
        mode_threshold = sd_fraction * mode_sd

        # Detect changes
        median_changes = np.abs(median[window_size-1:] - median_moving_avg) > median_threshold
        median_change_points = np.where(median_changes)[0] + (window_size - 1)
        mode_changes = np.abs(mode[window_size-1:] - mode_moving_avg) > mode_threshold
        mode_change_points = np.where(mode_changes)[0] + (window_size - 1)


        return median_change_points, median_moving_avg, median_threshold, np.array(median), mode_change_points, mode_moving_avg, mode_threshold, np.array(mode)

    # Example usage
    noised = data_pd['tablets'] / data_pd['interval'] * data_pd['dose']
    #noised = np.diff(noised)
    noised = noised.astype(float)
    dates = pd.to_datetime(data_pd['date'])
    dates = dates[0:(len(dates) - 1)]
    window_size = 9
    sd_fraction = 1.6

    median_change_points, median_moving_avg, median_threshold, median, mode_change_points, mode_moving_avg, mode_threshold, mode = detect_changes(noised, window_size=window_size, sd_fraction=sd_fraction)

    median_change_point_dates = dates[median_change_points]
    mode_change_point_dates = dates[mode_change_points]

    print(f"Detected change median points dates: {median_change_point_dates.tolist()}")
    print(f"Detected change mode points dates: {mode_change_point_dates.tolist()}")
    print(f"Calculated median threshold: {median_threshold:.2f}")
    print(f"Calculated mode threshold: {mode_threshold:.2f}")

    # Plot the results
    plt.figure(figsize=(12, 6))
    #plt.plot(dates, noised, label='Noised data', marker='o')
    #plt.plot(theory, label='Theory', linestyle='--')
    plt.plot(dates, median, label='Median data', marker='o')
    plt.plot(dates, mode, label='Mode data', marker='o')

    # Adjust x-coordinates for moving average plot
    x_moving_avg = np.arange(window_size - 1, len(noised))
    plt.plot(dates[window_size - 1:len(dates)], median_moving_avg, label='Median Moving Average', color='green')
    plt.plot(dates[window_size - 1:len(dates)], mode_moving_avg, label='Mode Moving Average', color='lightgreen')

    # Adjust x-coordinates for change points
    #x_change_points = change_points - (window_size // 2)
    plt.scatter(dates[median_change_points], median[median_change_points], color='red', label='Median change points', zorder=5)
    plt.scatter(dates[mode_change_points], mode[mode_change_points], color='magenta', label='Mode change points', zorder=5)

    # Plot threshold bands
    plt.fill_between(dates[window_size - 1:len(dates)],
                     median_moving_avg - median_threshold,
                     median_moving_avg + median_threshold,
                     color='gray', alpha=0.2, label='Median threshold band')
    plt.fill_between(dates[window_size - 1:len(dates)],
                     mode_moving_avg - mode_threshold,
                     mode_moving_avg + mode_threshold,
                     alpha=0.2, label='Mode threshold band')

    plt.legend()
    plt.title(f'Change Point Detection for {tid} (Window: {window_size}, SD Fraction: {sd_fraction})')
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.grid(True, linestyle=':', alpha=0.7)

    intervals(tid, data_pd, dates[[x - 4 for x in median_change_points]], dates[[x - 4 for x in mode_change_points]])

    plt.show()

In [None]:
df.count()

In [None]:
num_therapies = df.select('therapy_id').key_by('therapy_id').distinct().count()

In [None]:
schema = hl.tstruct(
    date=hl.tstr,
    therapy_id=hl.tint64,
    dose=hl.tint32,
    tablets=hl.tint32,
    eid=hl.tstr,
    num_prescriptions=hl.tint32,
    duration=hl.tint32,
    drug_name=hl.tstr,
    interval=hl.tint32,
    brand_name=hl.tstr,
    source=hl.tstr,
    code=hl.tstr,
    system=hl.tstr,
    info=hl.tstr,
    ratio=hl.tfloat64,
    change_flag=hl.tint32
)
new_df = hl.Table.parallelize([], schema)

In [None]:
def find_change_points(data_pd):
    if len(data_pd) <= 9:
        return []

    def detect_changes(data, window_size=3, sd_fraction=1.5):
        median = []
        for x in range(window_size- 1):
            median.append(data.iloc[x])
        
        for x in range(len(data) - window_size + 1):
            window_median = np.median(data.iloc[x:(x + window_size)])
            median.append(window_median)
        
        median = np.diff(median)
        
        median_moving_avg = np.convolve(median, np.ones(window_size)/window_size, mode='valid')
        median_sd = np.std(median)
        
        median_threshold = sd_fraction * median_sd

        median_changes = np.abs(median[window_size-1:] - median_moving_avg) > median_threshold
        median_change_points = np.where(median_changes)[0] + (window_size - 1)

        return median_change_points, median_moving_avg, median_threshold, np.array(median)

    noised = data_pd['tablets'] / data_pd['interval'] * data_pd['dose']
    noised = noised.astype(float)
    dates = data_pd['date']
    dates = dates[0:(len(dates) - 1)]
    window_size = 9
    sd_fraction = 1.6

    median_change_points, median_moving_avg, median_threshold, median= detect_changes(noised, window_size=window_size, sd_fraction=sd_fraction)

    median_change_point_dates = dates.iloc[median_change_points]
    
    return dates.iloc[[x - 4 for x in median_change_points]]

In [None]:
lowest_free_therapy_number = num_therapies + 1

In [None]:
results = []

In [None]:
for range_start in range(1, num_therapies + 1, 10000):
    range_end = min(range_start + 9999, num_therapies)
    data_chunk_hl = df.filter((df.therapy_id >= range_start) & (df.therapy_id <= range_end))
    data_chunk = data_chunk_hl.to_pandas()
    for i in range(range_start, range_end + 1):
        data = data_chunk[data_chunk['therapy_id'] == i]
        dates = pd.to_datetime(find_change_points(data))
        if not dates.empty: 
            for date in dates:
                results.append({'therapy_id': i, 'date': date})

In [None]:
for entry in results:
    entry['date'] = entry['date'].strftime('%Y-%m-%d')
    

In [None]:
results_set = hl.literal({(record['therapy_id'], record['date']) for record in results})

In [None]:
df = df.annotate(
    is_change = results_set.contains((hl.int32(df.therapy_id), hl.str(df.date)))
)

In [None]:
df.filter(df.is_change).count()

In [None]:
len(results)

In [None]:
df.describe()

In [None]:
drug_struct = hl.struct(
    year=hl.int(df.date.split('-')[0]), 
    month=hl.int(df.date.split('-')[1]), 
    day=hl.int(df.date.split('-')[2]), 
    drug_name=df.drug_name,
    brand_name=df.brand_name,
    source=df.source,
    code=df.code,
    date=df.date,
    system=df.system,
    info=df.info,
    dose=df.dose,
    tablets=df.tablets,
    eid=df.eid,
    interval=df.interval,
    ratio=df.ratio,
    is_change=df.is_change
)

In [None]:
aggregated_df = df.group_by(df.therapy_id).aggregate(
    medicines=hl.agg.collect(drug_struct)
)
aggregated_df.describe()

In [None]:
aggregated_df = aggregated_df.annotate(
    medicines=hl.sorted(
        aggregated_df.medicines, 
        key=lambda x: (x.year, x.month, x.day)
    )
)

In [None]:
aggregated_df=aggregated_df.explode(aggregated_df.medicines)

In [None]:
aggregated_df = aggregated_df.annotate(
    segment=hl.scan.count_where(
        (aggregated_df.medicines.is_change==True)
))

In [None]:
aggregated_df=aggregated_df.group_by('therapy_id','segment').aggregate(medicines=hl.agg.collect(aggregated_df.medicines))

In [None]:
aggregated_df=aggregated_df.add_index('tid')
aggregated_df=aggregated_df.key_by('tid')
aggregated_df=aggregated_df.drop('therapy_id').drop('segment')

In [None]:
aggregated_df=aggregated_df.annotate(
    num_prescriptions=hl.len(aggregated_df.medicines)
)

In [None]:
aggregated_df = aggregated_df.annotate(
    medicines=hl.enumerate(aggregated_df.medicines)
)

In [None]:
def add_is_last_flag(med_tuple, num_prescriptions):
    idx, struct = med_tuple
    is_last = hl.if_else(idx == num_prescriptions - 1, 1, 0)
    return hl.struct(
        year=struct.year,
        month=struct.month,
        day=struct.day,
        drug_name=struct.drug_name,
        brand_name=struct.brand_name,
        source=struct.source,
        code=struct.code,
        date=struct.date,
        system=struct.system,
        info=struct.info,
        dose=struct.dose,
        tablets=struct.tablets,
        eid=struct.eid,
        interval=struct.interval,
        ratio=struct.ratio,
        is_change=struct.is_change,
        is_last=is_last
    )

aggregated_df = aggregated_df.annotate(
    medicines=hl.map(lambda med: add_is_last_flag(med, aggregated_df.num_prescriptions), aggregated_df.medicines)
)

In [None]:
def days_between_dates(date1, date2):
    def is_leap_year(year):
        return (year % 4 == 0) & ((year % 100 != 0) | (year % 400 == 0))

    def days_in_month(year, month):
        days_in_month = hl.array([31, 28 + is_leap_year(year), 31, 30, 31, 30, 31, 31, 30, 31, 30, 31])
        return days_in_month[month - 1]

    def days_since_epoch(date):
        year = date.year
        month = date.month
        day = date.day

        days = day - 1
        days += hl.sum(hl.range(0, month - 1).map(lambda m: days_in_month(year, m + 1)))

        days += hl.sum(hl.range(1970, year).map(lambda y: 365 + is_leap_year(y)))

        return days

    days1 = days_since_epoch(date1)
    days2 = days_since_epoch(date2)

    return days2 - days1

In [None]:
aggregated_df=aggregated_df.annotate(
    duration=days_between_dates(aggregated_df.medicines[0], aggregated_df.medicines[aggregated_df.num_prescriptions - 1])
)

In [None]:
aggregated_df=aggregated_df.explode(aggregated_df.medicines)

In [None]:
def modify_interval_if_last(med_struct):
    new_interval = hl.if_else(med_struct.is_last == 1, -1, med_struct.interval)
    
    return med_struct.annotate(interval=new_interval)

aggregated_df = aggregated_df.annotate(
    medicines=modify_interval_if_last(aggregated_df.medicines)
)

In [None]:
aggregated_df = aggregated_df.transmute(
    year=aggregated_df.medicines.year,
    month=aggregated_df.medicines.month,
    day=aggregated_df.medicines.day,
    drug_name=aggregated_df.medicines.drug_name,
    brand_name=aggregated_df.medicines.brand_name,
    source=aggregated_df.medicines.source,
    code=aggregated_df.medicines.code,
    date=aggregated_df.medicines.date,
    system=aggregated_df.medicines.system,
    info=aggregated_df.medicines.info,
    dose=aggregated_df.medicines.dose,
    tablets=aggregated_df.medicines.tablets,
    eid=aggregated_df.medicines.eid,
    interval=aggregated_df.medicines.interval,
    ratio=aggregated_df.medicines.ratio,
    is_change=aggregated_df.medicines.is_change,
    is_last=aggregated_df.medicines.is_last
)

In [None]:
num_intervals_equal_minus_1 = aggregated_df.aggregate(
    hl.agg.count_where(aggregated_df.interval == -1)
)

In [None]:
num_intervals_equal_minus_1

In [None]:
aggregated_df = aggregated_df.annotate(
    interval=hl.if_else(aggregated_df.interval == -1, aggregated_df.tablets, aggregated_df.interval)
)

In [None]:
aggregated_df=aggregated_df.drop('is_change')
aggregated_df=aggregated_df.drop('is_last')
aggregated_df=aggregated_df.drop('ratio')

In [None]:
aggregated_df.aggregate(hl.len(hl.agg.collect_as_set(aggregated_df.tid)))

In [None]:
db_name = "mdd_db"
full_tb_name = "therapies.ht"

stmt = f"CREATE DATABASE IF NOT EXISTS {db_name} LOCATION 'dnax://'"
print(stmt)

spark.sql(stmt).show()

In [None]:
db_uri = dxpy.find_one_data_object(name=f"{db_name}", classname="database")['id']
url = f"dnax://{db_uri}/{full_tb_name}"

In [None]:
aggregated_df.write(url, overwrite=True)