<a href="https://colab.research.google.com/github/karthik7147/Robust-Image-Segmentation-Using-U-Net-with-Class-Imbalance-Aware-Training/blob/main/origin_medical_'.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

kkaarrtthhiikkr_task_segmentation_path = kagglehub.dataset_download('kkaarrtthhiikkr/task-segmentation')

print('Data source import complete.')


In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T


In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, img_size=256):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.img_size = img_size
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]

        # -------- IMAGE --------
        img_path = os.path.join(self.image_dir, image_name)
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.img_size, self.img_size))
        image = image / 255.0

        # -------- MASK --------
        base_name = os.path.splitext(image_name)[0]
        mask_name = base_name + "_Annotation.png"
        mask_path = os.path.join(self.mask_dir, mask_name)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Mask not found: {mask_path}")

        mask = cv2.resize(mask, (self.img_size, self.img_size))
        mask = (mask > 0).astype(np.float32)   # binary mask
        mask = np.expand_dims(mask, axis=0)

        image = torch.tensor(image, dtype=torch.float).permute(2, 0, 1)
        mask = torch.tensor(mask, dtype=torch.float)

        return image, mask


In [None]:
import os

IMAGE_DIR = "/kaggle/input/task-segmentation/Task - Segmentation/images"
MASK_DIR  = "/kaggle/input/task-segmentation/Task - Segmentation/masks"

print("Images:", len(os.listdir(IMAGE_DIR)))
print("Masks:", len(os.listdir(MASK_DIR)))


In [None]:
dataset = SegmentationDataset(IMAGE_DIR, MASK_DIR)

img, mask = dataset[10]

print("Image shape:", img.shape)
print("Mask shape:", mask.shape)

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.imshow(img.permute(1,2,0))
plt.title("Input Image")

plt.subplot(1,2,2)
plt.imshow(mask.squeeze(), cmap="gray")
plt.title("Ground Truth Mask")
plt.show()


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.d1 = DoubleConv(3, 64)
        self.d2 = DoubleConv(64, 128)
        self.d3 = DoubleConv(128, 256)
        self.d4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.u1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.u2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.u3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.u4 = nn.ConvTranspose2d(128, 64, 2, stride=2)

        self.c1 = DoubleConv(1024, 512)
        self.c2 = DoubleConv(512, 256)
        self.c3 = DoubleConv(256, 128)
        self.c4 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, 1, 1)

    # âœ… FORWARD MUST BE INSIDE THE CLASS
    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(self.pool(d1))
        d3 = self.d3(self.pool(d2))
        d4 = self.d4(self.pool(d3))

        b = self.bottleneck(self.pool(d4))

        u1 = self.u1(b)
        c1 = self.c1(torch.cat([u1, d4], dim=1))

        u2 = self.u2(c1)
        c2 = self.c2(torch.cat([u2, d3], dim=1))

        u3 = self.u3(c2)
        c3 = self.c3(torch.cat([u3, d2], dim=1))

        u4 = self.u4(c3)
        c4 = self.c4(torch.cat([u4, d1], dim=1))

        return self.out(c4)   # ðŸ”´ NO SIGMOID


In [None]:
IMAGE_DIR = "/kaggle/input/task-segmentation/Task - Segmentation/images"
MASK_DIR  = "/kaggle/input/task-segmentation/Task - Segmentation/masks"

dataset = SegmentationDataset(IMAGE_DIR, MASK_DIR)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------- MODEL ----------------
model = UNet().to(device)

# ðŸ”¥ HE INITIALIZATION (THIS WAS MISSING)
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model.apply(init_weights)   # âœ… THIS LINE IS CRITICAL

# ---------------- LOSS (STAGE 1) ----------------
pos_weight = torch.tensor([50.0]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# ---------------- OPTIMIZER ----------------
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [None]:
img, mask = next(iter(loader))
img, mask = img.to(device), mask.to(device)

optimizer.zero_grad()
logits = model(img)
loss = criterion(logits, mask)
loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.abs().mean().item())
        break


In [None]:
# =========================
# DEVICE
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"

# =========================
# MODEL + HE INITIALIZATION
# =========================
model = UNet().to(device)

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model.apply(init_weights)

# =========================
# OPTIMIZER
# =========================
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# =========================
# LOSS FUNCTIONS
# =========================
pos_weight = torch.tensor([50.0]).to(device)
bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

def dice_loss_logits(logits, target, smooth=1.0):
    probs = torch.sigmoid(logits)
    probs = probs.view(-1)
    target = target.view(-1)
    intersection = (probs * target).sum()
    return 1 - ((2. * intersection + smooth) /
                (probs.sum() + target.sum() + smooth))

def combined_loss(logits, target):
    return bce_loss(logits, target) + dice_loss_logits(logits, target)

# =========================
# METRICS
# =========================
def dice_coeff(pred, target, smooth=1.0):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def iou_score(pred, target, smooth=1.0):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return (intersection + smooth) / (union + smooth)

# =========================
# -------- STAGE 1 --------
# BCE ONLY (5 EPOCHS)
# =========================
print("ðŸ”µ Stage 1: BCE-only warmup")

epochs_stage1 = 5
criterion = bce_loss

for epoch in range(epochs_stage1):
    model.train()
    epoch_loss = 0

    for img, mask in loader:
        img, mask = img.to(device), mask.to(device)

        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, mask)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()

    print(f"[Stage 1 | Epoch {epoch+1}/{epochs_stage1}] Loss: {epoch_loss/len(loader):.4f}")

# =========================
# -------- STAGE 2 --------
# BCE + DICE (UP TO 50 EPOCHS)
# =========================
print("\nðŸ”µ Stage 2: BCE + Dice training")

epochs_stage2 = 50
criterion = combined_loss

for epoch in range(epochs_stage2):
    model.train()
    epoch_loss = 0

    for img, mask in loader:
        img, mask = img.to(device), mask.to(device)

        optimizer.zero_grad()
        logits = model(img)
        loss = criterion(logits, mask)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()

    # =========================
    # EVALUATION
    # =========================
    model.eval()
    dice_scores, iou_scores = [], []

    with torch.no_grad():
        for img, mask in loader:
            img, mask = img.to(device), mask.to(device)
            logits = model(img)
            prob = torch.sigmoid(logits)
            pred = (prob > 0.25).float()

            dice_scores.append(dice_coeff(pred, mask).item())
            iou_scores.append(iou_score(pred, mask).item())

    mean_dice = sum(dice_scores) / len(dice_scores)
    mean_iou = sum(iou_scores) / len(iou_scores)

    print(f"[Stage 2 | Epoch {epoch+1}/{epochs_stage2}] "
          f"Loss: {epoch_loss/len(loader):.4f} | "
          f"Dice: {mean_dice:.4f} | IoU: {mean_iou:.4f}")

    # =========================
    # VISUAL DEBUG EVERY 5 EPOCHS
    # =========================
    if (epoch + 1) % 5 == 0:
        img, mask = dataset[20]
        img = img.unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(img)
            prob = torch.sigmoid(logits).cpu().squeeze()

        plt.figure(figsize=(14,4))
        plt.subplot(1,4,1)
        plt.title("Input")
        plt.imshow(img.cpu().squeeze().permute(1,2,0))

        plt.subplot(1,4,2)
        plt.title("GT Mask")
        plt.imshow(mask.squeeze(), cmap="gray")

        plt.subplot(1,4,3)
        plt.title("Probability Map")
        plt.imshow(prob, cmap="gray")

        plt.subplot(1,4,4)
        plt.title("Predicted (0.25)")
        plt.imshow((prob > 0.25), cmap="gray")

        plt.show()


In [None]:
import os

os.makedirs("Model_Weights", exist_ok=True)

SAVE_PATH = "Model_Weights/hypothesis_final_full_saved_model.pth"

torch.save(model.state_dict(), SAVE_PATH)

print(f"âœ… Model weights saved at: {SAVE_PATH}")


In [None]:
torch.save(model.state_dict(), "Model_Weights/hypothesis_final_full_saved_model1.pth")
