In [1]:
import os
import json
import random
import scanpy as sc
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import from_scipy_sparse_matrix

import optuna
import argparse
from optuna.samplers import TPESampler
import json

from models import MODEL_CLASSES

In [2]:
# Initial experimental settings
results_dir = 'GIN-sota'
os.makedirs(results_dir, exist_ok=True) #Output directory to save results and model weights
batch_size = 512
num_epochs = 500

In [3]:
# Load datasets
train_adata = sc.read_h5ad('../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-8192/151507.h5ad')
X = train_adata.X.toarray() if not isinstance(train_adata.X, np.ndarray) else adata.X
Y = LabelEncoder().fit_transform(train_adata.obs['Region'])

sc.pp.neighbors(train_adata, n_neighbors=15, use_rep='X')
edge_index, _ = from_scipy_sparse_matrix(train_adata.obsp['connectivities'])

train_idx, eval_idx = train_test_split(np.arange(len(Y)), test_size=0.2, random_state=42, stratify=Y)

data = Data(x=torch.tensor(X, dtype=torch.float),
            edge_index=edge_index,
            y=torch.tensor(Y, dtype=torch.long))
data.train_mask = torch.zeros(len(Y), dtype=torch.bool)
data.train_mask[train_idx] = True
data.eval_mask = torch.zeros(len(Y), dtype=torch.bool)
data.eval_mask[eval_idx] = True

In [4]:
def evaluate(model, loader, loss_fn):
    model.eval()
    all_preds, all_labels, total_loss = [], [], 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to('cpu')
            out = model(batch.x, batch.edge_index)[batch.batch_size:]
            loss = loss_fn(out, batch.y[batch.batch_size:])
            total_loss += loss.item() * batch.batch_size
            all_preds.append(out.argmax(dim=1))
            all_labels.append(batch.y[batch.batch_size:])
    y_true = torch.cat(all_labels)
    y_pred = torch.cat(all_preds)
    acc = accuracy_score(y_true.numpy(), y_pred.numpy())
    return total_loss / len(loader.dataset), acc

In [5]:
selected_model = 'gin'
model_config_path = 'best_model_config.json'
# Load the JSON content
with open(model_config_path, 'r') as f:
    best_params = json.load(f)['params']
# Optional: print or inspect the keys
print(best_params)

{'hidden_channels': 128, 'dropout': 0.4605812003658651, 'lr': 0.0024675736770887937, 'num_layers': 2, 'optimizer': 'AdamW'}


In [6]:
 # ========== 5. Final Training ==========

config_str = f"{selected_model}_h{best_params['hidden_channels']}_l{best_params['num_layers']}_d{best_params['dropout']:.2f}_lr{best_params['lr']:.0e}"
best_model_path = os.path.join(results_dir, f"{config_str}.pth")

model = MODEL_CLASSES[selected_model](
    in_channels=X.shape[1],
    hidden_channels=best_params['hidden_channels'],
    out_channels=np.unique(Y).shape[0],
    num_layers=best_params['num_layers'],
    dropout=best_params['dropout']
)
optimizer_name = best_params.get('optimizer', 'AdamW')  # Default to AdamW if not present

if optimizer_name == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=best_params['lr'])
elif optimizer_name == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(), lr=best_params['lr'])
elif optimizer_name == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=best_params['lr'], momentum=0.9)
elif optimizer_name == 'RMSprop':
    optimizer = torch.optim.RMSprop(model.parameters(), lr=best_params['lr'])
else:
    raise ValueError(f"Unsupported optimizer: {optimizer_name}")


loss_fn = nn.CrossEntropyLoss()

train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                            num_neighbors=[3], batch_size=batch_size, shuffle=True, num_workers=0)
eval_loader = NeighborLoader(data, input_nodes=data.eval_mask,
                            num_neighbors=[3], batch_size=batch_size, shuffle=False, num_workers=0)

train_log = []
best_eval_acc = 0

for epoch in range(1, num_epochs+1):
    model.train()
    for batch in train_loader:
        batch = batch.to('cpu')
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)[batch.batch_size:]
        loss = loss_fn(out, batch.y[batch.batch_size:])
        loss.backward()
        optimizer.step()

    train_loss, train_acc = evaluate(model, train_loader, loss_fn)
    eval_loss, eval_acc = evaluate(model, eval_loader, loss_fn)

    if eval_acc > best_eval_acc:
        best_eval_acc = eval_acc
        torch.save(model.state_dict(), best_model_path)

    train_log.append({
        'epoch': epoch,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'eval_loss': eval_loss,
        'eval_acc': eval_acc
    })

    print(f"Epoch {epoch:02d} | Train Acc: {train_acc:.4f} | Eval Acc: {eval_acc:.4f}")
# Save logs
with open(os.path.join(results_dir, 'train_log.json'), 'w') as f:
    json.dump(train_log, f, indent=2)

with open(os.path.join(results_dir, 'best_model_config.json'), 'w') as f:
    json.dump({'best_model_path': best_model_path, 'params': best_params}, f, indent=2)


Epoch 01 | Train Acc: 0.4182 | Eval Acc: 0.4088
Epoch 02 | Train Acc: 0.6541 | Eval Acc: 0.6529
Epoch 03 | Train Acc: 0.6977 | Eval Acc: 0.7130
Epoch 04 | Train Acc: 0.7408 | Eval Acc: 0.7420
Epoch 05 | Train Acc: 0.8006 | Eval Acc: 0.8064
Epoch 06 | Train Acc: 0.8264 | Eval Acc: 0.8166
Epoch 07 | Train Acc: 0.8534 | Eval Acc: 0.8580
Epoch 08 | Train Acc: 0.8772 | Eval Acc: 0.8718
Epoch 09 | Train Acc: 0.9017 | Eval Acc: 0.9037
Epoch 10 | Train Acc: 0.9114 | Eval Acc: 0.9083
Epoch 11 | Train Acc: 0.9182 | Eval Acc: 0.9025
Epoch 12 | Train Acc: 0.9390 | Eval Acc: 0.9350
Epoch 13 | Train Acc: 0.9520 | Eval Acc: 0.9450
Epoch 14 | Train Acc: 0.9619 | Eval Acc: 0.9621
Epoch 15 | Train Acc: 0.9542 | Eval Acc: 0.9505
Epoch 16 | Train Acc: 0.9674 | Eval Acc: 0.9678
Epoch 17 | Train Acc: 0.9693 | Eval Acc: 0.9565
Epoch 18 | Train Acc: 0.9763 | Eval Acc: 0.9779
Epoch 19 | Train Acc: 0.9789 | Eval Acc: 0.9750
Epoch 20 | Train Acc: 0.9842 | Eval Acc: 0.9783
Epoch 21 | Train Acc: 0.9850 | Eval Acc:

In [8]:
# ========== 6. Test ==========
with open(os.path.join(results_dir, 'best_model_config.json')) as f:
    best_model_meta = json.load(f)

model = MODEL_CLASSES[selected_model](
    in_channels=X.shape[1],
    hidden_channels=best_model_meta['params']['hidden_channels'],
    out_channels=np.unique(Y).shape[0],
    num_layers=best_model_meta['params']['num_layers'],
    dropout=best_model_meta['params']['dropout']
)
model.load_state_dict(torch.load(best_model_meta['best_model_path']))
print(f"\nLoaded best model from {best_model_meta['best_model_path']}")

test_adata_paths = [
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-8192/151508.h5ad',
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-8192/151509.h5ad',
    '../../../dataset_10x-visium-processed/dataset_10x-visium_filtered_adatas_hvg-8192/151510.h5ad',
]

test_results = {}

for path in tqdm(test_adata_paths, desc="Evaluating"):
    test_pid = path.split('/')[-1].split('.')[0]
    adata_test = sc.read_h5ad(path)
    X_test = adata_test.X.toarray() if not isinstance(adata_test.X, np.ndarray) else adata_test.X
    Y_test = LabelEncoder().fit_transform(adata_test.obs['Region'])

    sc.pp.neighbors(adata_test, n_neighbors=3, use_rep='X')
    edge_index_test, _ = from_scipy_sparse_matrix(adata_test.obsp['connectivities'])

    data_test = Data(x=torch.tensor(X_test, dtype=torch.float),
                    edge_index=edge_index_test,
                    y=torch.tensor(Y_test, dtype=torch.long))
    data_test.test_mask = torch.ones(data_test.num_nodes, dtype=torch.bool)

    test_loader = NeighborLoader(
        data_test, input_nodes=data_test.test_mask,
        num_neighbors=[3], batch_size=16, shuffle=False, num_workers=0)

    test_loss, test_acc = evaluate(model, test_loader, loss_fn)
    test_results[test_pid] = {'loss': float(test_loss), 'accuracy': float(test_acc)}
    print(f"Test PID {test_pid} | Accuracy: {test_acc:.4f}")

with open(os.path.join(results_dir, 'test_results.json'), 'w') as f:
    json.dump(test_results, f, indent=2)



Loaded best model from GIN-sota/gin_h128_l2_d0.46_lr2e-03.pth
Evaluating:  33%|███▎      | 1/3 [00:02<00:05,  2.95s/it]Test PID 151508 | Accuracy: 0.6076
Evaluating:  67%|██████▋   | 2/3 [00:07<00:03,  3.62s/it]Test PID 151509 | Accuracy: 0.6597
Evaluating: 100%|██████████| 3/3 [00:10<00:00,  3.56s/it]Test PID 151510 | Accuracy: 0.6528

