In [None]:
#libraries
import pandas as pd
import numpy as np
import datetime as dt
pd.set_option("display.max_rows", None)
import plotly.express as px
import plotly
import plotly.offline
import plotly.graph_objs as go
import nbformat
#wrap text:
import textwrap
import random

In [None]:
#read csv
concept = pd.read_csv('../datasets/CONCEPT.csv')
condition_occurrence = pd.read_csv('../datasets/CONDITION_OCCURRENCE.csv')
drug_exposure = pd.read_csv('../datasets/DRUG_EXPOSURE.csv')
person = pd.read_csv('../datasets/PERSON.csv')
hierarchy = pd.read_csv('../datasets/hierarchy.csv')
props = pd.read_csv('../datasets/hemonc_component_properties.csv')

In [None]:
#list of valid drug categories from Ivy from RxNorm/HemOnc
sact=['Alkylating agent', 'Anti-CD38 antibody', 'Anti-CTLA-4 antibody', 'Anti-TACSTD2 antibody-drug conjugate', 'Anthracycline', 'Antiandrogen', 'Antifolate',
'Antimetabolite', 'Antitumor antibiotic', 'Anti-CD52 antibody', 'Anti-CD20 antibody', 'Anti-EGFR antibody', 'Anti-HER2 antibody', 'Anti-CD38 antibody', 'Anti-PD-1 antibody',
'Anti-PD-L1 antibody', 'Anti-RANKL antibody', 'Anti-SLAMF7 antibody','Anti-VEGF antibody', 'Aromatase inhibitor', 'Aromatase inhibitorsthird generation',
'Biosimilar', 'BRAF inhibitor', 'DNA methyltransferase inhibitor', 'Deoxycytidine analog', 'EGFR inhibitor', 'ERBB 2 inhibitor', 'Estrogen receptor inhibitor',
'Folic acid analog', 'Fluoropyrimidine', 'GnRH agonist', 'HDAC inhibitor', 'Human DNA synthesisinhibitor', 'Microtubule inhibitor', 'MTOR inhibitor',
'Nitrogen mustard', 'Nitrosourea', 'Neutral', 'PARP inhibitor', 'PARP1 inhibitor', 'PARP2 inhibitor', 'Phenothiazine', 'Platinum agent', 'Proteasome inhibitor',
'Purine analog', 'Pyrimidine analog', 'RANK ligand inhibitor', 'Selective estrogen receptor modulator', 'Somatostatin analog', 'T-cell activator',
'Targeted therapeutic', 'Taxane', 'Topoisomerase I inhibitor', 'Topoisomerase II inhibitor', 'Triazene', 'Vinca alkaloid', 'Xanthine oxidase inhibitor',
'WHO Essential Cancer Medicine']
#rxnorm = rxnorm[rxnorm['component_class_name'].isin(sact)]
props=props[props['component_class_name'].isin(sact)]
antican = props['concept_id_2']
drug_exposure=drug_exposure[drug_exposure['drug_concept_id'].isin(antican)]
#rxnorm['component_class_name'].value_counts()

In [None]:
#concept lookup in concept df
concept_lookup = {c.concept_id: c.concept_name for c in concept.itertuples()}

In [None]:
#add labels
def make_labels(df):
    for c in df.columns:
        if 'concept_id' in c:
            df[c.replace('_id', '_label')] = df[c].map(concept_lookup)
        if 'concept_id' in c or 'source' in c or len(df[df[c].notna()])==0:
            df = df.drop(c, axis=1)
    return df

In [None]:
#label the following dfs
condition_occurrence_labelled = make_labels(condition_occurrence)
drug_exposure_labelled = make_labels(drug_exposure)
exclusions = ['dexamethasone']
drug_exposure_labelled=drug_exposure_labelled[~drug_exposure_labelled['drug_concept_label'].isin(exclusions)]

In [None]:
'''if required, mask by a particular condition or set of conditions

# filter only by occurrences of Squamous cell carcinoma, NOS, of glottis
glottis = condition_occurrence[condition_occurrence.condition_concept_id==44500236]
# patient IDs matching this occurrence
glottis_patients = glottis.person_id.tolist()
# mask the drug exposures only by people matching the condition
mask = drug_exposure_labelled['person_id'].isin(glottis_patients)
masked = drug_exposure_labelled[mask]
'''

In [None]:
#reduce DF down to relevant variables for the visualization
small = drug_exposure_labelled[['person_id', 'drug_exposure_start_datetime', 'drug_concept_label']]
small = small.dropna()
small = small.drop_duplicates()
small_sorted = small.sort_values('drug_concept_label')
small['drug_concept_label'] = small_sorted.groupby(['person_id', 'drug_exposure_start_datetime'])['drug_concept_label'].transform(lambda x : ' & '.join(x))
#small.head()
'''small['drug_concept_label'] = small['drug_concept_label'].str.replace('dexamethasone & cisplatin','cisplatin & dexamethasone')
small['drug_concept_label'] = small['drug_concept_label'].str.replace('dexamethasone & cetuximab','cetuximab & dexamethasone')
small['drug_concept_label'] = small['drug_concept_label'].str.replace('dexamethasone & carboplatin','carboplatin & dexamethasone')'''
small_nodup = small_sorted.drop_duplicates()
#small_nodup['drug_concept_label']=small_nodup['drug_concept_label'].str.replace('& ', '&<br>')

In [None]:
# add new variable for every new drug administration per person
readministrations = pd.Series(np.zeros(len(small_nodup),dtype=int),index=small_nodup.index)

In [None]:
# Loop through all unique ids                                                                                                                                                                                      
all_id = small_nodup['person_id'].unique()
id_administrations = {}
for pid in all_id:
    # These are all the times a patient with a given ID has had surgery                                                                                                                                            
    patient = small_nodup.loc[small_nodup['person_id']==pid]
    administrations_sorted = pd.to_datetime(patient['drug_exposure_start_datetime'], format='%Y-%m-%d %H:%M:%S').sort_values()

# This checks if the previous surgery was longer than 180 days ago                                                                                                                                              
    frequency = administrations_sorted.diff()<dt.timedelta(days=6000)

    # Compute the readmission                                                                                                                                                                                      
    n_administrations = [0]
    for v in frequency.values[1:]:
       n_administrations.append((n_administrations[-1]+1)*v)

    # Add these value to the time series                                                                                                                                                                           
    readministrations.loc[administrations_sorted.index] = n_administrations

small_nodup['readministration'] = readministrations

In [None]:
small_nodup.head()

In [None]:
#pivot the DF from long to wide
pivoted = small_nodup.pivot(index='person_id', columns='readministration', values='drug_concept_label').reset_index()
# add the prefix 'drug' to every instance
prefixed = pivoted.add_prefix('drug')
#remove the word 'drug' from other variables
renamed = prefixed.rename(columns={"drugperson_id": "person_id", "readministration":"index"})
#fill all empty cells with "N/A"
#fillednones = renamed.fillna(" ")
fillednones = renamed

In [None]:
#add a value of 1 to all data points for sums in the visualization
fillednones["count"] = 1
fillednones.head()

In [None]:
drugs = pd.concat([drug_exposure_labelled, condition_occurrence_labelled])
display(drugs.dtypes)

drugs['drug_exposure_start_date']=pd.to_datetime(drugs['drug_exposure_start_date'])
drugs['condition_start_date']=pd.to_datetime(drugs['condition_start_date'])
display(drugs.head())

In [None]:
earliest_start = drugs[drugs.condition_start_date.notna()].groupby('person_id').condition_start_date.min().reset_index().rename(columns={'condition_start_date':'index_condition_date'})
drugs=drugs.merge(earliest_start, how='left')
drugs.condition_start_date = drugs.condition_start_date.combine_first(drugs.index_condition_date)
drugs['drug_offset'] = (drugs.drug_exposure_start_date - drugs.condition_start_date).dt.days

In [None]:
drugs = drugs.dropna(subset='drug_offset')
drugs['drug_offset']=drugs['drug_offset'].astype(int)

In [None]:
def make_chunks(row, long=False):
    traj_off = row.full_offset
    traj_drugs = list(full_drugs)
    offset, counts = np.unique(traj_off, return_counts=True)
    if long:
        return ['_'. join(sorted(set([traj_drugs.pop(0) for _ in range(c)]))) for c in counts]
    return ['_'. join(sorted(set([traj_drugs.pop(0)[:4] for _ in range(c)]))) for c in counts]

In [None]:
#chunks = make_chunks(drugs)
#chunks

In [None]:
traj = drugs
traj.head()

In [None]:
for row, group in enumerate([drugs]):
    traces = []
    current_traj_people = group # traj[~traj.drug_offset.isna()
                                # & traj.drug_concept_id.isin(chemo_drugs)
                                # & (traj.all_proc_code == reg)
                                # & (traj.procedure_date == traj.index_procedure_date) ].person_id.unique()

    #if len(current_traj_people) < 15:
        #continue

    full_drug_trajectory = traj.sort_values(['drug_exposure_start_date', 'drug_concept_label']
                                                                ).groupby(['person_id']
                                                                ).agg({'drug_concept_label': lambda x: list(x),
                                                                        'drug_offset': lambda x: list(x),
                                                                        'quantity': lambda x: list(x)}
                                                                    ).reset_index().rename(columns={'drug_concept_label': 'full_drugs',
                                                                        'drug_offset': 'full_offset',
                                                                    'quantity': 'full_dose'})
#([
                                #traj.person_id.isin(current_traj_people) & 
                                #traj.drug_offset >= 0
                                #& traj.drug_concept_id.isin(chemo_drugs) 
                                #].sort_values(['drug_exposure_start_date', 'drug_concept_label'])
    full_drug_trajectory['path'] = full_drug_trajectory.apply(make_chunks, axis=1)
    
    if not any(full_drug_trajectory.path.apply(len) > 1):
        continue
    node_lookup = {}

    depth = max(full_drug_trajectory.path.apply(len))
    cols = []
    for level in range(depth) :
        cols.append(list(set([x for x in full_drug_trajectory.path.apply(lambda x: x[level] if len(x) > level else None) if x])))
    nodes = list(chain(*[[f'{i}{r}' for r in row] for i, row in enumerate(cols)]))
    node_numbers = {n: i for i, n in enumerate(nodes)}
    node_colours = {n:colour_map[n.strip('0123456789')] for n in nodes}
    rev_node_numbers = {n: i for i, n in node_numbers.items ()}

    source, target = [], []
    for path in full_drug_trajectory.path:
        for i, (s, t) in enumerate(zip(path, path[1:])):
            source. append (node_numbers [f'{i}{s}'])
            target. append (node_numbers [f'{i+1}{t}'])

    source_colours = [node_colours[rev_node_numbers[s]] for s in source]
    targ_colours = [node_colours[rev_node_numbers[t]] for t in target]

    transitions, transition_count = np.unique([f'{s}, {t}' for s, t in zip(source, target)], return_counts=True)
    trans_source, trans_target = list(zip(*[t.split(',') for t in transitions]))
    #transition_colours = [colourFader(s, t, 0.5) for s, t in zip(source_colours, targ_colours)]

    s = [int(i) for i in trans_source] # [int(i) for i, t in zip(trans_source, transition_count) if t > 1]
    t = [int(i) for i in trans_target] # [int(i) for i, t in zip(trans_target, transition_count) if t > 1]
    v = transition_count #[t for t in transition_count if t > 1]

    sankey = go.Sankey(node = dict(pad=15,
                                thickness=20,
                                line=dict(color='black', width=0.5),
                                #label=labels,
                                color=list(node_colours.values())),
                            link = dict(
                                source = 5,
                                target = t,
                                value = v,
                            ))

    traces=[sankey] #+ legend
    layout = go. Layout(showlegend=True,
                        plot_bgcolor='#FFFFFF')
    fig = go.Figure(data=traces,
                    layout=layout)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    reg_name = f'group{row}'#selected_reg_name[reg]. replace('DISCONTINUED', ''). replace('SUPERSEDED', "').strip()
    fig.update_layout(title_text= reg_name, # + f' ({len(current_traj_people)})',
                        font_size = 20,
                        autosize=False,
                        width=2500,
                        height=600)
    fig.write_image(f"figures/{reg_name. replace(' ', '_')}.png")