# Analyse the SQLite database

In [2]:
from pathlib import Path
import os, sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/1_build_database
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from hydra import compose, initialize
from omegaconf import OmegaConf
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from FastEHR.database.collector import SQLiteDataCollector

In [5]:
# load the configuration file
with initialize(version_base=None, config_path="../../modelling/SurvivEHR/confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", overrides=[])

print(OmegaConf.to_yaml(cfg.data))

batch_size: 64
unk_freq_threshold: 0.0
min_workers: 12
global_diagnoses: false
repeating_events: true
path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db
path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
subsample_training: null



# Initialise connector to database

In [6]:
collector = SQLiteDataCollector(db_path=cfg.data.path_to_db)

## Reminder of what the ``static_table`` looks like

In [7]:
collector.connect()

collector.cursor.execute("""SELECT * FROM static_table WHERE sex=='M' AND imd=='1' LIMIT 10""")   # 
results = collector.cursor.fetchall()
for result in results:
    print(result)

collector.disconnect()

(20960, 2595705420960, 'WHITE', '1994-07-15', 'M', 'E', 1, 'North West', '2017-09-12', '2017-09-12', '2020-02-13')
(20960, 2595768620960, 'WHITE', '1961-07-15', 'M', 'E', 1, 'North West', '2005-01-01', '2005-01-01', '2022-02-22')
(20960, 2595914820960, 'WHITE', '1975-07-15', 'M', 'E', 1, 'North West', '2018-01-31', '2018-01-31', '2019-09-09')
(20960, 2595926220960, 'WHITE', '1937-07-15', 'M', 'E', 1, 'North West', '2005-01-01', '2005-01-01', '2009-08-25')
(20960, 2596034820960, 'WHITE', '1987-07-15', 'M', 'E', 1, 'North West', '2016-02-24', '2016-02-24', '2019-01-30')
(20960, 2596039620960, 'WHITE', '1945-07-15', 'M', 'E', 1, 'North West', '2005-01-01', '2005-01-01', '2018-09-28')
(20960, 2596131020960, 'WHITE', '1980-07-15', 'M', 'E', 1, 'North West', '2012-08-16', '2012-08-16', '2022-02-22')
(20960, 2596183720960, 'WHITE', '1988-07-15', 'M', 'E', 1, 'North West', '2005-01-01', '2005-01-01', '2022-02-22')
(21011, 2723778021011, 'MISSING', '1971-07-15', 'M', 'E', 1, 'London', '2005-01-

In [8]:
collector.connect()

collector.cursor.execute("""PRAGMA table_info(static_table);""") 
results = collector.cursor.fetchall()
for result in results:
    print(result)

collector.disconnect()

(0, 'PRACTICE_ID', 'INTEGER', 0, None, 0)
(1, 'PATIENT_ID', 'INTEGER', 0, None, 0)
(2, 'ETHNICITY', 'TEXT', 0, None, 0)
(3, 'YEAR_OF_BIRTH', 'TEXT', 0, None, 0)
(4, 'SEX', 'TEXT', 0, None, 0)
(5, 'COUNTRY', 'TEXT', 0, None, 0)
(6, 'IMD', 'INTEGER', 0, None, 0)
(7, 'HEALTH_AUTH', 'TEXT', 0, None, 0)
(8, 'INDEX_DATE', 'TEXT', 0, None, 0)
(9, 'START_DATE', 'TEXT', 0, None, 0)
(10, 'END_DATE', 'TEXT', 0, None, 0)


## Make figures of the demographics

In [28]:
def make_pie_chart(df, col_lbl, sort_by="values"):
    
    # ── handle labels ────────────────────────────────────────────────────────
    if df["labels"].dtype == np.float64:   
        # Numerical categories
        df = (
            df.fillna(0)              # 1. give every NaN a harmless temporary value
              .astype(int)            # 2. fast C-level cast floats → ints
              .mask(df.isna(), "Missing")   # 3. restore the NaN positions as "Missing"
        )
    else:
        # String categories
        df = df.mask(df.isna(), "Missing")  
        
        mask = df["labels"].str.isupper()          # rows where the label is ALL CAPS
        df.loc[mask, "labels"] = df.loc[mask, "labels"].str.capitalize()

    # Sort order labels will apear in the legend
    priority = {"Other": 1, "Missing": 2}                            # lower = earlier
    df = (df.assign(_p=df["labels"].map(priority).fillna(0))         # 0 for all others
             .sort_values(["_p", sort_by], ascending=[True, False if sort_by == "values" else True],  # or add "labels"
                          ignore_index=True)
             .drop(columns="_p"))
    
    # ── seaborn style & colour palette ────────────────────────────────────────
    # sns.set_theme(style="white")                     # clean, publication-ready
    # palette = sns.color_palette("pastel", len(df))   # gentle pastel colours

    # ── pie chart ─────────────────────────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(3.5/0.78, 3.5), dpi=300) 
    fig.subplots_adjust(right=0.78)   # reserve the right 22 % for the legend

    wedges, texts, autotexts = ax.pie(
        df["values"],
        labels=None,                # legend instead of on-slice labels
        # colors=palette,
        autopct=lambda pct: "" if pct < 1 else f"{pct:.0f}%",  # hide <1 %,
        pctdistance=1.2,            # pull % into the slice
        startangle=90,              # rotate so first slice starts at 12 o’clock
        counterclock=False,
        wedgeprops=dict(linewidth=1, edgecolor="white")  # thin slice borders
    )

    # ── legend (append “(<1%)” where we removed the label) ────────────────────
    total = df["values"].sum()
    pct = df["values"] / total * 100
    legend_labels = [
        f"{lab} (<1%)" if p < 1 else lab
        for lab, p in zip(df["labels"], pct)
    ]
            
    ax.legend(wedges, 
              legend_labels,
              title=col_lbl,
              loc="center left",
              bbox_to_anchor=(1.02, 0.5),
              frameon=False)
    ax.set(aspect="equal") #, title="Birth-Decade Distribution")
    sns.despine(left=True, bottom=True)  # clean up spines
    
    plt.savefig(f"figs/pie_chart_{col_lbl}.png", bbox_inches="tight")
    plt.tight_layout()
    plt.close(fig)

### Categorical columns of the static table

In [29]:
column_names = ["ETHNICITY", "SEX", "COUNTRY", "IMD", "HEALTH_AUTH"]
column_labels = ["Ethnicity", "Sex", "Country", "IMD", "Health authority"]
column_sort_by = ["values", "values", "values", "labels", "values"]

# Create query string for any practice inclusion conditions that were applied during pre-train dataset creation
practice_inclusion_conditions=["COUNTRY = 'E'"]
where_sql = f"WHERE {' AND '.join(practice_inclusion_conditions)}" \
            if practice_inclusion_conditions else ""

for col_name, col_lbl, col_sort_by in zip(column_names, column_labels, column_sort_by):
    print(col_lbl)
    
    query = f"""
    SELECT {col_name},
           COUNT(*) AS freq
    FROM   static_table
    {where_sql}
    GROUP  BY {col_name}
    ORDER  BY freq DESC;
    """

    collector.connect()
    collector.cursor.execute(query) 
    results = collector.cursor.fetchall()
    collector.disconnect()

    df = pd.DataFrame(results, columns=["labels", "values"])

    make_pie_chart(df, col_lbl, sort_by=col_sort_by)

    # Total number of patients fitting inclusion criteria
    total_patients = sum([r[1] for r in results])

Ethnicity
Sex
Country
IMD
Health authority


### Continuous year of birth column

In [30]:
query = """
SELECT
    MIN(YEAR_OF_BIRTH) AS smallest_value,
    MAX(YEAR_OF_BIRTH) AS largest_value
FROM   static_table;
"""
collector.connect()
collector.cursor.execute(query) 
results = collector.cursor.fetchall()
collector.disconnect()

min_yob = results[0][0][:4]
max_yob = results[0][1][:4]
print(min_yob)
print(max_yob)

1890
2021


In [27]:
chunk_size = 25

query = f"""
    SELECT (CAST(strftime('%Y', YEAR_OF_BIRTH) AS INT) / {chunk_size}) * {chunk_size} AS chunk_start,
           COUNT(*) AS freq
    FROM   static_table
    {where_sql}
    GROUP BY chunk_start
    ORDER BY chunk_start;
"""

collector.connect()
collector.cursor.execute(query) 
results = collector.cursor.fetchall()
collector.disconnect()

df = pd.DataFrame(results, columns=["chunk_start", "values"])
df["chunk_end"] = df["chunk_start"] + chunk_size - 1

# Combine first and second chunk
df.loc[df.index[1], "values"] = df.loc[df.index[0], "values"] + df.loc[df.index[1], "values"]
df = df.iloc[1:].reset_index(drop=True)  # keep all but the first row

# update start and end points of first and last bins then make labels
df.loc[df.index[0], "chunk_start"] = min_yob
df.loc[df.index[-1], "chunk_end"] = max_yob

df["labels"] = df["chunk_start"].astype(str) + "-" + df["chunk_end"].astype(str) 

make_pie_chart(df, "Birth period", sort_by="labels")