# XRAYs Classificaton

### This block imports all the required libraries for image processing, deep learning, and evaluation.
### It also defines file paths, ensures the artifacts directory exists, sets the compute device (CPU or GPU), and specifies the X-ray classification labels.

In [9]:
# Imports & Paths 
import os, glob, json, math
import numpy as np
from typing import Dict, Tuple, List

import torch, torchvision
from torch import nn
from torchvision import models, transforms

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from skimage import io, exposure
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

# Paths
ROOT = r"D:\\MedicalAI-Assistant"
DATA_DIR = os.path.join(ROOT, 'data')
ART_DIR = os.path.join(ROOT, 'artifacts')
os.makedirs(ART_DIR, exist_ok=True)

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Classes
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]

# 📝 Robust Chest X-ray Reader & Preprocessing

### This block defines utility functions for reading and preprocessing chest X-rays from different formats (JPG, PNG, and DICOM).

In [10]:
# Robust CXR reader
import numpy as np

def read_cxr_u8(path: str)-> np.ndarray:
 """Return uint8 grayscale image. Trims borders slightly to remove rulers/
 markers."""
 ext = os.path.splitext(path)[-1].lower()
 if ext in {'.dcm', '.dicom'}:
  d = pydicom.dcmread(path)
  pix = d.pixel_array.astype(np.float32)
  try:
   pix = apply_voi_lut(pix, d).astype(np.float32)
  except Exception:
   pass
  slope = float(getattr(d, 'RescaleSlope', 1.0))
  inter = float(getattr(d, 'RescaleIntercept', 0.0))
  pix = pix * slope + inter
  if getattr(d, 'PhotometricInterpretation', '').upper() == 'MONOCHROME1':
   pix = np.max(pix)- pix
  pix-= pix.min()
  pix = (pix / (pix.max() + 1e-8) * 255.0).astype(np.uint8)
 else:
  pix = io.imread(path)
  if pix.ndim == 3:
   pix = pix.mean(2)
  if pix.dtype != np.uint8:
   pix = exposure.rescale_intensity(pix, out_range=(0,
   255)).astype(np.uint8)
 h, w = pix.shape
 c = int(0.02 * min(h, w))
 if h > 2*c and w > 2*c:
  pix = pix[c:h-c, c:w-c]
 return pix

# Convert to 3ch tensor for ResNet (replicate grayscale)
_to_resnet = transforms.Compose([
 transforms.ToPILImage(),
 transforms.Resize((224, 224)),
 transforms.ToTensor(),
 transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def img_to_tensor3(img_u8: np.ndarray)-> torch.Tensor:
 if img_u8.ndim == 2:
  img_u8 = np.stack([img_u8]*3, axis=-1)
 return _to_resnet(img_u8)


# 📝 Load Chest X-ray Classifier (ResNet18)

### This block defines and loads a ResNet18 classifier for detecting chest conditions (e.g., Normal vs Pneumonia).

In [11]:
# Load classifier (ResNet18) & flexible state dict loader
def load_resnet18(weights_path: str, num_classes=2):
 model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
 in_feats = model.fc.in_features
 model.fc = nn.Sequential(
 nn.Dropout(0.2),
 nn.Linear(in_feats, num_classes)
 )
 sd = torch.load(weights_path, map_location=DEVICE)
 if isinstance(sd, dict) and 'state_dict' in sd:
  model.load_state_dict(sd['state_dict'], strict=False)
 else:
  model.load_state_dict(sd, strict=False)
 model.eval().to(DEVICE)
 return model

XMODEL_PATH = os.path.join(ART_DIR, 'xray_model.pth') # ensure this exists
clf = load_resnet18(XMODEL_PATH, num_classes=2)


# 📝 Temperature Scaling & Prediction with TTA

### This block improves prediction reliability with temperature scaling and makes predictions more robust using test-time augmentation.

In [12]:
# Temperature scaling calibrator --
class _TempScaler(nn.Module):
 def __init__(self, T=1.0):
  super().__init__()
  self.logT = nn.Parameter(torch.tensor(math.log(T), dtype=torch.float32))

 def forward(self, logits):
  T = torch.exp(self.logT)
  return logits / T


CAL_PATH = os.path.join(ART_DIR, 'xray_temp.pt')

def fit_temperature(model, val_dir: str, max_iter=500, lr=0.01)-> float:
 """Fits temperature on validation set to minimize NLL."""
 model.eval()
 scaler = _TempScaler(1.0).to(DEVICE)
 opt = torch.optim.LBFGS(list(scaler.parameters()), lr=lr, max_iter=max_iter)

 imgs, labels = [], []
 for cls_idx, cls in enumerate(CLASS_NAMES):
  for p in glob.glob(os.path.join(val_dir, cls, '*')):
   try:
    u8 = read_cxr_u8(p)
    t = img_to_tensor3(u8)
    imgs.append(t)
    labels.append(cls_idx)
   except Exception:
    pass

 if not imgs:
  raise RuntimeError('Validation set empty or unreadable')

 X = torch.stack(imgs).to(DEVICE)
 y = torch.tensor(labels, dtype=torch.long, device=DEVICE)

 with torch.no_grad():
  logits = model(X)

 ce = nn.CrossEntropyLoss()

 def closure():
  opt.zero_grad()
  loss = ce(scaler(logits), y)
  loss.backward()
  return loss

 opt.step(closure)
 T = float(torch.exp(scaler.logT).item())
 torch.save({'T': T}, CAL_PATH)
 return T


# Load or fit temperature
if os.path.exists(CAL_PATH):
 T = torch.load(CAL_PATH, map_location='cpu')['T']
else:
 VAL_DIR = os.path.join(DATA_DIR, 'Xray', 'val')
 T = fit_temperature(clf, VAL_DIR)
 print(f"Temperature T = {T:.3f}")


# Helper to apply T and predict with TTA
from PIL import Image

def predict_binary(image_path: str)-> Tuple[int, float, List[float]]:
 u8 = read_cxr_u8(image_path)
 t1 = img_to_tensor3(u8)
 t2 = img_to_tensor3(u8[:, ::-1].copy()) 
 X = torch.stack([t1, t2]).to(DEVICE)
 with torch.no_grad():
  logits = clf(X)
  logits = logits.mean(0, keepdim=True)
  logits = logits / T
  probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
  pred = int(probs.argmax())
 return pred, float(probs[pred]), probs.tolist()


In [13]:
# Evaluate on a folder (val)
VAL_DIR = os.path.join(DATA_DIR, 'Xray', 'val')
all_y, all_p, all_paths = [], [], []

for cls_idx, cls in enumerate(CLASS_NAMES):
 folder = os.path.join(VAL_DIR, cls)
 for p in glob.glob(os.path.join(folder, '*')):
  try:
   pred, conf, _ = predict_binary(p)
   all_y.append(cls_idx)
   all_p.append(pred)
   all_paths.append(p)
  except Exception:
   pass

print(classification_report(all_y, all_p, target_names=CLASS_NAMES))
print(confusion_matrix(all_y, all_p))


              precision    recall  f1-score   support

      NORMAL       0.00      0.00      0.00         4
   PNEUMONIA       0.43      1.00      0.60         3

    accuracy                           0.43         7
   macro avg       0.21      0.50      0.30         7
weighted avg       0.18      0.43      0.26         7

[[0 4]
 [0 3]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


# 📝 Decision Helper & Quick Test

In [15]:
# Decision helper --
UNCERTAIN_THRESH = 0.55 # after temperature scaling, 0.55 is a sane default

def decision_text(pred_idx: int, conf: float)-> str:
 label = CLASS_NAMES[pred_idx]
 if conf < UNCERTAIN_THRESH:
  return f"Normal — No {label} (confidence {conf:.2f}). Doesn't need clinical correlation."
 else:
  return f"Prediction: {label} (confidence {conf:.2f})"


# Quick test on one image
sample = os.path.join(DATA_DIR, 'Xray', 'val', 'NORMAL')
try:
 one = glob.glob(os.path.join(sample, '*'))[0]
 pi, cf, _ = predict_binary(one)
 print(one, '\n', decision_text(pi, cf))
except Exception as e:
 print('Sample test skipped:', e)


D:\\MedicalAI-Assistant\data\Xray\val\NORMAL\person1947_bacteria_4876.jpeg 
 Prediction: PNEUMONIA (confidence 0.55)
