# Libraries

In [1]:
# Imports 
import os, glob, math
import numpy as np
from typing import Tuple, List

import torch, torchvision
from torch import nn
from torchvision import models, transforms

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from skimage import io, exposure

# 📝 Constants, Helpers, Model & Prediction Setup

### This block defines constants, helper functions, model loading, and prediction utilities for the X-ray classification workflow.

In [2]:
# --- Constants ---
ROOT = r"D:\\MedicalAI-Assistant"
ART_DIR = os.path.join(ROOT, 'artifacts')
os.makedirs(ART_DIR, exist_ok=True)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]

# --- Helpers ---
def read_cxr_u8(path: str)-> np.ndarray:
    """Return uint8 grayscale image. Trims borders slightly to remove rulers/markers."""
    ext = os.path.splitext(path)[-1].lower()
    if ext in {'.dcm', '.dicom'}:
        d = pydicom.dcmread(path)
        pix = d.pixel_array.astype(np.float32)
        try:
            pix = apply_voi_lut(pix, d).astype(np.float32)
        except Exception:
            pass
        slope = float(getattr(d, 'RescaleSlope', 1.0))
        inter = float(getattr(d, 'RescaleIntercept', 0.0))
        pix = pix * slope + inter
        if getattr(d, 'PhotometricInterpretation', '').upper() == 'MONOCHROME1':
            pix = np.max(pix)- pix
        pix -= pix.min()
        pix = (pix / (pix.max() + 1e-8) * 255.0).astype(np.uint8)
    else:
        pix = io.imread(path)
        if pix.ndim == 3:
            pix = pix.mean(2)
        if pix.dtype != np.uint8:
            pix = exposure.rescale_intensity(pix, out_range=(0, 255)).astype(np.uint8)
    h, w = pix.shape
    c = int(0.02 * min(h, w))
    if h > 2*c and w > 2*c:
        pix = pix[c:h-c, c:w-c]
    return pix

_to_resnet = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def img_to_tensor3(img_u8: np.ndarray)-> torch.Tensor:
    if img_u8.ndim == 2:
        img_u8 = np.stack([img_u8]*3, axis=-1)
    return _to_resnet(img_u8)

def load_resnet18(weights_path: str, num_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_feats = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(in_feats, num_classes)
    )
    sd = torch.load(weights_path, map_location=DEVICE)
    if isinstance(sd, dict) and 'state_dict' in sd:
        model.load_state_dict(sd['state_dict'], strict=False)
    else:
        model.load_state_dict(sd, strict=False)
    model.eval().to(DEVICE)
    return model

def predict_binary(image_path: str)-> Tuple[int, float, List[float]]:
    u8 = read_cxr_u8(image_path)
    t1 = img_to_tensor3(u8)
    t2 = img_to_tensor3(u8[:, ::-1].copy()) # h-flip
    X = torch.stack([t1, t2]).to(DEVICE)
    with torch.no_grad():
        logits = clf(X)
        logits = logits.mean(0, keepdim=True) # TTA average
        logits = logits / T
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        pred = int(probs.argmax())
    return pred, float(probs[pred]), probs.tolist()

def decision_text(pred_idx: int):
    return CLASS_NAMES[pred_idx]

# --- Model & Temperature Initialization ---
XMODEL_PATH = os.path.join(ART_DIR, 'xray_model.pth')
clf = load_resnet18(XMODEL_PATH, num_classes=2)

CAL_PATH = os.path.join(ART_DIR, 'xray_temp.pt')
if os.path.exists(CAL_PATH):
    T = torch.load(CAL_PATH, map_location='cpu')['T']
else:
    T = 1.0   # fallback if calibration file is missing


In [3]:
def load_resnet18(weights_path: str, num_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_feats = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(in_feats, num_classes)
    )
    sd = torch.load(weights_path, map_location=DEVICE)
    if isinstance(sd, dict) and 'state_dict' in sd:
        model.load_state_dict(sd['state_dict'], strict=False)
    else:
        model.load_state_dict(sd, strict=False)
    model.eval().to(DEVICE)
    return model


# 📝 Function: predict_binary

### This function runs binary classification on a single chest X-ray image. It handles preprocessing, inference, and probability extraction.

In [4]:
def predict_binary(image_path: str)-> Tuple[int, float, List[float]]:
    u8 = read_cxr_u8(image_path)
    t1 = img_to_tensor3(u8)
    t2 = img_to_tensor3(u8[:, ::-1].copy()) # h-flip
    X = torch.stack([t1, t2]).to(DEVICE)
    with torch.no_grad():
        logits = clf(X)
        logits = logits.mean(0, keepdim=True) # TTA average
        logits = logits / T
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        pred = int(probs.argmax())
    return pred, float(probs[pred]), probs.tolist()

In [5]:
def decision_text(pred_idx: int):
    return CLASS_NAMES[pred_idx]

###  Paths

In [6]:
import os, glob, torch
from datetime import datetime
from PIL import Image
 # Reuse helpers from notebook 1 (copy the same functions here if running standalone)
 # read_cxr_u8, img_to_tensor3, load_resnet18, predict_binary, decision_text, 
CLASS_NAMES, ART_DIR, DEVICE
ROOT = r"D:\\MedicalAI-Assistant"
ART_DIR = os.path.join(ROOT, 'artifacts')
CLASS_NAMES = ['NORMAL','PNEUMONIA']

### Load classifier & calibrator 

In [7]:
clf = load_resnet18(os.path.join(ART_DIR, 'xray_model.pth'), num_classes=2)
CAL_PATH = os.path.join(ART_DIR, 'xray_temp.pt')

if os.path.exists(CAL_PATH):
    T = torch.load(CAL_PATH, map_location='cpu')['T']
else:
    T = 1.0

### Normal CXR report template

In [8]:
NORMAL_REPORT = (
 "Chest X-ray (PA):\n"
 "• Cardiomediastinal silhouette: Within normal limits.\n"
 "• Lungs: No focal consolidation identified. No pleural effusion.\n"
 "• Pneumothorax: Not seen.\n"
 "• Bones/soft tissues: Unremarkable.\n\n"
 "Impression: No acute cardiopulmonary abnormality."
 )

### Run on one image 

In [9]:
val_norm = os.path.join(ROOT, 'data', 'Xray', 'val', 'NORMAL')
img_path = sorted(glob.glob(os.path.join(val_norm, '*')))[0]
pred_idx, conf, dist = predict_binary(img_path)

def decision_text(pred_idx: int, conf: float = None):
    label = CLASS_NAMES[pred_idx]
    if conf is not None:
        return f"{label} (confidence: {conf:.2f})"
    return label

NORMAL_REPORT = "No abnormal findings detected. Normal chest X-ray."

print(decision_text(pred_idx, conf))
print("\nAuto-report:")
print(NORMAL_REPORT if pred_idx == 0 else "Abnormal findings suspected. Correlate clinically.")


NORMAL (confidence: 0.54)

Auto-report:
No abnormal findings detected. Normal chest X-ray.
