---

See `2__preprocess_CSVs.ipynb` to see why my CSVs are suffixed with `_sub_col*`  

---

(I didn't want all columns/tables--change the below dictionaries if you want everything)   

---

---
---

Before running below, make sure the following files are in the same folder as this .ipynb file.

- datadictionary.xlsx
- manifest.csv

You can find the files from TrinetX data zip file.


# Generate `create.sql`

In [1]:
specific_columns = {
    "patient_demographic": ["patient_id", "sex", "race", "ethnicity", "year_of_birth", "patient_regional_location"],
    "encounter": ["encounter_id", "patient_id", "start_date", "type"],
    "lab_result": ["patient_id", "encounter_id", "code", "date", "lab_result_num_val", "lab_result_text_val", "units_of_measure"],
    "diagnosis": ["patient_id", "encounter_id", "code", "date"],
    "procedure": ["patient_id", "encounter_id", "code", "date"],
    "medication": ["patient_id", "encounter_id", "code", "start_date"],
    "vital_sign": ["patient_id", "encounter_id", "code", "date", "value", "text_value", "units_of_measure", "code_system"]
}

In [None]:
import pandas as pd

#!conda install openpyxl -y

# Load the data dictionary
data_dict = pd.read_excel('datadictionary.xlsx')

In [4]:
data_dict.head()

Unnamed: 0,Table Name,Source File Name,Col #,Data Element,Data Type,Length,Nullable,Key,enumValues,Sample Data,Description
0,Chemotherapy Lines of Treatment,chemo_lines.csv,1,patient_id,VARCHAR,200,No,Foreign Key,-,123456789,The unique ID for the patient (de-identified).
1,Chemotherapy Lines of Treatment,chemo_lines.csv,2,start_date,DATETIME (YYYYMMDD),8,No,-,-,20150314,The date the chemotherapy line of treatment wa...
2,Chemotherapy Lines of Treatment,chemo_lines.csv,3,line,BIGINT,2,No,-,12345,1,The sequential order of chemotherapy regimens....
3,Chemotherapy Lines of Treatment,chemo_lines.csv,4,derived_by_TriNetX,BOOLEAN,1,No,-,"T,F",T,Flag that indicates whether the chemotherapy l...
4,Chemotherapy Lines of Treatment,chemo_lines.csv,5,source_id,VARCHAR,200,No,-,"HCO,HCO-NLP,TriNetX",HCO,The data source and data type. Data source opt...


In [None]:
# Define the mapping of data types
sql_type_mapping = {
    'VARCHAR': 'VARCHAR',
    'DATETIME (YYYYMMDD)': 'DATE',
    'BIGINT': 'BIGINT',
    'BOOLEAN': 'BOOLEAN',
    'DECIMAL': 'FLOAT'
}

# Function to generate SQL statements
def generate_sql_statements(data_dict):
    table_sql = {}
    
    # Group by Table Name
    grouped = data_dict.groupby('Table Name')
    
    for table_name, group in grouped:
        table_name = table_name.lower().strip().replace(' ','_')
        if table_name not in specific_columns:
            continue
        columns = []
        for _, row in group.iterrows():
            col_name = row['Data Element']
            col_type = sql_type_mapping.get(row['Data Type'], 'VARCHAR')
            length = row['Length']
            
            nullable = ''
            # don't enforce non-nullable at first per observed rows with null
            # # Check if the column is in the specific columns list
            # if table_name in specific_columns and col_name in specific_columns[table_name]:
            #     nullable = '' if row['Nullable'] == 'Yes' else ' NOT NULL'
            
            # Include length for VARCHAR type
            if col_type == 'VARCHAR':
                col_definition = f"{col_name} {col_type}({length}){nullable}"
            else:
                col_definition = f"{col_name} {col_type}{nullable}"
            
            columns.append(col_definition)
        
        # Create SQL statement
        columns_sql = ",\n  ".join(columns)
        create_table_sql = f"""
            DROP TABLE IF EXISTS {table_name};
            CREATE TABLE {table_name} (
              {columns_sql}
            );
        """
        table_sql[table_name.lower()] = create_table_sql
    
    return table_sql

# Generate SQL statements
sql_statements = generate_sql_statements(data_dict)

# Print the SQL statements
for table, sql in sql_statements.items():
    print(sql)

# Write to SQL file
with open('create.sql', 'w') as f:
    for table, sql in sql_statements.items():
        f.write(sql + '\n');

~e.g., for patient_demographic 
```
DROP TABLE IF EXISTS patient_demographic;
CREATE TABLE patient_demographic (
  patient_id VARCHAR(200),
  sex VARCHAR(50),
  race VARCHAR(180),
  ethnicity VARCHAR(180),
  marital_status VARCHAR(180),
  year_of_birth BIGINT,
  reason_yob_missing VARCHAR(50),
  death_date_source_id VARCHAR(200),
  month_year_death BIGINT,
  patient_regional_location VARCHAR(100),
  source_id VARCHAR(50)
);
```

---
---

# Generate `load.sql`

__Note:__   
  
Change `\cd :diabetes_data_dir` in `pre` if you plan to use a different variable name (`psql` command 2 lines above)

In [18]:
# Define the dictionary with the keys and values we care about
table_to_file_mapping = {
    "patient_demographic": "patient_sub_col.csv",
    "encounter": "encounter_sub_col.csv",
    "lab_result": "lab_result_sub_col.csv",
    "diagnosis": "diagnosis_sub_col.csv",
    "procedure": "procedure_sub_col.csv",
    "medication": "medication_ingredient_sub_col.csv",
    "vital_sign": "vitals_signs_sub_col.csv"
}

pre = """-----------------------------------------
-- Load data into the diabetes schemas --
-----------------------------------------

-- To run from a terminal:
--  psql "dbname=<DBNAME> user=<USER>" -v diabetes_data_dir=<PATH TO DATA DIR> -f load.sql
-- The script assumes the files are in the diabetes_data_dir
\cd :diabetes_data_dir

-- making sure correct encoding is defined as -utf8- 
SET CLIENT_ENCODING TO 'utf8';"""

 # Write to SQL file
with open('load.sql', 'w') as f:
    f.write(pre + '\n')
    print(pre + '\n')
    for table_name, cols in specific_columns.items():
        csv = table_to_file_mapping[table_name]
        tc = f'\\COPY {table_name} ({", ".join(cols)}) FROM \'{csv}\' DELIMITER \',\' CSV HEADER NULL \'\';'
        
        f.write(tc + '\n')
        print(tc + '\n')


-----------------------------------------
-- Load data into the diabetes schemas --
-----------------------------------------

-- To run from a terminal:
--  psql "dbname=<DBNAME> user=<USER>" -v diabetes_data_dir=<PATH TO DATA DIR> -f load.sql
-- The script assumes the files are in the diabetes_data_dir
\cd :diabetes_data_dir

-- making sure correct encoding is defined as -utf8- 
SET CLIENT_ENCODING TO 'utf8';

\COPY patient_demographic (patient_id, sex, race, ethnicity, year_of_birth, patient_regional_location) FROM 'patient_sub_col2.csv' DELIMITER ',' CSV HEADER NULL '';

\COPY encounter (encounter_id, patient_id, start_date, type) FROM 'encounter_sub_col2.csv' DELIMITER ',' CSV HEADER NULL '';

\COPY lab_result (patient_id, encounter_id, code, date, lab_result_num_val, lab_result_text_val, units_of_measure) FROM 'lab_result_sub_col2.csv' DELIMITER ',' CSV HEADER NULL '';

\COPY diagnosis (patient_id, encounter_id, code, date) FROM 'diagnosis_sub_col2.csv' DELIMITER ',' CSV HEADER N

  pre = """-----------------------------------------


---

# Generate `validate.sql`

---

In [21]:
df = pd.read_csv('./manifest.csv')
df

Unnamed: 0,file,column_count,row_count,unique_patient_count
0,chemo_lines.csv,5,908938,658258
1,cohort_details.csv,3,1,-
2,dataset_details.csv,4,1,-
3,diagnosis.csv,10,2653978299,6012826
4,encounter.csv,9,1274437790,6108666
5,genomic.csv,6,3670955,13057
6,lab_result.csv,10,5713861360,5676455
7,medication_drug.csv,13,1813961625,4902228
8,medication_ingredient.csv,11,5674491072,5775824
9,oncology_treatment.csv,11,528018,90242


In [62]:
val = """-- Validate the TriNetX tables built correctly by checking against known row counts
-- from manifest.csv
WITH expected AS
("""

last_table_name = list(table_to_file_mapping.keys())[-1]
for table, file in table_to_file_mapping.items():
    union_all = '' if table == last_table_name else ' UNION ALL'
    expected_row_cnt = df[df['file']==file.replace('_sub_col','').replace('2','')]['row_count'].values[0]
    table_as = f'\'{table}\' AS tbl,'
    sql = f"""
    SELECT {table_as:<34}{expected_row_cnt} AS row_count{union_all}"""
    
    val +=  sql

    print(table)
    
    
val += """
), observed as
("""

for table, file in table_to_file_mapping.items():
    union_all = '' if table == last_table_name else ' UNION All'
    
    sql = f"""
    SELECT '{table}' AS tbl, count(*) AS row_count FROM {table}{union_all}"""
    
    val +=  sql

    print(table)
    
    
val += """
)
SELECT
    exp.tbl
    , exp.row_count AS expected_count
    , obs.row_count AS observed_count
    , CASE
        WHEN exp.row_count = obs.row_count
        THEN 'PASSED'
        ELSE 'FAILED'
    END AS ROW_COUNT_CHECK
FROM expected exp
INNER JOIN observed obs
  ON exp.tbl = obs.tbl
ORDER BY exp.tbl
;"""

# Write to SQL file
with open('validate.sql', 'w') as f:
    f.write(val)

# print(val)
!cat validate.sql

-- Validate the TriNetX tables built correctly by checking against known row counts
-- from manifest.csv
WITH expected AS
(
    SELECT 'patient_demographic' AS tbl,     6108666 AS row_count UNION ALL
    SELECT 'encounter' AS tbl,               1274437790 AS row_count UNION ALL
    SELECT 'lab_result' AS tbl,              5713861360 AS row_count UNION ALL
    SELECT 'diagnosis' AS tbl,               2653978299 AS row_count UNION ALL
    SELECT 'procedure' AS tbl,               1570070266 AS row_count UNION ALL
    SELECT 'medication' AS tbl,              5674491072 AS row_count UNION ALL
    SELECT 'vital_sign' AS tbl,              2033768504 AS row_count
)
, observed as
(
    SELECT 'patient_demographic' AS tbl, count(*) AS row_count FROM patient_demographic UNION ALL
    SELECT 'encounter' AS tbl, count(*) AS row_count FROM encounter UNION ALL
    SELECT 'lab_result' AS tbl, count(*) AS row_count FROM lab_result UNION ALL
    SELECT 'diagnosis' AS tbl, count(*) AS row_count FROM diag