# CloudSeg: Training & Visualization

This notebook will train your `SimpleSegModel` on the SWINySEG data and visualize:

- A sample image & mask
- Training vs. validation loss curves
- IoU metric per epoch
- A metrics table and its correlation matrix
- Side-by-side sample predictions

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from src.data_loader import get_simple_loaders
from src.model import SimpleSegModel

import matplotlib.pyplot as plt
import pandas as pd


In [None]:
# Hyperparameters
BATCH_SIZE = 4
IMG_SIZE   = (256, 256)
LR         = 1e-3
MAX_EPOCHS = 20

# Create loaders (train & val)
train_loader = get_simple_loaders(
    images_dir="data/raw/images",
    masks_dir="data/raw/masks",
    batch_size=BATCH_SIZE,
    img_size=IMG_SIZE,
    normalize=True,
    shuffle=True
)
val_loader = get_simple_loaders(
    images_dir="data/raw/images",
    masks_dir="data/raw/masks",
    batch_size=BATCH_SIZE,
    img_size=IMG_SIZE,
    normalize=True,
    shuffle=False
)


In [None]:
imgs, masks = next(iter(train_loader))

fig, axes = plt.subplots(1,2, figsize=(8,4))
axes[0].imshow(imgs[0].permute(1,2,0))
axes[0].set_title("Input Image")
axes[0].axis("off")
axes[1].imshow(masks[0,0], cmap="gray")
axes[1].set_title("Mask")
axes[1].axis("off")
plt.show()


In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Running on:", device)

model     = SimpleSegModel().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

# Early‐stop settings
MAX_EPOCHS        = 50
PATIENCE          = 5
best_val_loss     = float("inf")
epochs_no_improve = 0

train_losses = []
val_losses   = []
val_ious     = []

for epoch in range(1, MAX_EPOCHS+1):
    # ── Training ──
    model.train()
    running_train_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss    = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
    train_loss = running_train_loss / len(train_loader)
    
    # ── Validation ──
    model.eval()
    running_val_loss = 0.0
    intersection, union = 0.0, 0.0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            running_val_loss += criterion(outputs, masks).item()
            preds = (torch.sigmoid(outputs) > 0.5).float()
            intersection += (preds * masks).sum().item()
            union        += (preds + masks).sum().item() - (preds * masks).sum().item()
    val_loss = running_val_loss / len(val_loader)
    val_iou  = (intersection + 1e-6) / (union + 1e-6)
    
    # ── Early stopping check ──
    if val_loss < best_val_loss:
        best_val_loss     = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print(f"No improvement in {PATIENCE} epochs—stopping at epoch {epoch}.")
            break

    # ── Record & print ──
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_ious.append(val_iou)
    print(f"Epoch {epoch}/{MAX_EPOCHS} — "
          f"train_loss: {train_loss:.4f}, "
          f"val_loss: {val_loss:.4f}, "
          f"val_iou: {val_iou:.4f}")

In [None]:
plt.figure()
plt.plot(range(1, MAX_EPOCHS+1), train_losses, label="Train")
plt.plot(range(1, MAX_EPOCHS+1), val_losses,   label="Validation")
plt.title("Loss Curves")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()


In [None]:
metrics_df = pd.DataFrame({
    "epoch":       list(range(1, MAX_EPOCHS+1)),
    "train_loss":  train_losses,
    "val_loss":    val_losses,
    "val_iou":     val_ious
}).set_index("epoch")

metrics_df


In [None]:
corr = metrics_df.corr()
plt.figure()
plt.imshow(corr, interpolation="nearest")
plt.title("Correlation Matrix")
plt.xticks(range(len(corr)), corr.columns)
plt.yticks(range(len(corr)), corr.index)
plt.colorbar()
plt.show()


In [None]:
# Grab a few from the val set
imgs, masks = next(iter(val_loader))
outputs     = model(imgs.to(device)).cpu()
preds       = (torch.sigmoid(outputs) > 0.5).float()

for i in range(3):
    fig, axes = plt.subplots(1,3, figsize=(12,4))
    axes[0].imshow(imgs[i].permute(1,2,0)); axes[0].set_title("Image"); axes[0].axis("off")
    axes[1].imshow(masks[i,0], cmap="gray");        axes[1].set_title("Mask");  axes[1].axis("off")
    axes[2].imshow(preds[i,0], cmap="gray");        axes[2].set_title("Pred");  axes[2].axis("off")
    plt.show()
