# Model Inference & GenAI Demo
Demonstrating the trained Survival Prediction model and the Generative AI model.


In [None]:
import torch
import sys
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from pathlib import Path

# Add src to path
sys.path.append("..")

from src.models.model_survival import SurvivalCNN
from src.models.model_gen import UNetGenerator
from src.data.dataset import TCGAPatchDataset

DEVICE = torch.device("cpu") # Use CPU for demo inference



## 1. Survival Prediction (Phase 1)


In [None]:
# Load Model
model_surv = SurvivalCNN().to(DEVICE)
ckpt_path = Path("../checkpoints/survival/best_model.pth")

if ckpt_path.exists():
    model_surv.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
    model_surv.eval()
    print("Survival Model Loaded.")
else:
    print("Checkpoint not found. Run pipeline first.")



In [None]:
# Run Inference on a Sample
ds = TCGAPatchDataset(split="val", patch_dir="../data/processed/patches", 
                      clinical_file="../data/processed/clinical_processed.csv",
                      manifest_file="../data/raw/image_manifest.csv")

if len(ds) > 0:
    img, clinical, target = ds[0]
    img_tensor = img.unsqueeze(0).to(DEVICE) # Add batch dim
    
    with torch.no_grad():
        risk_score = model_surv(img_tensor)
        
    print(f"Predicted Risk Score: {risk_score.item():.4f}")
    print(f"Actual Time: {target[0]:.1f} months, Event: {target[1]}")
    
    plt.imshow(img.permute(1, 2, 0)) # CHW -> HWC
    plt.title(f"Risk: {risk_score.item():.2f}")
    plt.axis('off')
    plt.show()
else:
    print("Dataset empty.")



## 2. Generative AI: Rewinding Disease (Phase 3)
Translating Late Stage images to Early Stage appearance.


In [None]:
# Load GenAI Model
gen = UNetGenerator().to(DEVICE)
ckpt_gen_path = Path("../checkpoints/gen/G_L2E_epoch15.pth") # Try to load a later epoch

if not ckpt_gen_path.exists():
    # Fallback to any saved
    avail = list(Path("../checkpoints/gen").glob("*.pth"))
    if avail:
        ckpt_gen_path = sorted(avail)[-1]

if ckpt_gen_path.exists():
    gen.load_state_dict(torch.load(ckpt_gen_path, map_location=DEVICE))
    gen.eval()
    print(f"GenAI Model Loaded from {ckpt_gen_path.name}")
else:
    print("GenAI Checkpoint not found.")



In [None]:
# Visualize Translation
if len(ds) > 0:
    # Pick a random sample
    idx = np.random.randint(0, len(ds))
    real_img, _, _ = ds[idx]
    
    real_tensor = real_img.unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        fake_early = gen(real_tensor)
        
    # Plot Side by Side
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    
    # Unnormalize for display if needed, but for now assuming roughly [0,1] or standard
    # The transforms used Normalize mean/std, so we should denormalize to look good
    def denorm(t):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        return (t * std + mean).clamp(0, 1)

    ax[0].imshow(denorm(real_img).permute(1, 2, 0))
    ax[0].set_title("Input (Original)")
    ax[0].axis('off')
    
    ax[1].imshow(denorm(fake_early.squeeze()).permute(1, 2, 0))
    ax[1].set_title("Generated (Early Stage)")
    ax[1].axis('off')
    
    plt.show()

