# SigRelayST Training 

This notebook trains SigRelayST (Signature-based Relay for Spatial Transcriptomics) with signature bias on Google Colab.

SigRelayST extends CellNEST with signature-based bias terms derived from the Lignature database.

**CellNEST Citation:**
Zohora, F. T., et al. "CellNEST: A Graph Neural Network Framework for Cell-Cell Communication Inference from Spatial Transcriptomics Data."

## Setup


In [None]:
# Install dependencies
# Required packages for GATv2Conv_SigRelayST and training
%pip install torch torch-geometric torch-sparse torch-scatter scanpy pandas numpy scipy scikit-learn qnorm -q

# Verify GPU and imports
import torch
import torch_geometric
from torch_sparse import SparseTensor

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


## Verify Files

Check that all required files are present in the repository.


In [None]:
import os

required_files = [
    'input_graph/V1_Human_Lymph_Node_spatial/V1_Human_Lymph_Node_spatial_adjacency_records',
    'input_graph/V1_Human_Lymph_Node_spatial/V1_Human_Lymph_Node_spatial_cell_vs_gene_quantile_transformed',
    'database/signatures_all.csv',
    'database/CellNEST_database.csv',
    'run_SigRelayST.py',
    'CCC_gat.py',
    'GATv2Conv_SigRelayST.py'
]

print("Checking required files...")
print(f"Working directory: {os.getcwd()}\n")

all_present = True
for file in required_files:
    if os.path.exists(file):
        size = os.path.getsize(file) / (1024*1024)  # MB
        print(f"✓ {file} ({size:.2f} MB)")
    else:
        print(f"✗ {file} - MISSING")
        all_present = False

if all_present:
    print("\n✓ All files present! Ready to train.")
else:
    print("\n⚠ Some files are missing. Please ensure data files are in the repository or upload them.")


## Run Training

Train SigRelayST with signature bias. Model and outputs will be saved in the repository directories.


In [None]:
# Run training with signature bias
# This will automatically detect 4D edge attributes and use signature bias
# Model and outputs will be saved to: model/ and embedding_data/ directories
# Loss curve will be saved to: logs/DGI_{model_name}_loss_curve.csv

import os
from datetime import datetime

# Ensure output directories exist
os.makedirs('model', exist_ok=True)
os.makedirs('embedding_data', exist_ok=True)
os.makedirs('logs', exist_ok=True)

print("Starting SigRelayST training...")
print("This may take several hours. Progress will be shown every 500 epochs.")
print(f"Working directory: {os.getcwd()}")
print("=" * 60)

# Run training
# Model will be saved to: model/V1_Human_Lymph_Node_spatial/
# Embeddings will be saved to: embedding_data/V1_Human_Lymph_Node_spatial/
# Loss curve will be saved to: logs/DGI_SigRelayST_r1_loss_curve.csv
!python run_SigRelayST.py \
    --data_name='V1_Human_Lymph_Node_spatial' \
    --model_name='SigRelayST' \
    --run_id=1 \
    --num_epoch=40000 \
    --hidden=256 \
    --heads=1 \
    --lr_rate=0.00001

print("=" * 60)
print("Training complete!")
print(f"\nModel saved to: model/V1_Human_Lymph_Node_spatial/")
print(f"Embeddings saved to: embedding_data/V1_Human_Lymph_Node_spatial/")
print(f"Loss curve saved to: logs/DGI_SigRelayST_r1_loss_curve.csv")


## Visualize Training Loss

Plot the training loss curve to monitor training progress.


In [None]:
# Visualize training loss curve
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

loss_file = 'logs/DGI_SigRelayST_r1_loss_curve.csv'

try:
    # Load loss curve (saved every 500 epochs)
    loss_data = np.loadtxt(loss_file, delimiter=',')
    
    # Create epoch numbers (every 500 epochs)
    epochs = np.arange(0, len(loss_data)) * 500
    
    # Plot
    plt.figure(figsize=(12, 6))
    plt.plot(epochs, loss_data, linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('DGI Loss', fontsize=12)
    plt.title('SigRelayST Training Loss Curve (with Signature Bias)', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"Final loss: {loss_data[-1]:.4f}")
    print(f"Minimum loss: {np.min(loss_data):.4f} (at epoch {np.argmin(loss_data) * 500})")
    print(f"Total training epochs: {len(loss_data) * 500}")
    
except FileNotFoundError:
    print(f"Loss curve file not found: {loss_file}")
    print("Training may still be in progress or file hasn't been created yet.")
except Exception as e:
    print(f"Error loading loss curve: {e}")


In [None]:
# Commit and push results 


# !git add model/ embedding_data/ logs/
# !git commit -m "Add SigRelayST training results"
# !git push


