## View 5: UNet

Also bad for now. Can copy the code here to improve

https://www.kaggle.com/code/balraj98/unet-resnet50-frontend-road-segmentation-pytorch

In [None]:
import os, numpy as np, torch, rasterio
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import albumentations as A
import segmentation_models_pytorch as smp
from scipy.ndimage import distance_transform_edt

# ============================================================
# Dataset: TIFF with 5 bands (RGB + DEM + label)
# ============================================================

class RoadTiffDataset(Dataset):
    def __init__(self, paths, augment=None, preprocess=None):
        self.paths = paths
        self.augment = augment
        self.preprocess = preprocess

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        with rasterio.open(p) as src:
            data = src.read().astype(np.float32)

        rgb = np.moveaxis(data[:3], 0, -1) / 255.0      # (H,W,3)
        dem = data[3][..., None] / 50.0                 # (H,W,1)
        x = np.concatenate([rgb, dem], axis=-1)         # (H,W,4)
        y = data[4][..., None]                          # (H,W,1), 0/1

        if self.augment:
            out = self.augment(image=x, mask=y)
            x, y = out["image"], out["mask"]

        if self.preprocess:
            out = self.preprocess(image=x, mask=y)
            x, y = out["image"], out["mask"]

        return x, y


# ============================================================
# Load TIFFs + split
# ============================================================

TIFF_DIR = Path("data/tiffs")
tifs = sorted(TIFF_DIR.glob("*.tif")) + sorted(TIFF_DIR.glob("*.tiff"))

N = len(tifs)
train_paths = tifs[: int(0.8*N)]
val_paths   = tifs[int(0.8*N): int(0.9*N)]
test_paths  = tifs[int(0.9*N):]

# ============================================================
# Augmentation + Preprocessing
# ============================================================

augment_train = A.Compose([
    A.RandomCrop(512,512),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
])

def to_tensor(x, **kw): 
    return x.transpose(2,0,1).astype("float32")

preprocess = A.Compose([A.Lambda(image=to_tensor, mask=to_tensor)])

train_ds = RoadTiffDataset(train_paths, augment=augment_train, preprocess=preprocess)
val_ds   = RoadTiffDataset(val_paths, preprocess=preprocess)
test_ds  = RoadTiffDataset(test_paths, preprocess=preprocess)

train_ld = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_ld   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)
test_ld  = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=2)

# ============================================================
# Model
# ============================================================

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=4,
    classes=1,
    activation="sigmoid",   # output 0–1 mask
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)

loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# ============================================================
# Training
# ============================================================

def train_one_epoch():
    model.train()
    losses=[]
    for x,y in train_ld:
        x,y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        pred = model(x)
        loss = loss_fn(pred,y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return np.mean(losses)

@torch.no_grad()
def val_one_epoch():
    model.eval()
    losses=[]
    for x,y in val_ld:
        x,y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x)
        loss = loss_fn(pred,y)
        losses.append(loss.item())
    return np.mean(losses)

best_val = 1e9
for epoch in range(5):
    tr = train_one_epoch()
    va = val_one_epoch()
    print(f"Epoch {epoch}: train={tr:.4f}  val={va:.4f}")
    if va < best_val:
        best_val = va
        torch.save(model.state_dict(), "best_unet.pth")
        print("Model saved.")

# ============================================================
# Metrics: IoU + Modified Hausdorff
# ============================================================

def compute_iou(pred, true):
    pb = pred > 0.5
    tb = true > 0.5
    inter = (pb & tb).sum()
    union = (pb | tb).sum()
    return float(inter) / float(union + 1e-9)

def modified_hausdorff(pred, true):
    p = pred.squeeze() > 0.5
    t = true.squeeze() > 0.5
    if not p.any() and not t.any(): return 0.0
    if not p.any() or not t.any(): return 9999.0
    dt_t = distance_transform_edt(~t)
    dt_p = distance_transform_edt(~p)
    d1 = dt_t[p].mean()
    d2 = dt_p[t].mean()
    return max(d1, d2)


KeyboardInterrupt: 

In [4]:

# ============================================================
# Test Evaluation
# ============================================================

model.load_state_dict(torch.load("best_unet.pth"))
model.eval()

ious=[]
mhds=[]

for x,y in tqdm(test_ld, desc="Testing"):
    x = x.to(DEVICE)
    pred = model(x).detach().cpu().numpy()
    y = y.numpy()
    ious.append(compute_iou(pred,y))
    mhds.append(modified_hausdorff(pred,y))

print("\n=== TEST RESULTS ===")
print("Mean IoU:",    np.mean(ious))
print("Median IoU:",  np.median(ious))
print("Mean MHD:",    np.mean(mhds))
print("Median MHD:",  np.median(mhds))

Testing:   0%|          | 0/118 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB. GPU 0 has a total capacity of 5.61 GiB of which 149.62 MiB is free. Including non-PyTorch memory, this process has 4.53 GiB memory in use. Of the allocated memory 4.28 GiB is allocated by PyTorch, and 166.06 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [5]:
# ============================================================
# Sliding window inference (fits 6GB GPU)
# ============================================================

WINDOW = 512
STRIDE = 384    # overlap so seams don't appear

def sliding_window_predict(model, x):
    """
    x: numpy array (C,H,W)
    returns: prediction mask (H,W)
    """
    c, h, w = x.shape
    out = np.zeros((h, w), dtype=np.float32)
    counts = np.zeros((h, w), dtype=np.float32)

    model.eval()
    with torch.no_grad():
        for i in range(0, h - WINDOW + 1, STRIDE):
            for j in range(0, w - WINDOW + 1, STRIDE):
                patch = x[:, i:i+WINDOW, j:j+WINDOW]
                patch = torch.tensor(patch).unsqueeze(0).to(DEVICE)

                pred = model(patch)[0, 0].detach().cpu().numpy()
                out[i:i+WINDOW, j:j+WINDOW] += pred
                counts[i:i+WINDOW, j:j+WINDOW] += 1

    # avoid divide-by-zero
    counts[counts == 0] = 1

    return out / counts


# ============================================================
# Test Evaluation
# ============================================================

model.load_state_dict(torch.load("best_unet.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
torch.cuda.empty_cache()

ious = []
mhds = []

for x, y in tqdm(test_ld, desc="Testing"):
    # x: (1,C,H,W)
    x_np = x[0].numpy()              # (C,H,W)
    pred = sliding_window_predict(model, x_np)  # (H,W)

    # Convert GT mask
    y_np = y[0].numpy()              # (1,H,W) but already 0/1
    y_np = y_np.squeeze()

    # Threshold pred
    pred_bin = (pred > 0.5).astype(np.uint8)

    # Metrics
    ious.append(compute_iou(pred_bin, y_np))
    mhds.append(modified_hausdorff(pred_bin, y_np))


print("\n=== TEST RESULTS ===")
print("Mean IoU:",    np.mean(ious))
print("Median IoU:",  np.median(ious))
print("Mean MHD:",    np.mean(mhds))
print("Median MHD:",  np.median(mhds))

Testing: 100%|██████████| 118/118 [00:36<00:00,  3.27it/s]


=== TEST RESULTS ===
Mean IoU: 0.21518108412626188
Median IoU: 0.22451167941882685
Mean MHD: 301.09696901055776
Median MHD: 70.4869050196447



