In [None]:
import pandas as pd
import numpy as np
from google.cloud import bigquery
import pickle
from datetime import date

import warnings
warnings.filterwarnings("ignore")

In [None]:
def run_query(query): 
    # Set up the BigQuery client
    project_id = 'som-nero-phi-sywang-starr'
    client = bigquery.Client(project=project_id)

    # Execute the query
    df = client.query(query, project=project_id).to_dataframe()

    return df

## Load Initial Cohort

In [None]:
cohort = pd.read_csv("processed_data/cohort.csv", dtype={'MRN': 'string'})

In [None]:
cohort.head()

In [None]:
print(f'There is a total of {len(cohort)} patients in this cohort.')

## Load and Prep Diag Codes

In [None]:
diag_codes = None

with open("diag_codes.pkl", "rb") as f:
    diag_codes = pickle.load(f)
    
diag_codes = [int(x.split("_")[1]) for x in diag_codes]


In [None]:
print(f"# of diag codes: {len(diag_codes)}.")

## Get Diagnoses Data 

In [None]:
query = """
SELECT 
    co.condition_occurrence_id,
    co.person_id,
    co.condition_concept_id,
    co.condition_start_date,
    co.condition_start_datetime,
    mc.*
FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.condition_occurrence` AS co
LEFT JOIN `som-nero-phi-sywang-starr.gps_stanford_clinic.mrn_crosswalk` AS mc
ON co.person_id = mc.person_id;

"""
cond_occurance_dat = run_query(query)
cond_occurance_dat = cond_occurance_dat.drop('person_id_1', axis = 1, inplace = False)
print(f"# of rows: {len(cond_occurance_dat)}")
print(f"# of unique pats: {len(cond_occurance_dat['MRN'].unique())}")

In [None]:
selected_diag_codes = cond_occurance_dat[cond_occurance_dat.condition_concept_id.isin(diag_codes)].copy()

selected_diag_codes = selected_diag_codes[['MRN', 'condition_concept_id']]

In [None]:
selected_diag_codes.head()

In [None]:
selected_diag_codes['pivot_value'] = 1
selected_diag_codes_wide = selected_diag_codes.pivot_table(values ='pivot_value', index = ['MRN'], 
                                                         columns = 'condition_concept_id',
                                                         fill_value = 0).add_prefix("omop_").reset_index().copy()

In [None]:
for code in diag_codes:
    code = "omop_" + str(code)
    if code not in selected_diag_codes_wide.columns:
        selected_diag_codes_wide[code] = 0
        
assert len(selected_diag_codes_wide.columns) - 1 == len(diag_codes)

In [None]:
selected_diag_codes_wide.head()

## Save File

In [None]:
selected_diag_codes_wide.head()

In [None]:
missing_mrns = np.setdiff1d(cohort['MRN'].values, selected_diag_codes_wide['MRN'].values)

In [None]:
new_rows = pd.DataFrame(0, columns=selected_diag_codes_wide.columns, index=range(len(missing_mrns)))
new_rows["MRN"] = missing_mrns  # Set person_id column

In [None]:
# Append to the original DataFrame
selected_diag_codes_wide = pd.concat([selected_diag_codes_wide, new_rows], ignore_index=True)

In [None]:
selected_diag_codes_wide

In [None]:
filtered_diag_codes = selected_diag_codes_wide[
    selected_diag_codes_wide["MRN"].isin(cohort["MRN"])
].copy()

In [None]:
assert len(filtered_diag_codes) == len(cohort)
assert len(filtered_diag_codes.columns) - 1 == len(diag_codes)

In [None]:
print(f'There is a total of {len(filtered_diag_codes)} patients in this cohort.')

In [None]:
filtered_diag_codes.to_csv('processed_data/diag_data.csv', index = False)