In [None]:
import os
import torch
from PIL import Image
from torchvision.transforms import functional as TF

# -----------------------------
# Device setup (GPU if available)
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Limit CPU threads to avoid contention
torch.set_num_threads(4)

from model.SRKNNabmil import *
from dataloaders.transforms import GetValidTransforms

In [None]:
import pickle

SPATIAL_DIST_PATH = "./spatial_dist_7x7.pkl"

if not os.path.exists(SPATIAL_DIST_PATH):
    from model.SRKNNabmil import spatial_distance_mat
    print("Computing spatial distance matrix (one-time)...")
    spatial_dist = spatial_distance_mat(
        img_shape=(1, 630, 630),
        patch_size=90
    )
    with open(SPATIAL_DIST_PATH, "wb") as f:
        pickle.dump(spatial_dist, f)
else:
    print("Loading cached spatial distance matrix...")
    with open(SPATIAL_DIST_PATH, "rb") as f:
        spatial_dist = pickle.load(f)

print("Spatial dist shape:", spatial_dist.shape)


In [None]:
def load_model(model_name: str):
    net = SRkNNAttentionMIL(
        color_img=False,
        num_classes=1,
        loss="deepsurv",
        lr=1e-4,
        patch_size=90,
        img_dim=2,
        pos_encoding=True,
        extraction_layer="conv",
        hist_output_size=20,
        ssl_feat_pretrain_fname=None,
        feat_embedding_size=512,
        att_embedding_size=256,
        knn_att_type="both",
        topk_R=5,
        topk_S=1,
        spatial_dist_mat=spatial_dist,
        training=False,
    )

    # VERY IMPORTANT: map to CPU
    net.model.load_state_dict(
        torch.load(f"./trained_model/{model_name}.pth", map_location="cpu"),
        strict=False
    )

    net.model.eval()
    net.to(device)
    return net

MODEL_NAME = "xray_knnboth_patch7_r5_s1_best"  # change to yours
net = load_model(MODEL_NAME)
print("Model loaded")


In [None]:
# =============================
# DR Score inference (FULL)
# =============================

RAW_MIN = -3.255
RAW_MAX = 1.493
TARGET_MIN = 0.0
TARGET_MAX = 4.0


def rescale_dr_score(raw_score: float) -> float:
    """
    Linearly rescale raw DR score to [0, 4].
    """
    scaled = (raw_score - RAW_MIN) / (RAW_MAX - RAW_MIN)
    scaled = TARGET_MIN + scaled * (TARGET_MAX - TARGET_MIN)
    return float(scaled)


def predict_dr_score(image_path: str, clip: bool = True):
    """
    Predict DR Score for a single X-ray image.

    Returns
    -------
    raw_score : float
        Raw model output (risk score)
    scaled_score : float
        DR Score rescaled to [0, 4]
    """
    # Load & preprocess image
    img = Image.open(image_path).convert("RGB")
    img = TF.rgb_to_grayscale(img)

    trsf = GetValidTransforms((630, 630))
    img = trsf(img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        raw_score = net(img)

    raw_score = float(raw_score.detach().cpu().numpy().reshape(-1)[0])

    # Rescale to [0, 4]
    scaled_score = rescale_dr_score(raw_score)

    if clip:
        scaled_score = max(0.0, min(4.0, scaled_score))

    return raw_score, scaled_score


# =============================
# Test
# =============================
image_path = "./new_image/test.png"

raw, scaled = predict_dr_score(image_path)

print(f"Raw DR Score: {raw:.3f}")
print(f"Final DR Score (0â€“4): {scaled:.3f}")