In [None]:
from pathlib import Path
import pandas as pd
from itertools import chain

In [None]:
# copy the data to your drive and then modify this path as required
#from google.colab import drive
#drive.mount('/content/drive')

folder = '/'

In [None]:
# base query for generating the cohort

# (anyone who has a diagnosis of breast cancer OR had at least one dose of doxorubicin OR at least one dose of cyclophosphomide)

# N.B. - these data are synthetically generated and therefore not really representative, but should give you a handle on how to
# interact with the concepts and data structures. That's why there are no breast cancer patients who had breast cancer AND
# the drugs of interest, which would be very common in the real world, but it's definitely adequate to get started.

# SELECT distinct(de.person_id)
# FROM
# `bigquery-public-data.cms_synthetic_patient_data_omop.drug_exposure` as de join
# `bigquery-public-data.cms_synthetic_patient_data_omop.concept` as c on c.concept_id = de.drug_concept_id join
# `bigquery-public-data.cms_synthetic_patient_data_omop.drug_exposure` as de2 on de.person_id = de2.person_id join
# `bigquery-public-data.cms_synthetic_patient_data_omop.concept` as c2 on c2.concept_id = de2.drug_concept_id join
# `bigquery-public-data.cms_synthetic_patient_data_omop.condition_occurrence` as co on co.person_id = de.person_id join
# `bigquery-public-data.cms_synthetic_patient_data_omop.concept` as c3 on c3.concept_id = co.condition_concept_id
# where upper(c.concept_name) like '%DOXORUBICIN%'
# or upper(c2.concept_name) like '%CYCLOPHOSPHAMIDE%'
# or upper(c3.concept_name) LIKE '%NEOPLASM%BREAST%'
# LIMIT 1000

In [None]:
# ended up downloading them separately because it was too slow combined...

cohort_files = pd.concat([pd.read_csv(f'cyclo.csv'),
                          pd.read_csv(f'doxo.csv'),
                          pd.read_csv(f'breast.csv')])

In [None]:
# copy this as filter to next queries
print(list(cohort_files.person_id.unique()))

# select *
# from `bigquery-public-data.cms_synthetic_patient_data_omop.drug_exposure` as p
# where p.person_id in (...)

In [None]:
person = pd.read_csv('person.csv')
condition_occurrence = pd.read_csv('condition_occurrence.csv')
drug_exposure = pd.read_csv('drug_exposure.csv')
concept = pd.read_csv('concept.csv')

In [None]:
print(list(person.location_id.unique()))

In [None]:
location = pd.read_csv('location.csv')

In [None]:
concept_lookup = {c.concept_id: c.concept_name for c in concept.itertuples()}

In [None]:
def make_labels(df):
    for c in df.columns:
        if 'concept_id' in c:
            df[c.replace('_id', '_label')] = df[c].map(concept_lookup)
        if 'concept_id' in c or 'source' in c or len(df[df[c].notna()])==0:
            df = df.drop(c, axis=1)
    return df

In [None]:
person_labelled = make_labels(person)
condition_occurrence_labelled = make_labels(condition_occurrence)
drug_exposure_labelled = make_labels(drug_exposure)
location_labelled = make_labels(location)

In [None]:
drug_exposure_labelled[drug_exposure_labelled.drug_concept_label.str.contains('cyclo', case=False, na=False)]

In [None]:
person_labelled.head()

In [None]:
location_labelled.head()

In [None]:
condition_occurrence_labelled.head()

In [None]:
drug_exposure_labelled.head()

In [None]:
def make_chunks(row, long=False):
    traj_off = row.full_offset
    traj_drugs = list(row.full_drugs)
    offset, counts = np.unique(traj_off, return_counts=True)
    if long:
        return ['_'. join(sorted(set([traj_drugs.pop(0) for _in range(c)]))) for c in counts]
    return ['_'. join(sorted(set([traj_drugs.pop(0)[:4] for _ in range(c)]))) for c in counts]

In [None]:
for row, group in enumerate([gl,g2,g3,g4,g5,g6]):
    traces = []
    current_traj_people = group # traj[~traj.drug_offset.isna()
                                # & traj.drug_concept_id.isin(chemo_drugs)
                                # & (traj.all_proc_code == reg)
                                # & (traj.procedure_date == traj.index_procedure_date) ].person_id.unique()

    #if len(current_traj_people) < 15:
        #continue

    full_drug_trajectory = traj[traj.person_id.isin(current_traj_people)
                                & (traj.drug_offset >= 0)
                                & traj.drug_concept_id.isin(chemo_drugs) ].sort_values(['drug_exposure_start_date',
    'drug_name' ]
                                                                ).groupby(['person_id']
                                                                ).agg({'drug_name': lambda x: list(x),
                                                                        'drug_offset': lambda x: list(x),
                                                                        'value_as_number': lambda x: list(x)}
                                                                    ). reset_index(). rename(columns={'drug_name': 'full_drugs',
                                                                        'drug_offset': 'full_offset',
                                                                    'value_as_number': 'full_dose'})

    full_drug_trajectory['path'] = full_drug_trajectory.apply(make_chunks, axis=1)
    if not any(full_drug_trajectory.path.apply(len) > 1):
        continue
    node_lookup = {}

    depth = max(full_drug_trajectory.path.apply(len))
    cols = []
    for level in range(depth) :
        cols.append(list(set([x for x in full_drug_trajectory.path.apply(lambda x: x[level] if len(x) > level else None) if x])))
    nodes = list(chain(*[[f'{i}{r}' for r in row] for i, row in enumerate(cols)]))
    node_numbers = {n: i for i, n in enumerate(nodes)}
    node_colours = {n:colour_map[n.strip('0123456789')] for n in nodes}
    rev_node_numbers = {n: i for i, n in node_numbers.items ()}

    source, target = [], []
    for path in full_drug_trajectory.path:
        for i, (s, t) in enumerate(zip(path, path[1:])):
            source. append (node_numbers [f'{i}{s}'])
            target. append (node_numbers [f'{i+1}{t}'])

    source_colours = [node_colours[rev_node_numbers[s]] for s in source]
    targ_colours = [node_colours[rev_node_numbers[t]] for t in target]

    transitions, transition_count = np.unique([f'{s}, {t}' for s, t in zip(source, target)], return_counts=True)
    trans_source, trans_target = list(zip(*[t.split(',') for t in transitions]))
    #transition_colours = [colourFader(s, t, 0.5) for s, t in zip(source_colours, targ_colours)]

In [None]:
s = [int(i) for i in trans_source] # [int(i) for i, t in zip(trans_source, transition_count) if t > 1]
t = [int(i) for i in trans_target] # [int(i) for i, t in zip(trans_target, transition_count) if t > 1]
v = transition_count #[t for t in transition_count if t > 1]

sankey = go.Sankey(node = dict(pad=15,
                            thickness=20,
                            line=dict(color='black', width=0.5),
                            #label=labels,
                            color=list(node_colours.values())),
                        link = dict(
                            source = 5,
                            target = t,
                            value = v,
                        ))

traces=[sankey] #+ legend
layout = go. Layout(showlegend=True,
                    plot_bgcolor='#FFFFFF')
fig = go.Figure(data=traces,
                layout=layout)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
reg_name = f'group{row}'#selected_reg_name[reg]. replace('DISCONTINUED', ''). replace('SUPERSEDED', "').strip()
fig.update_layout(title_text= reg_name, # + f' ({len(current_traj_people)})',
                    font size = 20,
                    autosize=False,
                    width=2500,
                    height=600)
fig.write_image(f"figures/{reg_name. replace(' ', '_')}.png")