In [1]:

from dotenv import load_dotenv
import sys, os
sys.path.append('../')
from utils.seed import seed_everything
seed_everything()
load_dotenv('../.env')

EXPERIMENT_DIR=os.getenv('EXPERIMENT_DIR')

## Experiment Settings

In [None]:
import pandas as pd 
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold

tuning_df = pd.read_csv(os.getenv('TUNING_CSV'))

kf = KFold(n_splits=5, shuffle=True, random_state=42)
folds = [
    (tuning_df.iloc[train_idx].reset_index(drop=True), tuning_df.iloc[val_idx].reset_index(drop=True))
    for train_idx, val_idx in kf.split(tuning_df)
]

# Fold 1개만 적용 
train_df, valid_df = folds[0]
train_df

## Dataset Setting

In [None]:
from torch.utils.data import DataLoader
from utils.datasets import CAG_Dataset
from Args import Args_Train_Loader, Args_Valid_Loader

Train_Dataset = CAG_Dataset(
    df=train_df,
    image_dir=os.getenv('IMAGE_DIR'),
    mask_dir=os.getenv('MASK_DIR'),
    default_transform = Args_Train_Loader._get_default_transform(),
    aug_transform = Args_Train_Loader._get_aug_transform()
)
Valid_Dataset = CAG_Dataset(
    df=valid_df,
    image_dir=os.getenv('IMAGE_DIR'),
    mask_dir=os.getenv('MASK_DIR'),
    default_transform = Args_Valid_Loader._get_default_transform(),
    # Prompt_Args = {
    #     "n_shot" : 3
    # }
)

Train_Loader = DataLoader(
    Train_Dataset,
    batch_size=Args_Train_Loader.train_bs,
    shuffle=Args_Train_Loader.shuffle,
    num_workers=Args_Train_Loader.num_workers,
    pin_memory=Args_Train_Loader.pin_memory,
    drop_last=Args_Train_Loader.drop_last,

)
Valid_Loader = DataLoader(
    Valid_Dataset,
    batch_size=Args_Valid_Loader.valid_bs,
    shuffle=Args_Valid_Loader.shuffle,
    num_workers=Args_Valid_Loader.num_workers,
    pin_memory=Args_Valid_Loader.pin_memory,
    drop_last=Args_Valid_Loader.drop_last,
)

sample_imgs, sample_masks = next(iter(Train_Loader))
print(sample_imgs.shape, sample_masks.shape)
print(sample_masks.unique())
# sample_imgs, sample_masks 에 하나만 시각화
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(sample_imgs[0].squeeze().cpu().numpy(), cmap = 'gray')
plt.title('Image')
plt.subplot(1, 2, 2)
plt.imshow(sample_masks[0].squeeze().cpu().numpy(), cmap = 'gray')
plt.title('Mask')
plt.show()



# Model Setting

In [4]:
from models.DeepSA.model import build_model
from Args import Args_experiments


from models.DeepSA.model import UNet
# seg_model = UNet(1, 1, 32, bilinear=True).to(Args_experiments.device)

seg_model = build_model(ckpt_path = os.getenv('deepsa_ckpt_path'), device = Args_experiments.device).to(Args_experiments.device)

optimizer = Args_experiments.optimizer_fn(seg_model.parameters())
scheduler = Args_experiments.scheduler_fn(optimizer)
loss_fn = Args_experiments.loss_fn()

## Training Setting

In [None]:
import torch
from tqdm import tqdm
import os
from utils.metrics import SegmentationMetrics

metrics = SegmentationMetrics()

def set_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_lr(optimizer):
    return optimizer.param_groups[0]['lr']

def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    epoch_loss, metric_sum = 0, None
    for imgs, masks in tqdm(loader, desc="Train", leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, masks)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        epoch_loss += loss.item() * imgs.size(0)
        batch_metrics = metrics.evaluate(torch.sigmoid(outputs), masks)
        if metric_sum is None:
            metric_sum = {k: v * imgs.size(0) for k, v in batch_metrics.items()}
        else:
            for k in metric_sum:
                metric_sum[k] += batch_metrics[k] * imgs.size(0)
    n = len(loader.dataset)
    avg_metrics = {k: v / n for k, v in metric_sum.items()}
    return epoch_loss / n, avg_metrics

@torch.no_grad()
def valid_one_epoch(model, loader, loss_fn, device):
    model.eval()
    epoch_loss, metric_sum = 0, None
    for imgs, masks in tqdm(loader, desc="Valid", leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        loss = loss_fn(outputs, masks)
        epoch_loss += loss.item() * imgs.size(0)
        batch_metrics = metrics.evaluate(torch.sigmoid(outputs), masks)
        if metric_sum is None:
            metric_sum = {k: v * imgs.size(0) for k, v in batch_metrics.items()}
        else:
            for k in metric_sum:
                metric_sum[k] += batch_metrics[k] * imgs.size(0)
    n = len(loader.dataset)
    avg_metrics = {k: v / n for k, v in metric_sum.items()}
    return epoch_loss / n, avg_metrics

def run_training(
    model, optimizer, scheduler, loss_fn, 
    train_loader, valid_loader, device, 
    num_epochs, patience, exp_name
):
    best_dice = 0
    best_valid_loss = float('inf')
    patience_counter = 0
    save_dir = os.path.join(os.getenv("EXPERIMENT_DIR", "./EXPERIMENT_DIR"), exp_name)
    os.makedirs(save_dir, exist_ok=True)
    best_weight_dice_path = os.path.join(save_dir, "best_weight.pth")
    warmup_epoch = getattr(Args_experiments, "warmup_epoch", 0)
    base_lr = Args_experiments.lr
    for epoch in range(1, num_epochs+1):
        print(f"Epoch {epoch}/{num_epochs}")
        if warmup_epoch > 0 and epoch <= warmup_epoch:
            warmup_lr = base_lr * epoch / warmup_epoch
            set_lr(optimizer, warmup_lr)
            print(f"Warmup lr: {get_lr(optimizer):.6f}")
        elif warmup_epoch > 0 and epoch == warmup_epoch + 1:
            set_lr(optimizer, base_lr)
            print(f"Set lr to base: {get_lr(optimizer):.6f}")
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
        valid_loss, valid_metrics = valid_one_epoch(model, valid_loader, loss_fn, device)
        scheduler.step()
        # dice 기준 best
        if valid_metrics["dice_coef"] > best_dice:
            best_dice = valid_metrics["dice_coef"]
            patience_counter = 0
            torch.save(model.state_dict(), best_weight_dice_path)
            print(f"[Best Dice] Train Loss: {train_loss:.4f} | " + " | ".join([f'{k}: {v:.4f}' for k, v in train_metrics.items()]))
            print(f"[Best Dice] Valid Loss: {valid_loss:.4f} | " + " | ".join([f'{k}: {v:.4f}' for k, v in valid_metrics.items()]))
            print(f"Best Dice model saved at {best_weight_dice_path} (Dice: {best_dice:.4f})")
        else:
            patience_counter += 1
            print(f"Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
        # loss 기준 best
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss

num_epochs = Args_experiments.epoch
patience = Args_experiments.patience
exp_name = "DeepSA-FineTuning-v2"

run_training(
    seg_model, optimizer, scheduler, loss_fn,
    Train_Loader, Valid_Loader, Args_experiments.device,
    num_epochs, patience, exp_name
)