In [1]:
%load_ext autoreload
%autoreload 2 
import torch, os
from tqdm.auto import tqdm
import torch.nn.functional as F  
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Subset

from huggingface_hub import notebook_login

#notebook_login()

In [None]:

# Model IMPORTS
from models.DinoV3.SemanDino import GlacierSegmenter
from models.DinoV3.GlacierDataset import GlacierDataset

# Utils import
from models.utils.metrics import get_combined_loss, get_iou_metric
from models.utils.training import train_one_epoch, validate_one_epoch

# constants
IGNORE_INDEX = 255
NUM_CLASS = 2
BATCH_SIZE = 8
LR = 1e-4
EPOCHS = 10
N_SPLITS = 5
NUM_WORKERS = 4

# Paths
TRAIN_IMAGE_DIR = "dataset/clean/images/"
TRAIN_MASK_DIR = "dataset/clean/masks/"
TEST_IMAGE_DIR = "dataset/test/images/"
TEST_MASK_DIR = "dataset/test/masks/"
CHECKPOINT_DIR = "checkpoints/"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on: {device}")


# Dataset
full_train_ds = GlacierDataset(
    image_dir="dataset/clean/images/",
    mask_dir="dataset/clean/masks/",
    mode="train",
)

train_ds_for_val = GlacierDataset(
    image_dir=TRAIN_IMAGE_DIR,
    mask_dir=TRAIN_MASK_DIR,
    mode="test"   # Augmentations désactivées
)

# Metrics
criterion = get_combined_loss(0.5, 0.5, ignore_index=IGNORE_INDEX)
iou_metric = get_iou_metric(ignore_index=IGNORE_INDEX)

# K-Fold Cross Validation
kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

for fold, (train_ids, val_ids) in enumerate(kfold.split(full_train_ds)):
    print(f"FOLD {fold+1}")
    print("------------------------------")
    # Initialize model
    model = GlacierSegmenter(num_classes=NUM_CLASS).to(device)
    params_to_update = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(params_to_update, lr=LR)
    
    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = Subset(full_train_ds, train_ids)
    val_subsampler = Subset(train_ds_for_val, val_ids)
    
    # Define data loaders for training and validation
    train_loader = DataLoader(
        train_subsampler,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_subsampler,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=False
    )
    
    best_val_iou = 0.0
    
    epoch_pbar = tqdm(range(EPOCHS), desc=f"Fold {fold+1} Progress")
    for epoch in epoch_pbar:
        
        # Train
        train_loss = train_one_epoch(
            model, train_loader, optimizer, criterion, device
        )
        
        # Validation
        val_loss, val_iou = validate_one_epoch(
            model, val_loader, criterion, iou_metric, device
        )
                
        # Affichage des métriques dans la barre de progression principale
        epoch_pbar.set_postfix({
            "T_Loss": f"{train_loss:.3f}", 
            "V_Loss": f"{val_loss:.3f}", 
            "V_IoU": f"{val_iou:.3f}"
        })
        
        # Save Best Model
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            torch.save(
                model.state_dict(),
                os.path.join(CHECKPOINT_DIR, f"best_model_fold_{fold+1}.pth")
            )

    print(f"FOLD {fold+1}: Best val IoU : {best_val_iou:.4f}")
    

Training on: cuda
FOLD 1
------------------------------


Fold 1 Progress:   0%|          | 0/10 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]

Training:   0%|          | 0/55 [00:00<?, ?it/s]

Validation:   0%|          | 0/14 [00:00<?, ?it/s]