In [1]:
from copy import deepcopy
from collections import Counter
from pathlib import Path
import numpy as np
import pandas as pd
from src.data import data_io, dataset

In [2]:
def get_last_followup_status(patient_data):
    if patient_data.get("vital_status") == "Dead":
        cancer_status = patient_data.get("person_neoplasm_cancer_status", "NA")
        primary_therapy_outcome = patient_data.get(
            "primary_therapy_outcome_success", "NA"
        )
        return "Dead", cancer_status, primary_therapy_outcome
    followups = list(patient_data.get("follow_ups", {}).values())
    if not followups:
        return (
            patient_data.get("vital_status", "NA"),
            "NA",
            patient_data.get("primary_therapy_outcome_success", "NA"),
        )
    last_followup = None
    last_followup_days = -1
    for followup in followups:
        if followup.get("vital_status") == "Dead":
            # If the patient is dead, this must be the most recent followup
            last_followup = followup
            break
        days = int(followup.get("days_to_last_followup", 1))
        if days > last_followup_days:
            last_followup = followup
            last_followup_days = days
    vital_status = last_followup.get("vital_status", "NA")
    cancer_status = last_followup.get("person_neoplasm_cancer_status", "NA")
    primary_therapy_outcome = last_followup.get(
        "primary_therapy_outcome_success", "NA"
    )
    return vital_status, cancer_status, primary_therapy_outcome


def get_drug_treatment_responses(patient_data):
    responses = set()
    for treatment in patient_data.get("drugs", []):
        response = treatment.get("measure_of_response")
        if response:
            responses.add(response)
    return responses


def get_stage_and_grade(patient_data: dict) -> tuple[str, str]:
    stage = patient_data.get("clinical_stage")
    if not stage:
        stage = patient_data.get("pathologic_stage")
    grade = patient_data.get("neoplasm_histologic_grade")
    return stage, grade

In [3]:
DATA_DIR = Path("../data")
CLINICAL_DIR = DATA_DIR / "interim/clinical"
CLINICAL_DIR.mkdir(parents=True, exist_ok=True)
data_io.unzip_files(DATA_DIR / "raw/TCGA_clinical_bcr_xml.zip", CLINICAL_DIR)
clin_tcga = dataset.parse_clinical_data(CLINICAL_DIR)
TCGA_rnaseq_counts_metadata = pd.read_csv("../data/raw/TCGA_rnaseq_counts_metadata.csv.gz", index_col=0)

tcga_coldata = pd.DataFrame(
    [
        [
            clin_tcga[barcode]["disease_code"],
            *get_stage_and_grade(clin_tcga[barcode]),
            "NA",
        ]
        for barcode in clin_tcga
    ],
    columns=["disease_code", "stage", "grade", "stage_grade"],
    index=list(clin_tcga),
)
tcga_coldata["stage"] = tcga_coldata["stage"].str.replace(
    r"[^IV]", "", regex=True
)
tcga_coldata = tcga_coldata[
    tcga_coldata["stage"].isin(("I", "II", "III", "IV"))
]
tcga_coldata["grade"] = tcga_coldata["grade"].str.replace("High Grade", "G3")
tcga_coldata["grade"] = tcga_coldata["grade"].str.replace(
    r"[^GXx\d]", "", regex=True
)
tcga_coldata = tcga_coldata[
    tcga_coldata["grade"].isin(("G1", "G2", "G3", "G4"))
]
tcga_coldata["stage_grade"] = tcga_coldata["stage"].str.cat(
    tcga_coldata["grade"], sep="."
)
tcga_coldata = tcga_coldata[tcga_coldata.index.isin(set(TCGA_rnaseq_counts_metadata["cases.submitter_id"]))]
clin_tcga = {k: v for k, v in clin_tcga.items() if k in tcga_coldata.index}
len(tcga_coldata)

3702

## Extract platinum (PT) chemotherapy response groups
*E.g. Carboplatin, Cisplatin, and Oxaliplatin*

In [4]:
clin_pt_treated = {
    k: deepcopy(v)
    for k, v in clin_tcga.items()
    if "platin" in str(v.get("drugs", [])).lower()
}
for k, v in clin_pt_treated.items():
    v["drugs"] = [d for d in v["drugs"].values() if "platin" in str(d).lower()]

In [5]:
clin_pt_resistant = {}
clin_pt_sensitive = {}

for k in list(clin_pt_treated):
    patient_data = clin_pt_treated[k]
    stage, grade = get_stage_and_grade(patient_data)
    if not stage or not grade or "GX" in grade:
        continue
    responses = get_drug_treatment_responses(patient_data)
    if responses:
        # Infer PT resistance based on documented responses
        any_complete = any("Complete" in r for r in responses)
        any_partial = any("Partial" in r for r in responses)
        any_stable = any("Stable" in r for r in responses)
        any_progressive = any("Progressive" in r for r in responses)
        if (any_complete or any_partial) and not (any_stable or any_progressive):
            clin_pt_sensitive[k] = patient_data
        elif not (any_complete or any_partial) and (any_stable or any_progressive):
            # Additional requirement for stage IV: "Progressive Disease" response only
            # (stable might actually be a better than expected response in stage IV)
            if "IV" in stage and any_stable:
                continue
            clin_pt_resistant[k] = patient_data
    else:
        # Infer PT resistance based on outcomes after treatment. To reduce risk of
        # introducing bias, stage IV patients are excluded here due to poor prognosis.
        # if "IV" not in stage:
        vital_status, cancer_status, primary_therapy_outcome = get_last_followup_status(
            patient_data
        )
        if (
            vital_status == "Alive"
            and cancer_status == "TUMOR FREE"
            and "Disease" not in primary_therapy_outcome
        ):
            clin_pt_sensitive[k] = patient_data
        elif (
            vital_status == "Dead"
            and cancer_status == "WITH TUMOR"
            and "Remission" not in primary_therapy_outcome
            and "IV" not in stage
        ):
            clin_pt_resistant[k] = patient_data

# Filter out underrepresented stage/grade subgroups within each disease code
# Require each subgroup to have at least 2 instances in each response group
disease_codes = set(
    [v["disease_code"] for v in (clin_pt_resistant | clin_pt_sensitive).values()]
)
for disease_code in disease_codes:
    resistant_barcodes = [
        k for k, v in clin_pt_resistant.items() if v["disease_code"] == disease_code
    ]
    sensitive_barcodes = [
        k for k, v in clin_pt_sensitive.items() if v["disease_code"] == disease_code
    ]
    resistant_subgroups = tcga_coldata.loc[resistant_barcodes]
    sensitive_subgroups = tcga_coldata.loc[sensitive_barcodes]
    sensitive_subgroups_counts = sensitive_subgroups["stage_grade"].value_counts()
    resistant_subgroups_counts = resistant_subgroups["stage_grade"].value_counts()
    exclude_subgroups = []
    for group in set(sensitive_subgroups["stage_grade"]).union(
        set(resistant_subgroups["stage_grade"])
    ):
        if (
            group not in sensitive_subgroups_counts
            or group not in resistant_subgroups_counts
        ):
            exclude_subgroups.append(group)
        elif (
            sensitive_subgroups_counts[group] < 2
            or resistant_subgroups_counts[group] < 2
        ):
            exclude_subgroups.append(group)
    exclude_barcodes_sen = sensitive_subgroups[
        sensitive_subgroups["stage_grade"].isin(exclude_subgroups)
    ].index
    exclude_barcodes_res = resistant_subgroups[
        resistant_subgroups["stage_grade"].isin(exclude_subgroups)
    ].index
    clin_pt_sensitive = {
        k: v for k, v in clin_pt_sensitive.items() if k not in exclude_barcodes_sen
    }
    clin_pt_resistant = {
        k: v for k, v in clin_pt_resistant.items() if k not in exclude_barcodes_res
    }

# Count frequencies of disease codes in each cohort
resistant_counter = Counter(v["disease_code"] for v in clin_pt_resistant.values())
sensitive_counter = Counter(v["disease_code"] for v in clin_pt_sensitive.values())

# Keep only disease codes with more than 15 instances in each response group
keep_codes = {
    c
    for c in resistant_counter
    if resistant_counter[c] > 15 and sensitive_counter[c] > 15
}

# Filter out patients with disease codes not in keep_codes
clin_pt_resistant = {
    k: v for k, v in clin_pt_resistant.items() if v["disease_code"] in keep_codes
}
clin_pt_sensitive = {
    k: v for k, v in clin_pt_sensitive.items() if v["disease_code"] in keep_codes
}

resistant_counter = Counter(v["disease_code"] for v in clin_pt_resistant.values())
sensitive_counter = Counter(v["disease_code"] for v in clin_pt_sensitive.values())

# Update TCGA coldata
pt_cohorts_disease_codes = set([v["disease_code"] for v in clin_pt_sensitive.values()])
tcga_coldata_filt = tcga_coldata[
    tcga_coldata["disease_code"].isin(pt_cohorts_disease_codes)
].copy()
tcga_coldata_filt["PT_resistant"] = "NA"
tcga_coldata_filt.loc[clin_pt_sensitive.keys(), "PT_resistant"] = "No"
tcga_coldata_filt.loc[clin_pt_resistant.keys(), "PT_resistant"] = "Yes"

print("-" * 80)
for disease_code in pt_cohorts_disease_codes:
    print(disease_code)
    # sen_barcodes = [k for k, v in clin_pt_sensitive.items() if v["disease_code"] == disease_code]
    # res_barcodes = [k for k, v in clin_pt_resistant.items() if v["disease_code"] == disease_code]
    coldata_disease = tcga_coldata_filt[
        tcga_coldata_filt["disease_code"] == disease_code
    ]
    coldata_unlabeled = coldata_disease[coldata_disease["PT_resistant"] == "NA"]
    coldata_sensitive = coldata_disease[coldata_disease["PT_resistant"] == "No"]
    coldata_resistant = coldata_disease[coldata_disease["PT_resistant"] == "Yes"]
    unlabeled_counts = coldata_unlabeled["stage_grade"].value_counts().sort_index()
    sensitive_counts = coldata_sensitive["stage_grade"].value_counts().sort_index()
    resistant_counts = coldata_resistant["stage_grade"].value_counts().sort_index()
    unlabeled_counts = unlabeled_counts.loc[sensitive_counts.index]
    unlabeled_proportions = unlabeled_counts / len(coldata_disease)
    sensitive_proportions = sensitive_counts / len(coldata_disease)
    resistant_proportions = resistant_counts / len(coldata_disease)
    subgroup_counts_table = pd.DataFrame(
        {
            "Unlabeled": unlabeled_counts,
            "Sensitive": sensitive_counts,
            "Resistant": resistant_counts,
            "Total": unlabeled_counts + unlabeled_counts + resistant_counts,
        }
    )
    subgroup_counts_table.loc["Total"] = subgroup_counts_table.sum(0)

    print("Subgroup counts:")
    print(subgroup_counts_table)
    print("\n" + "-" * 80)

#######################################################################################
# Notice the trend in the output for higher stages and grades to be labeled resistant.
# Stage and grade are covariates for our target variable (inferred PT resistance).
# To deal with this, we will adjust downstream analysis as follows:
# - Add covariates stage and grade to DE analysis design formula.
# - For binary classification, include stage and grade as categorical features.
# - For binary classification, undersample the minority class per stage/grade subgroup.
# - Use propensity scoring to evaluate binary classification models.
# - For binary classification feature attributions, train additional models
#   on random samples containing the same distribution of stage and grade within each
#   class (classes = {sensitive, resistant}) as the labeled dataset, and use the
#   feature attributions of models trained on pseudo samples as the baseline
#   attribution of each feature. Subtract baseline feature attributions from the
#   feature attribution values calculated for models trained on the labeled dataset.
#######################################################################################

### Downstream analyses (not necessarily in this order):
# - PCA, UMAP, ICA
# - Differential expression (DE) analysis
# - Gene set enrichment analysis (GSEA)
# - Weighted gene correlation network analysis (WGCNA)
# - Binary classification and feature attribution explorations
#   - Iterative random forest (iRF) using ranger (https://github.com/imbs-hl/ranger) for RF, Gini w/ p-values
#     - stage and grade included as categorical variables and always considered for splitting. See `--catvars` and `--alwayssplitvars`
#   - Top-scoring pairs (TSP) + Support vector machine (SVM)
#   - L1 and L2 regularized logistic regression (Elastic net)
#   - Pre-defined feature subsets
#     - Hallmark gene sets from MSigDB
#     - Matrisome, core matrisome (MSigDB: NABA_MATRISOME and NABA_CORE_MATRISOME)
#     - PT sensitivity gene expression markers (links below)
#       - Publication: https://www.nature.com/articles/s41388-021-02055-2
#       - Database: http://ptrc-ddr.cptac-data-view.org/#/
#       - Additional filters:
#         - "Source of the supporting data" must include human cancer tissue
#         - Source of supporting data cannot include TCGA (data leakage).
#           - Check is performed by searching through articles by their PMIDs annotated in the database.

--------------------------------------------------------------------------------
CESC
Subgroup counts:
             Unlabeled  Sensitive  Resistant  Total
stage_grade                                        
I.G2                51         17          4    106
I.G3                46         12          7     99
II.G2               16         10          3     35
II.G3               12         11          2     26
III.G2              13          7          2     28
III.G3              10          7          2     22
IV.G2                4          3          3     11
Total              152         67         23    327

--------------------------------------------------------------------------------
STAD
Subgroup counts:
             Unlabeled  Sensitive  Resistant  Total
stage_grade                                        
II.G2               33          4          3     69
II.G3               62         11          4    128
III.G2              40          4          3     83
III.G3       

In [6]:
### Prepare coldata and raw counts for each disease code cohort
count_df = pd.read_parquet(DATA_DIR / "raw/TCGA_rnaseq_counts.parquet")
genes_df = pd.read_csv(DATA_DIR / "raw/genes.csv.gz", index_col=0)
assert np.all(genes_df.index == count_df.index)
genes_keep = genes_df["gene_type"] == "protein_coding"
genes_df = genes_df[genes_keep]
count_df = count_df.loc[genes_df.index]
count_df.index = count_df.index.str.replace(r"\.\d+$", "", regex=True)

# Use symbols instead of gene id. For duplicate symbols, keep only one w/ max total expression
# sym = genes_df.loc[count_df.index, "gene_name"]
# sym = sym.loc[count_df.assign(sum=count_df.sum(1), group=sym).groupby("group")["sum"].idxmax()]
# count_df = count_df.loc[sym.index]
# count_df.index = sym

# Create count matrices for each disease code cohort
barcodes_df = TCGA_rnaseq_counts_metadata
for disease_code in pt_cohorts_disease_codes:
    barcodes = tcga_coldata_filt[
        tcga_coldata_filt["disease_code"] == disease_code
    ].index
    count_df_disease = pd.DataFrame(
        np.zeros((len(count_df), len(barcodes)), dtype=np.float32),
        index=count_df.index,
        columns=barcodes,
    )
    count_df_disease.index.name = "gene"
    for barcode in barcodes:
        fnames = barcodes_df[barcodes_df["cases.submitter_id"] == barcode].index
        counts = count_df[fnames].values
        # If more than one expression profile available, use the one with more counts
        count_df_disease[barcode] = counts[:, np.argmax(counts.sum(0))]

    # Filter out genes with low expression (more than 50% of samples zero)
    keep = (count_df_disease > 0).sum(1) > count_df_disease.shape[1] / 2
    count_df_disease = count_df_disease[keep]

    counts_dir = DATA_DIR / "interim/TCGA/counts"
    counts_dir.mkdir(parents=True, exist_ok=True)
    coldata_dir = DATA_DIR / "interim/TCGA/coldata"
    coldata_dir.mkdir(parents=True, exist_ok=True)

    count_df_disease.to_csv(counts_dir / f"{disease_code}.csv.gz", compression="gzip")
    coldata = tcga_coldata_filt.loc[barcodes].drop(
        columns=["disease_code", "stage_grade"]
    )
    coldata.to_csv(coldata_dir / f"{disease_code}.csv")