In [1]:
import anndata

# Load the .h5ad file
adata = anndata.read_h5ad("897e76b2-59f6-482c-827d-37cc62fa4f50.h5ad")

# Check basic details about the dataset
print(adata)


AnnData object with n_obs × n_vars = 352734 × 45453
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'harm_study', 'harm_healthy.tissue', 'harm_tumor.site', 'harm_sample.type', 'harm_condition', 'harm_tumor.type', 'harm_cd45pos', 'harm_healthy.pat', 'percent.mt', 'ratio_nCount_nFeature', 'batch', 'X_scvi_batch', 'X_scvi_labels', 'X_scvi_local_l_mean', 'X_scvi_local_l_var', 'leiden_0.2', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1', 'leiden_1.2', 'leiden_1.4', 'author_first_cell_type', 'author_cell_type', 'cnv_score', 'organism_ontology_term_id', 'donor_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_id', 'dup', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    v

In [2]:
import scanpy as sc
import pandas as pd
import numpy as np
from mygene import MyGeneInfo

# -----------------------------
# Step 0: Remove mitochondrial genes
# -----------------------------
mt_mask = adata.var["feature_name"].str.upper().str.startswith("MT-")
adata = adata[:, ~mt_mask].copy()

# -----------------------------
# Step 1: Filter samples
# -----------------------------
adata = adata[adata.obs["disease"] != "normal"].copy()
adata = adata[adata.obs["cell_type"] != "megakaryocyte"].copy()

# -----------------------------
# Step 2: Define group and filter rare groups
# -----------------------------
adata.obs["group"] = adata.obs["cell_type"].astype(str) + "_" + adata.obs["disease"].astype(str)
group_counts = adata.obs["group"].value_counts()
valid_groups = group_counts[group_counts >= 2].index
adata = adata[adata.obs["group"].isin(valid_groups)].copy()

# -----------------------------
# Step 3: Normalize and log1p
# -----------------------------
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# -----------------------------
# Step 4: HVG selection
# -----------------------------
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3")

# -----------------------------
# Step 5: Differential expression on HVGs only
# -----------------------------
adata_hvg = adata[:, adata.var["highly_variable"]].copy()
sc.tl.rank_genes_groups(
    adata_hvg,
    groupby="group",
    method="wilcoxon",
    n_genes=50
)

# -----------------------------
# Step 6: Collect marker genes (Ensembl IDs)
# -----------------------------
marker_genes_by_group = {}
for group in adata_hvg.obs["group"].unique():
    marker_genes_by_group[group] = list(adata_hvg.uns["rank_genes_groups"]["names"][group][:10])

# -----------------------------
# Step 7: Map Ensembl to gene symbols
# -----------------------------
mg = MyGeneInfo()
unmapped = set(g for genes in marker_genes_by_group.values() for g in genes if g.startswith("ENSG"))
query_result = mg.querymany(list(unmapped), scopes="ensembl.gene", fields="symbol", species="human")
external_map = {item["query"]: item.get("symbol", item["query"]) for item in query_result}
ensembl_to_symbol = adata.var["feature_name"].to_dict()
ensembl_to_symbol.update(external_map)

# -----------------------------
# Step 8: Rebuild gene symbol marker dict
# -----------------------------
converted_marker_genes_by_group = {}
for group, gene_list in marker_genes_by_group.items():
    converted = [ensembl_to_symbol.get(g, g) for g in gene_list]
    converted_marker_genes_by_group[group] = converted

# -----------------------------
# Step 9: Merge HVGs and marker genes for downstream
# -----------------------------
# 1. Flatten marker gene list
all_marker_genes = set(g for genes in converted_marker_genes_by_group.values() for g in genes)

# 2. Make sure they are in the dataset
gene_symbols = adata.var["feature_name"]
marker_genes_in_data = [g for g in all_marker_genes if g in gene_symbols.values]

# 3. Create mask for HVGs + marker genes
hvg_mask = adata.var["highly_variable"]
additional_genes_mask = gene_symbols.isin(marker_genes_in_data)
final_mask = hvg_mask | additional_genes_mask

# 4. Subset adata
adata = adata[:, final_mask].copy()

# -----------------------------
# Optional: Report
# -----------------------------
included = [g for g in all_marker_genes if g in adata.var["feature_name"].values]
print(f" Included marker genes: {len(included)} / {len(all_marker_genes)}")






  foldchanges = (self.expm1_func(mean_group) + 1e-9) / (
  foldchanges = (self.expm1_func(mean_group) + 1e-9) / (
  foldchanges = (self.expm1_func(mean_group) + 1e-9) / (
  foldchanges = (self.expm1_func(mean_group) + 1e-9) / (
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] =

 Included marker genes: 300 / 311


In [6]:
print("Final existing marker genes:")
print(included)


Final existing marker genes:
['GLUL', 'MIF', 'ENG', 'ESAM', 'TPSAB1', 'CLN8', 'ORAI2', 'CXCR4', 'CD52', 'PPP1R14A', 'IKZF3', 'SRGN', 'TM4SF4', 'SPARCL1', 'RAMP2', 'PECAM1', 'NAPSA', 'VWA5A', 'HBB', 'IGKC', 'VWF', 'RBP1', 'IGHV3-23', 'TNFAIP3', 'ORC4', 'MALAT1', 'ANKRD20A9P', 'ACTG1', 'RPL21', 'KRT7', 'HLA-DQA1', 'B2M', 'ZNF106', 'S100P', 'CD68', 'TCF4', 'NKG7', 'COL3A1', 'PTPRC', 'MS4A1', 'EEF1A1', 'SEC61G', 'MZB1', 'RAMP3', 'CORO1A', 'HBA1', 'SPARC', 'ALOX5AP', 'IGLC3', 'CD2', 'SHISA9', 'C15orf48', 'CRIP2', 'LGALS1', 'RAB38', 'HSP90AA1', 'ACTB', 'FCER1G', 'KRT19', 'ITM2C', 'C1S', 'MLANA', 'SCGB3A2', 'VIM', 'TMSB4X', 'H3-3A', 'ANXA5', 'FTL', 'FABP4', 'MUCL1', 'LYZ', 'SFTPA2', 'GZMB', 'MPEG1', 'PLAUR', 'FKBP1A', 'C12orf75', 'FCGR3B', 'C1R', 'ELOB', 'PTGDS', 'HOPX', 'TPSD1', 'TEX41', 'CTSS', 'MGST1', 'JCHAIN', 'PPDPF', 'EPCAM', 'HBA2', 'TCEA1', 'TFF3', 'BCL2A1', 'IRF4', 'VTN', 'HLA-C', 'CD3D', 'CST7', 'CSF3R', 'WFDC2', 'SEC61B', 'RGS13', 'KRT18', 'LUM', 'BGN', 'S100A9', 'HSPG2', 'RGS1', 

In [9]:
# -----------------------------
# Final: Print marker genes per cell type_disease group
# -----------------------------
filtered_marker_dict = {}

for group, genes in converted_marker_genes_by_group.items():
    valid_genes = [g for g in genes if g in adata.var["feature_name"].values]
    if valid_genes:
        filtered_marker_dict[group] = valid_genes

# Print in requested format
for group, genes in filtered_marker_dict.items():
    print(f"{group}: {genes}")


endothelial cell_lung cancer: ['VWF', 'EPAS1', 'PECAM1', 'CLDN5', 'CLEC14A', 'SPARCL1', 'RAMP2', 'HYAL2', 'AQP1', 'TM4SF1']
mononuclear phagocyte_lung cancer: ['TYROBP', 'LYZ', 'HLA-DRA', 'FCER1G', 'HLA-DRB1', 'CD74', 'LAPTM5', 'SRGN', 'GLUL', 'CTSS']
B cell_lung cancer: ['CD79A', 'CD74', 'HLA-DRA', 'CXCR4', 'TTC19', 'MS4A1', 'CD37', 'HLA-DRB1']
malignant cell_lung cancer: ['KRT19', 'KRT8', 'KRT18', 'PERP', 'SPINT2', 'EPCAM', 'MGST1', 'TXN', 'KRT7', 'TACSTD2']
epithelial cell_lung cancer: ['SFTPB', 'SFTPA2', 'NAPSA', 'SFTPA1', 'KRT19', 'SFTA2', 'HOPX', 'SCGB3A2']
T cell_lung cancer: ['CXCR4', 'CD3D', 'TRBC2', 'CD2', 'SRGN', 'RGS1', 'IL32', 'TRBC1', 'PTPRC', 'CD7']
mast cell_lung cancer: ['CPA3', 'TPSAB1', 'TPSB2', 'MS4A2', 'KIT', 'SLC18A2', 'GATA2', 'HPGDS', 'VWA5A', 'SRGN']
fibroblast_lung cancer: ['DCN', 'COL1A2', 'SPARC', 'COL6A2', 'BGN', 'COL3A1', 'SERPING1', 'C1R', 'LUM', 'C1S']
plasmacytoid dendritic cell_lung cancer: ['GZMB', 'GPR183', 'JCHAIN', 'PLAC8', 'IRF7', 'IRF4', 'CD74', 

In [10]:
# -----------------------------
# Final: Print marker genes grouped by cell type and cancer
# -----------------------------
from collections import defaultdict

grouped_by_celltype = defaultdict(dict)

for group, genes in converted_marker_genes_by_group.items():
    if "_" in group:
        cell_type, disease = group.split("_", 1)
        grouped_by_celltype[cell_type][disease] = [g for g in genes if g in adata.var["feature_name"].values]

# Print in requested format
for cell_type, disease_dict in grouped_by_celltype.items():
    for disease, genes in disease_dict.items():
        print(f"{cell_type}_{disease}: {genes}")


endothelial cell_lung cancer: ['VWF', 'EPAS1', 'PECAM1', 'CLDN5', 'CLEC14A', 'SPARCL1', 'RAMP2', 'HYAL2', 'AQP1', 'TM4SF1']
endothelial cell_breast cancer: ['RAMP2', 'GNG11', 'EGFL7', 'PLVAP', 'SPARCL1', 'PECAM1', 'CRIP2', 'CD34', 'ADGRL4', 'IGFBP7']
endothelial cell_melanoma: ['SHISA9', 'KCNQ1OT1', 'ANKRD20A9P', 'ORC4', 'EEF1G', 'VIM', 'CFLAR', 'TNFAIP8L1', 'HLA-B', 'RBMS2']
endothelial cell_liver cancer: ['RAMP2', 'EGFL7', 'HSPG2', 'GNG11', 'PLVAP', 'ENG', 'PECAM1', 'TCF4', 'FKBP1A', 'CRIP2']
endothelial cell_uveal melanoma: ['GNG11', 'RAMP2', 'CAVIN2', 'IFITM3', 'PECAM1', 'TM4SF1', 'RAMP3', 'CRIP2', 'ESAM', 'ENG']
endothelial cell_colorectal cancer: ['RAMP2', 'GNG11', 'PECAM1', 'FKBP1A', 'IGFBP7', 'CRIP2', 'TM4SF1', 'HSPG2', 'VWF', 'SPARCL1']
endothelial cell_ovarian cancer: ['SPARCL1', 'GNG11', 'RAMP2', 'VWF', 'IGFBP7', 'CLDN5', 'PECAM1', 'TM4SF1', 'HSPG2', 'CRIP2']
mononuclear phagocyte_lung cancer: ['TYROBP', 'LYZ', 'HLA-DRA', 'FCER1G', 'HLA-DRB1', 'CD74', 'LAPTM5', 'SRGN', 'GLUL

In [11]:
# Total genes in adata
print(f"Genes in adata: {adata.shape[1]}")

# Marker genes in adata
print(f"Included marker genes: {len(marker_genes_in_data)}")

# HVG genes in adata
print(f"HVGs in adata: {adata.var['highly_variable'].sum()}")


Genes in adata: 2188
Included marker genes: 300
HVGs in adata: 2000


In [12]:
# This was your confirmed list of marker genes
included_marker_genes = [g for g in all_marker_genes if g in adata.var["feature_name"].values]

print(f"Included marker genes in adata: {len(included_marker_genes)} / {len(all_marker_genes)}")
print("Some examples:", included_marker_genes[:10])


Included marker genes in adata: 300 / 311
Some examples: ['GLUL', 'MIF', 'ENG', 'ESAM', 'TPSAB1', 'CLN8', 'ORAI2', 'CXCR4', 'CD52', 'PPP1R14A']


In [14]:
print("Genes used for scGPT training:", len(adata.var["feature_name"]))


Genes used for scGPT training: 2188


In [15]:
adata.write("adata_hvg_marker.h5ad")


In [16]:
import anndata

# Load the .h5ad file
adata = anndata.read_h5ad("adata_hvg_marker.h5ad")

# Check basic details about the dataset
print(adata)

AnnData object with n_obs × n_vars = 234610 × 2188
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'harm_study', 'harm_healthy.tissue', 'harm_tumor.site', 'harm_sample.type', 'harm_condition', 'harm_tumor.type', 'harm_cd45pos', 'harm_healthy.pat', 'percent.mt', 'ratio_nCount_nFeature', 'batch', 'X_scvi_batch', 'X_scvi_labels', 'X_scvi_local_l_mean', 'X_scvi_local_l_var', 'leiden_0.2', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1', 'leiden_1.2', 'leiden_1.4', 'author_first_cell_type', 'author_cell_type', 'cnv_score', 'organism_ontology_term_id', 'donor_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_id', 'dup', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'grou

In [17]:
print(adata.obs["group"].unique())


['endothelial cell_lung cancer', 'mononuclear phagocyte_lung cancer', 'B cell_lung cancer', 'malignant cell_lung cancer', 'epithelial cell_lung cancer', ..., 'mast cell_ovarian cancer', 'plasmacytoid dendritic cell_ovarian cancer', 'neutrophil_ovarian cancer', 'epithelial cell_uveal melanoma', 'malignant cell_uveal melanoma']
Length: 66
Categories (66, object): ['B cell_breast cancer', 'B cell_colorectal cancer', 'B cell_liver cancer', 'B cell_lung cancer', ..., 'plasmacytoid dendritic cell_lung cancer', 'plasmacytoid dendritic cell_melanoma', 'plasmacytoid dendritic cell_ovarian cancer', 'plasmacytoid dendritic cell_uveal melanoma']


In [18]:
print(adata.obs[["cell_type", "disease", "group"]].head())


                                          cell_type      disease  \
BT1299_GTACTCCTCGGAAACG-1-17       endothelial cell  lung cancer   
BT1299_GTACTTTAGCAGGCTA-1-17  mononuclear phagocyte  lung cancer   
BT1299_GTACTTTAGCATGGCA-1-17       endothelial cell  lung cancer   
BT1299_GTACTTTAGGATGTAT-1-17                 B cell  lung cancer   
BT1299_GTACTTTAGGCTCTTA-1-17         malignant cell  lung cancer   

                                                          group  
BT1299_GTACTCCTCGGAAACG-1-17       endothelial cell_lung cancer  
BT1299_GTACTTTAGCAGGCTA-1-17  mononuclear phagocyte_lung cancer  
BT1299_GTACTTTAGCATGGCA-1-17       endothelial cell_lung cancer  
BT1299_GTACTTTAGGATGTAT-1-17                 B cell_lung cancer  
BT1299_GTACTTTAGGCTCTTA-1-17         malignant cell_lung cancer  


In [19]:
print(adata.obs["group"].value_counts())


group
T cell_lung cancer                               30482
epithelial cell_uveal melanoma                   23628
malignant cell_uveal melanoma                    18137
fibroblast_ovarian cancer                        15163
mononuclear phagocyte_lung cancer                13364
                                                 ...  
plasmacytoid dendritic cell_colorectal cancer       24
neutrophil_ovarian cancer                           11
neutrophil_uveal melanoma                           11
mast cell_liver cancer                               4
mast cell_melanoma                                   3
Name: count, Length: 66, dtype: int64


In [20]:
print("group" in adata.obs.columns)


True


In [21]:
print(adata.obs.columns.tolist())


['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'harm_study', 'harm_healthy.tissue', 'harm_tumor.site', 'harm_sample.type', 'harm_condition', 'harm_tumor.type', 'harm_cd45pos', 'harm_healthy.pat', 'percent.mt', 'ratio_nCount_nFeature', 'batch', 'X_scvi_batch', 'X_scvi_labels', 'X_scvi_local_l_mean', 'X_scvi_local_l_var', 'leiden_0.2', 'leiden_0.4', 'leiden_0.6', 'leiden_0.8', 'leiden_1', 'leiden_1.2', 'leiden_1.4', 'author_first_cell_type', 'author_cell_type', 'cnv_score', 'organism_ontology_term_id', 'donor_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_id', 'dup', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'group']


In [23]:
print(adata.obs["cell_type"].value_counts())

cell_type
T cell                         65153
malignant cell                 47442
epithelial cell                38933
mononuclear phagocyte          27218
fibroblast                     26306
B cell                         16527
endothelial cell                7255
neutrophil                      2824
mast cell                       1989
plasmacytoid dendritic cell      963
Name: count, dtype: int64


In [24]:
from sklearn.model_selection import train_test_split
import anndata as ad

# Split with stratification
train_indices, test_indices = train_test_split(
    adata.obs_names,
    test_size=0.2,
    stratify=adata.obs["group"],
    random_state=42,
)

# Subset
train_dataset = adata[train_indices].copy()
test_dataset = adata[test_indices].copy()

# Save
train_dataset.write_h5ad("train_datamarker.h5ad")
test_dataset.write_h5ad("test_datamarker.h5ad")

print(f"Train dataset: {train_dataset.shape[0]} cells")
print(f"Test dataset: {test_dataset.shape[0]} cells")

# Reload if needed
train_dataset = ad.read_h5ad("train_datamarker.h5ad")
test_dataset = ad.read_h5ad("test_datamarker.h5ad")

print(f"Loaded train dataset: {train_dataset.shape}")
print(f"Loaded test dataset: {test_dataset.shape}")


Train dataset: 187688 cells
Test dataset: 46922 cells
Loaded train dataset: (187688, 2188)
Loaded test dataset: (46922, 2188)


In [32]:
import helical
from helical.models.scgpt.model import scGPT, scGPTConfig


print("Helical and scGPT imported successfully!")

from helical.models.scgpt import scGPTConfig, scGPTFineTuningModel  # ✅ Correct Import
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
import logging, warnings
import umap
import pandas as pd
import seaborn as sns
import anndata as ad 

Helical and scGPT imported successfully!


In [38]:
from sklearn.preprocessing import LabelEncoder
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
import anndata as ad
import numpy as np
import torch

# **Set device (Use GPU if available)**
device = "cuda" if torch.cuda.is_available() else "cpu"

# **Load preprocessed train and test datasets**
adata = ad.read_h5ad("adata_hvg_marker.h5ad")

# **Ensure 'group' column exists**
if "group" not in adata.obs:
    raise ValueError("The dataset must have a 'group' column for classification!")

# **Split into train (80%) and test (20%)**
train_indices, test_indices = train_test_split(
    adata.obs_names,
    test_size=0.2,
    stratify=adata.obs["group"],
    random_state=42
)

# Create separate train and test datasets
train_dataset = adata[train_indices, :].copy()
test_dataset = adata[test_indices, :].copy()

# Save them separately for direct use in training
train_dataset.write_h5ad("train_finaldata.h5ad")
test_dataset.write_h5ad("test_finaldata.h5ad")

print(f" Train dataset: {train_dataset.shape[0]} cells")
print(f" Test dataset: {test_dataset.shape[0]} cells")

# **Load preprocessed train and test datasets**
train_dataset = ad.read_h5ad("train_finaldata.h5ad")
test_dataset = ad.read_h5ad("test_finaldata.h5ad")

print(f"Loaded train dataset: {train_dataset.shape}")
print(f"Loaded test dataset: {test_dataset.shape}")

# **Convert UMI counts to integers**
train_dataset.X = np.round(train_dataset.X).astype(np.int32)
test_dataset.X = np.round(test_dataset.X).astype(np.int32)
print("Converted gene expression matrix to integer UMI counts.")

# **Extract unique group labels and count**
label_encoder = LabelEncoder()
group_labels = sorted(train_dataset.obs["group"].unique())
num_groups = len(group_labels)
print(f"Found {num_groups} unique group labels.")

# **Configure scGPT**
scgpt_config = scGPTConfig(batch_size=128, device=device)

# **Initialize the fine-tuning model**
scgpt_fine_tune = scGPTFineTuningModel(
    scGPT_config=scgpt_config,
    fine_tuning_head="classification",
    output_size=num_groups
)
print("scGPT fine-tuning model initialized successfully.")

final_marker_dict = {
"endothelial cell_lung cancer": ['VWF', 'EPAS1', 'PECAM1', 'CLDN5', 'CLEC14A', 'SPARCL1', 'RAMP2', 'HYAL2', 'AQP1', 'TM4SF1'],
"endothelial cell_breast cancer": ['RAMP2', 'GNG11', 'EGFL7', 'PLVAP', 'SPARCL1', 'PECAM1', 'CRIP2', 'CD34', 'ADGRL4', 'IGFBP7'],
"endothelial cell_melanoma": ['SHISA9', 'KCNQ1OT1', 'ANKRD20A9P', 'ORC4', 'EEF1G', 'VIM', 'CFLAR', 'TNFAIP8L1', 'HLA-B', 'RBMS2'],
"endothelial cell_liver cancer": ['RAMP2', 'EGFL7', 'HSPG2', 'GNG11', 'PLVAP', 'ENG', 'PECAM1', 'TCF4', 'FKBP1A', 'CRIP2'],
"endothelial cell_uveal melanoma": ['GNG11', 'RAMP2', 'CAVIN2', 'IFITM3', 'PECAM1', 'TM4SF1', 'RAMP3', 'CRIP2', 'ESAM', 'ENG'],
"endothelial cell_colorectal cancer": ['RAMP2', 'GNG11', 'PECAM1', 'FKBP1A', 'IGFBP7', 'CRIP2', 'TM4SF1', 'HSPG2', 'VWF', 'SPARCL1'],
"endothelial cell_ovarian cancer": ['SPARCL1', 'GNG11', 'RAMP2', 'VWF', 'IGFBP7', 'CLDN5', 'PECAM1', 'TM4SF1', 'HSPG2', 'CRIP2'],
"mononuclear phagocyte_lung cancer": ['TYROBP', 'LYZ', 'HLA-DRA', 'FCER1G', 'HLA-DRB1', 'CD74', 'LAPTM5', 'SRGN', 'GLUL', 'CTSS'],
"mononuclear phagocyte_breast cancer": ['TYROBP', 'FCER1G', 'GPX1', 'AIF1', 'SPI1', 'CD74', 'GABARAP', 'CD68', 'CST3', 'HLA-DRB1'],
"mononuclear phagocyte_melanoma": ['HLA-B', 'SAT1', 'HLA-C', 'SRGN', 'B2M', 'HLA-DRA', 'TAPBP', 'LYZ', 'METTL21A', 'ACTB'],
"mononuclear phagocyte_liver cancer": ['AIF1', 'HLA-DRA', 'TYROBP', 'LST1', 'HLA-DPB1', 'HLA-DRB1', 'CST3', 'CD74', 'HLA-DPA1', 'LYZ'],
"mononuclear phagocyte_uveal melanoma": ['TYROBP', 'CST3', 'CD74', 'AIF1', 'FCER1G', 'FTL', 'HLA-DRA', 'HLA-DPA1', 'CD68', 'MS4A7'],
"mononuclear phagocyte_colorectal cancer": ['TYROBP', 'FCER1G', 'AIF1', 'LYZ', 'CYBA', 'FTL', 'GPX1', 'HLA-DRA', 'SRGN', 'LST1'],
"mononuclear phagocyte_ovarian cancer": ['TYROBP', 'AIF1', 'HLA-DRA', 'FCER1G', 'HLA-DPB1', 'HLA-DPA1', 'HLA-DRB1', 'CD74', 'FTL', 'SAT1'],
"B cell_lung cancer": ['CD79A', 'CD74', 'HLA-DRA', 'CXCR4', 'TTC19', 'MS4A1', 'CD37', 'HLA-DRB1'],
"B cell_breast cancer": ['CD79A', 'CD74', 'MS4A1', 'NCF1', 'CD37', 'HLA-DRB5', 'HLA-DQA1'],
"B cell_melanoma": ['CD74', 'ORAI2', 'UGDH-AS1', 'KCNQ1OT1', 'CFLAR', 'SHISA9', 'ORC4', 'TMEM212', 'ASTN2', 'METTL21A'],
"B cell_liver cancer": ['CD79A', 'RPS29', 'RPS10', 'IGKC', 'CD74', 'RPLP2', 'RPS27', 'MS4A1', 'HBB', 'RPL27A'],
"B cell_uveal melanoma": ['IGHG1', 'CD79A', 'JCHAIN', 'HERPUD1', 'MZB1', 'SSR4'],
"B cell_colorectal cancer": ['CD79A', 'MS4A1', 'IGKC', 'CD37', 'RPL21', 'CD52', 'RPS29', 'CD79B', 'HLA-DRA', 'RPLP2'],
"B cell_ovarian cancer": ['IGKC', 'IGLC2', 'IGHG3', 'IGHG1', 'IGHG4', 'CD79A', 'IGLC3', 'IGHGP', 'HLA-DRA', 'MZB1'],
"malignant cell_lung cancer": ['KRT19', 'KRT8', 'KRT18', 'PERP', 'SPINT2', 'EPCAM', 'MGST1', 'TXN', 'KRT7', 'TACSTD2'],
"malignant cell_breast cancer": ['MIF', 'KRT19', 'PPDPF', 'COX6C', 'GABARAP', 'H3-3A', 'RPL30', 'TCEA1', 'RPL8'],
"malignant cell_melanoma": ['GAPDH', 'HSP90AB1', 'KCNQ1OT1', 'RPL17', 'RPS17', 'ACTG1', 'EEF1G', 'ALDOA', 'EEF1A1', 'HSP90AA1'],
"malignant cell_liver cancer": ['KRT18', 'UQCRQ', 'SEC61G', 'KRT8', 'SERPINA1', 'ATP5MF', 'TM4SF4', 'ELOB', 'COX6A1'],
"malignant cell_colorectal cancer": ['PHGR1', 'KRT8', 'S100A6', 'TFF3', 'LGALS4', 'KRT18', 'TSPAN8', 'EPCAM', 'KRT19', 'C19orf33'],
"malignant cell_ovarian cancer": ['KRT18', 'KRT8', 'CD24', 'WFDC2', 'RPL7', 'GSTP1', 'MDK', 'KRT19', 'YWHAE', 'RBP1'],
"malignant cell_uveal melanoma": ['S100A1', 'CITED1', 'SDCBP', 'MLANA', 'ZNF106', 'RAB38', 'MIF', 'EFHD1', 'ANXA5', 'ST3GAL4'],
"epithelial cell_lung cancer": ['SFTPB', 'SFTPA2', 'NAPSA', 'SFTPA1', 'KRT19', 'SFTA2', 'HOPX', 'SCGB3A2'],
"epithelial cell_breast cancer": ['MUCL1', 'SCGB2A2', 'IGKV3-20', 'IGHV3-23', 'IGKV4-1', 'IGLV1-51', 'KRT19', 'MGP', 'AZGP1', 'IGHG1'],
"epithelial cell_melanoma": ['KCNQ1OT1', 'ANKRD20A9P', 'TOR1AIP2', 'MBOAT1', 'CFLAR', 'UGDH-AS1', 'CHP1', 'GAPDH', 'ORAI2', 'GATD1'],
"epithelial cell_liver cancer": ['ALB', 'ATP1B1', 'SPP1', 'VTN', 'FXYD2', 'TM4SF4', 'PDZK1IP1', 'ANXA4', 'CLDN10', 'SERPINA1'],
"epithelial cell_colorectal cancer": ['PHGR1', 'TFF3', 'LGALS4', 'KRT8', 'S100A6', 'KRT18', 'FXYD3', 'TSPAN8', 'S100P', 'AGR2'],
"epithelial cell_ovarian cancer": ['HBB', 'HBA2', 'HBA1', 'WFDC2', 'COL1A1', 'COL1A2', 'TAGLN', 'COL3A1', 'RBP1', 'FABP4'],
"epithelial cell_uveal melanoma": ['MLANA', 'TYRP1', 'PMEL', 'S100A1', 'SNHG7', 'MITF', 'TEX41', 'SLCO4A1-AS1', 'RAB38'],
"T cell_lung cancer": ['CXCR4', 'CD3D', 'TRBC2', 'CD2', 'SRGN', 'RGS1', 'IL32', 'TRBC1', 'PTPRC', 'CD7'],
"T cell_breast cancer": ['CD3E', 'IFITM1', 'RAC2', 'CORO1A', 'PTPRC', 'CD52', 'IL32'],
"T cell_melanoma": ['ASTN2', 'UGDH-AS1', 'SHISA9', 'KCNQ1OT1', 'IKZF3', 'ORC4', 'CFLAR', 'TMEM212', 'ANKRD20A9P', 'B2M'],
"T cell_liver cancer": ['RPS29', 'LTB', 'RPS27', 'RPLP2', 'CD3D', 'RPS10', 'CD52', 'RPL21', 'HBB', 'RPL39'],
"T cell_uveal melanoma": ['CST7', 'CD3E', 'CREM', 'CD3D', 'PTPRC', 'SRGN', 'IL32', 'TNFAIP3', 'NKG7'],
"T cell_colorectal cancer": ['CD3D', 'CD2', 'IL32', 'CD52', 'CORO1A', 'TMSB4X', 'RPS29', 'BTG1', 'MALAT1', 'CD3E'],
"T cell_ovarian cancer": ['CXCR4', 'IL32', 'BTG1', 'CD2', 'CD3D', 'CD52', 'CORO1A', 'GZMA', 'IGKC', 'TRBC1'],
"mast cell_lung cancer": ['CPA3', 'TPSAB1', 'TPSB2', 'MS4A2', 'KIT', 'SLC18A2', 'GATA2', 'HPGDS', 'VWA5A', 'SRGN'],
"mast cell_breast cancer": ['TPSAB1', 'CPA3', 'TPSB2', 'MS4A2', 'GATA2', 'HPGDS', 'VWA5A', 'SLC18A2', 'LTC4S', 'HDC'],
"mast cell_melanoma": ['TPSD1', 'CPA3', 'TPSAB1', 'TPSB2', 'MS4A2', 'HPGDS', 'CTSG', 'SLC18A2', 'VWA5A'],
"mast cell_liver cancer": ['TPSAB1', 'TPSB2', 'HPGDS', 'IL1RL1', 'CPA3', 'LINC00623', 'REST', 'ALOX5AP', 'SRGN', 'ID2'],
"mast cell_colorectal cancer": ['TPSAB1', 'CPA3', 'TPSB2', 'MS4A2', 'HPGDS', 'SAMSN1', 'LTC4S', 'HPGD', 'RGS1', 'ALOX5AP'],
"mast cell_ovarian cancer": ['TPSAB1', 'TPSB2', 'CPA3', 'HPGDS', 'MS4A2', 'RGS13', 'VWA5A', 'FCER1G', 'ALOX5AP', 'LTC4S'],
"fibroblast_lung cancer": ['DCN', 'COL1A2', 'SPARC', 'COL6A2', 'BGN', 'COL3A1', 'SERPING1', 'C1R', 'LUM', 'C1S'],
"fibroblast_breast cancer": ['COL1A2', 'COL1A1', 'COL3A1', 'SPARC', 'COL6A2', 'COL6A1', 'CALD1', 'COL6A3', 'BGN', 'AEBP1'],
"fibroblast_melanoma": ['EEF1G', 'KCNQ1OT1', 'UGDH-AS1', 'VIM', 'CHP1', 'RPL13AP5', 'ACTB', 'EEF1A1', 'RPS17', 'ACTG1'],
"fibroblast_liver cancer": ['IGFBP7', 'CALD1', 'BGN', 'MYL9', 'RGS5', 'PPP1R14A', 'MGP', 'TAGLN', 'PLAC9', 'SPARC'],
"fibroblast_uveal melanoma": ['FRZB', 'CST3', 'TIMP3', 'IGFBP7', 'SELENOW', 'RBP1', 'SERPINF1', 'PTGDS'],
"fibroblast_colorectal cancer": ['CALD1', 'LGALS1', 'TPM2', 'IGFBP7', 'SELENOM', 'COL3A1', 'SPARC', 'TIMP1', 'COL1A2', 'COL6A2'],
"fibroblast_ovarian cancer": ['COL1A2', 'COL1A1', 'DCN', 'CALD1', 'COL3A1', 'LUM', 'SPARC', 'LGALS1', 'RARRES2', 'TPM2'],
"plasmacytoid dendritic cell_lung cancer": ['GZMB', 'GPR183', 'JCHAIN', 'PLAC8', 'IRF7', 'IRF4', 'CD74', 'RGS2', 'TCF4', 'PLD4'],
"plasmacytoid dendritic cell_breast cancer": ['JCHAIN', 'IRF8', 'IRF7', 'LILRA4', 'PLAC8', 'GZMB', 'ITM2C', 'PLD4', 'CD74', 'CCDC50'],
"plasmacytoid dendritic cell_melanoma": ['IRF7', 'MPEG1', 'PLD4', 'PLEK', 'IRF8', 'JCHAIN', 'PLAC8', 'GZMB'],
"plasmacytoid dendritic cell_liver cancer": ['JCHAIN', 'IRF7', 'GZMB', 'PLAC8', 'GPR183', 'PPP1R14B', 'IRF8', 'RGS2', 'ITM2C', 'SEC61B'],
"plasmacytoid dendritic cell_uveal melanoma": ['GZMB', 'LILRA4', 'JCHAIN', 'PLD4', 'IRF8', 'IRF7', 'SEC61B', 'CD74', 'PLAC8', 'GPR183'],
"plasmacytoid dendritic cell_colorectal cancer": ['JCHAIN', 'GZMB', 'IRF8', 'PLAC8', 'IRF7', 'SEC61B', 'UGCG', 'IRF4', 'CCDC50', 'CLN8'],
"plasmacytoid dendritic cell_ovarian cancer": ['GZMB', 'JCHAIN', 'GPR183', 'IRF7', 'PLAC8', 'ALOX5AP', 'AREG', 'C12orf75', 'TSPAN13', 'CD74'],
"neutrophil_lung cancer": ['G0S2', 'CXCL8', 'NAMPT', 'ALOX5AP', 'SRGN', 'BCL2A1', 'PLAUR', 'LITAF'],
"neutrophil_uveal melanoma": ['FCGR3B', 'CSF3R', 'S100A8', 'MNDA', 'S100A9', 'NAMPT', 'RGS2', 'FPR1', 'RIPOR2', 'MXD1'],
"neutrophil_colorectal cancer": ['CXCL8', 'BCL2A1', 'S100A8', 'NAMPT', 'G0S2', 'S100A9', 'SAT1', 'FTH1', 'ALOX5AP', 'PLEK'],
"neutrophil_ovarian cancer": ['CXCL8', 'BCL2A1', 'C15orf48', 'PELATON', 'S100A9', 'G0S2', 'NAMPT', 'FTH1', 'MXD1', 'S100A8']
}

# **(Optional) Set marker genes if using attention supervision**
scgpt_fine_tune.model.marker_gene_dict = final_marker_dict  # Define this earlier if needed
scgpt_fine_tune.model.lambda_attn = 0.1  # Tune this weight

# **Process datasets**
dataset = scgpt_fine_tune.process_data(train_dataset, gene_names="feature_name")
validation_dataset = scgpt_fine_tune.process_data(test_dataset, gene_names="feature_name")

# **Check dataset sizes**
print(f"Processed training dataset size: {len(dataset)}")
print(f"Processed validation dataset size: {len(validation_dataset)}")

if len(dataset) == 0 or len(validation_dataset) == 0:
    raise ValueError("Processed dataset is empty! Check gene names or dataset formatting.")

# **Encode group labels**
label_encoder.fit(group_labels)
group_labels_train = torch.tensor(label_encoder.transform(train_dataset.obs["group"]), dtype=torch.long)
group_labels_test = torch.tensor(label_encoder.transform(test_dataset.obs["group"]), dtype=torch.long)

print(f"Encoded {len(label_encoder.classes_)} unique groups.")

# **Compute Raw Class Weights**
class_counts = train_dataset.obs["group"].value_counts().to_dict()
total_samples = sum(class_counts.values())

# **Ensure class alignment**
raw_weights = torch.tensor(
    [total_samples / class_counts.get(cls, 1) for cls in group_labels], dtype=torch.float32
)

# **Apply Log Scaling for Stability**
log_weights = torch.log1p(raw_weights)
class_weights = log_weights / log_weights.mean()
class_weights = class_weights.to(device)

print("Log-scaled normalized class weights:", class_weights)

# **Apply Class Weights to CrossEntropyLoss**
loss_fn = CrossEntropyLoss(weight=class_weights)


 Train dataset: 187688 cells
 Test dataset: 46922 cells
Loaded train dataset: (187688, 2188)
Loaded test dataset: (46922, 2188)
Converted gene expression matrix to integer UMI counts.
Found 66 unique group labels.


INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cls'.
INFO:helical.models.scgpt.model:Processing data for scGPT.


scGPT fine-tuning model initialized successfully.


  new_obj.index = new_index


INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.
INFO:helical.models.scgpt.model:Processing data for scGPT.
  new_obj.index = new_index


INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.


Processed training dataset size: 187688
Processed validation dataset size: 46922
Encoded 66 unique groups.
Log-scaled normalized class weights: tensor([0.8413, 0.9679, 1.1711, 0.5796, 0.9881, 0.8200, 1.0100, 0.5529, 0.7757,
        0.8683, 0.3827, 1.0418, 0.5403, 0.6291, 0.9064, 0.9676, 0.9639, 0.8910,
        1.3871, 0.8459, 1.2289, 0.7797, 0.7779, 0.9876, 0.7773, 1.2474, 0.6698,
        0.4232, 0.6869, 0.8010, 0.9758, 0.8247, 1.3410, 0.4958, 1.1622, 0.7311,
        1.0548, 0.9333, 0.5727, 0.8640, 0.5311, 0.4662, 1.2095, 1.2139, 1.9542,
        0.9117, 2.0260, 1.3448, 0.8766, 0.9064, 1.0993, 0.5168, 1.5097, 0.5905,
        0.8820, 1.4396, 0.7900, 1.7598, 1.7598, 1.2987, 1.6276, 1.4428, 1.1277,
        1.5938, 1.2431, 1.4128], device='cuda:0')


In [40]:
scgpt_fine_tune.train(
    train_input_data=dataset,
    train_labels=group_labels_train,
    validation_input_data=validation_dataset,
    validation_labels=group_labels_test,
    epochs=25,
    optimizer_params={"lr": 2e-5, "weight_decay": 1e-4},
    lr_scheduler_params={
        "name": "linear",
        "num_warmup_steps": 100,
        "num_training_steps": len(dataset) * 10
    },
    loss_function=loss_fn
)

print("Training completed successfully!")

# ✅ Save the fine-tuned model
torch.save(scgpt_fine_tune.state_dict(), "scgpt_markerattn.pth")
print("Fine-tuned model saved as 'scgpt_final_marker_attn.pth'.")


INFO:helical.models.scgpt.fine_tuning_model:Starting Fine-Tuning
  validation_batch_count = 0

Fine-Tuning: epoch 1/25: 100%|████████████████████████████████████████████████████████████████████| 1467/1467 [06:26<00:00,  3.79it/s, loss=4.33]
  The dataset to get the outputs from.

Fine-Tuning Validation: 100%|███████████████████████████████████████████████████████████████████| 367/367 [00:37<00:00,  9.90it/s, val_loss=4.33]
Fine-Tuning: epoch 2/25: 100%|████████████████████████████████████████████████████████████████████| 1467/1467 [06:26<00:00,  3.79it/s, loss=3.77]
Fine-Tuning Validation: 100%|███████████████████████████████████████████████████████████████████| 367/367 [00:37<00:00,  9.90it/s, val_loss=3.27]
Fine-Tuning: epoch 3/25: 100%|████████████████████████████████████████████████████████████████████| 1467/1467 [06:27<00:00,  3.79it/s, loss=2.67]
Fine-Tuning Validation: 100%|███████████████████████████████████████████████████████████████████| 367/367 [00:37<00:00,  9.90it/s, val_

Training completed successfully!
Fine-tuned model saved as 'scgpt_final_marker_attn.pth'.


In [41]:
from sklearn.preprocessing import LabelEncoder
import torch

# Ensure 'group' is categorical (group = cell type + cancer type)
train_labels = train_dataset.obs["group"].astype("category")
test_labels = test_dataset.obs["group"].astype("category")

# Encode group labels as numerical labels
label_encoder = LabelEncoder()
group_labels_train = label_encoder.fit_transform(train_labels)
group_labels_test = label_encoder.transform(test_labels)

# Convert to PyTorch tensor (scGPT expects tensors)
group_labels_train = torch.tensor(group_labels_train, dtype=torch.long)
group_labels_test = torch.tensor(group_labels_test, dtype=torch.long)

print(f"Group labels: {len(label_encoder.classes_)} unique classes")


Group labels: 66 unique classes


In [42]:
outputs = scgpt_fine_tune.get_outputs(validation_dataset)

Fine-Tuning Validation: 100%|██████████████████████████████████████████████████████████████████████████████████| 367/367 [00:35<00:00, 10.43it/s]


In [43]:
embeddings = scgpt_fine_tune.get_embeddings(validation_dataset)

INFO:helical.models.scgpt.model:Started getting embeddings:
  torch.cuda.amp.autocast(enabled=True),

Embedding cells: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 367/367 [00:16<00:00, 22.05it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


In [44]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Convert predictions and true labels to NumPy arrays (if not already)
y_true = group_labels_test.numpy() if isinstance(group_labels_test, torch.Tensor) else group_labels_test
y_pred = outputs.argmax(dim=1).cpu().numpy() if isinstance(outputs, torch.Tensor) else outputs.argmax(axis=1)

# Compute evaluation metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

# Display results
metrics = {
    "Accuracy": accuracy,
    "Precision": precision,
    "Recall": recall,
    "F1 Score": f1
}
metrics


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



{'Accuracy': 0.9240654703550574,
 'Precision': 0.9251111742860351,
 'Recall': 0.9240654703550574,
 'F1 Score': 0.9240700466694403}

In [48]:
import torch
import numpy as np
import pandas as pd
from scipy.sparse import issparse

# Step 1: Prepare validation data
test_data = validation_dataset.count_matrix
if issparse(test_data):
    test_data = test_data.toarray()
test_data = test_data.astype(int)

gene_names = adata.var["feature_name"].tolist()
num_samples, seq_len = test_data.shape
batch_size = 32

# Step 2: Define function to get attention-based importance per sample
def get_attention_weights_batch(model, input_tensor):
    with torch.no_grad():
        embedded_input = model.encoder(input_tensor)
        attn_list = []
        for layer in model.transformer_encoder.layers:
            _, attn_weight = layer.self_attn(
                embedded_input, embedded_input, embedded_input, need_weights=True
            )
            attn_list.append(attn_weight)
        attn_stack = torch.stack(attn_list)  # [num_layers, B, seq_len, seq_len]
        attn_avg = attn_stack.mean(dim=0)    # [B, seq_len, seq_len]
        gene_importance = attn_avg.sum(dim=1)  # Sum over source: [B, seq_len]
    return gene_importance

# Step 3: Run attention scoring in batches
all_gene_importance = []
for i in range(0, num_samples, batch_size):
    batch = test_data[i:i + batch_size]
    batch_tensor = torch.tensor(batch, dtype=torch.long).to(device)
    gene_imp_batch = get_attention_weights_batch(scgpt_fine_tune.model, batch_tensor)
    all_gene_importance.append(gene_imp_batch.cpu().numpy())

all_gene_importance = np.concatenate(all_gene_importance, axis=0)  # [N, seq_len]

# Step 4: Map label indices to group names
id_class_dict = {i: label for i, label in enumerate(label_encoder.classes_)}
true_groups = [id_class_dict[label] for label in y_true]
predicted_groups = [id_class_dict[label] for label in y_pred]

# Step 5: For each sample, extract top attended genes
results = []
for i in range(num_samples):
    gene_imp = all_gene_importance[i]
    top_indices = np.argsort(gene_imp)[::-1][:10]
    top_genes = [gene_names[idx] for idx in top_indices]
    results.append({
        "True Group": true_groups[i],
        "Predicted Group": predicted_groups[i],
        "Top Genes": top_genes
    })

final_df = pd.DataFrame(results)
display(final_df.head(20))



Unnamed: 0,True Group,Predicted Group,Top Genes
0,T cell_lung cancer,T cell_lung cancer,"[CYBA, HERPUD1, MGP, PNLIPRP3, ENSG00000224500..."
1,malignant cell_lung cancer,malignant cell_lung cancer,"[CLEC3A, COL1A1, CRABP1, MCEMP1, HERPUD1, MACR..."
2,fibroblast_colorectal cancer,fibroblast_colorectal cancer,"[GZMK, CHRNA4, ENSG00000237993.1, CLEC3A, HTR5..."
3,T cell_uveal melanoma,T cell_uveal melanoma,"[CLEC3A, AIF1, ENSG00000237993.1, EPCAM, RPL8,..."
4,T cell_lung cancer,T cell_lung cancer,"[CXCL13, C4orf50, COX6A1, IFNG, PNLIPRP3, SAMS..."
5,B cell_melanoma,B cell_melanoma,"[RUBCNL, ACTB, IFNG, CD74, ENSG00000237993.1, ..."
6,T cell_lung cancer,T cell_lung cancer,"[ENSG00000224500.1, CLEC3A, IL32, IL1RN, SAMSN..."
7,endothelial cell_ovarian cancer,endothelial cell_ovarian cancer,"[MACROD2-IT1, COL6A3, MIR28, ENSG00000237993.1..."
8,malignant cell_uveal melanoma,malignant cell_uveal melanoma,"[ENSG00000176593.8, PAK6-AS1, VCX2, TAAR8, GNG..."
9,malignant cell_ovarian cancer,B cell_ovarian cancer,"[ENSG00000237993.1, IFNG, SAMSN1, CXCL13, FTL,..."


In [56]:
tokens = input_ids_list[i]


In [60]:
print("Batch attention shape:", attn_scores.shape)


Batch attention shape: torch.Size([10, 123])


In [61]:
print("DEBUG: max_length used in collator =", collator.max_length)

val_loader = DataLoader(
    validation_dataset,
    batch_size=32,
    sampler=SequentialSampler(validation_dataset),
    collate_fn=collator,
    drop_last=False,
)


DEBUG: max_length used in collator = 1200


In [64]:
for i, batch in enumerate(val_loader):
    input_ids = batch["gene"].to(device)
    print(f"Batch {i} input shape: {input_ids.shape}")  # This should be [B, 1200]
    break


Batch 0 input shape: torch.Size([32, 168])


In [68]:
print(f"Batch {i} - attn shape: {attn_scores.shape}")


Batch 0 - attn shape: torch.Size([10, 123])


In [74]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, SequentialSampler
from helical.models.scgpt.data_collator import DataCollator

# Ensure eval mode
scgpt_fine_tune.model.eval()
vocab = scgpt_fine_tune.vocab
id_to_gene = {v: k for k, v in vocab.items()}

# --- Collator with fixed padding ---
collator = DataCollator(
    do_padding=True,
    pad_token_id=vocab["<pad>"],
    pad_value=0,
    max_length=1200,
    do_binning=True,
    do_mlm=False,
    sampling=False,
    keep_first_n_tokens=1,
)

val_loader = DataLoader(
    validation_dataset,
    batch_size=32,
    sampler=SequentialSampler(validation_dataset),
    collate_fn=collator,
    drop_last=False,
)

# --- Prediction outputs ---
outputs = scgpt_fine_tune.get_outputs(validation_dataset)
y_pred = np.argmax(outputs, axis=1)
y_true = group_labels_test.numpy() if isinstance(group_labels_test, torch.Tensor) else group_labels_test
id_class_dict = {i: c for i, c in enumerate(label_encoder.classes_)}
true_groups = [id_class_dict[idx] for idx in y_true]
pred_groups = [id_class_dict[idx] for idx in y_pred]

# --- Map marker genes to token IDs ---
group_to_marker_ids = {}
for group, genes in final_marker_dict.items():
    marker_ids = [vocab[g] for g in genes if g in vocab]
    group_to_marker_ids[group] = marker_ids

# --- CLS attention extractor ---
def get_cls_attention(model, input_tensor):
    with torch.no_grad():
        embedded = model.encoder(input_tensor)
        attn_list = []
        for layer in model.transformer_encoder.layers:
            attn = layer.self_attn(embedded, embedded, embedded, need_weights=True)[1]
            cls_attn = attn[:, 0, :]  # CLS to each token
            attn_list.append(cls_attn)
        attn_avg = torch.stack(attn_list).mean(dim=0)  # (B, seq)
    return attn_avg

# --- Extract top attended marker genes per sample ---
results = []
sample_idx = 0

for batch in val_loader:
    device = next(scgpt_fine_tune.model.parameters()).device
    input_ids = batch["gene"].to(device)
    attn_scores = get_cls_attention(scgpt_fine_tune.model, input_ids)  # (B, seq)

    for i in range(attn_scores.shape[0]):
        sample_input = input_ids[i].cpu().numpy()
        sample_attn = attn_scores[i].cpu().numpy()
        true_group = true_groups[sample_idx]
        pred_group = pred_groups[sample_idx]

        marker_ids = group_to_marker_ids.get(true_group, [])
        top_marker_attn = [
            (id_to_gene[token_id], sample_attn[j])
            for j, token_id in enumerate(sample_input)
            if token_id in marker_ids
        ]
        top_marker_attn.sort(key=lambda x: x[1], reverse=True)
        top_genes = [g for g, _ in top_marker_attn[:10]]

        results.append({
            "True Group": true_group,
            "Predicted Group": pred_group,
            "Top Attended Marker Genes": top_genes
        })
        sample_idx += 1

# --- Display ---
df = pd.DataFrame(results)
display(df.head(20))


Fine-Tuning Validation: 100%|██████████████████████████████████████████████████████████████████████████████████| 367/367 [00:35<00:00, 10.43it/s]


Unnamed: 0,True Group,Predicted Group,Top Attended Marker Genes
0,T cell_lung cancer,T cell_lung cancer,"[CD2, TRBC2, TRBC1, CD3D, CD7, RGS1, IL32, CXC..."
1,malignant cell_lung cancer,malignant cell_lung cancer,"[EPCAM, KRT19, SPINT2, KRT18, MGST1, TXN]"
2,fibroblast_colorectal cancer,fibroblast_colorectal cancer,"[SPARC, COL6A2, COL1A2, TPM2, COL3A1, LGALS1, ..."
3,T cell_uveal melanoma,T cell_uveal melanoma,"[CST7, CD3D, CREM, CD3E, IL32, NKG7, SRGN, TNF..."
4,T cell_lung cancer,T cell_lung cancer,"[TRBC2, CD3D, RGS1, IL32, PTPRC, CXCR4, SRGN]"
5,B cell_melanoma,B cell_melanoma,"[TMEM212, UGDH-AS1, KCNQ1OT1, METTL21A, ORC4, ..."
6,T cell_lung cancer,T cell_lung cancer,"[CD2, CD7, CD3D, RGS1, IL32, CXCR4, SRGN]"
7,endothelial cell_ovarian cancer,endothelial cell_ovarian cancer,"[RAMP2, CLDN5, HSPG2, VWF, CRIP2, GNG11, PECAM..."
8,malignant cell_uveal melanoma,malignant cell_uveal melanoma,"[EFHD1, ST3GAL4, S100A1, MLANA, CITED1, RAB38,..."
9,malignant cell_ovarian cancer,B cell_ovarian cancer,"[MDK, CD24, RPL7, WFDC2, GSTP1]"


In [75]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, SequentialSampler
from helical.models.scgpt.data_collator import DataCollator

# Ensure eval mode
scgpt_fine_tune.model.eval()
vocab = scgpt_fine_tune.vocab
id_to_gene = {v: k for k, v in vocab.items()}

# --- Collator with fixed padding ---
collator = DataCollator(
    do_padding=True,
    pad_token_id=vocab["<pad>"],
    pad_value=0,
    max_length=1200,
    do_binning=True,
    do_mlm=False,
    sampling=False,
    keep_first_n_tokens=1,
)

val_loader = DataLoader(
    validation_dataset,
    batch_size=32,
    sampler=SequentialSampler(validation_dataset),
    collate_fn=collator,
    drop_last=False,
)

# --- Prediction outputs ---
outputs = scgpt_fine_tune.get_outputs(validation_dataset)
y_pred = np.argmax(outputs, axis=1)
y_true = group_labels_test.numpy() if isinstance(group_labels_test, torch.Tensor) else group_labels_test
id_class_dict = {i: c for i, c in enumerate(label_encoder.classes_)}
true_groups = [id_class_dict[idx] for idx in y_true]
pred_groups = [id_class_dict[idx] for idx in y_pred]

# --- Map marker genes to token IDs ---
group_to_marker_ids = {}
for group, genes in final_marker_dict.items():
    marker_ids = [vocab[g] for g in genes if g in vocab]
    group_to_marker_ids[group] = marker_ids

# --- CLS attention extractor ---
def get_cls_attention(model, input_tensor):
    with torch.no_grad():
        embedded = model.encoder(input_tensor)
        attn_list = []
        for layer in model.transformer_encoder.layers:
            attn = layer.self_attn(embedded, embedded, embedded, need_weights=True)[1]
            cls_attn = attn[:, 0, :]  # CLS to each token
            attn_list.append(cls_attn)
        attn_avg = torch.stack(attn_list).mean(dim=0)  # (B, seq)
    return attn_avg

# --- Extract top attended marker genes per sample ---
results = []
sample_idx = 0

for batch in val_loader:
    device = next(scgpt_fine_tune.model.parameters()).device
    input_ids = batch["gene"].to(device)
    attn_scores = get_cls_attention(scgpt_fine_tune.model, input_ids)  # (B, seq)

    for i in range(attn_scores.shape[0]):
        sample_input = input_ids[i].cpu().numpy()
        sample_attn = attn_scores[i].cpu().numpy()
        true_group = true_groups[sample_idx]
        pred_group = pred_groups[sample_idx]

        marker_ids = group_to_marker_ids.get(true_group, [])
        top_marker_attn = [
            (id_to_gene[token_id], sample_attn[j])
            for j, token_id in enumerate(sample_input)
            if token_id in marker_ids
        ]
        top_marker_attn.sort(key=lambda x: x[1], reverse=True)
        top_genes = [g for g, _ in top_marker_attn[:5]]

        results.append({
            "True Group": true_group,
            "Predicted Group": pred_group,
            "Top Attended Marker Genes": top_genes
        })
        sample_idx += 1

# --- Display ---
df = pd.DataFrame(results)
display(df.head(20))


Fine-Tuning Validation: 100%|██████████████████████████████████████████████████████████████████████████████████| 367/367 [00:35<00:00, 10.43it/s]


Unnamed: 0,True Group,Predicted Group,Top Attended Marker Genes
0,T cell_lung cancer,T cell_ovarian cancer,"[CD2, TRBC2, TRBC1, CD3D, CD7]"
1,malignant cell_lung cancer,malignant cell_lung cancer,"[EPCAM, KRT19, SPINT2, KRT18, MGST1]"
2,fibroblast_colorectal cancer,fibroblast_colorectal cancer,"[SPARC, COL6A2, COL1A2, TPM2, COL3A1]"
3,T cell_uveal melanoma,T cell_uveal melanoma,"[CST7, CD3D, CREM, CD3E, IL32]"
4,T cell_lung cancer,T cell_lung cancer,"[TRBC2, CD3D, RGS1, IL32, PTPRC]"
5,B cell_melanoma,B cell_melanoma,"[TMEM212, UGDH-AS1, KCNQ1OT1, METTL21A, ORC4]"
6,T cell_lung cancer,T cell_lung cancer,"[CD2, CD7, CD3D, RGS1, IL32]"
7,endothelial cell_ovarian cancer,endothelial cell_ovarian cancer,"[RAMP2, CLDN5, HSPG2, VWF, CRIP2]"
8,malignant cell_uveal melanoma,malignant cell_uveal melanoma,"[EFHD1, ST3GAL4, S100A1, MLANA, CITED1]"
9,malignant cell_ovarian cancer,B cell_ovarian cancer,"[MDK, CD24, RPL7, WFDC2, GSTP1]"
