In [None]:
!pip install torch torchvision pydicom opencv-python matplotlib
!pip install transformers==4.4.0


In [None]:
0 - Aortic enlargement x
1 - Atelectasis -->  "Atelectasis"
2 - Calcification x
3 - Cardiomegaly --> "Cardiomegaly"
4 - Consolidation --> "Consolidation"
5 - ILD x
6 - Infiltration --> "Infiltration"
7 - Lung Opacity >
8 - Nodule/Mass >
9 - Other lesion -
10 - Pleural effusion 
11 - Pleural thickening >
12 - Pneumothorax >
13 - Pulmonary fibrosis >

[
    "Atelectasis" >, "Consolidation" >, "Infiltration">, "Pneumothorax">,
    "Edema", "Emphysema"x, "Fibrosis">, "Effusion">, "Pneumonia"x,
    "Pleural_Thickening">, "Cardiomegaly">, "Nodule", "Mass", "Hernia",
    "Lung Lesion"-, "Fracture", "Lung Opacity">, "Enlarged Cardiomediastinum"
]

In [None]:
import os
import math
import time
import torch
import random
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import pydicom
import cv2  # to help convert DICOM to uint8

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms as T
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

from transformers import AutoConfig, AutoModel
from tqdm import tqdm

warnings.filterwarnings("ignore", category=UserWarning)

# ────────────────────────────────────────────────────────────────────────
# 1) GLOBAL CONFIGURATION (adjust these as needed)
# ────────────────────────────────────────────────────────────────────────

# ─── Paths to VinDr data ───────────────────────────────────────────────
VINDR_ROOT       = "/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/"            # <— change this
DICOM_TRAIN_DIR  = os.path.join(VINDR_ROOT, "train")
ANNOTATION_CSV   = os.path.join(VINDR_ROOT, "train.csv")
# (optionally, you can have a separate val split; here we’ll just split train→train/val ourselves)
CHECKPOINT_DIR   = "./checkpoints"                  # where to save model ckpts
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ─── Training hyperparameters ──────────────────────────────────────────
NUM_EPOCHS        = 12
BATCH_SIZE_TRAIN  = 4      # reduce if you run out of GPU memory
BATCH_SIZE_VAL    = 8
NUM_WORKERS       = 4

LEARNING_RATE_BACKBONE = 5e-6    # when fine‐tuning entire backbone
LEARNING_RATE_HEAD     = 1e-4    # classifier head & detection head
WEIGHT_DECAY           = 1e-2

# ─── SigLIP ViT settings ───────────────────────────────────────────────
SIGLIP_MODEL_NAME = "StanfordAIMI/XraySigLIP__vit-l-16-siglip-384__webli"
DEVICE            = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE        = 512
PATCH_SIZE        = 16   # SigLIP ViT‐L uses 16×16 patches
NUM_CLASSES       = 18   # VinDr has 18 disease categories (no background)
CLASS_NAMES       = [
    "Atelectasis", "Consolidation", "Infiltration", "Pneumothorax",
    "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
    "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia",
    "Lung Lesion", "Fracture", "Lung Opacity", "Enlarged Cardiomediastinum"
]
# Build a mapping from disease name → integer label:
CLASS2IDX = {cls_name: idx+1 for idx, cls_name in enumerate(CLASS_NAMES)}
# (background class = 0)

# ────────────────────────────────────────────────────────────────────────
# 2) UTILITY FUNCTIONS: DICOM → PIL, parse VinDr CSV, etc.
# ────────────────────────────────────────────────────────────────────────

def dicom_to_pil(dicom_path: str) -> Image.Image:
    """
    Read a DICOM file, convert to 8‐bit grayscale, then to RGB PIL.
    This uses a simple windowing approach: scale pixel intensities to [0,255].
    """
    ds = pydicom.dcmread(dicom_path)
    arr = ds.pixel_array.astype(np.float32)
    # normalize to [0, 255]
    arr -= arr.min()
    arr /= arr.max() + 1e-6
    arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
    # convert to 3‐channel by cv2.cvtColor
    arr_rgb = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
    pil = Image.fromarray(arr_rgb)
    return pil

def load_vindr_annotations(csv_path: str):
    """
    Parse the VinDr CSV (image_id, x_min, y_min, width, height, conf, label).
    Returns a dict:
      { "00000001": [ { "bbox": [x1, y1, x2, y2], "label": <int> },  … ],  … }
    Coordinates in pixel‐space (original DICOM size).
    """
    df = pd.read_csv(csv_path)
    # We assume columns: image_id, x_min, y_min, width, height, Confidence, class_id
    # If column names differ, adjust below accordingly.
    records = {}
    for _, row in df.iterrows():
        image_id = str(row["image_id"])
        x, y, w, h = float(row["x_min"]), float(row["y_min"]), float(row["width"]), float(row["height"])
        x1, y1 = x, y
        x2, y2 = x + w, y + h
        label_name = str(row["class_id"])
        if label_name not in CLASS2IDX:
            # skip unknown labels
            continue
        lbl = CLASS2IDX[label_name]
        entry = {"bbox": [x1, y1, x2, y2], "label": lbl}
        records.setdefault(image_id, []).append(entry)
    return records

# ────────────────────────────────────────────────────────────────────────
# 3) DATASET CLASSES
# ────────────────────────────────────────────────────────────────────────

class VinDrCXRDetectionDataset(Dataset):
    """
    A PyTorch Dataset for VinDr CXR detection.
    Each item returns:
        image_tensor (3×512×512), 
        target { "boxes": FloatTensor[K×4], "labels": Int64Tensor[K], "image_id": …, "orig_size": (H_orig, W_orig) }.
    """
    def __init__(self, image_dir: str, annotations: dict, transform=None):
        """
        image_dir: path to folder containing DICOM files, named like "00000001.dcm"
        annotations: dict from image_id (no .dcm) → list of { "bbox": [x1,y1,x2,y2], "label": idx }
        transform: a torchvision Transform that takes a PIL image → Tensor sized 3×512×512
        """
        super().__init__()
        self.image_dir    = image_dir
        self.annotations  = annotations
        self.transform    = transform
        # Build a list of all image IDs that have at least one annotation
        # (You could also include images with no boxes, if desired.)
        self.image_ids = sorted(list(self.annotations.keys()))

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        dicom_path = os.path.join(self.image_dir, f"{image_id}.dcm")
        pil_img    = dicom_to_pil(dicom_path)
        orig_w, orig_h = pil_img.size

        # Apply transform (resize to 512×512 + normalize → Tensor)
        img_tensor = self.transform(pil_img)  # [3×512×512], float32

        # Prepare target:
        anns = self.annotations[image_id]
        boxes = []
        labels = []
        for ann in anns:
            x1, y1, x2, y2 = ann["bbox"]
            # Scale bounding box from original DICOM size → 512×512
            # (since transform does a Resize((512,512))). 
            # We assume a uniform resize, so scale factors:
            sx = 512.0 / float(orig_w)
            sy = 512.0 / float(orig_h)
            boxes.append([x1 * sx, y1 * sy, x2 * sx, y2 * sy])
            labels.append(ann["label"])
        boxes = torch.tensor(boxes, dtype=torch.float32)      # [K×4]
        labels = torch.tensor(labels, dtype=torch.int64)      # [K]

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "orig_size": torch.tensor([orig_h, orig_w])
        }
        return img_tensor, target


def collate_fn(batch):
    """
    Custom collate_fn to pass to DataLoader for detection:
    Each batch is a list of tuples (image_tensor, target_dict).
    We must return:
      images_list = [image_tensor_i, …], 
      targets_list = [target_dict_i, …]
    """
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return images, targets


# ────────────────────────────────────────────────────────────────────────
# 4) BUILD SIGLIP BACKBONE FOR Faster R-CNN
# ────────────────────────────────────────────────────────────────────────

class SigLIPBackbone(nn.Module):
    """
    Wrap the pretrained SigLIP ViT so that its patch tokens become a feature map.
    We remove the [CLS] token, reshape the patch embeddings into (B, hidden, H, W),
    and return a dict {"0": feature_map}. The Faster R-CNN head will pick up from "0".
    """
    def __init__(self, vision_model):
        super().__init__()
        self.vision = vision_model
        # SigLIP ViT patch embedding/encoder expects 'pixel_values' = [B,3,512,512]
        self.hidden_dim = vision_model.config.hidden_size  # e.g. 1024 or 2560
        # For image_size=512, patch_size=16 ⇒ feature map is (512/16)×(512/16) = 32×32
        self.feature_size = IMAGE_SIZE // PATCH_SIZE       # =32

    def forward(self, x):
        """
        x: Tensor[B, 3, 512, 512], float32, normalized as SigLIP expects.
        returns: dict of { "0": Tensor[B, hidden_dim, feature_size, feature_size ] }
        """
        out = self.vision(pixel_values=x, return_dict=True)
        last_hidden = out.last_hidden_state  # [B, 1 + num_patches, hidden_dim]
        # Discard [CLS] token at index 0, keep only patch tokens:
        patch_tokens = last_hidden[:, 1:, :]  # [B, num_patches, hidden_dim]
        B, N, D = patch_tokens.shape
        H = W = self.feature_size  # assume N = H*W
        # reshape into [B, hidden_dim, H, W]
        feats = patch_tokens.permute(0, 2, 1).reshape(B, D, H, W)
        return {"0": feats}


def make_fasterrcnn_model(num_classes):
    """
    Create a Faster R-CNN model whose backbone is SigLIPBackbone.
      - rpn_anchor_generator: choose sizes/aspect_ratios that roughly match typical nodule sizes.
      - box_roi_pool: ROI align over the feature map "0" to 7×7.
      - box_head, box_predictor: leave default.
    """
    # ① Load pretrained SigLIP ViT
    vision_config = AutoConfig.from_pretrained(SIGLIP_MODEL_NAME, trust_remote_code=True)
    vision_full   = AutoModel.from_pretrained(SIGLIP_MODEL_NAME, config=vision_config, trust_remote_code=True)
    vision_model  = vision_full.vision_model.to(DEVICE)
    del vision_full
    torch.cuda.empty_cache()

    # ② Wrap into backbone
    backbone = SigLIPBackbone(vision_model)
    backbone.out_channels = backbone.hidden_dim  # say 1024 or 2560

    # ③ Define an RPN anchor generator with custom sizes/aspect ratios
    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))  
    # 32→small nodules; 512→large opacities
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    rpn_anchor_generator = AnchorGenerator(
        sizes=anchor_sizes,
        aspect_ratios=aspect_ratios
    )

    # ④ ROI Pooling: use only the single feature map "0"
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=["0"], 
        output_size=7, 
        sampling_ratio=2
    )

    # ⑤ Build Faster R‐CNN
    model = FasterRCNN(
        backbone=backbone,
        num_classes=num_classes + 1,  # +1 for background
        rpn_anchor_generator=rpn_anchor_generator,
        box_roi_pool=roi_pooler
    ).to(DEVICE)

    return model


# ────────────────────────────────────────────────────────────────────────
# 5) TRAIN / VALIDATION LOOPS
# ────────────────────────────────────────────────────────────────────────

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=50):
    """
    Standard single‐epoch training loop for Faster RCNN.
    """
    model.train()
    total_loss = 0.0
    it = 0
    pbar = tqdm(data_loader, desc=f"[Epoch {epoch}][Train]", leave=False)
    for images, targets in pbar:
        images = [img.to(device) for img in images]
        # Ensure each target["boxes"], target["labels"] on GPU:
        for t in targets:
            t["boxes"]  = t["boxes"].to(device)
            t["labels"] = t["labels"].to(device)
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        batch_loss = losses.item()
        total_loss += batch_loss
        it += 1
        if it % print_freq == 0:
            pbar.set_postfix(loss=batch_loss)
    mean_loss = total_loss / it
    return mean_loss


@torch.no_grad()
def validate_one_epoch(model, data_loader, device, epoch):
    """
    Simple validation loop: we report the same total loss as training.
    (A “proper” detection metric like mAP would require additional code.)
    """
    model.eval()
    total_loss = 0.0
    it = 0
    pbar = tqdm(data_loader, desc=f"[Epoch {epoch}][Val]", leave=False)
    for images, targets in pbar:
        images = [img.to(device) for img in images]
        for t in targets:
            t["boxes"]  = t["boxes"].to(device)
            t["labels"] = t["labels"].to(device)
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        total_loss += losses.item()
        it += 1
    mean_loss = total_loss / max(it, 1)
    return mean_loss


def save_checkpoint(state, filename):
    torch.save(state, filename)


# ────────────────────────────────────────────────────────────────────────
# 6) MAIN TRAINING SCRIPT
# ────────────────────────────────────────────────────────────────────────

def main():
    # 6.1) Load VinDr annotations and split into train/val 
    all_annotations = load_vindr_annotations(ANNOTATION_CSV)

    # Do an 80/20 split of image IDs for train vs. val:
    all_ids = sorted(list(all_annotations.keys()))
    random.shuffle(all_ids)
    split_idx = int(0.8 * len(all_ids))
    train_ids, val_ids = all_ids[:split_idx], all_ids[split_idx:]

    train_anns = {img_id: all_annotations[img_id] for img_id in train_ids}
    val_anns   = {img_id: all_annotations[img_id] for img_id in val_ids}

    # 6.2) Build Dataset & DataLoader 
    transform = T.Compose([
        T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std =[0.26862954, 0.26130258, 0.27577711]
        ),
    ])

    train_dataset = VinDrCXRDetectionDataset(DICOM_TRAIN_DIR, train_anns, transform)
    val_dataset   = VinDrCXRDetectionDataset(DICOM_TRAIN_DIR, val_anns,   transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE_TRAIN,
        shuffle=True,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE_VAL,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn,
        pin_memory=True
    )

    print(f"  → Training on {len(train_dataset)} images, validating on {len(val_dataset)} images.")
    print(f"  → {len(train_loader)} train batches, {len(val_loader)} val batches.")

    # 6.3) Build model
    model = make_fasterrcnn_model(NUM_CLASSES)
    model.to(DEVICE)

    # 6.4) Create optimizer & LR scheduler
    params = [
        {"params": [p for p in model.backbone.parameters() if p.requires_grad], 
         "lr": LEARNING_RATE_BACKBONE, "weight_decay": WEIGHT_DECAY},
        {"params": [p for p in model.rpn.parameters() if p.requires_grad]   , 
         "lr": LEARNING_RATE_HEAD,     "weight_decay": WEIGHT_DECAY},
        {"params": [p for p in model.roi_heads.parameters() if p.requires_grad], 
         "lr": LEARNING_RATE_HEAD,     "weight_decay": WEIGHT_DECAY},
    ]
    optimizer = torch.optim.AdamW(params)
    total_steps = len(train_loader) * NUM_EPOCHS
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)

    best_val_loss = float("inf")
    train_losses, val_losses = [], []

    for epoch in range(1, NUM_EPOCHS + 1):
        start_time = time.time()
        train_loss = train_one_epoch(model, optimizer, train_loader, DEVICE, epoch, print_freq=50)
        val_loss   = validate_one_epoch(model, val_loader, DEVICE, epoch)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"[Epoch {epoch}/{NUM_EPOCHS}] " 
              f"Train Loss={train_loss:.4f}  Val Loss={val_loss:.4f}  "
              f"Time={time.time() - start_time:.1f}s")

        # Step the scheduler once per epoch (we stepped inside train loop already, but 
        # if you prefer stepping per iteration, keep what's above; skipping now)
        # scheduler.step()

        # Save checkpoint every epoch
        ckpt = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "train_losses": train_losses,
            "val_losses": val_losses,
        }
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"fasterrcnn_epoch{epoch}.pt")
        save_checkpoint(ckpt, ckpt_path)
        print(f"  → Saved checkpoint: {ckpt_path}")

        # Save “best” if val_loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_path = os.path.join(CHECKPOINT_DIR, "fasterrcnn_best.pt")
            save_checkpoint(ckpt, best_path)
            print(f"  → Saved new BEST checkpoint: {best_path}")

    # 6.5) Plot Training & Validation Loss Curves
    epochs = list(range(1, NUM_EPOCHS + 1))
    plt.figure(figsize=(8,4))
    plt.plot(epochs, train_losses, label="Train Loss", marker="o")
    plt.plot(epochs, val_losses,   label="Val   Loss", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Faster R-CNN Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(CHECKPOINT_DIR, "loss_curve.png"))
    plt.close()

    print("Training complete. Loss curve saved.")


# ────────────────────────────────────────────────────────────────────────
# 7) INFERENCE ON NEW IMAGES (after training)
# ────────────────────────────────────────────────────────────────────────

def run_inference_on_folder(model_ckpt_path: str, dicom_folder: str, output_folder: str, score_thresh: float = 0.5):
    """
    Load the trained Faster R-CNN from `model_ckpt_path`, run inference on all DICOMs 
    in `dicom_folder`, and save images with drawn bounding boxes to `output_folder`.
    Only boxes with score ≥ score_thresh are drawn.
    """
    os.makedirs(output_folder, exist_ok=True)

    # 7.1) Reload backbone + head
    model = make_fasterrcnn_model(NUM_CLASSES)
    state = torch.load(model_ckpt_path, map_location="cpu")
    model.load_state_dict(state["model_state"])
    model.to(DEVICE).eval()

    # 7.2) Loop over each DICOM in folder
    for fname in os.listdir(dicom_folder):
        if not fname.lower().endswith(".dcm"):
            continue
        image_id = os.path.splitext(fname)[0]
        dicom_path = os.path.join(dicom_folder, fname)
        pil = dicom_to_pil(dicom_path)
        orig_w, orig_h = pil.size

        # Preprocess
        img_tensor = T.Compose([
            T.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std =[0.26862954, 0.26130258, 0.27577711]
            )
        ])(pil).unsqueeze(0).to(DEVICE)  # [1,3,512,512]

        with torch.no_grad():
            outputs = model(img_tensor)[0]
            # outputs: dict with "boxes" [M×4], "labels"[M], "scores"[M]

        boxes  = outputs["boxes"].cpu().numpy()
        labels = outputs["labels"].cpu().numpy()
        scores = outputs["scores"].cpu().numpy()

        # Filter by score threshold
        keep = scores >= score_thresh
        boxes  = boxes[keep]
        labels = labels[keep]
        scores = scores[keep]

        # Draw on original PIL
        draw_img = pil.copy()
        draw = ImageDraw.Draw(draw_img)
        # Boxes are in 512×512 coordinate space; scale back to orig:
        for (x1, y1, x2, y2), lbl, scr in zip(boxes, labels, scores):
            cls_name = CLASS_NAMES[lbl-1]  # because model outputs [1..NUM_CLASSES]
            # scale to orig
            sx = orig_w / IMAGE_SIZE
            sy = orig_h / IMAGE_SIZE
            rx1 = x1 * sx
            ry1 = y1 * sy
            rx2 = x2 * sx
            ry2 = y2 * sy
            draw.rectangle([rx1, ry1, rx2, ry2], outline="red", width=4)
            draw.text((rx1, ry1 - 10), f"{cls_name}:{scr:.2f}", fill="white")

        out_path = os.path.join(output_folder, f"{image_id}_pred.png")
        draw_img.save(out_path)
        print(f"  → Saved: {out_path}")

    print("Inference complete.")


# ────────────────────────────────────────────────────────────────────────
# 8) ENTRY POINT
# ────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    main()
    # Optionally, after training you can call inference:
    # run_inference_on_folder(
    #     model_ckpt_path=os.path.join(CHECKPOINT_DIR, "fasterrcnn_best.pt"),
    #     dicom_folder="/path/to/your/test_dicom_folder",
    #     output_folder="./inference_results",
    #     score_thresh=0.5
    # )
