In [None]:
"""
Main script for training and evaluating the DR Score on knee X-ray data.

This version has:
- Local paths anonymised
- Example subject ID anonymised
- Checkpoint path generalized

Fill in:
- DIR: path to your preprocessed X-ray images
- checkpoint_path: path to a valid checkpoint if resuming training
- id in the inference_Xray call: subject ID you want to visualise / infer on
"""

import os
import random
import numpy as np
from tqdm import tqdm
import lifelines

import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers

from dataloaders.data_utils_xray import *
from test import *
from inference import *
from utils import *
from model.SRKNNabmil import *

In [None]:
# ----------------------------------------------------------------------
# CUDA & reproducibility
# ----------------------------------------------------------------------

torch.cuda.empty_cache()

# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print("CUDA Available:" if cuda_available else "CUDA Not Available")


def seed_torch(seed: int = 1029) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)  # Disable hash randomization
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed)  # Uncomment if using multi-GPU
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed_torch()

In [None]:
# ----------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------

DIR = r"/path/to/xray_images"# replace with your actual data directory

model_name = "xray_knnboth_patch7_r5_s1"

color_img = False
batch_size = 16
img_input_size = 630
patch_size = 90
img_dim = 2
pos_encoding = True
slide_num = 1

target = "KR"  # e.g. knee replacement outcome
sample_ratio = None

knn_type = "both"          # 'radiomic', 'spatial', or 'both'
extraction_layer = "conv"  # feature extraction layer
topk_R = 5                 # k for radiomic nearest neighbours
topk_S = 1                 # k for spatial nearest neighbours

random_state = 11
weighted_sampling = True

loss = "deepsurv"
# lr = 7e-6
# lr = 1e-6
lr = 8e-6  # large batch size, larger lr
# lr = 2e-5

find_lr = False
resume_from_checkpoint = False

checkpoint_path = "path/to/checkpoint.ckpt" # update if you want to resume from a real checkpoint

In [None]:
# ----------------------------------------------------------------------
# Data loading
# ----------------------------------------------------------------------

xray_datamodule = XrayDataModule(
    img_dir=DIR,
    img_input_size=img_input_size,
    batch_size=batch_size,
    target=target,
    sample_ratio=sample_ratio,
    random_state=random_state,
    weighted_sampling=weighted_sampling,
)

df_train = xray_datamodule.df_train
df_valid = xray_datamodule.df_valid
valid_trsf = xray_datamodule.valid_trsfs

# Optionally export validation metadata
# df_valid.to_csv("/path/to/save/df_valid.csv", index=False)

# ----------------------------------------------------------------------
# Spatial distance map for spatial kNN attention
# ----------------------------------------------------------------------

# Prepare L1 distance matrix for Spatial Nearest Neighbour Self-Attention
l1_dist_map = None
if (knn_type == "both") or (knn_type == "spatial"):
    l1_dist_map = spatial_distance_mat(
        img_shape=(slide_num, img_input_size, img_input_size),
        patch_size=patch_size,
    )

In [None]:
# ----------------------------------------------------------------------
# Model definition
# ----------------------------------------------------------------------

net = SRkNNAttentionMIL(
    color_img=color_img,
    num_classes=1,
    loss=loss,
    lr=lr,
    patch_size=patch_size,
    img_dim=img_dim,
    pos_encoding=pos_encoding,
    extraction_layer=extraction_layer,
    hist_output_size=20,
    feat_embedding_size=512,  # 256 in some experiments
    att_embedding_size=256,   # 128 in some experiments
    knn_att_type=knn_type,
    topk_R=topk_R,
    topk_S=topk_S,
    spatial_dist_mat=l1_dist_map,
    training=True,
).cuda()

# If you need to unfreeze specific modules, you can do it here
# for param in net.parameters():
#     param.requires_grad = True

In [None]:
# ----------------------------------------------------------------------
# Trainer setup
# ----------------------------------------------------------------------

# Example: freeze feature extractor if needed
# freeze_modules(net, module="conv_extractor")

xray_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")

trainer_params = {
    "accumulate_grad_batches": 4,
    # "callbacks": [EarlyStopping(monitor="valid_loss", mode="min")],
    "gpus": [0],
    "auto_lr_find": find_lr,
    "logger": xray_logger,
}

if resume_from_checkpoint:
    trainer_params["max_epochs"] = 13
    trainer_params["resume_from_checkpoint"] = checkpoint_path
else:
    trainer_params["max_epochs"] = 9

trainer = Trainer(**trainer_params)

# Optional learning rate finder
if find_lr:
    trainer.tune(net, xray_datamodule)

In [None]:
# ----------------------------------------------------------------------
# Training
# ----------------------------------------------------------------------

trainer.fit(net, xray_datamodule)

In [None]:
# ----------------------------------------------------------------------
# Evaluation
# ----------------------------------------------------------------------

y_preds, y_trues, low_risk_idx, medium_risk_idx, high_risk_idx = net.test(
    dloader=xray_datamodule.val_dataloader(),
    task="xray",
    model_name=model_name,
    save=True,
)

In [None]:
# ----------------------------------------------------------------------
# Example inference for a single subject (explainability / attention maps)
# ----------------------------------------------------------------------

att_map, y_pred = inference_Xray(
    net,
    df_train,
    df_valid,
    valid_trsf,
    id="EXAMPLE_SUBJECT_ID",  # replace with an anonymised / synthetic ID
    side=1,
)

In [None]:
# ----------------------------------------------------------------------
# Save & reload model (optional)
# ----------------------------------------------------------------------

model_name = "xray_knnboth_patch7_r5_s1"
net.save(model_name)

# Later, or in another script / session:
#model_name = "xray_knnboth_patch7_r5_s1"
#net.load(model_name)