In [None]:
# 修改 fold_idx 選擇要做哪一fold (0~4)
# datasets 平均分好 5 fold ㄌ
fold_idx = '0' # default

In [None]:
!pip install cellpose

In [None]:
"""
augment.py
------------------------------------------------------------------
• Patch‑wise Mosaic
• 階段後：水平/垂直翻轉 + 亮度 ±10%

適用 Sartorius 細胞實例分割
"""

import numpy as np
import albumentations as A
from typing import Sequence, Tuple, Union

# ------------ 翻轉 + 亮度流水線 ------------------------------
_geom_tf = A.Compose(
    [
        A.HorizontalFlip(p=0.0),
        A.VerticalFlip(p=0.0),
        A.RandomBrightnessContrast(
            brightness_limit=0.10,  # ±10% 亮度
            contrast_limit=0.00,
            p=0.0
        ),
    ],
    additional_targets={"mask": "mask"},
    is_check_shapes=False,
)

# ------------ Augmenter 類別 ----------------------------------------
class Augmenter:
    """
    支援：
      1) Patch‑wise Mosaic
      2) 水平/垂直翻轉 + 亮度 ±10%

    參數:
      mosaic_prob: Mosaic 執行機率 (0~1)
    """
    def __init__(
        self,
        mosaic_prob: float = 0.0
    ):
        self.mosaic_prob = mosaic_prob

    def __call__(
        self,
        imgs: Union[np.ndarray, Sequence[np.ndarray]],
        masks: Union[np.ndarray, Sequence[np.ndarray]],
        epoch: int = 0
    ) -> Tuple[np.ndarray, np.ndarray]:
        # 支援單張或 batch
        single = False
        if isinstance(imgs, np.ndarray) and imgs.ndim == 3:
            imgs_list = [imgs]
            masks_list = [masks]  # type: ignore
            single = True
        else:
            imgs_list = list(imgs)  # type: ignore
            masks_list = list(masks)  # type: ignore

        out_imgs, out_masks = [], []
        N = len(imgs_list)
        for idx, (im, mk) in enumerate(zip(imgs_list, masks_list)):
            # 1) 翻轉 + 亮度
            aug = _geom_tf(image=im, mask=mk)
            out_imgs.append(aug["image"].astype(im.dtype))
            out_masks.append(aug["mask"].astype(mk.dtype))

        if single:
            return out_imgs[0], out_masks[0]
        return np.stack(out_imgs), np.stack(out_masks)


In [None]:
import time, logging
import numpy as np
from pathlib import Path
from tqdm import trange

import torch
from torch.cuda.amp import autocast, GradScaler          # ← AMP
from cellpose import models, train
from cellpose.transforms import random_rotate_and_resize

train_logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
#  Train with callback  +  AMP support
# ---------------------------------------------------------------------------
def train_seg_with_callback(
        net,
        train_data=None, train_labels=None, train_files=None,
        train_labels_files=None, train_probs=None,
        test_data=None,  test_labels=None,  test_files=None,
        test_labels_files=None, test_probs=None,
        channel_axis=None, load_files=True,
        batch_size=1, learning_rate=5e-5, SGD=False,
        n_epochs=100, weight_decay=0.1,
        normalize=True, compute_flows=False,
        save_path=None, save_every=100, save_each=False,
        nimg_per_epoch=None, nimg_test_per_epoch=None,
        rescale=False, scale_range=None, bsize=256,
        min_train_masks=5, model_name=None, class_weights=None,
        callback=None, validate_every=5, tb_writer=None,
        augment_fn=None, use_amp=True  # ← 新增
    ):
    """
    Cellpose segmentation training with:
      • custom callback
      • AMP half-precision (set use_amp=False to disable)
    """

    # ----------------------- basic setup ----------------------- #
    if SGD:
        train_logger.warning("SGD is deprecated, using AdamW instead")

    device = net.device
    scale_range = 0.5 if scale_range is None else scale_range

    # normalisation dict handling
    if isinstance(normalize, dict):
        normalize_params = {**models.normalize_default, **normalize}
    elif isinstance(normalize, bool):
        normalize_params = models.normalize_default
        normalize_params["normalize"] = normalize
    else:
        raise ValueError("normalize must be bool or dict")

    # preprocess data (cellpose util)
    result = train._process_train_test(
        train_data=train_data, train_labels=train_labels,
        train_files=train_files, train_labels_files=train_labels_files,
        train_probs=train_probs,
        test_data=test_data, test_labels=test_labels,
        test_files=test_files, test_labels_files=test_labels_files,
        test_probs=test_probs,
        load_files=load_files, min_train_masks=min_train_masks,
        compute_flows=compute_flows, channel_axis=channel_axis,
        normalize_params=normalize_params, device=device
    )
    (train_data, train_labels, train_files, train_labels_files, train_probs, diam_train,
     test_data,  test_labels,  test_files,  test_labels_files, test_probs, diam_test,
     normed) = result

    # parameters for later batches
    kwargs = {} if normed else {"normalize_params": normalize_params,
                                "channel_axis": channel_axis}

    net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)

    # optional class weights
    if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)):
        class_weights = torch.tensor(class_weights, dtype=torch.float32,
                                     device=device)

    # dataset sizes
    nimg        = len(train_data) if train_data is not None else len(train_files)
    nimg_test   = len(test_data) if test_data is not None else len(test_files) if test_files is not None else None
    nimg_per_epoch      = nimg       if nimg_per_epoch       is None else nimg_per_epoch
    nimg_test_per_epoch = nimg_test  if nimg_test_per_epoch  is None else nimg_test_per_epoch

    # ----------------------- lr schedule ---------------------- #
    warmup_epochs = 5                   # ← 5 個 warm-up
    base_lr       = learning_rate       # = 1e-5
    eta_min       = 1e-7                # 最低 LR

    LR = np.zeros(n_epochs, dtype=np.float32)

    # ❶ 線性 warm-up: 0 → base_lr
    for e in range(min(warmup_epochs, n_epochs)):
        LR[e] = base_lr * (e + 1) / warmup_epochs

    # ❷ Cosine phase
    if n_epochs > warmup_epochs:
        t = np.arange(0, n_epochs - warmup_epochs)
        cos_part = eta_min + 0.5 * (base_lr - eta_min) * \
                (1 + np.cos(np.pi * t / (n_epochs - warmup_epochs)))
        LR[warmup_epochs:] = cos_part

    # ----------------------- optimizer & scaler --------------- #
    optimizer = torch.optim.AdamW(net.parameters(),
                                  lr=learning_rate,
                                  weight_decay=weight_decay)
    scaler = GradScaler(enabled=use_amp)          # AMP scaler

    # ----------------------- paths ----------------------------- #
    t0 = time.time()
    model_name = f"cellpose_{t0}" if model_name is None else model_name
    save_path  = Path.cwd() if save_path is None else Path(save_path)
    model_dir  = save_path / model_name
    model_dir.mkdir(exist_ok=True)
    filename   = model_dir / model_name
    train_logger.info(f"Saving checkpoints to {filename}")

    # ----------------------- track losses ---------------------- #
    train_losses = np.zeros(n_epochs, dtype=np.float32)
    test_losses  = np.zeros(n_epochs, dtype=np.float32)

    # ----------------------- epoch loop ------------------------ #
    for iepoch in trange(n_epochs, desc="Epoch", ncols=100):
        # set seed & shuffle
        np.random.seed(iepoch)
        rperm = (np.random.choice(np.arange(nimg), nimg_per_epoch, p=train_probs)
                 if nimg != nimg_per_epoch
                 else np.random.permutation(np.arange(nimg)))

        # update LR
        for pg in optimizer.param_groups:
            pg["lr"] = LR[iepoch]

        net.train()
        epoch_train_loss, nsamples = 0.0, 0

        # -------- mini-batch loop -------- #
        for k in trange(0, nimg_per_epoch, batch_size,
                        leave=False, desc="Train", ncols=100):
            inds = rperm[k:k+batch_size]
            imgs, lbls = train._get_batch(inds,
                                          data=train_data, labels=train_labels,
                                          files=train_files, labels_files=train_labels_files,
                                          **kwargs)
            diams = np.array([diam_train[i] for i in inds])
            rsc   = diams / net.diam_mean.item() if rescale else np.ones_like(diams)

            if augment_fn is not None:
                imgs, lbls = augment_fn(imgs, lbls, epoch=iepoch)

            imgi, lbl = random_rotate_and_resize(imgs, Y=lbls,
                                                 rescale=rsc,
                                                 scale_range=scale_range,
                                                 xy=(bsize, bsize))[:2]

            X   = torch.from_numpy(imgi).to(device)
            lbl = torch.from_numpy(lbl).to(device)

            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=use_amp):
                y = net(X)[0]
                loss = train._loss_fn_seg(lbl, y, device)
                if y.shape[1] > 3:
                    loss += train._loss_fn_class(lbl, y,
                                                 class_weights=class_weights)

            # ------ backward scaled ------ #
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            batch_loss = loss.item() * len(imgi)
            epoch_train_loss += batch_loss
            nsamples         += len(imgi)

        train_losses[iepoch] = epoch_train_loss / nsamples

        # -------- validation -------- #
        val_loss_mean = 0.0
        if (test_data is not None or test_files is not None) and \
           (iepoch % validate_every == 0 or iepoch == n_epochs-1):

            np.random.seed(42)
            rperm = (np.random.choice(np.arange(nimg_test), nimg_test_per_epoch, p=test_probs)
                     if nimg_test != nimg_test_per_epoch
                     else np.random.permutation(np.arange(nimg_test)))

            net.eval()
            val_sum, v_nsamp = 0.0, 0
            with torch.no_grad():
                for k in trange(0, len(rperm), batch_size,
                                leave=False, desc="Val", ncols=100):
                    inds = rperm[k:k+batch_size]
                    imgs, lbls = train._get_batch(inds,
                                                  data=test_data, labels=test_labels,
                                                  files=test_files, labels_files=test_labels_files,
                                                  **kwargs)
                    diams = np.array([diam_test[i] for i in inds])
                    rsc   = diams / net.diam_mean.item() if rescale else np.ones_like(diams)

                    imgi, lbl = random_rotate_and_resize(imgs, Y=lbls,
                                                         rescale=rsc,
                                                         scale_range=scale_range,
                                                         xy=(bsize, bsize))[:2]
                    X   = torch.from_numpy(imgi).to(device)
                    lbl = torch.from_numpy(lbl).to(device)

                    with autocast(enabled=use_amp):
                        y = net(X)[0]
                        vloss = train._loss_fn_seg(lbl, y, device)
                        if y.shape[1] > 3:
                            vloss += train._loss_fn_class(lbl, y,
                                                          class_weights=class_weights)
                    val_sum  += vloss.item() * len(imgi)
                    v_nsamp  += len(imgi)

            val_loss_mean           = val_sum / v_nsamp
            test_losses[iepoch]     = val_loss_mean
        else:
            test_losses[iepoch] = test_losses[iepoch-1] if iepoch else 0

        # -------- callback / TB log -------- #
        if callback is not None:
            callback.on_epoch_end(iepoch,
                                  train_losses[iepoch],
                                  test_losses[iepoch],
                                  LR[iepoch])

        if tb_writer is not None:
            tb_writer.add_scalar("Loss/train", train_losses[iepoch], iepoch)
            tb_writer.add_scalar("Loss/valid", test_losses[iepoch],  iepoch)
            tb_writer.add_scalar("LR",         LR[iepoch],          iepoch)

        # -------- save checkpoint -------- #
        if iepoch == n_epochs-1 or (iepoch % save_every == 0 and iepoch):
            ckpt_name = (f"{filename}_epoch_{iepoch:04d}"
                         if (save_each and iepoch != n_epochs-1) else filename)
            net.save_model(str(ckpt_name))
            train_logger.info(f"Saved model to {ckpt_name}")

    # final save
    net.save_model(str(filename))
    train_logger.info(f"Final model saved to {filename}")
    return str(filename), train_losses, test_losses

In [None]:
import os
import torch
from pathlib import Path
from cellpose import io, models, core, metrics
from torch.utils.tensorboard import SummaryWriter

log_dir_base = Path("/kaggle/working/runs")
log_dir_base.mkdir(parents=True, exist_ok=True)


# --------------------------- environment setup ----------------------------- #
io.logger_setup()
if not core.use_gpu():
    raise ImportError("No GPU detected; please switch to a GPU runtime")

model_name = "fold" + fold_idx
save_path = Path("/kaggle/working")
save_path.mkdir(exist_ok=True)

# Training hyperparameters
n_epochs      = 100
learning_rate = 1e-5
weight_decay  = 0.1
batch_size    = 2

train_dir = Path("/kaggle/input/fold-data/mytrain_fold" + fold_idx)
test_dir  = Path("/kaggle/input/fold-data/myval_fold" + fold_idx)
masks_ext = "_seg.npy"

if not train_dir.exists() or not test_dir.exists():
    raise FileNotFoundError("train_dir or test_dir not found")

# --------------------------- load full-size images & masks ----------------------------- #
train_imgs, train_msks, _, val_imgs, val_msks, _ = io.load_train_test_data(
    str(train_dir), str(test_dir),
    mask_filter=masks_ext,
    image_filter=""
)
print(f"Loaded full images → train: {len(train_imgs)}, val: {len(val_imgs)}")

# --------------------------------- 不作 patch 切割，直接使用原圖和原遮罩 --------------------------
train_data, train_labels = train_imgs, train_msks
val_data,   val_labels  = val_imgs,   val_msks

# --------------------------- 確認 GPU 可用性 ----------------------------- #
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------- 初始化 Augmenter、並在 50 epoch 後關閉 Mosaic ----------------------
# 由於現在不做 patch-based mosaic，若只想要原本的翻轉/亮度等等強度 Augment，可自行調整 augment_fn
train_aug = Augmenter(mosaic_prob=0.0)

def scheduled_augment(imgs, masks, epoch=0):
    # 這裡維持「不做任何 mosaic」的設計，你可視需要在 epoch < 50 時做其他 augment
    if epoch >= 0:
        return imgs, masks
    return train_aug(imgs, masks, epoch=epoch)

# --------------------------- set up TensorBoard callback ----------------------------- #
log_dir = Path("/kaggle/working/runs") / model_name
log_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(str(log_dir))

class MAPCallback:
    """
    Custom callback to log training/validation loss, learning rate,
    並計算 full-res 驗證集上的 mAP@0.50 及 mAP@[0.50:0.95]。
    """
    def __init__(self, model, val_imgs, val_msks, writer, save_path, model_name, num_best_models=5):
        self.model    = model
        self.val_imgs = val_imgs
        self.val_msks = val_msks
        self.writer   = writer
        self.save_path = Path(save_path) # 模型保存的基礎路徑
        self.model_name = model_name # 模型基礎名稱
        self.num_best_models = num_best_models # 要保留的最佳模型數量
        self.best_models = [] # 儲存 (mAP, epoch, temp_model_path) 的列表
        
    def on_epoch_end(self, epoch, train_loss, val_loss, lr):
        # 1) Log losses and LR
        self.writer.add_scalar("Loss/train", train_loss, epoch)
        self.writer.add_scalar("Loss/val",   val_loss,   epoch)
        self.writer.add_scalar("LR",         lr,         epoch)
        # 2) Infer on full-res val set
        with torch.no_grad():
            preds_list, *_ = self.model.eval(
                self.val_imgs, batch_size=1, diameter=None
            )
        
        # 3) Compute COCO-style average_precision
        ap_array = metrics.average_precision(self.val_msks, preds_list)[0]
        
        # 4) Log mAP@0.50 and mAP@[0.50:0.95]
        self.writer.add_scalar("mAP50/val",    ap_array[:,0].mean(), epoch)
        self.writer.add_scalar("mAP50_95/val", ap_array.mean(),      epoch)

        current_mAP50_95 = ap_array.mean() if ap_array.size > 0 else 0.0
        current_mAP50 = ap_array[:,0].mean() if ap_array.shape[1] > 0 else 0.0
        
        # 5) 追蹤並保存最佳模型
        current_model_filename = self.save_path / \
                                 f"{self.model_name}_epoch_{epoch:04d}_mAP_{current_mAP50_95:.4f}.pth"

        if len(self.best_models) < self.num_best_models:
            # 如果還沒有達到數量限制，直接添加並保存
            self.model.net.save_model(str(current_model_filename))
            print(f"Saved model to {current_model_filename}")
            self.best_models.append((current_mAP50_95, epoch, current_model_filename))
            self.best_models.sort(key=lambda x: x[0], reverse=True) # 按 mAP 降序排序

        elif current_mAP50_95 > self.best_models[-1][0]:
            # 如果當前 mAP 比列表中最差的還要好，則替換它
            old_mAP, old_epoch, old_path = self.best_models.pop() # 移除最差的
            if old_path and old_path.exists():
                os.remove(old_path) # 刪除舊模型文件
                print(f"Removed old best model: {old_path.name} (mAP={old_mAP:.4f})")

            # 保存新模型
            self.model.net.save_model(str(current_model_filename))
            print(f"Saved new best model: {current_model_filename}")
            
            self.best_models.append((current_mAP50_95, epoch, current_model_filename))
            self.best_models.sort(key=lambda x: x[0], reverse=True) # 重新排序

# --------------------------- build model & register callback ----------------------------- #
model = models.CellposeModel(gpu=True, pretrained_model = r'/kaggle/input/cellposesam-pretrained-weight-on-livecell/LiveCell_epoch50_bs2_norm_part2')
net   = model.net.to(device)
map_cb = MAPCallback(model, val_data, val_labels, writer, save_path=save_path, model_name=model_name, num_best_models=5)

# --------------------------- run training ----------------------------- #
new_model_path, train_losses, val_losses = train_seg_with_callback(
    net,
    train_data=train_data,         # 直接用 full-res 圖
    train_labels=train_labels,     # 直接用 full-res 遮罩
    test_data=val_data,            # full-res 驗證集
    test_labels=val_labels,        # full-res 驗證遮罩
    batch_size=batch_size,
    n_epochs=n_epochs,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    nimg_per_epoch=len(train_data),
    save_every=n_epochs + 1, # 將 Cellpose 內建的定期保存功能關閉，只在最後保存一個模型
    save_each=False, # 確保不會生成大量臨時文件
    save_path=save_path, # 這個 save_path 現在主要用於傳遞給 callback 以便在 /kaggle/working/ 下保存
    model_name=model_name,
    callback=map_cb,
    validate_every=1,
    min_train_masks=0,
    normalize=True,
    augment_fn=scheduled_augment,  # ← 使用簡單的 augment_fn
)

print("✅ Training complete. Model saved to:", new_model_path)
writer.close()
