# Load Libaries

In [10]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random, os, gc, psutil
import scanpy as sc
import anndata as ann
from scipy.stats import pearsonr as pr

import sklearn as skl
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score,
    mean_squared_error, average_precision_score,
    precision_score, recall_score, precision_recall_curve as prc,
    auc, silhouette_score, confusion_matrix, classification_report
)

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# PyTorch Geometric
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv, GraphConv, GCNConv
import torch_geometric.nn as pyg_nn
from torch_geometric.utils import from_scipy_sparse_matrix, subgraph
from torch_geometric.loader import NeighborLoader, RandomNodeLoader

# Progress bar
from tqdm import tqdm

# Collections
from collections import Counter

In [2]:
def set_random_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # if using CUDA
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seeds(42)
# Check if GPU is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [3]:
def memory_usgae():
    gc.collect()
    torch.cuda.empty_cache()
    process = psutil.Process(os.getpid())
    memory_gb = process.memory_info().rss / 1024**3  # in GB

    print(f"Current memory usage: {memory_gb:.2f} GB")

# Load data

Load data in an anndata object, here we load the processed atlas as an example, a classifier label is required in ad.obs as the metastasis label.

In [5]:
# load the data
ad = sc.read_h5ad('../Data/Cancer_cell_data/All_integrated.hallmark.harmony.h5ad')
ad

AnnData object with n_obs × n_vars = 725643 × 2132
    obs: 'Project_ID', 'Primary_or_Metastatic', 'Final_cancer_type', 'Final_histological_subtype', 'Final_molecular_subtype', 'Final_tissue', 'Final_sample_id', 'Final_patient_age', 'Final_patient_stage', 'Final_patient_treatment', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Final_histological_subtype_backup', 'Final_patient_age_backup', 'Final_tissue_backup', 'Final_patient_treatment_backup', 'Final_patient_stage_backup', 'Classifier_label'
    var: 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: 'Final_cancer_type_colors', 'Final_histological_subtype_colors', 'Final_molecular_subtype_colors', 'Final_patient_stage_colors', 'Final_patient_treatment_colors', 'Final_tissue_colors', 'Primary_or_Metastatic_colors', 'Project_ID_colors', 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_pca_harmony', 'X_pca_harmony_Project_ID', 'X_pca_harmony_project_id', 'X_uma

Here we created an extra column in ad.obs as the label to predict.

In [6]:
# Map cancer type → its "primary tissue"
primary_tissue_map = {
    "Breast Cancer": "Breast",
    "Lung Cancer": "Lung",
    "Ovarian Cancer": "Ovary",
    "Colorectal Cancer": "Colon"
}

# Initialize as primary
ad.obs["source"] = "Non_metastatic_local"


# Loop through rows
for idx, row in ad.obs.iterrows():
    cancer_type = row["Final_cancer_type"]
    tissue = row["Final_tissue"]

    primary_site = primary_tissue_map.get(cancer_type, None)

    if row["Primary_or_Metastatic"] == "Primary":
        ad.obs.at[idx, "source"] = "Non_metastatic_local"

    else:  # Metastatic
        if tissue == primary_site:
            ad.obs.at[idx, "source"] = "Metastatic_local"
        else:
            ad.obs.at[idx, "source"] = "Metastatic_distant"

# Summary:
print(ad.obs["source"].value_counts())


source
Non_metastatic_local    328282
Metastatic_distant      249125
Metastatic_local        148236
Name: count, dtype: int64


In [7]:
# here we sample 10% of samples for an illustration

n_cells = ad.n_obs
sample_size = int(n_cells * 0.1)
np.random.seed(42)  # for reproducibility
sample_indices = np.random.choice(n_cells, size=sample_size, replace=False)
ad_subset = ad[sample_indices, :]

print(f"Original: {n_cells} cells")
print(f"Sampled: {ad_subset.n_obs} cells")

Original: 725643 cells
Sampled: 72564 cells


# Model

In [11]:
class scMeta(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, heads=4, dropout=0.3):
        super(scMeta, self).__init__()

        # Transformer layers
        self.conv1 = TransformerConv(in_channels=input_dim, out_channels=hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = TransformerConv(in_channels=hidden_dim * heads, out_channels=hidden_dim, heads=1, dropout=dropout)

        # Classifier head
        self.classifier = torch.nn.Sequential(
            Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            Dropout(dropout),
            Linear(hidden_dim, num_classes)
        )

    def forward(self, x, edge_index, return_embedding=False):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        logits = self.classifier(x)

        if return_embedding:
            return logits, x  # logits, embeddings
        else:
            return logits


# Data preperation

In [13]:
# Features
if not isinstance(ad_subset.X, np.ndarray):
    X = torch.tensor(ad_subset.X.toarray(), dtype=torch.float32)
else:
    X = torch.tensor(ad_subset.X, dtype=torch.float32)

# Graph edges
adj = ad_subset.obsp["connectivities"].tocoo()
edge_index = torch.tensor(np.vstack((ad_subsetj.row, ad_subsetj.col)), dtype=torch.long)

In [32]:
# Labels

# cells from patients with no mets
primary_mask = (ad_subset.obs["source"] == "Non_metastatic_local").values

# cells from patients with known mets
local_mask = (ad_subset.obs["source"] == "Metastatic_local").values

# cells from non-primary organs, distant mets
distant_mask = (ad_subset.obs["source"] == "Metastatic_distant").values

# this is for all 3 classes, which are all used during training
y_3class = np.full(ad_subset.n_obs, -1, dtype=int)
y_3class[primary_mask] = 0
y_3class[local_mask] = 1
y_3class[distant_mask] = 2
y_3class = torch.tensor(y_3class, dtype=torch.long)

# Binary label for Non_metastatic_local vs Metastatic_local mets only, used in evaluation
y_binary = np.full(ad_subset.n_obs, -1, dtype=int)
y_binary[primary_mask] = 0
y_binary[local_mask] = 1
y_binary = torch.tensor(y_binary, dtype=torch.long)
valid_mask = (y_3class == 0) | (y_3class == 1)
valid_idx = np.where(valid_mask)[0]



In [33]:
# ===== 10% Stratified Split (Primary vs Local only) =====
split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
train_valid_idx, heldout_idx = next(split.split(valid_idx, y_3class[valid_idx]))

val_idx = valid_idx[heldout_idx]       # final val set = only 0/1 labels
train_idx = np.setdiff1d(np.arange(ad_subset.n_obs), val_idx)  # everything else is training


data = Data(
    x=X,
    edge_index=edge_index,
    y=y_3class
)
data.train_idx = train_idx
data.val_idx = val_idx
data.y_binary = y_binary

print(f"Prepared {data.num_nodes} nodes")
print(f"Primary: {(y_3class==0).sum().item()} | Local: {(y_3class==1).sum().item()} | Distant: {(y_3class==2).sum().item()}")
print(len(train_idx), 'training nodes', len(val_idx), 'validation nodes')

Prepared 72564 nodes
Primary: 32880 | Local: 14915 | Distant: 24769
67784 training nodes 4780 validation nodes


# Training

In [34]:
def evaluate(model, x_val, y_val, y_binary, device):
    model.eval()

    # Self-loop edges only
    N = x_val.size(0)
    edge_index_val = torch.arange(N, device=device).unsqueeze(0).repeat(2, 1)

    with torch.no_grad():
        logits, _ = model(x_val, edge_index=edge_index_val, return_embedding=True)

    y_true_3class = y_val.cpu().numpy()
    y_pred_3class = logits.argmax(dim=1).cpu().numpy()
    y_prob_3class = torch.softmax(logits, dim=1).cpu().numpy()
    
    y_binary = y_binary.to(device)
    y_val = y_val.to(device)


    # Filter for binary classification (primary vs local mets)
    binary_mask = (y_binary != -1) & (y_val != 2)  # exclude distant mets
    
    y_true_bin = y_binary[binary_mask].cpu().numpy()
    y_pred_bin = logits.argmax(dim=1)[binary_mask].cpu().numpy()
    y_prob_bin = torch.softmax(logits, dim=1)[binary_mask, :2].cpu().numpy()
    
    # Compute metrics
    try:
        auc = roc_auc_score(y_true_bin, y_prob_bin[:, 1], multi_class='ovr', average='weighted')
    except:
        auc = np.nan

    try:
        auprc = average_precision_score(y_true_bin, y_prob_bin[:, 1], average='weighted')
    except:
        auprc = np.nan

    acc = accuracy_score(y_true_bin, y_pred_bin)
    f1 = f1_score(y_true_bin, y_pred_bin, average='weighted')
    
    return acc, f1, auc, auprc, y_true_bin, y_pred_bin


In [35]:
# contrast loss
def NT_Xent(embeddings, tau=0.5):
    device = embeddings.device
    z_i = F.normalize(embeddings, dim=1)
    z_j = F.normalize(embeddings[torch.randperm(z_i.size(0))], dim=1)

    logits = torch.mm(z_i, z_j.t()) / tau
    labels = torch.arange(z_i.size(0), device=device)
    loss = F.cross_entropy(logits, labels)
    return loss

In [38]:
def run_model(data, save_dir, epochs=100, seed=42, patience=10, gamma=0.1):
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    memory_usgae()

    train_idx = data.train_idx
    val_idx = data.val_idx

    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    train_mask[train_idx] = True

    # Build train graph
    edge_index_train, _ = subgraph(train_mask, data.edge_index, relabel_nodes=True)
    x_train = data.x[train_mask]
    y_train = data.y[train_mask]

    train_data = Data(
        x=x_train,
        edge_index=edge_index_train,
        y=y_train
    )

    x_val = data.x[val_idx].to(device)
    y_val = data.y[val_idx].to(device)

    y_binary_val = data.y_binary[val_idx]


    model = scMeta(
        input_dim=data.num_node_features,
        hidden_dim=128,
        num_classes=len(torch.unique(data.y))
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    best_acc = 0
    best_auc = 0.0
    patience_counter = 0
    best_model_path = os.path.join(save_dir, f"best_model.pt")

    # double check class distribution in val set
    # no distant mets should be here, no label 2
    y_val_counts = torch.bincount(data.y[val_idx])
    for label, count in enumerate(y_val_counts):
        print(f"Val label {label}: {count.item()} samples")

    # Pre-training eval
    acc, f1, auc, auprc, _, _ = evaluate(
        model, x_val, y_val, y_binary_val, device
    )
    print(f"Pre-training stats: Acc={acc:.4f}, F1={f1:.4f}, AUROC={auc:.4f}, AUPRC={auprc:.4f}")

    for epoch in tqdm(range(epochs)):
        model.train()
        total_loss = 0

        loader = RandomNodeLoader(train_data, num_parts=100, shuffle=True)
        for batch in loader:
            batch = batch.to(device)
            logits, emb = model(batch.x, batch.edge_index, return_embedding=True)

            loss_ce = F.cross_entropy(logits, batch.y)
            loss_con = NT_Xent(emb, tau=0.5)
            loss = loss_ce + gamma * loss_con

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Evaluate
        acc, f1, auc, auprc, _, _ = evaluate(
            model, x_val, y_val, y_binary_val, device
        )

        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Acc={acc:.4f}, F1={f1:.4f}, AUROC={auc:.4f}, AUPRC={auprc:.4f}")

        if auc > best_auc:
            best_auc = auc
            torch.save(model.state_dict(), best_model_path)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break


In [39]:
save_dir = '/scratch/gilbreth/wang3712/Metastasis_single_cell/scMeta_example/'

run_model(
    data,
    save_dir=save_dir,
    epochs=200,
    patience=20,
    gamma=4.0
)

Current memory usage: 23.45 GB
Val label 0: 3288 samples
Val label 1: 1492 samples
Pre-training stats: Acc=0.0274, F1=0.0466, AUROC=0.5672, AUPRC=0.3736


  0%|          | 1/200 [00:00<02:14,  1.48it/s]

Epoch 0: Acc=0.5278, F1=0.5973, AUROC=0.8214, AUPRC=0.6450


  6%|▌         | 11/200 [00:07<02:00,  1.56it/s]

Epoch 10: Acc=0.8427, F1=0.8785, AUROC=0.9419, AUPRC=0.8735


 10%|█         | 21/200 [00:13<01:52,  1.58it/s]

Epoch 20: Acc=0.7257, F1=0.8076, AUROC=0.9183, AUPRC=0.8346


 16%|█▌        | 31/200 [00:19<01:45,  1.60it/s]

Epoch 30: Acc=0.8033, F1=0.8569, AUROC=0.9410, AUPRC=0.8736


 20%|██        | 41/200 [00:25<01:40,  1.59it/s]

Epoch 40: Acc=0.7132, F1=0.8005, AUROC=0.8815, AUPRC=0.7775


 26%|██▌       | 51/200 [00:32<01:33,  1.59it/s]

Epoch 50: Acc=0.8123, F1=0.8647, AUROC=0.9338, AUPRC=0.8621


 29%|██▉       | 58/200 [00:37<01:31,  1.56it/s]

Early stopping at epoch 59





# Downstream Analysis

Please refer https://github.com/loooooooopi/scMeta/blob/master/Reproducibility/scMeta_downstream_analysis.ipynb for visualization of latent embeddings genereated by the transformer, feature(gene) priorization and pathway analysis.