In [1]:
# import ehrapy as ep
import pandas as pd
import scanpy as sc
import numpy as np

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import os


In [2]:
sample_id_col = "scRNASeq_sample_ID"
cell_type_key = "Annotation_major_subset"
samples_metadata_cols = ["Source", "Outcome", "Death28", "Institute", "Pool_ID"]

In [None]:
DATA_PATH = "/Users/farhad/helmholtz/patpy/pat_rep_benchmark/data/combat_processed_200.h5ad"
HVG_CSV = "hvg_genes.csv"

In [9]:
SAMPLE_ADATAS_PATH = "sample_adatas_500gene"
os.makedirs(SAMPLE_ADATAS_PATH, exist_ok=True)

In [4]:
dirs = [
    "logs/pretrain",
    "logs/finetune",
    "ckpts/pretrain",
    "ckpts/finetune"
]

# Create each directory
for d in dirs:
    os.makedirs(d, exist_ok=True)

In [14]:
adata = sc.read_h5ad(DATA_PATH)
adata

AnnData object with n_obs × n_vars = 27588 × 3000
    obs: 'Annotation_cluster_id', 'Annotation_cluster_name', 'Annotation_minor_subset', 'Annotation_major_subset', 'Annotation_cell_type', 'GEX_region', 'QC_ngenes', 'QC_total_UMI', 'QC_pct_mitochondrial', 'QC_scrub_doublet_scores', 'TCR_chain_composition', 'TCR_clone_ID', 'TCR_clone_count', 'TCR_clone_proportion', 'TCR_contains_unproductive', 'TCR_doublet', 'TCR_chain_TRA', 'TCR_v_gene_TRA', 'TCR_d_gene_TRA', 'TCR_j_gene_TRA', 'TCR_c_gene_TRA', 'TCR_productive_TRA', 'TCR_cdr3_TRA', 'TCR_umis_TRA', 'TCR_chain_TRA2', 'TCR_v_gene_TRA2', 'TCR_d_gene_TRA2', 'TCR_j_gene_TRA2', 'TCR_c_gene_TRA2', 'TCR_productive_TRA2', 'TCR_cdr3_TRA2', 'TCR_umis_TRA2', 'TCR_chain_TRB', 'TCR_v_gene_TRB', 'TCR_d_gene_TRB', 'TCR_j_gene_TRB', 'TCR_c_gene_TRB', 'TCR_productive_TRB', 'TCR_chain_TRB2', 'TCR_v_gene_TRB2', 'TCR_d_gene_TRB2', 'TCR_j_gene_TRB2', 'TCR_c_gene_TRB2', 'TCR_productive_TRB2', 'TCR_cdr3_TRB2', 'TCR_umis_TRB2', 'BCR_umis_HC', 'BCR_contig_qc_HC'

In [5]:
adata.obs.Death28.unique()

array([0, 1])

In [6]:
adata.X = adata.layers["X_raw_counts"]

In [15]:
hvg_info = sc.pp.highly_variable_genes(
    adata, flavor="seurat_v3", n_top_genes=500, inplace=False
)

hvg_genes = adata.var_names[hvg_info.highly_variable]

# Save to CSV
pd.DataFrame({"genes": hvg_genes}).to_csv("hvg_genes.csv", index=False)



In [16]:
SELECTED_GENES = np.array(pd.read_csv("hvg_genes.csv")["genes"])
DONOR_COLUMN = sample_id_col

def process_donor(adata,donor):
    print(f"Processing: {donor}")
    donor_h5 = adata[adata.obs[DONOR_COLUMN] == donor]
    save_path = SAMPLE_ADATAS_PATH + f"/{donor}.h5ad"
    sc.write(save_path, donor_h5)
    print(f"Saved to {save_path}")



sc.pp.normalize_total(adata, target_sum=1, inplace=True)
donors = adata.obs[DONOR_COLUMN].unique()

for donor in donors:
    process_donor(adata, donor)

Processing: S00109-Ja001E-PBCa
Saved to sample_adatas_500gene/S00109-Ja001E-PBCa.h5ad
Processing: S00112-Ja003E-PBCa
Saved to sample_adatas_500gene/S00112-Ja003E-PBCa.h5ad
Processing: S00005-Ja005E-PBCa
Saved to sample_adatas_500gene/S00005-Ja005E-PBCa.h5ad
Processing: S00061-Ja003E-PBCa
Saved to sample_adatas_500gene/S00061-Ja003E-PBCa.h5ad
Processing: S00056-Ja003E-PBCa
Saved to sample_adatas_500gene/S00056-Ja003E-PBCa.h5ad
Processing: N00027-Ja001E-PBGa
Saved to sample_adatas_500gene/N00027-Ja001E-PBGa.h5ad
Processing: H00067-Ha001E-PBGa
Saved to sample_adatas_500gene/H00067-Ha001E-PBGa.h5ad
Processing: G05153-Ja005E-PBCa
Saved to sample_adatas_500gene/G05153-Ja005E-PBCa.h5ad
Processing: U00515-Ua005E-PBUa
Saved to sample_adatas_500gene/U00515-Ua005E-PBUa.h5ad
Processing: U00505-Ua005E-PBUa
Saved to sample_adatas_500gene/U00505-Ua005E-PBUa.h5ad
Processing: S00028-Ja001E-PBCa
Saved to sample_adatas_500gene/S00028-Ja001E-PBCa.h5ad
Processing: N00023-Ja001E-PBGa
Saved to sample_adatas_

In [18]:
!python pretrain.py --config configs/configs_pretraining/Data2Vec_heart_example.yaml



Epoch: 1/2      100%|###########################################################
Evaluating...   100%|###########################################################

Epoch: 2/2      100%|###########################################################
Evaluating...   100%|###########################################################
Saved checkpoint to `ckpts/pretrain/2.pt`


In [23]:
!python train.py --config configs/configs_finetuning/FineTune_heart_example.yaml


Using 98 patients for training and 14 patients for validation representing 2 unique disease
Training diseases:  [0 1]
Epoch 0, loss: 2.8876: 100%|█████████████████████| 4/4 [23:19<00:00, 349.98s/it]
100%|█████████████████████████████████████████████| 4/4 [03:57<00:00, 59.44s/it]
  0%|                                                     | 0/1 [00:00<?, ?it/s]python(8445) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8446) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8447) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(8448) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 0, val_loss: 0.0000: 100%|██████████████████| 1/1 [00:37<00:00, 37.58s/it]
Epoch 1, loss: 2.8876: 100%|█████████████████████| 4/4 [22:25<00:00, 336.35s/it]
100%|█████████████████████████████████████████████| 4/4 [03:48<00:00, 57.23s/it]
Epoch 1, 

In [25]:
!python inference.py --config configs/configs_inference/Inference_heart_example.yaml


python(19439) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Using 140 patients representing [0 1] diseases
100%|███████████████████████████████████████████| 18/18 [01:56<00:00,  6.49s/it]
  warn(
