In [None]:
import pandas as pd
import numpy as np
from google.cloud import bigquery
import pickle
from datetime import date

import warnings
warnings.filterwarnings("ignore")

In [None]:
def run_query(query): 
    # Set up the BigQuery client
    project_id = 'som-nero-phi-sywang-starr'
    client = bigquery.Client(project=project_id)

    # Execute the query
    df = client.query(query, project=project_id).to_dataframe()

    return df

## Load Initial Cohort

In [None]:
cohort = pd.read_csv("processed_data/cohort.csv", dtype={'MRN': 'string'})

In [None]:
cohort.head()

In [None]:
print(f'There is a total of {len(cohort)} patients in this cohort.')

## Load and Prep Drugs Codes

In [None]:
drug_codes = None

with open("drugs_codes.pkl", "rb") as f:
    drug_codes = pickle.load(f)
    
drug_codes = [int(x.split("_")[1]) for x in drug_codes]


In [None]:
print(f"# of drugs codes: {len(drug_codes)}.")

## Get Drug Data 

In [None]:
query = """
SELECT 
    co.drug_era_id,
    co.person_id,
    co.drug_concept_id,
    co.drug_era_start_date,
    mc.*
FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.drug_era` AS co
LEFT JOIN `som-nero-phi-sywang-starr.gps_stanford_clinic.mrn_crosswalk` AS mc
ON co.person_id = mc.person_id;

"""
drug_dat = run_query(query)
drug_dat = drug_dat.drop('person_id_1', axis = 1, inplace = False)
print(f"# of rows: {len(drug_dat)}")
print(f"# of unique pats: {len(drug_dat['MRN'].unique())}")

In [None]:
selected_drug_codes = drug_dat[drug_dat.drug_concept_id.isin(drug_codes)].copy()

selected_drug_codes = selected_drug_codes[['MRN', 'drug_concept_id']]

In [None]:
selected_drug_codes.head()

In [None]:
selected_drug_codes['pivot_value'] = 1
selected_drug_codes_wide = selected_drug_codes.pivot_table(values ='pivot_value', index = ['MRN'], 
                                                         columns = 'drug_concept_id',
                                                         fill_value = 0).add_prefix("omop_").reset_index().copy()

In [None]:
for code in drug_codes:
    code = "omop_" + str(code)
    if code not in selected_drug_codes_wide.columns:
        selected_drug_codes_wide[code] = 0
        
assert len(selected_drug_codes_wide.columns) - 1 == len(drug_codes)

In [None]:
selected_drug_codes_wide.head()

In [None]:
missing_mrns = np.setdiff1d(cohort['MRN'].values, selected_drug_codes_wide['MRN'].values)

In [None]:
new_rows = pd.DataFrame(0, columns=selected_drug_codes_wide.columns, index=range(len(missing_mrns)))
new_rows["MRN"] = missing_mrns  # Set person_id column

In [None]:
# Append to the original DataFrame
selected_drug_codes_wide = pd.concat([selected_drug_codes_wide, new_rows], ignore_index=True)

## Save File

In [None]:
assert len(np.setdiff1d(cohort['MRN'].values, selected_drug_codes_wide['MRN'].values)) == 0
assert len(selected_drug_codes_wide.columns) - 1 == len(drug_codes)

In [None]:
print(f'There is a total of {len(selected_drug_codes_wide)} patients in this cohort.')

In [None]:
selected_drug_codes_wide.to_csv('processed_data/drugs_data.csv', index = False)