In [None]:
import os
import sys
import random
import warnings

warnings.filterwarnings("ignore")
sys.path.append(os.path.abspath('..'))

import cv2
import time
import numpy as np
import pandas as pd
import kagglehub
from sklearn.model_selection import train_test_split

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, SequentialSampler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from utils.datasets import BusbraDataset, BarlowTwinDataset
from utils.metric import BCEDiceLoss, BarlowTwinsLoss
from utils.visualization import predict_compare, plot_history_loss
from utils.training import train, pretrain_barlow_twins
from models.BarlowTwins import BarlowTwinsModel, Projector
from models.SpatialAttention_UNet import SpatialAttentionUNet_Barlow

# Hyper-parameter

In [None]:
SEED = 42
IMG_SIZE = (256, 256)
BATCH_SIZE_PRETRAIN = 16
BATCH_SIZE_FINETUNE = 8
NUM_WORKERS = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INIT_FEATURES = 32
PROJECTOR_DIMS = [INIT_FEATURES * 16, 2048, 2048]
PRETRAIN_EPOCHS = 100
FINETUNE_EPOCHS = 100
LEARNING_RATE_PRETRAIN = 1e-4
LEARNING_RATE_FINETUNE = 1e-3
FINE_TUNE_FRACTION = 0.2
BARLOW_LAMBDA = 0.2

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

# Load Dataset

In [None]:
path = kagglehub.dataset_download("orvile/bus-bra-a-breast-ultrasound-dataset")
base = os.path.join(path, "BUSBRA")
images_root = os.path.join(base, "Images")
masks_root = os.path.join(base, "Masks")
csv_path = os.path.join(base, "bus_data.csv")
df_meta = pd.read_csv(csv_path)

In [None]:
entries = []
for _, row in df_meta.iterrows():
    base_id = str(row['ID'])
    img_file = f"{base_id}.png"
    if base_id.startswith("bus_"):
        mask_file = f"mask_{base_id[4:]}.png"
    else:
        continue
    img_p = os.path.join(images_root, img_file)
    mask_p = os.path.join(masks_root, mask_file)
    if os.path.exists(img_p) and os.path.exists(mask_p):
        entries.append((row['Pathology'], img_p, mask_p))

In [None]:
df = pd.DataFrame(entries, columns=["label", "image_path", "mask_path"])
stratify = df['label'] if df['label'].value_counts().min() >= 2 else None
train_df, val_df = train_test_split(df, test_size=0.2, stratify=stratify)
fine_tune_df, _ = train_test_split(train_df, train_size=FINE_TUNE_FRACTION, random_state=SEED, stratify=train_df["label"])

In [None]:
IMG_SIZE = (256, 256)
train_transform = A.Compose([
    A.Resize(*IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.8),
    A.GaussNoise(p=0.3),
    A.Normalize(mean=0, std=1, max_pixel_value=255),
    ToTensorV2(),
])
val_transform = A.Compose([
    A.Resize(*IMG_SIZE),
    A.Normalize(mean=0, std=1, max_pixel_value=255),
    ToTensorV2(),
])
barlow_twins_transform = A.Compose([
    A.RandomResizedCrop(size=IMG_SIZE,
                        scale=(0.5, 1.0),
                        p=0.5,
                        interpolation=cv2.INTER_LINEAR),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4,
                               p=0.8),
    A.GaussNoise(var_limit=(10.0, 60.0), p=0.5),
    A.GaussianBlur(blur_limit=(3, 7), p=0.5),
    A.Resize(height=IMG_SIZE[0],
             width=IMG_SIZE[1],
             interpolation=cv2.INTER_LINEAR),
    A.Normalize(mean=(0.5, ), std=(0.5, )),
    ToTensorV2(),
])

In [None]:
dataset_args = dict(batch_size=8,
                    num_workers=2,
                    pin_memory=torch.cuda.is_available())
train_ds = BusbraDataset(train_df, IMG_SIZE, train_transform)
val_ds = BusbraDataset(val_df, IMG_SIZE, val_transform)
train_loader = DataLoader(train_ds, shuffle=True, **dataset_args)
val_loader = DataLoader(val_ds,
                        sampler=SequentialSampler(val_ds),
                        shuffle=False,
                        **dataset_args)

# Model Training and Visualization

## Pre-train

In [None]:
pretrain_dataset = BarlowTwinDataset(train_df, IMG_SIZE,
                                      transform=barlow_twins_transform)
pretrain_loader = DataLoader(pretrain_dataset,
                             batch_size=BATCH_SIZE_PRETRAIN,
                             shuffle=True,
                             num_workers=NUM_WORKERS,
                             pin_memory=True,
                             drop_last=True)

unet_encoder_base = SpatialAttentionUNet_Barlow(INIT_FEATURES, in_channels=1, out_channels=1).to(DEVICE)
projector = Projector(PROJECTOR_DIMS[0], PROJECTOR_DIMS[1],
                      PROJECTOR_DIMS[2]).to(DEVICE)
bt_model = BarlowTwinsModel(encoder=unet_encoder_base,
                            projector=projector).to(DEVICE)

optimizer_bt = optim.AdamW(bt_model.parameters(),
                           lr=LEARNING_RATE_PRETRAIN,
                           weight_decay=1e-6)
criterion_bt = BarlowTwinsLoss(
    lambda_param=BARLOW_LAMBDA,
    batch_size=BATCH_SIZE_PRETRAIN,
    projector_output_dim=PROJECTOR_DIMS[-1]).to(DEVICE)

pretrain_history = pretrain_barlow_twins(
    bt_model,
    pretrain_loader,
    optimizer_bt,
    criterion_bt,
    PRETRAIN_EPOCHS,
    DEVICE,
    encoder_save_path="bt_encoder_pretrained.pth")

## Fine-tune

In [None]:
fine_tune_dataset = BusbraDataset(fine_tune_df, IMG_SIZE, transform=train_transform)
val_dataset = BusbraDataset(val_df, IMG_SIZE, transform=val_transform)

fine_tune_loader = DataLoader(fine_tune_dataset,
                              batch_size=BATCH_SIZE_FINETUNE,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
                              pin_memory=True)
val_loader = DataLoader(val_dataset,
                        batch_size=BATCH_SIZE_FINETUNE,
                        shuffle=False,
                        num_workers=NUM_WORKERS,
                        pin_memory=True)

fine_tune_model = SpatialAttentionUNet_Barlow(INIT_FEATURES, in_channels=1,
                                       out_channels=1).to(DEVICE)

pretrained_encoder_path = "bt_encoder_pretrained.pth"
pretrained_dict = torch.load(pretrained_encoder_path, map_location=DEVICE)
model_dict = fine_tune_model.state_dict()

encoder_keys = {
    k
    for k in model_dict if k.startswith('encoder') or k.startswith('pool')
    or k.startswith('bottleneck')
}
pretrained_dict_filtered = {
    k: v
    for k, v in pretrained_dict.items() if k in encoder_keys
    and k in model_dict and v.shape == model_dict[k].shape
}

model_dict.update(pretrained_dict_filtered)
fine_tune_model.load_state_dict(model_dict)

In [None]:
criterion_ft = BCEDiceLoss(weight_bce=0.5, weight_dice=0.5).to(DEVICE)

NUM_RUNS, EPOCHS = 5, 60
all_histories, best_ios, best_dices, best_precs = [], [], [], []

for run in range(1, NUM_RUNS + 1):
    print(f"Run {run}/{NUM_RUNS}".center(50, "-"))
    optimizer_ft = optim.Adam(fine_tune_model.parameters(),
                            lr=LEARNING_RATE_FINETUNE)
    scheduler_ft = lr_scheduler.ReduceLROnPlateau(optimizer_ft,
                                                mode='max',
                                                factor=0.1,
                                                patience=5)
    history, run_best, run_best_dice, run_best_prec = train(
        fine_tune_model,
        DEVICE,
        fine_tune_loader,
        val_loader,
        criterion_ft,
        optimizer_ft,
        scheduler_ft,
        num_epochs=EPOCHS,
        save_path=f'bt_unet_finetuned_{int(FINE_TUNE_FRACTION*100)}pct.pth',
    )

    all_histories.append(history)
    best_ios.append(run_best)
    best_dices.append(run_best_dice)
    best_precs.append(run_best_prec)
    print(
        f"Run {run} best → IoU: {run_best:.4f}, Dice: {run_best_dice:.4f}, Prec: {run_best_prec:.4f}"
    )

    predict_compare(fine_tune_model, DEVICE, val_loader, num_samples=5)

print(f"\nAll runs done. Best IoUs: {best_ios}")

# Inference

In [None]:
dice_arr = np.array(best_dices)
prec_arr = np.array(best_precs)
io_arr = np.array(best_ios)

print(f"Dice      : {dice_arr.mean():.4f} ± {dice_arr.std():.4f}")
print(f"Precision : {prec_arr.mean():.4f} ± {prec_arr.std():.4f}")
print(f"mIoU      : {io_arr.mean():.4f} ± {io_arr.std():.4f}")

In [None]:
params_m = sum(p.numel() for p in fine_tune_model.parameters()) / 1e6
print(f"Params    : {params_m:.2f}M")

In [None]:
fine_tune_model.eval()
times = []
with torch.no_grad():
    for imgs, _ in val_loader:
        t0 = time.time()
        _ = fine_tune_model(imgs.to(DEVICE))
        times.append((time.time() - t0) / imgs.size(0))
times = np.array(times)
print(
    f"Inference Time: {times.mean()*1000:.2f} ± {times.std()*1000:.2f} ms/image"
)