In [None]:
#0) Install & GPU check

In [None]:
!pip -q install albumentations==1.4.7 pycocotools opencv-python tqdm

import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
!nvidia-smi -L || echo "No GPU detected (OK, will use CPU—lower IMG_SIZE/BATCH)."


In [None]:
#1) Config (your paths + knobs)

In [None]:
# Paths
DATA_ROOT = "/mmfs1/home/jacks.local/kkumari/FHB_Project/Wheatproject-v1-coco"
OUT_DIR   = "/mmfs1/home/jacks.local/kkumari/FHB_Project/Outcome/runs_attention_unet"

# Training knobs (lower if on CPU or small GPU)
IMG_SIZE  = 768
BATCH     = 2
VAL_BATCH = 2
EPOCHS    = 200
LR        = 3e-4
BASE_CH   = 64  # 64 = stronger but heavier

import os
assert os.path.exists(DATA_ROOT), f"DATA_ROOT not found: {DATA_ROOT}"
os.makedirs(OUT_DIR, exist_ok=True)
print("DATA_ROOT OK →", DATA_ROOT)

In [None]:
#2) Fix COCO JSON (annotations→images, drop missing)

In [None]:
import json, os, shutil
from pathlib import Path

def fix_split(split_dir):
    split_dir = Path(split_dir)
    ann_dir   = split_dir / "annotations"
    img_dir   = split_dir / "images"
    src_json  = ann_dir / "_annotations.coco.json"
    if not src_json.exists():
        print(f"Skip {split_dir.name}: no _annotations.coco.json")
        return
    dst_json  = ann_dir / "_annotations.coco.fixed.json"
    coco = json.load(open(src_json, "r"))

    kept_images, kept_ids = [], set()
    fixed, dropped = 0, 0

    for im in coco["images"]:
        fname = im.get("file_name","")
        base  = Path(fname).name
        candidates = [
            split_dir/fname,
            img_dir/base,
            split_dir/("images/"+base),
            split_dir/(fname.replace("annotations/","images/"))
        ]
        found = None
        for c in candidates:
            if c.exists():
                found = c; break
        if found is None:
            dropped += 1
            continue
        rel = found.relative_to(split_dir).as_posix()
        if rel != fname: fixed += 1
        im["file_name"] = rel
        kept_images.append(im)
        kept_ids.add(im["id"])

    coco["images"] = kept_images
    coco["annotations"] = [a for a in coco["annotations"] if a["image_id"] in kept_ids]

    bak = ann_dir / "_annotations.coco.orig.json"
    if not bak.exists():
        shutil.copy(src_json, bak)
    json.dump(coco, open(dst_json, "w"), indent=2)
    print(f"[{split_dir.name}] fixed:{fixed}  dropped:{dropped}  kept:{len(kept_images)}  anns:{len(coco['annotations'])}")
    print(f"→ wrote {dst_json}")

for sub in ["train","valid","test"]:
    if os.path.exists(f"{DATA_ROOT}/{sub}/annotations/_annotations.coco.json"):
        fix_split(f"{DATA_ROOT}/{sub}")


In [None]:
#3) Dataset (COCO → semantic masks)

In [None]:
import json, cv2, numpy as np, random, torch
from pathlib import Path
from albumentations import (Compose, HorizontalFlip, VerticalFlip, RandomBrightnessContrast,
                            CLAHE, Normalize, PadIfNeeded, LongestMaxSize)
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO

# Label map (robust to old "Diseased")
NAME2CLS = {"healthy":1,"Healthy":1,"unhealthy":2,"Unhealthy":2,"diseased":2,"Diseased":2}
N_CLASSES = 3  # 0 bg, 1 healthy, 2 unhealthy

def colorize_mask(mask):
    pal = np.array([[0,0,0],[0,255,0],[255,0,0]], dtype=np.uint8)
    return pal[mask]

class CocoSegDataset(Dataset):
    def __init__(self, split_dir, img_size=768, augment=True):
        self.split_dir = Path(split_dir)
        fixed = self.split_dir/"annotations/_annotations.coco.fixed.json"
        orig  = self.split_dir/"annotations/_annotations.coco.json"
        ann_path = fixed if fixed.exists() else orig
        assert ann_path.exists(), f"Missing COCO json for {split_dir}"
        self.coco = COCO(str(ann_path))
        self.img_ids = list(self.coco.imgs.keys())
        self.img_size = img_size

        self.cat_to_cls = {}
        for cat in self.coco.loadCats(self.coco.getCatIds()):
            self.cat_to_cls[cat["id"]] = NAME2CLS.get(cat["name"], None)

        if augment:
            self.tf = Compose([
                LongestMaxSize(max_size=img_size),
                PadIfNeeded(img_size, img_size, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
                HorizontalFlip(p=0.5), VerticalFlip(p=0.2),
                RandomBrightnessContrast(p=0.4), CLAHE(p=0.2),
                Normalize(), ToTensorV2()
            ])
        else:
            self.tf = Compose([
                LongestMaxSize(max_size=img_size),
                PadIfNeeded(img_size, img_size, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
                Normalize(), ToTensorV2()
            ])

    def __len__(self): return len(self.img_ids)

    def _build_mask(self, img_info, anns):
        h, w = img_info["height"], img_info["width"]
        mask = np.zeros((h,w), np.uint8)
        # healthy first, then unhealthy overwrites on overlaps
        for priority in [1,2]:
            for ann in anns:
                cls = self.cat_to_cls.get(ann["category_id"], None)
                if cls is None or cls != priority: continue
                m = self.coco.annToMask(ann).astype(np.uint8)
                mask[m==1] = cls
        return mask

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        info   = self.coco.loadImgs([img_id])[0]
        path   = str(self.split_dir / info["file_name"])

        img = cv2.imread(path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns    = self.coco.loadAnns(ann_ids)
        mask    = self._build_mask(info, anns)

        out = self.tf(image=img, mask=mask)
        return out["image"], out["mask"].long(), path

In [None]:
pip install --upgrade albumentations

In [None]:
#4) Attention U-Net Architecture (no in-place ops)

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=False),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=False),
        )
    def forward(self,x): return self.net(x)

class AttentionBlock(nn.Module):
    def __init__(self, g_ch, x_ch, inter_ch):
        super().__init__()
        self.Wg  = nn.Sequential(nn.Conv2d(g_ch, inter_ch, 1, bias=False), nn.BatchNorm2d(inter_ch))
        self.Wx  = nn.Sequential(nn.Conv2d(x_ch, inter_ch, 1, bias=False), nn.BatchNorm2d(inter_ch))
        self.psi = nn.Sequential(nn.Conv2d(inter_ch, 1, 1), nn.BatchNorm2d(1), nn.Sigmoid())
        self.relu= nn.ReLU(inplace=False)
    def forward(self, g, x):
        g_up = F.interpolate(g, size=x.shape[2:], mode='bilinear', align_corners=False)
        a = self.relu(self.Wg(g_up) + self.Wx(x))
        a = self.psi(a)
        return x * a

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up  = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
        self.att = AttentionBlock(in_ch//2, skip_ch, skip_ch//2)
        self.conv = ConvBlock(in_ch//2 + skip_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        skip = self.att(x, skip)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class AttentionUNet(nn.Module):
    def __init__(self, in_ch=3, n_classes=3, base=32):
        super().__init__()
        self.c1 = ConvBlock(in_ch, base)       ; self.p1 = nn.MaxPool2d(2)
        self.c2 = ConvBlock(base, base*2)      ; self.p2 = nn.MaxPool2d(2)
        self.c3 = ConvBlock(base*2, base*4)    ; self.p3 = nn.MaxPool2d(2)
        self.c4 = ConvBlock(base*4, base*8)    ; self.p4 = nn.MaxPool2d(2)
        self.c5 = ConvBlock(base*8, base*16)

        self.u6 = UpBlock(base*16, base*8,  base*8)
        self.u7 = UpBlock(base*8,  base*4,  base*4)
        self.u8 = UpBlock(base*4,  base*2,  base*2)
        self.u9 = UpBlock(base*2,  base,    base)
        self.outc = nn.Conv2d(base, n_classes, 1)

    def forward(self, x):
        c1 = self.c1(x); p1 = self.p1(c1)
        c2 = self.c2(p1); p2 = self.p2(c2)
        c3 = self.c3(p2); p3 = self.p3(c3)
        c4 = self.c4(p3); p4 = self.p4(c4)
        c5 = self.c5(p4)

        x  = self.u6(c5, c4)
        x  = self.u7(x,  c3)
        x  = self.u8(x,  c2)
        x  = self.u9(x,  c1)
        return self.outc(x)

In [None]:
#5) Losses, metrics, helpers

In [None]:
import torch, torch.nn as nn, numpy as np

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0, ignore_index=0):
        super().__init__()
        self.smooth = smooth
        self.ignore_index = ignore_index
    def forward(self, logits, targets):
        C = logits.shape[1]
        probs = torch.softmax(logits, dim=1)                 # no in-place ops after this
        one_hot = torch.nn.functional.one_hot(targets, C).permute(0,3,1,2).float()
        if self.ignore_index is not None:
            keep = (targets != self.ignore_index).unsqueeze(1).float()
            probs = probs * keep
            one_hot = one_hot * keep
        dims = (0,2,3)
        inter = (probs * one_hot).sum(dims)
        denom = probs.sum(dims) + one_hot.sum(dims)
        dice = (2*inter + self.smooth) / (denom + self.smooth)
        return 1 - dice[1:].mean()  # drop background

def per_class_iou(pred, target, num_classes=3):
    ious=[]
    for cls in range(num_classes):
        tp = ((pred==cls)&(target==cls)).sum().item()
        fp = ((pred==cls)&(target!=cls)).sum().item()
        fn = ((pred!=cls)&(target==cls)).sum().item()
        ious.append(float('nan') if (tp+fp+fn)==0 else tp/(tp+fp+fn))
    return ious

def overlay_rgb(img_t, mask_pred):
    import cv2
    img = img_t.permute(1,2,0).cpu().numpy()
    img = (img*255).clip(0,255).astype(np.uint8)
    cm  = colorize_mask(mask_pred.cpu().numpy())
    return (0.6*img + 0.4*cm).astype(np.uint8)


In [None]:
#6) Dataloaders

In [None]:
from torch.utils.data import DataLoader
import torch

train_dir = f"{DATA_ROOT}/train"
val_dir   = f"{DATA_ROOT}/valid"

train_ds = CocoSegDataset(train_dir, img_size=IMG_SIZE, augment=True)
val_ds   = CocoSegDataset(val_dir,   img_size=IMG_SIZE, augment=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Start with single-process loading to avoid worker/pickle surprises.
NUM_WORKERS = 0
PIN_MEMORY  = bool(torch.cuda.is_available())

def safe_collate(batch):
    """
    A basic collate function that filters out None values.
    """
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    return torch.utils.data.dataloader.default_collate(batch)


train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                          collate_fn=safe_collate, drop_last=False)

val_loader   = DataLoader(val_ds, batch_size=VAL_BATCH, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                          collate_fn=safe_collate, drop_last=False)

# quick sanity check (will raise clearly if something’s off)
sample = next(iter(train_loader))
if sample is None:
    raise RuntimeError("All samples in first batch were invalid. Check dataset.")
x, y, p = sample
print("train batch:", x.shape, y.shape, "device:", device)

In [None]:
#7) Train (corrected)

In [None]:
from tqdm import tqdm
import numpy as np
from pathlib import Path
import cv2, os

# (Optional) help catch anything else
import torch as _torch
_torch.autograd.set_detect_anomaly(False)

model  = AttentionUNet(in_ch=3, n_classes=N_CLASSES, base=BASE_CH).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
ce_loss   = torch.nn.CrossEntropyLoss()
dice      = DiceLoss(ignore_index=0)

best = 0.0
save_path = Path(OUT_DIR) / "best_attention_unet.pth"

for epoch in range(1, EPOCHS+1):
    model.train()
    tl=[]
    for imgs, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", ncols=80):
        imgs, masks = imgs.to(device), masks.to(device)
        logits = model(imgs)
        loss = ce_loss(logits, masks) + dice(logits, masks)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        tl.append(loss.item())

    # validate
    model.eval(); vloss=[]; vdice=[]; viou=[]
    with torch.no_grad():
        for imgs, masks, _ in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            vloss.append( (ce_loss(logits, masks) + dice(logits, masks)).item() )
            vdice.append( 1 - dice(logits, masks).item() )
            viou.append( per_class_iou(logits.argmax(1).cpu(), masks.cpu(), N_CLASSES) )

    mean_vloss = float(np.mean(vloss))
    mean_vdice = float(np.mean(vdice))
    mean_iou   = np.nanmean(np.array(viou), axis=0)

    print(f"Epoch {epoch:03d} | train {np.mean(tl):.4f} | val {mean_vloss:.4f} | "
          f"Dice {mean_vdice:.4f} | IoU(bg,healthy,unhealthy) {np.round(mean_iou,4)}")

    if mean_vdice > best:
        best = mean_vdice
        torch.save({"model": model.state_dict(), "epoch": epoch,
                    "img_size": IMG_SIZE, "base": BASE_CH}, save_path)
        print("  ✅ Saved best:", save_path)


In [None]:
#8) Validation previews

In [None]:
from pathlib import Path
import cv2
preview_dir = Path(OUT_DIR)/"previews"; preview_dir.mkdir(parents=True, exist_ok=True)

state = torch.load(save_path, map_location=device)
model.load_state_dict(state["model"]); model.eval()

with torch.no_grad():
    for b,(imgs,masks,paths) in enumerate(val_loader):
        pred = model(imgs.to(device)).softmax(1).argmax(1).cpu()
        for i in range(imgs.size(0)):
            ov = overlay_rgb(imgs[i].cpu(), pred[i])
            outp = preview_dir/f"val_{b:03d}_{i:02d}_{Path(paths[i]).stem}.png"
            cv2.imwrite(str(outp), cv2.cvtColor(ov, cv2.COLOR_RGB2BGR))
        if b>1: break

print("Previews saved to:", preview_dir)


In [None]:
#9) Single-image inference

In [None]:
import os, cv2, torch
from albumentations import Compose, LongestMaxSize, PadIfNeeded, Normalize
from albumentations.pytorch import ToTensorV2

RAW_IMG = f"{DATA_ROOT}/valid/images/" + os.listdir(f"{DATA_ROOT}/valid/images")[0]

tf = Compose([
    LongestMaxSize(IMG_SIZE),
    PadIfNeeded(IMG_SIZE, IMG_SIZE, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    Normalize(), ToTensorV2()
])

rgb = cv2.cvtColor(cv2.imread(RAW_IMG), cv2.COLOR_BGR2RGB)
x = tf(image=rgb)["image"].unsqueeze(0).to(device)

state = torch.load(save_path, map_location=device)
model.load_state_dict(state["model"]); model.eval()
with torch.no_grad():
    pred = model(x).softmax(1).argmax(1)[0].cpu()

ov = overlay_rgb(x[0].cpu(), pred)
out_path = f"{OUT_DIR}/inference_overlay.png"
cv2.imwrite(out_path, cv2.cvtColor(ov, cv2.COLOR_RGB2BGR))
print("Saved:", out_path)
