# X‑ray Fracture Detection — Colab Template
*Auto-generated on 2025-09-01. Plug in your dataset path in Drive and run top-to-bottom.*

This notebook is structured to support **incremental experiments**:
1) Configuration  2) Environment setup  3) Data loading  4) Training  5) Evaluation  6) LIME explanations  7) Save artifacts.

> Tip: Keep data in Google Drive, code in GitHub. Use a Colab badge in your README.

## 0. Configuration

In [None]:

# === User-configurable settings ===
PROJECT_NAME = "xray-fracture-experiments"
DATA_DIR = "/content/drive/MyDrive/fracatlas/data"  # change this to your dataset path
EXP_DIR = "/content/drive/MyDrive/fracatlas/experiments"  # where to save runs/models
MODEL_NAME = "resnet50"
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 5            # increase when ready
LR = 1e-4
SEED = 42

# Optional: If you host code in GitHub, set these and uncomment clone step below
GITHUB_USER = ""     # e.g., "your-username"
GITHUB_REPO = ""     # e.g., "fracture-detection"
GITHUB_BRANCH = "main"

print("Config loaded.")


## 1. Environment & Drive

In [None]:

# If on Colab, install/upgrade deps
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    !pip -q install torch torchvision lime matplotlib --upgrade

from google.colab import drive as _drive if IN_COLAB else None
if IN_COLAB:
    _drive.mount('/content/drive')

import os, random, numpy as np, torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


### (Optional) Clone your GitHub repo into the runtime

In [None]:

# Uncomment to clone your repo; notebooks/scripts will be available under /content/<repo>
# if GITHUB_USER and GITHUB_REPO:
#     !git clone -b {GITHUB_BRANCH} https://github.com/{GITHUB_USER}/{GITHUB_REPO}.git
#     %cd {GITHUB_REPO}
# else:
#     print("Skipping GitHub clone (no repo configured).")


## 2. Data Loading (ImageFolder layout)

In [None]:

# Expected structure:
# DATA_DIR/
#   train/
#     fracture/
#     non-fracture/
#   val/
#     fracture/
#     non-fracture/

from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'train'), transform=transform)
val_dataset   = datasets.ImageFolder(os.path.join(DATA_DIR, 'val'), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

class_names = train_dataset.classes
print("Classes:", class_names)
print(f"Train: {len(train_dataset)} images | Val: {len(val_dataset)} images")


## 3. Model (ResNet50 baseline)

In [None]:

model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # binary classification
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)


## 4. Training Loop

In [None]:

import math, time

def train_one_epoch(loader):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss/total, total_correct/total

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        total_loss += loss.item() * x.size(0)
        total_correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss/total, total_correct/total

hist = []
for epoch in range(NUM_EPOCHS):
    tr_loss, tr_acc = train_one_epoch(train_loader)
    val_loss, val_acc = evaluate(val_loader)
    hist.append((epoch+1, tr_loss, tr_acc, val_loss, val_acc))
    print(f"Epoch {epoch+1:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | val loss {val_loss:.4f} acc {val_acc:.3f}")


## 5. LIME Explanations (single image demo)

In [None]:

from lime import lime_image
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries

# Helper to run model on a numpy image batch (HxWxC, uint8)
def batch_predict(images):
    model.eval()
    batch = []
    for i in images:
        img = Image.fromarray(i)
        t = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])(img).to(device)
        batch.append(t)
    batch = torch.stack(batch, dim=0)
    logits = model(batch)
    probs = torch.softmax(logits, dim=1).detach().cpu().numpy()
    return probs

# Pick one sample from validation set
img_path, true_label = val_dataset.samples[0]
raw = Image.open(img_path).convert("RGB")
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(
    np.array(raw),
    batch_predict,
    top_labels=2,
    hide_color=0,
    num_samples=1000
)
top = explanation.top_labels[0]
temp, mask = explanation.get_image_and_mask(
    label=top,
    positive_only=True,
    num_features=5,
    hide_rest=False
)
plt.figure()
plt.imshow(mark_boundaries(temp/255.0, mask))
plt.title(f"LIME (pred label={top}, true={true_label})")
plt.axis("off")
plt.show()


## 6. Save Artifacts (model + history)

In [None]:

os.makedirs(EXP_DIR, exist_ok=True)
model_path = os.path.join(EXP_DIR, f"{PROJECT_NAME}_{MODEL_NAME}.pt")
torch.save(model.state_dict(), model_path)

# Save simple history CSV
import csv
hist_path = os.path.join(EXP_DIR, f"{PROJECT_NAME}_history.csv")
with open(hist_path, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["epoch","train_loss","train_acc","val_loss","val_acc"])
    for row in hist:
        w.writerow(row)

print("Saved:", model_path)
print("Saved:", hist_path)


## 7. Next Steps

- Try **EfficientNet** or **Swin Transformer** backbones
- Handle class imbalance with **Weighted CE** or **Focal Loss**
- Add proper metrics (AUC, sensitivity, specificity)
- If you have boxes/masks, train **YOLO/Faster R-CNN** or **UNet** for localization
- Log experiments with **TensorBoard** or **Weights & Biases**