In [None]:
import os
import sys
import random
import warnings

warnings.filterwarnings('ignore')
sys.path.append(os.path.abspath('..'))

import time
import numpy as np
import pandas as pd
import kagglehub
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, SequentialSampler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils.datasets import MobileNetV2Dataset
from utils.metric import BCEDiceLoss
from utils.visualization import predict_compare, plot_history_loss
from utils.training import train_rl
from models.MobileNetV2_UNet import MobileNetV2_UNet_Attn_MS

# Hyper-parameter

In [None]:
SEED = 42
IMG_SIZE = (256, 256)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
THRESH = 0.55
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 5e-4
DROPOUT_P = 0.4
ALPHA = 1.5
RL_WEIGHT = 0.005
PREHEAT_EPOCHS = 20

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

# Load Dataset

In [None]:
path = kagglehub.dataset_download("orvile/bus-bra-a-breast-ultrasound-dataset")
base = os.path.join(path, "BUSBRA")
images_root = os.path.join(base, "Images")
masks_root = os.path.join(base, "Masks")
csv_path = os.path.join(base, "bus_data.csv")
df_meta = pd.read_csv(csv_path)

In [None]:
entries = []
for _, row in df_meta.iterrows():
    base_id = str(row['ID'])
    img_file = f"{base_id}.png"
    if base_id.startswith("bus_"):
        mask_file = f"mask_{base_id[4:]}.png"
    else:
        continue
    img_p = os.path.join(images_root, img_file)
    mask_p = os.path.join(masks_root, mask_file)
    if os.path.exists(img_p) and os.path.exists(mask_p):
        entries.append((row['Pathology'], img_p, mask_p))

In [None]:
df = pd.DataFrame(entries, columns=["label", "image_path", "mask_path"])
stratify = df['label'] if df['label'].value_counts().min() >= 2 else None
train_df, val_df = train_test_split(df, test_size=0.2, stratify=stratify)

In [None]:
IMG_SIZE = (256, 256)
train_transform = A.Compose([
    A.Resize(*IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.8),
    A.GaussNoise(p=0.3),
    A.Normalize(mean=0, std=1, max_pixel_value=255),
    ToTensorV2(),
])
val_transform = A.Compose([
    A.Resize(*IMG_SIZE),
    A.Normalize(mean=0, std=1, max_pixel_value=255),
    ToTensorV2(),
])

In [None]:
dataset_args = dict(batch_size=8,
                    num_workers=2,
                    pin_memory=torch.cuda.is_available())
train_ds = MobileNetV2Dataset(train_df, IMG_SIZE, train_transform)
val_ds = MobileNetV2Dataset(val_df, IMG_SIZE, val_transform)
train_loader = DataLoader(train_ds, shuffle=True, **dataset_args)
val_loader = DataLoader(val_ds,
                        sampler=SequentialSampler(val_ds),
                        shuffle=False,
                        **dataset_args)

# Model Training and Visualization

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCELoss()

In [None]:
NUM_RUNS, EPOCHS = 5, 60
best_ios, best_dices, best_precs = [], [], []

for run in range(1, NUM_RUNS + 1):
    print(f"Run {run}/{NUM_RUNS}".center(50, "-"))
    model = MobileNetV2_UNet_Attn_MS(DROPOUT_P, IMG_SIZE).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='max',
                                               factor=0.1,
                                               patience=5)
    history, run_best, run_best_dice, run_best_prec = train_rl(
        model,
        DEVICE,
        train_loader,
        val_loader,
        criterion,
        optimizer,
        scheduler,
        IMG_SIZE,
        ALPHA,
        RL_WEIGHT,
        PREHEAT_EPOCHS,
        num_epochs=EPOCHS,
        save_path=f'best_mobilenetv2-unet_run{run}.pth',
    )

    best_ios.append(run_best)
    best_dices.append(run_best_dice)
    best_precs.append(run_best_prec)
    print(
        f"Run {run} best → IoU: {run_best:.4f}, Dice: {run_best_dice:.4f}, Prec: {run_best_prec:.4f}"
    )

    predict_compare(model, DEVICE, val_loader, num_samples=5)

print(f"\nAll runs done. Best IoUs: {best_ios}")

# Inference

In [None]:
dice_arr = np.array(best_dices)
prec_arr = np.array(best_precs)
io_arr = np.array(best_ios)

print(f"Dice      : {dice_arr.mean():.4f} ± {dice_arr.std():.4f}")
print(f"Precision : {prec_arr.mean():.4f} ± {prec_arr.std():.4f}")
print(f"mIoU      : {io_arr.mean():.4f} ± {io_arr.std():.4f}")

In [None]:
params_m = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Params    : {params_m:.2f}M")

In [None]:
model.eval()
times = []
with torch.no_grad():
    for imgs, _ in val_loader:
        t0 = time.time()
        _ = model(imgs.to(DEVICE))
        times.append((time.time() - t0) / imgs.size(0))
times = np.array(times)
print(
    f"Inference Time: {times.mean()*1000:.2f} ± {times.std()*1000:.2f} ms/image"
)