In [1]:
import os
import json
import random
import scanpy as sc
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import from_scipy_sparse_matrix

import optuna
import argparse
from optuna.samplers import TPESampler
import json

from utils import * 
from models import MODEL_CLASSES

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
# Initial experimental settings
results_dir = 'GIN-sota'
os.makedirs(results_dir, exist_ok=True) #Output directory to save results and model weights
batch_size = 512
num_epochs = 500

In [4]:
def evaluate(model, loader, loss_fn):
    model.eval()
    all_preds, all_labels, total_loss = [], [], 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to('cpu')
            out = model(batch.x, batch.edge_index)[batch.batch_size:]
            loss = loss_fn(out, batch.y[batch.batch_size:])
            total_loss += loss.item() * batch.batch_size
            all_preds.append(out.argmax(dim=1))
            all_labels.append(batch.y[batch.batch_size:])
    y_true = torch.cat(all_labels)
    y_pred = torch.cat(all_preds)
    acc = accuracy_score(y_true.numpy(), y_pred.numpy())
    return total_loss / len(loader.dataset), acc

In [5]:
selected_model = 'gin'
model_config_path = 'best_model_config.json'
# Load the JSON content
with open(model_config_path, 'r') as f:
    best_params = json.load(f)['params']
# Optional: print or inspect the keys
print(best_params)

{'hidden_channels': 512, 'num_neighbors': 5, 'dropout': 0.2690851529012735, 'lr': 0.002025225495945881, 'num_layers': 3, 'optimizer': 'Adam'}


In [6]:
best_params['num_neighbors'] = 15

In [7]:
# Load datasets
train_adata = sc.read_h5ad('../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151507.h5ad')
X = train_adata.X.toarray() if not isinstance(train_adata.X, np.ndarray) else adata.X
Y = LabelEncoder().fit_transform(train_adata.obs['Region'])

construct_interaction_KNN(train_adata, n_neighbors=best_params['num_neighbors'])  
adj_spatial = train_adata.obsm['adj']
edge_index = dense_to_sparse_edge_index(adj_spatial)

train_idx, eval_idx = train_test_split(np.arange(len(Y)), test_size=0.2, random_state=42, stratify=Y)

data = Data(x=torch.tensor(X, dtype=torch.float),
            edge_index=edge_index,
            y=torch.tensor(Y, dtype=torch.long))
data.train_mask = torch.zeros(len(Y), dtype=torch.bool)
data.train_mask[train_idx] = True
data.eval_mask = torch.zeros(len(Y), dtype=torch.bool)
data.eval_mask[eval_idx] = True

Graph constructed with 15 neighbors and stored in adata.obsm['adj'].


In [8]:
 # ========== 5. Final Training ==========

config_str = f"{selected_model}_h{best_params['hidden_channels']}_l{best_params['num_layers']}_d{best_params['dropout']:.2f}_lr{best_params['lr']:.0e}"
best_model_path = os.path.join(results_dir, f"{config_str}.pth")

model = MODEL_CLASSES[selected_model](
    in_channels=X.shape[1],
    hidden_channels=best_params['hidden_channels'],
    out_channels=np.unique(Y).shape[0],
    num_layers=best_params['num_layers'],
    dropout=best_params['dropout']
)
print(model)

GIN(
  (convs): ModuleList(
    (0): GINConv(nn=Sequential(
      (0): Linear(in_features=3266, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    ))
    (1): GINConv(nn=Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    ))
    (2): GINConv(nn=Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=7, bias=True)
    ))
  )
)


In [9]:

optimizer_name = best_params.get('optimizer', 'AdamW')  # Default to AdamW if not present

if optimizer_name == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=best_params['lr'])
elif optimizer_name == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(), lr=best_params['lr'])
elif optimizer_name == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=best_params['lr'], momentum=0.9)
elif optimizer_name == 'RMSprop':
    optimizer = torch.optim.RMSprop(model.parameters(), lr=best_params['lr'])
else:
    raise ValueError(f"Unsupported optimizer: {optimizer_name}")


loss_fn = nn.CrossEntropyLoss()

train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                            num_neighbors=[best_params['num_neighbors']], batch_size=batch_size, shuffle=True, num_workers=0)
eval_loader = NeighborLoader(data, input_nodes=data.eval_mask,
                            num_neighbors=[best_params['num_neighbors']], batch_size=batch_size, shuffle=False, num_workers=0)

train_log = []
best_eval_acc = 0

for epoch in range(1, num_epochs+1):
    model.train()
    for batch in train_loader:
        batch = batch.to('cpu')
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)[batch.batch_size:]
        loss = loss_fn(out, batch.y[batch.batch_size:])
        loss.backward()
        optimizer.step()

    train_loss, train_acc = evaluate(model, train_loader, loss_fn)
    eval_loss, eval_acc = evaluate(model, eval_loader, loss_fn)

    if eval_acc > best_eval_acc:
        best_eval_acc = eval_acc
        torch.save(model.state_dict(), best_model_path)

    train_log.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'eval_loss': eval_loss,
        'eval_acc': eval_acc
    })

    print(f"Epoch {epoch:02d} | Train Acc: {train_acc:.4f} | Eval Acc: {eval_acc:.4f}")
# Save logs
with open(os.path.join(results_dir, 'train_log.json'), 'w') as f:
    json.dump(train_log, f, indent=2)

with open(os.path.join(results_dir, 'best_model_config.json'), 'w') as f:
    json.dump({'best_model_path': best_model_path, 'params': best_params}, f, indent=2)


al Acc: 0.9615
Epoch 93 | Train Acc: 0.9950 | Eval Acc: 0.9942
Epoch 94 | Train Acc: 0.9837 | Eval Acc: 0.9832
Epoch 95 | Train Acc: 0.9986 | Eval Acc: 0.9982
Epoch 96 | Train Acc: 0.9999 | Eval Acc: 0.9997
Epoch 97 | Train Acc: 0.9998 | Eval Acc: 0.9997
Epoch 98 | Train Acc: 0.9999 | Eval Acc: 0.9997
Epoch 99 | Train Acc: 0.9998 | Eval Acc: 0.9997
Epoch 100 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 101 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 102 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 103 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 104 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 105 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 106 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 107 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 108 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 109 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 110 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 111 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 112 | Train Acc: 1.0000 | Eval Acc: 1.0000
Epoch 113 | 

In [10]:
# ========== 6. Test ==========
with open(os.path.join(results_dir, 'best_model_config.json')) as f:
    best_model_meta = json.load(f)

model = MODEL_CLASSES[selected_model](
    in_channels=X.shape[1],
    hidden_channels=best_model_meta['params']['hidden_channels'],
    out_channels=np.unique(Y).shape[0],
    num_layers=best_model_meta['params']['num_layers'],
    dropout=best_model_meta['params']['dropout']
)
model.load_state_dict(torch.load(best_model_meta['best_model_path']))
print(f"\nLoaded best model from {best_model_meta['best_model_path']}")

test_adata_paths = [
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151508.h5ad',
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151509.h5ad',
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151510.h5ad',
]

test_results = {}

for path in tqdm(test_adata_paths, desc="Evaluating"):
    test_pid = path.split('/')[-1].split('.')[0]
    adata_test = sc.read_h5ad(path)
    X_test = adata_test.X.toarray() if not isinstance(adata_test.X, np.ndarray) else adata_test.X
    Y_test = LabelEncoder().fit_transform(adata_test.obs['Region'])

    construct_interaction_KNN(adata_test, n_neighbors=best_params['num_neighbors'])  
    adj_spatial_test = adata_test.obsm['adj']
    edge_index_test = dense_to_sparse_edge_index(adj_spatial_test)

    data_test = Data(x=torch.tensor(X_test, dtype=torch.float),
                    edge_index=edge_index_test,
                    y=torch.tensor(Y_test, dtype=torch.long))
    data_test.test_mask = torch.ones(data_test.num_nodes, dtype=torch.bool)

    test_loader = NeighborLoader(
        data_test, input_nodes=data_test.test_mask,
        num_neighbors=[best_params['num_neighbors']], batch_size=512, shuffle=False, num_workers=0)

    test_loss, test_acc = evaluate(model, test_loader, loss_fn)
    test_results[test_pid] = {'loss': float(test_loss), 'accuracy': float(test_acc)}
    print(f"Test PID {test_pid} | Accuracy: {test_acc:.4f}")

with open(os.path.join(results_dir, 'test_results.json'), 'w') as f:
    json.dump(test_results, f, indent=2)



Loaded best model from GIN-sota/gin_h512_l3_d0.27_lr2e-03.pth
Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]Graph constructed with 15 neighbors and stored in adata.obsm['adj'].
Evaluating:  33%|███▎      | 1/3 [00:02<00:04,  2.14s/it]Test PID 151508 | Accuracy: 0.5755
Graph constructed with 15 neighbors and stored in adata.obsm['adj'].
Evaluating:  67%|██████▋   | 2/3 [00:04<00:02,  2.31s/it]Test PID 151509 | Accuracy: 0.6134
Graph constructed with 15 neighbors and stored in adata.obsm['adj'].
Evaluating: 100%|██████████| 3/3 [00:06<00:00,  2.27s/it]Test PID 151510 | Accuracy: 0.5905



In [11]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
import torch

def extract_embeddings_with_hook(model, data):
    embeddings = []

    def hook(module, input, output):
        embeddings.append(output.detach().cpu())

    # Register the hook on model.convs[2].nn[0] (i.e., Linear(512 → 512))
    handle = model.convs[2].nn[0].register_forward_hook(hook)

    # Forward pass (you don’t care about output here)
    with torch.no_grad():
        model.eval()
        _ = model(data.x, data.edge_index)

    # Remove the hook
    handle.remove()

    # Return the embeddings
    return embeddings[0]  # shape: [num_nodes, 512]

def test_cluster(model, data, num_classes):
    # Step 1: Extract intermediate node embeddings
    node_embeddings = extract_embeddings_with_hook(model, data)

    # Step 2: Clustering
    kmeans = KMeans(n_clusters=num_classes, n_init=10, random_state=42)
    pred = kmeans.fit_predict(node_embeddings.numpy())

    # Step 3: ARI & NMI
    true = data.y.cpu().numpy()
    ari = adjusted_rand_score(true, pred)
    nmi = normalized_mutual_info_score(true, pred)

    print(f"ARI: {ari:.4f}, NMI: {nmi:.4f}")
    return ari, nmi


In [13]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
import scanpy as sc
import numpy as np
import torch
import os
import json
from tqdm import tqdm

def extract_embeddings_with_hook(model, data, device):
    embeddings = []

    def hook(module, input, output):
        embeddings.append(output.detach().cpu())

    handle = model.convs[2].nn[0].register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        _ = model(data.x.to(device), data.edge_index.to(device))

    handle.remove()
    return embeddings[0]  # shape [N, 512]

def run_clustering_eval(embeddings, labels, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    pred = kmeans.fit_predict(embeddings.numpy())

    ari = adjusted_rand_score(labels, pred)
    nmi = normalized_mutual_info_score(labels, pred)

    return float(ari), float(nmi)


# === Main evaluation loop with clustering ===
test_adata_paths = [
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151508.h5ad',
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151509.h5ad',
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-16384/151510.h5ad',
]

test_results = {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for path in tqdm(test_adata_paths, desc="Evaluating"):
    test_pid = path.split('/')[-1].split('.')[0]
    adata_test = sc.read_h5ad(path)
    X_test = adata_test.X.toarray() if not isinstance(adata_test.X, np.ndarray) else adata_test.X
    Y_test = LabelEncoder().fit_transform(adata_test.obs['Region'])

    construct_interaction_KNN(adata_test, n_neighbors=best_params['num_neighbors'])  
    adj_spatial_test = adata_test.obsm['adj']
    edge_index_test = dense_to_sparse_edge_index(adj_spatial_test)

    data_test = Data(
        x=torch.tensor(X_test, dtype=torch.float),
        edge_index=edge_index_test,
        y=torch.tensor(Y_test, dtype=torch.long)
    )
    data_test.test_mask = torch.ones(data_test.num_nodes, dtype=torch.bool)

    # Run classification evaluation
    test_loader = NeighborLoader(
        data_test, input_nodes=data_test.test_mask,
        num_neighbors=[best_params['num_neighbors']], batch_size=512, shuffle=False, num_workers=0
    )
    test_loss, test_acc = evaluate(model, test_loader, loss_fn)

    # Run embedding extraction + clustering
    emb = extract_embeddings_with_hook(model, data_test, device)
    ari, nmi = run_clustering_eval(emb, Y_test, n_clusters=len(np.unique(Y_test)))

    # Store all results
    test_results[test_pid] = {
        'loss': float(test_loss),
        'accuracy': float(test_acc),
        'ARI': ari,
        'NMI': nmi
    }
    print(f"Test PID {test_pid} | Acc: {test_acc:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f}")

# Save to JSON
with open(os.path.join(results_dir, 'test_results.json'), 'w') as f:
    json.dump(test_results, f, indent=2)


Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]Graph constructed with 15 neighbors and stored in adata.obsm['adj'].
Evaluating:  33%|███▎      | 1/3 [00:04<00:08,  4.36s/it]Test PID 151508 | Acc: 0.5750 | ARI: 0.5938 | NMI: 0.7071
Graph constructed with 15 neighbors and stored in adata.obsm['adj'].
Evaluating:  67%|██████▋   | 2/3 [00:08<00:04,  4.46s/it]Test PID 151509 | Acc: 0.6130 | ARI: 0.5914 | NMI: 0.6851
Graph constructed with 15 neighbors and stored in adata.obsm['adj'].
Evaluating: 100%|██████████| 3/3 [00:13<00:00,  4.36s/it]Test PID 151510 | Acc: 0.5902 | ARI: 0.5306 | NMI: 0.6669

