In [1]:
import numpy as np 
import pandas as pd 
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2 
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold, GroupKFold, StratifiedGroupKFold
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import warnings
from collections import Counter, defaultdict
import torch.cuda.amp as amp
warnings.filterwarnings('ignore')
from tqdm import tqdm
import math
from timm.utils import ModelEmaV2



In [2]:
class Config:
    def __init__(self):
        
        self.num_header = 3
        self.model_name = "vit_base_patch14_dinov2"
        
        if self.model_name == 'convnext_base' or self.model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k':
            self.img_size = 224
        elif self.model_name == 'convnext_tiny':
            self.img_size = 500
        elif self.model_name == "vit_base_patch14_dinov2":
            self.img_size = 518

        self.freeze_lr = 1e-3
        self.unfreeze_lr = 1e-4

        self.head_lr = 1e-3
        self.backbone_lr = 1e-4
        
        self.wd = 1e-2

        self.warmup_epochs = 3
        self.epochs = 1
        
        self.loss_weights = {'total_loss' : 0.5, 'gdm_loss': 0.2, 'green_loss':0.1}
        self.scoring_weights = [0.5, 0.2, 0.1, 0.1, 0.1]

        self.train_path = "/kaggle/input/csiro-biomass/train.csv"
        self.parent_image_path = "/kaggle/input/csiro-biomass/"

        self.n_folds = 4
        self.random_state = 42

        self.batch_size = 8
        self.num_workers = 4
        self.n_epochs_before_unfreeze = 10
        self.n_epochs_after_unfreeze = 20

        self.all_targets = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Clover_g', 'Dry_Dead_g']

        self.accumulation_steps = 4

        

        
        

class TabularDataLoader:

    def __init__(self, config: Config):
        self.config = config

    def load_and_pivot_data(self):
        
        df = pd.read_csv(self.config.train_path)
        df_pivoted = df.pivot(
        index=["image_path", "Pre_GSHH_NDVI", "Height_Ave_cm", 'Sampling_Date', 'State'],
        columns="target_name",
        values="target",
        ).reset_index()

        df_pivoted['image_path'] = self.config.parent_image_path + df_pivoted['image_path']

        df_pivoted['Sampling_Date'] = pd.to_datetime(df_pivoted['Sampling_Date'])
        day = df_pivoted['Sampling_Date'].dt.dayofyear
        df_pivoted['date_encoding'] = np.sin(2 * np.pi * day / 365)
        
        return df_pivoted
    
    def create_stratified_folds_w_total(self, df: pd.DataFrame) -> pd.DataFrame:

        print(f"\nPreparing {self.config.n_folds}-Fold Cross-Validation...")
        
        df = df.copy()
        df['fold'] = -1
        
        # Bin targets (continuous → discrete)
        # Determine number of bins using Sturges' formula
        num_bins = min(10, int(np.floor(1 + np.log2(len(df)))))
        print(f"Stratifying Dry_Total_g into {num_bins} bins")
        
        df['total_bin'] = pd.cut(
            df['Dry_Total_g'], 
            bins=num_bins, 
            labels=False,
            duplicates='drop'  # Remove duplicate edges
        )
        
        # Stratified K-Fold split
        skf = StratifiedKFold(
            n_splits=self.config.n_folds,
            shuffle=True,
            random_state=self.config.random_state
        )
        
        for fold_num, (_, valid_idx) in enumerate(skf.split(df, df['total_bin'])):
            df.loc[valid_idx, 'fold'] = fold_num
        
        # Remove binning column (no longer needed)
        df = df.drop(columns=['total_bin'])
        
        print("\nFold distribution:")
        print(df['fold'].value_counts().sort_index())
        
        return df.drop(columns = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'date_encoding'])

    def create_group_folds_w_date(self, df: pd.DataFrame) -> pd.DataFrame:

        print(f"\nPreparing {self.config.n_folds}-Fold Cross-Validation...")
        cols_to_keep = ['Pre_GSHH_NDVI', 
                        'Height_Ave_cm', 'date_encoding']
        df = df.copy()
        df['fold'] = -1
        
        #X_scaled = StandardScaler().fit_transform(df[cols_to_keep].values)

        #kmeans = KMeans(n_clusters=15, random_state=42)
        #clusters = kmeans.fit_predict(X_scaled)

        skf = GroupKFold(n_splits=self.config.n_folds)
        
        for fold, (_, val_idx) in enumerate(skf.split(df, groups = df['Sampling_Date'])):
            df.loc[val_idx, "fold"] = fold

        print("\nFold distribution:")
        print(df['fold'].value_counts().sort_index())
        return df.drop(columns = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'date_encoding', 'Sampling_Date'])
                
    def create_stratified_groups_fold(self,df: pd.DataFrame) -> pd.DataFrame:
        print(f"\nPreparing {self.config.n_folds}-Fold Cross-Validation...")

        df["State_code"] = pd.factorize(df["State"])[0]

        n_splits = self.config.n_folds
        folds = defaultdict(list)
        fold_counts = [Counter() for _ in range(n_splits)]
        
        # Unique dates
        dates = df["Sampling_Date"].unique()
        
        # Trier les dates par nombre d’échantillons
        dates = sorted(dates, key=lambda d: len(df[df["Sampling_Date"]==d]), reverse=True)

        for date in dates:
            date_states = df[df["Sampling_Date"]==date]["State_code"]
            
            # Choisir le fold avec le moins de cette combinaison
            fold_idx = np.argmin([sum([fold_counts[i][s] for s in date_states]) for i in range(n_splits)])
            
            # Ajouter cette date au fold
            folds[fold_idx].append(date)
            
            # Mettre à jour les compteurs
            for s in date_states:
                fold_counts[fold_idx][s] += 1
        
        # Ajouter la colonne fold à df
        df["fold"] = -1
        for fold_idx, date_list in folds.items():
            df.loc[df["Sampling_Date"].isin(date_list), "fold"] = fold_idx

        print("\nFold distribution:")
        print(df['fold'].value_counts().sort_index())
        return df.drop(columns = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'date_encoding', 'Sampling_Date', 'State_code'])

    def create_str_group_w_month_and_state(self,df: pd.DataFrame)-> pd.DataFrame:
        df = df.copy()
        df['fold'] = -1
        
        groups = df['Sampling_Date'].dt.month
        skgf = StratifiedGroupKFold(n_splits=self.config.n_folds,shuffle=True,random_state=self.config.random_state) 

        X = df[self.config.all_targets]
        y = df['State']

        for fold, (_, val_idx) in enumerate(skgf.split(X,y=y, groups = groups)):
            df.loc[val_idx, "fold"] = fold
        
        print("\nFold distribution:")
        print(df['fold'].value_counts().sort_index())

        return pd.DataFrame(df[['image_path', 'Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Clover_g', 'Dry_Dead_g', 'fold']])


    def create_str_group_w_month_and_bin(self, df: pd.DataFrame) -> pd.DataFrame:
        print(f"\nPreparing Optimal {self.config.n_folds}-Fold Cross-Validation...")
        df = df.copy()
        df['fold'] = -1
    
        groups = df['Sampling_Date'] # Ou df['Sampling_Date'].astype(str)
    
        num_bins = min(10, int(np.floor(1 + np.log2(len(df)))))
        df['total_bin'] = pd.cut(
            df['Dry_Total_g'], 
            bins=num_bins, 
            labels=False
        )
    
        sgkf = StratifiedGroupKFold(
            n_splits=self.config.n_folds, 
            shuffle=True, 
            random_state=self.config.random_state
        )
    
        for fold, (_, val_idx) in enumerate(sgkf.split(X=df, y=df['total_bin'], groups=groups)):
            df.loc[val_idx, "fold"] = fold
    
        print("\nFold distribution (Images count):")
        print(df['fold'].value_counts().sort_index())
        
        print("\nMean Biomass per Fold:")
        print(df.groupby('fold')['Dry_Total_g'].mean())
    
        return df.drop(columns=['total_bin'])


    
class VisionDataTransformer:

    def __init__(self):
        self.img_size = Config().img_size
        return None

    def get_left_right_input(self,img):

        if img is None:
            raise ValueError("img error")
        
        h, w = img.shape[:2]
        
        mid = w // 2
        
        img_left = img[:, :mid]      
        img_right = img[:, mid:]

        return img_left, img_right
    
    def data_augmentation_pipeline(self):
        
        transform = A.Compose([A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.5), 
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(
        p=0.5
        ),
        # A.RandomGamma(
        #     gamma_limit=(80, 120),
        #     p=0.3
        # ),
        A.Resize(self.img_size, self.img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
        ToTensorV2()])
        
        return transform

    def data_reshape_only_pipeline(self):
        transform = A.Compose([A.Resize(self.img_size, self.img_size),A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2()])
        
        return transform
        

    def transform(self, img, part):
        assert part == "right" or part == "left"
        pipeline = self.data_augmentation_pipeline()
        
        if part == "right":
            img_cut = self.get_left_right_input(img)[1]
        elif part == 'left':
            img_cut = self.get_left_right_input(img)[0]
        
        transformed = pipeline(image = img_cut)
        transformed_image = transformed["image"]

        return transformed_image

    def reshape(self, img, part):
        assert part == "right" or part == "left"
        pipeline = self.data_reshape_only_pipeline()
        
        if part == "right":
            img_cut = self.get_left_right_input(img)[1]
        elif part == 'left':
            img_cut = self.get_left_right_input(img)[0]
        
        transformed = pipeline(image = img_cut)
        transformed_image = transformed["image"]

        return transformed_image

class BiomassDataset:
    
    def __init__(self, labels, transform = False):

        self.labels = labels
        self.vision_transformer = VisionDataTransformer()
        self.transform = transform
        self.train_targets = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g']
        self.all_targets = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Clover_g', 'Dry_Dead_g']
        self.train_targets = self.labels[self.train_targets].values
        self.all_targets = self.labels[self.all_targets].values

    def __len__(self):
        return len(self.labels)

    def __getitem__(self,idx):
        
        img_path = self.labels['image_path'].iloc[idx]
        img = cv2.imread(img_path)

        train_target = self.train_targets[idx]
        all_target = self.all_targets[idx]
        

        train_target_tensor = torch.tensor(train_target, dtype=torch.float32)
        all_target_tensor = torch.tensor(all_target, dtype=torch.float32)
        
        if self.transform:
            transformed_image_right = self.vision_transformer.transform(img, 'right')
            transformed_image_left = self.vision_transformer.transform(img, 'left')
            
            return transformed_image_right, transformed_image_left, train_target_tensor, all_target_tensor

        else:
            image_right = self.vision_transformer.reshape(img, 'right')
            image_left = self.vision_transformer.reshape(img, 'left')
            return image_right, image_left, train_target_tensor, all_target_tensor



class LocalMambaBlock(nn.Module):
    """
    Lightweight Mamba-style block (Gated CNN) from the reference notebook.
    Efficiently mixes tokens with linear complexity.
    """
    def __init__(self, dim, kernel_size=5, dropout=0.0):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        # Depthwise conv mixes spatial information locally
        self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
        self.gate = nn.Linear(dim, dim)
        self.proj = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # x: (Batch, Tokens, Dim)
        shortcut = x
        x = self.norm(x)
        # Gating mechanism
        g = torch.sigmoid(self.gate(x))
        x = x * g
        # Spatial mixing via 1D Conv (requires transpose)
        x = x.transpose(1, 2)  # -> (B, D, N)
        x = self.dwconv(x)
        x = x.transpose(1, 2)  # -> (B, N, D)
        # Projection
        x = self.proj(x)
        x = self.drop(x)
        return shortcut + x



class CsiroModel(nn.Module):

    def __init__(self, config):

        super().__init__()
        self.config = config
        self.backbone_model = timm.create_model(config.model_name, pretrained=True, num_classes=0, global_pool='')

        self.n_features = self.backbone_model.num_features
        self.n_combined = self.n_features * 2 

    

        self.fusion = nn.Sequential(
            LocalMambaBlock(self.n_features, kernel_size=5, dropout=0.1),
            LocalMambaBlock(self.n_features, kernel_size=5, dropout=0.1)
        )

        self.pool = nn.AdaptiveAvgPool1d(1)

        self.head_total = self.create_head()
        self.head_gdm = self.create_head()
        self.head_green = self.create_head()
        
    def create_head(self):

        head = nn.Sequential(nn.Linear(self.n_features, self.n_features//2),
                            #nn.LayerNorm(self.n_combined//2),
                            nn.GELU(),
                            nn.Dropout(0.2),
                            nn.Linear(self.n_features//2 , 1),
                            nn.Softplus()
                )
        return head

    def forward(self, img_right, img_left):

        right_embedd = self.backbone_model(img_right)
        left_embedd = self.backbone_model(img_left)

        combined_embedd = torch.concat([right_embedd, left_embedd], dim = 1)

        x_fused = self.fusion(combined_embedd)
        x_pool = self.pool(x_fused.transpose(1, 2)).flatten(1)

        out_total = self.head_total(x_pool)
        out_gdm = self.head_gdm(x_pool)
        out_green = self.head_green(x_pool)
        
        return out_total, out_gdm, out_green


class WeightedBiomassLoss(nn.Module):
  
    def __init__(self, loss_weights: dict[str, float]):

        super().__init__()
        self.criterion = nn.SmoothL1Loss(beta=5.0) # A variant of Huber loss: nn.SmoothL1Loss(beta=5.0)
        self.weights = loss_weights
    
    def forward(self, predictions, targets):

        pred_total, pred_gdm, pred_green = predictions
        
        true_total = targets[:, 0:1]  # Maintain [B, 1] shape
        true_gdm = targets[:, 1:2]
        true_green = targets[:, 2:3]
        
        loss_total = self.criterion(pred_total, true_total)
        loss_gdm = self.criterion(pred_gdm, true_gdm)
        loss_green = self.criterion(pred_green, true_green)
        
        total_loss = (
            self.weights['total_loss'] * loss_total +
            self.weights['gdm_loss'] * loss_gdm +
            self.weights['green_loss'] * loss_green
        )
        
        return total_loss


def csiro_scheduler(optimizer, max_epochs):
    def lr_lambda(epoch):
        config = Config()
        e = max(0, epoch)
        if e < config.warmup_epochs:
            return float(e + 1) / float(max(1, config.warmup_epochs))
        progress = (e - config.warmup_epochs) / float(max(1, max_epochs - config.warmup_epochs))
        progress = min(1.0,progress)
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


class Trainer:
    
    def __init__(self,model,ema_model, optimizer, config, loss, train_loader, valid_loader, device):
        self.model = model
        self.criterion = loss
        self.device = device
        
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model.to(device)
        self.config = config
        
        self.optimizer = optimizer
        self.scaler = amp.GradScaler()
        self.epoch_index = 0

        self.model_ema = ema_model
        self.model_ema.to(device)

    def train_one_epoch(self):
        self.model.train()
        self.model.module.backbone_model.eval()
        total_loss = 0
        self.optimizer.zero_grad()

        pbar = tqdm(self.train_loader, desc=f"Train Epoch {self.epoch_index}", leave=False)
        
        for i, (img_right, img_left, train_targets, all_targets) in enumerate(pbar):
            
            img_left = img_left.to(self.device)
            img_right = img_right.to(self.device)
            train_targets = train_targets.to(self.device)

            with amp.autocast(): #gradient accumulation
                out_total, out_gdm, out_green = self.model(img_right, img_left)

                out_tuple = (out_total, out_gdm, out_green)

                loss = self.criterion(out_tuple, train_targets)
                loss = loss / self.config.accumulation_steps

            
            self.scaler.scale(loss).backward()
            is_last_batch = (i + 1) == len(self.train_loader)
            
            if (i + 1) % self.config.accumulation_steps == 0 or is_last_batch:
                self.scaler.step(self.optimizer)                
                self.scaler.update()                
                self.optimizer.zero_grad()
                self.model_ema.update(self.model) #EMA

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            total_loss += loss.item() * self.config.accumulation_steps

        self.epoch_index +=1
        return total_loss / len(self.train_loader)


    def predict_with_tta(self, model, img_right, img_left):
        # images shape: (B, C, H, W)
        preds_total = []
        preds_gdm = []
        preds_green = []
        
        # 1. Original
        out_total, out_gdm, out_green = model(img_right, img_left)
        preds_total.append(out_total)
        preds_gdm.append(out_gdm)
        preds_green.append(out_green)
        
        
        # 2. Horizontal Flip
        out_h_total, out_h_gdm, out_h_green = model(torch.flip(img_right, [3]), torch.flip(img_left, [3]))
        preds_total.append(out_h_total)
        preds_gdm.append(out_h_gdm)
        preds_green.append(out_h_green)
        
        # 3. Vertical Flip
        out_v_total, out_v_gdm, out_v_green  = model(torch.flip(img_right, [2]), torch.flip(img_left, [2]))
        preds_total.append(out_v_total)
        preds_gdm.append(out_v_gdm)
        preds_green.append(out_v_green)
        
        
        final_pred_total = torch.stack(preds_total).mean(dim=0)
        final_pred_gdm = torch.stack(preds_gdm).mean(dim=0)
        final_pred_green = torch.stack(preds_green).mean(dim=0)
        
        return final_pred_total, final_pred_gdm, final_pred_green


    def valid_one_epoch(self, scorer, n_fold, tta = False, ema =False):
        self.model.eval()
        total_loss = 0

        predictions = {'dry_total_pred' : [], 'gdm_pred': [], 'dry_green_pred': []}
        targets = []

        pbar = tqdm(self.valid_loader, desc="Validating", leave=False)
        
        with torch.no_grad(): 
            for img_right, img_left, train_targets, all_targets in pbar:
                
                img_left = img_left.to(self.device)
                img_right = img_right.to(self.device)
                train_targets = train_targets.to(self.device)


                if tta:
                    out_total, out_gdm, out_green = self.predict_with_tta(self.model,img_right, img_left)
                else:
                    if ema:
                        out_total, out_gdm, out_green = self.model_ema(img_right, img_left)
                    else:
                        out_total, out_gdm, out_green = self.model(img_right, img_left)
    
                out_tuple = (out_total, out_gdm, out_green)
    
                loss = self.criterion(out_tuple, train_targets)
                total_loss += loss.item()

                predictions['dry_total_pred'].append(out_total.cpu().numpy())
                predictions['gdm_pred'].append(out_gdm.cpu().numpy())
                predictions['dry_green_pred'].append(out_green.cpu().numpy())

                targets.append(all_targets.cpu().numpy())


        predictions = {
            k: np.concatenate(v).flatten() 
            for k, v in predictions.items()
        }
        targets = np.concatenate(targets)
        
        avg_loss = total_loss / len(self.valid_loader)
        score = scorer.compute_score(predictions, targets)

        return avg_loss, score


class CsiroScorer:

    def __init__(self, scoring_weights):
        self.scoring_weights = scoring_weights

    def weighted_r2_score(self,y_true, y_pred, weights):
        y_true = np.asarray(y_true).reshape(-1)
        y_pred = np.asarray(y_pred).reshape(-1)
        weights = np.asarray(weights).reshape(-1)

        #print(f'y_true: {y_true[:10]}, y_pred: {y_pred[:10]}, weights: {weights[:10]}')
        
        # Moyenne pondérée
        y_mean = np.sum(weights * y_true) / np.sum(weights)
    
        ss_res = np.sum(weights * (y_true - y_pred) ** 2)
        ss_tot = np.sum(weights * (y_true - y_mean) ** 2)

        return 1 - ss_res / ss_tot
    
    
    def compute_score(self, pred_targets:dict, true_targets):
        
        dry_total_pred = np.array(pred_targets['dry_total_pred'])
        gdm_pred = np.array(pred_targets['gdm_pred'])
        dry_green_pred = np.array(pred_targets['dry_green_pred'])

        dry_clover_pred = np.maximum(0, gdm_pred - dry_green_pred)
        dry_dead_pred = np.maximum(0, dry_total_pred - gdm_pred)

        
        y_preds = np.stack([
            dry_total_pred, gdm_pred, dry_green_pred, dry_clover_pred, dry_dead_pred
        ], axis=1)
        

        weights_per_target = np.array([
        0.5,  # Dry_Total_g
        0.2,  # GDM_g
        0.1,  # Dry_Green_g
        0.1,  # Dry_Clover_g
        0.1,  # Dry_Dead_g
        ])
        
        weights = np.tile(weights_per_target, true_targets.shape[0])
        
        score =self.weighted_r2_score(
            y_true=true_targets.reshape(-1),
            y_pred=y_preds.reshape(-1),
            weights=weights
        )
        
        return float(score)
        

class CsiroPipeline:
    
    def __init__(self, config):
        self.config = config
        self.tabular_data_loader = TabularDataLoader(self.config)
        self.score = CsiroScorer(self.config.scoring_weights)
        self.best_score = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.loss = WeightedBiomassLoss(self.config.loss_weights)

        self.all_scores = []


    def get_train_valid_from_fold(self, df, n_fold : int):
        assert n_fold <= self.config.n_folds

        train = df[df['fold'] != n_fold].reset_index(drop = True)
        valid = df[df['fold'] == n_fold].reset_index(drop = True)

        return BiomassDataset(train, transform = True), BiomassDataset(valid, transform = False)


    def unfreeze_stages(self,model, stage_ids):
        for i in stage_ids:
            for p in model.module.backbone_model.stages[i].parameters():
                p.requires_grad = True

    def unfreeze_backbone(self, model):
        for p in model.module.backbone_model.parameters():
            p.requires_grad = True
    
    def valid_one_fold(self, df, n_fold:int, scorer):
        self.best_score = 0
        self.model = CsiroModel(self.config)
        self.model_ema = ModelEmaV2(self.model, decay=0.90)

        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs (DataParallel)")
            self.model = nn.DataParallel(self.model)
     
        train_dataset, valid_dataset = self.get_train_valid_from_fold(df, n_fold)
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        valid_loader = DataLoader(
            valid_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True
        )


        for param in self.model.module.backbone_model.parameters():
            param.requires_grad = False
        
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.config.head_lr,
            weight_decay = self.config.wd
        )

        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer, mode='min', factor=0.2, patience=3
        # )

        scheduler = csiro_scheduler(optimizer, max_epochs = self.config.n_epochs_before_unfreeze)
        
        trainer = Trainer(self.model, self.model_ema, optimizer, self.config, self.loss, train_loader, valid_loader, self.device)
        print(f'\n=== Fold {n_fold} | Phase 1: Backbone Frozen ===')
        for n in range(self.config.n_epochs_before_unfreeze):
            train_loss = trainer.train_one_epoch()
            valid_loss, score = trainer.valid_one_epoch(scorer, n_fold, tta = True, ema = True)

            scheduler.step()
            print(f"[Epoch {n}] Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Score: {score:.4f}")
            if score > self.best_score:
                print(f"  >>> Score Improved ({self.best_score:.4f} -> {score:.4f}). Saving model.")
                torch.save(self.model_ema.module.state_dict(), f"best_model_fold_{n_fold}.pth")
                self.best_score = score
            


        for param in self.model.module.backbone_model.parameters():
            param.requires_grad = True
        
        #Unfreeze Backbone
        self.unfreeze_backbone(self.model)

        head_lr = self.config.unfreeze_lr        
        backbone_lr = head_lr * 0.1

        head_params = list(self.model.module.head_total.parameters()) + list(self.model.module.head_gdm.parameters()) + list(self.model.module.head_green.parameters())
        
        optimizer = optim.Adam([
            {'params': self.model.module.backbone_model.parameters(), 'lr': 5e-6, 'weight_decay': self.config.wd}, 
            {'params': head_params, 'lr': 5e-5, 'weight_decay': self.config.wd}
        ]
        )

        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        #     optimizer, mode='min', factor=0.1, patience=2
        # )
        
        scheduler = csiro_scheduler(optimizer, max_epochs = self.config.n_epochs_after_unfreeze)
        trainer = Trainer(self.model, self.model_ema, optimizer, self.config, self.loss, train_loader, valid_loader, self.device)

        print(f'\n=== Fold {n_fold} | Phase 2: Backbone Unfrozen ===')
        for n in range(self.config.n_epochs_after_unfreeze):
            print(f"epoch {n + self.config.n_epochs_before_unfreeze}")

            # if n == 10:
            #     print('Continuation of training: stage 2-3 Backbone Unfreeze')
            #     self.unfreeze_stages(self.model, [2,3])
            #     optimizer = optim.Adam(
            #         filter(lambda p: p.requires_grad, self.model.parameters()),
            #         lr=self.config.unfreeze_lr
            #     )
            #     scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            #         optimizer, mode='min', factor=0.2, patience=3
            #     )
            #     trainer = Trainer(self.model, optimizer, self.config, self.loss, train_loader, valid_loader, self.device)
            
            train_loss = trainer.train_one_epoch()
            valid_loss, score = trainer.valid_one_epoch(scorer, n_fold, tta = True, ema = True)
            scheduler.step()
            print(f"[Epoch {n}] Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Score: {score:.4f}")
            
            if score > self.best_score:
                print(f"  >>> Score Improved ({self.best_score:.4f} -> {score:.4f}). Saving model.")
                torch.save(self.model_ema.module.state_dict(), f"best_model_fold_{n_fold}.pth")
                self.best_score = score
        
        return self.best_score
        

    def run(self):
        df = self.tabular_data_loader.load_and_pivot_data()
        df = self.tabular_data_loader.create_str_group_w_month_and_state(df)
        for fold in range(self.config.n_folds):
            print(f"start training on fold {fold}")
            best_score = self.valid_one_fold(df, fold, self.score)
            self.all_scores.append(best_score)   

        print(self.all_scores)

In [3]:
config = Config()
print(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
pipeline = CsiroPipeline(config)
pipeline.run()

cuda

Fold distribution:
fold
0    115
1     91
2     95
3     56
Name: count, dtype: int64
start training on fold 0


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Using 2 GPUs (DataParallel)

=== Fold 0 | Phase 1: Backbone Frozen ===


                                                           

[Epoch 0] Train Loss: 28.5639 | Val Loss: 14.1115 | Score: -0.3866


                                                           

[Epoch 1] Train Loss: 18.6302 | Val Loss: 16.1110 | Score: -0.3375


                                                           

[Epoch 2] Train Loss: 16.4941 | Val Loss: 9.3728 | Score: 0.3593
  >>> Score Improved (0.0000 -> 0.3593). Saving model.


                                                           

[Epoch 3] Train Loss: 14.2719 | Val Loss: 8.0945 | Score: 0.4064
  >>> Score Improved (0.3593 -> 0.4064). Saving model.


                                                           

[Epoch 4] Train Loss: 11.5351 | Val Loss: 8.3723 | Score: 0.4706
  >>> Score Improved (0.4064 -> 0.4706). Saving model.


                                                           

[Epoch 5] Train Loss: 9.4000 | Val Loss: 6.3136 | Score: 0.5827
  >>> Score Improved (0.4706 -> 0.5827). Saving model.


                                                           

[Epoch 6] Train Loss: 8.0185 | Val Loss: 6.3514 | Score: 0.5688


                                                           

[Epoch 7] Train Loss: 7.2732 | Val Loss: 5.6913 | Score: 0.6365
  >>> Score Improved (0.5827 -> 0.6365). Saving model.


                                                           

[Epoch 8] Train Loss: 6.7942 | Val Loss: 6.2510 | Score: 0.6036


                                                           

[Epoch 9] Train Loss: 6.8581 | Val Loss: 6.1095 | Score: 0.6104

=== Fold 0 | Phase 2: Backbone Unfrozen ===
epoch 10


                                                           

[Epoch 0] Train Loss: 6.8585 | Val Loss: 5.6571 | Score: 0.6402
  >>> Score Improved (0.6365 -> 0.6402). Saving model.
epoch 11


                                                           

[Epoch 1] Train Loss: 7.4198 | Val Loss: 6.4384 | Score: 0.5825
epoch 12


                                                           

[Epoch 2] Train Loss: 6.6596 | Val Loss: 5.6042 | Score: 0.6503
  >>> Score Improved (0.6402 -> 0.6503). Saving model.
epoch 13


                                                           

[Epoch 3] Train Loss: 5.6588 | Val Loss: 5.2478 | Score: 0.6706
  >>> Score Improved (0.6503 -> 0.6706). Saving model.
epoch 14


                                                           

[Epoch 4] Train Loss: 4.8574 | Val Loss: 5.5586 | Score: 0.6515
epoch 15


                                                           

[Epoch 5] Train Loss: 4.6690 | Val Loss: 7.2925 | Score: 0.5003
epoch 16


                                                           

[Epoch 6] Train Loss: 5.3934 | Val Loss: 6.0441 | Score: 0.5902
epoch 17


                                                           

[Epoch 7] Train Loss: 4.2785 | Val Loss: 5.3078 | Score: 0.6767
  >>> Score Improved (0.6706 -> 0.6767). Saving model.
epoch 18


                                                           

[Epoch 8] Train Loss: 3.9111 | Val Loss: 5.4152 | Score: 0.6703
epoch 19


                                                           

[Epoch 9] Train Loss: 3.7170 | Val Loss: 5.4398 | Score: 0.6556
epoch 20


                                                           

[Epoch 10] Train Loss: 3.5847 | Val Loss: 5.8355 | Score: 0.6538
epoch 21


                                                           

[Epoch 11] Train Loss: 3.3380 | Val Loss: 5.4601 | Score: 0.6501
epoch 22


                                                           

[Epoch 12] Train Loss: 3.3513 | Val Loss: 5.3875 | Score: 0.6622
epoch 23


                                                           

[Epoch 13] Train Loss: 2.9369 | Val Loss: 5.4749 | Score: 0.6421
epoch 24


                                                           

[Epoch 14] Train Loss: 3.0279 | Val Loss: 5.5991 | Score: 0.6551
epoch 25


                                                           

[Epoch 15] Train Loss: 3.1713 | Val Loss: 5.3487 | Score: 0.6560
epoch 26


                                                           

[Epoch 16] Train Loss: 2.7295 | Val Loss: 5.7929 | Score: 0.6508
epoch 27


                                                           

[Epoch 17] Train Loss: 2.8084 | Val Loss: 5.7119 | Score: 0.6613
epoch 28


                                                           

[Epoch 18] Train Loss: 2.8017 | Val Loss: 5.3498 | Score: 0.6584
epoch 29


                                                           

[Epoch 19] Train Loss: 2.8543 | Val Loss: 5.5405 | Score: 0.6605
start training on fold 1
Using 2 GPUs (DataParallel)

=== Fold 1 | Phase 1: Backbone Frozen ===


                                                           

[Epoch 0] Train Loss: 22.9429 | Val Loss: 23.6860 | Score: -0.4636


                                                           

[Epoch 1] Train Loss: 14.8920 | Val Loss: 20.1070 | Score: -0.0213


                                                           

[Epoch 2] Train Loss: 12.6349 | Val Loss: 15.3550 | Score: 0.3521
  >>> Score Improved (0.0000 -> 0.3521). Saving model.


                                                           

[Epoch 3] Train Loss: 10.8677 | Val Loss: 16.5309 | Score: 0.1829


                                                           

[Epoch 4] Train Loss: 8.7569 | Val Loss: 12.7889 | Score: 0.5103
  >>> Score Improved (0.3521 -> 0.5103). Saving model.


                                                           

[Epoch 5] Train Loss: 7.4299 | Val Loss: 9.4583 | Score: 0.6601
  >>> Score Improved (0.5103 -> 0.6601). Saving model.


                                                           

[Epoch 6] Train Loss: 6.8038 | Val Loss: 9.2746 | Score: 0.6791
  >>> Score Improved (0.6601 -> 0.6791). Saving model.


                                                           

[Epoch 7] Train Loss: 6.4383 | Val Loss: 9.2848 | Score: 0.6848
  >>> Score Improved (0.6791 -> 0.6848). Saving model.


                                                           

[Epoch 8] Train Loss: 6.1128 | Val Loss: 8.4859 | Score: 0.7178
  >>> Score Improved (0.6848 -> 0.7178). Saving model.


                                                           

[Epoch 9] Train Loss: 6.2435 | Val Loss: 8.6283 | Score: 0.7209
  >>> Score Improved (0.7178 -> 0.7209). Saving model.

=== Fold 1 | Phase 2: Backbone Unfrozen ===
epoch 10


                                                           

[Epoch 0] Train Loss: 6.2091 | Val Loss: 8.2809 | Score: 0.7391
  >>> Score Improved (0.7209 -> 0.7391). Saving model.
epoch 11


                                                           

[Epoch 1] Train Loss: 5.5722 | Val Loss: 8.0941 | Score: 0.7478
  >>> Score Improved (0.7391 -> 0.7478). Saving model.
epoch 12


                                                           

[Epoch 2] Train Loss: 5.0965 | Val Loss: 9.2751 | Score: 0.7568
  >>> Score Improved (0.7478 -> 0.7568). Saving model.
epoch 13


                                                           

[Epoch 3] Train Loss: 5.3945 | Val Loss: 9.6049 | Score: 0.6757
epoch 14


                                                           

[Epoch 4] Train Loss: 5.4978 | Val Loss: 8.1007 | Score: 0.7589
  >>> Score Improved (0.7568 -> 0.7589). Saving model.
epoch 15


                                                           

[Epoch 5] Train Loss: 5.0396 | Val Loss: 7.4960 | Score: 0.7775
  >>> Score Improved (0.7589 -> 0.7775). Saving model.
epoch 16


                                                           

[Epoch 6] Train Loss: 4.5599 | Val Loss: 7.5644 | Score: 0.8050
  >>> Score Improved (0.7775 -> 0.8050). Saving model.
epoch 17


                                                           

[Epoch 7] Train Loss: 4.4030 | Val Loss: 8.2372 | Score: 0.7842
epoch 18


                                                           

[Epoch 8] Train Loss: 3.9125 | Val Loss: 7.4754 | Score: 0.8018
epoch 19


                                                           

[Epoch 9] Train Loss: 3.3745 | Val Loss: 7.2194 | Score: 0.8009
epoch 20


                                                           

[Epoch 10] Train Loss: 3.1661 | Val Loss: 7.1324 | Score: 0.8166
  >>> Score Improved (0.8050 -> 0.8166). Saving model.
epoch 21


                                                           

[Epoch 11] Train Loss: 2.9666 | Val Loss: 7.4607 | Score: 0.8108
epoch 22


                                                           

[Epoch 12] Train Loss: 3.0528 | Val Loss: 7.2227 | Score: 0.8044
epoch 23


                                                           

[Epoch 13] Train Loss: 2.8241 | Val Loss: 7.6027 | Score: 0.8056
epoch 24


                                                           

[Epoch 14] Train Loss: 2.6378 | Val Loss: 7.4141 | Score: 0.8102
epoch 25


                                                           

[Epoch 15] Train Loss: 2.4525 | Val Loss: 7.3553 | Score: 0.8026
epoch 26


                                                           

[Epoch 16] Train Loss: 2.4750 | Val Loss: 7.5820 | Score: 0.8054
epoch 27


                                                           

[Epoch 17] Train Loss: 2.6145 | Val Loss: 7.4503 | Score: 0.8075
epoch 28


                                                           

[Epoch 18] Train Loss: 2.4842 | Val Loss: 7.0377 | Score: 0.8083
epoch 29


                                                           

[Epoch 19] Train Loss: 2.3333 | Val Loss: 7.0784 | Score: 0.8084
start training on fold 2
Using 2 GPUs (DataParallel)

=== Fold 2 | Phase 1: Backbone Frozen ===


                                                           

[Epoch 0] Train Loss: 28.2325 | Val Loss: 11.7587 | Score: -0.0874


                                                           

[Epoch 1] Train Loss: 16.4887 | Val Loss: 12.5173 | Score: 0.0940
  >>> Score Improved (0.0000 -> 0.0940). Saving model.


                                                           

[Epoch 2] Train Loss: 13.6021 | Val Loss: 12.9395 | Score: 0.1209
  >>> Score Improved (0.0940 -> 0.1209). Saving model.


                                                           

[Epoch 3] Train Loss: 10.8627 | Val Loss: 8.1126 | Score: 0.4920
  >>> Score Improved (0.1209 -> 0.4920). Saving model.


                                                           

[Epoch 4] Train Loss: 8.1514 | Val Loss: 7.1979 | Score: 0.5668
  >>> Score Improved (0.4920 -> 0.5668). Saving model.


                                                           

[Epoch 5] Train Loss: 7.7725 | Val Loss: 6.7323 | Score: 0.6010
  >>> Score Improved (0.5668 -> 0.6010). Saving model.


                                                           

[Epoch 6] Train Loss: 6.8423 | Val Loss: 5.9882 | Score: 0.6554
  >>> Score Improved (0.6010 -> 0.6554). Saving model.


                                                           

[Epoch 7] Train Loss: 6.7599 | Val Loss: 6.1088 | Score: 0.6482


                                                           

[Epoch 8] Train Loss: 6.5372 | Val Loss: 7.3810 | Score: 0.5903


                                                           

[Epoch 9] Train Loss: 6.3794 | Val Loss: 6.5977 | Score: 0.6359

=== Fold 2 | Phase 2: Backbone Unfrozen ===
epoch 10


                                                           

[Epoch 0] Train Loss: 6.9947 | Val Loss: 8.4270 | Score: 0.5313
epoch 11


                                                           

[Epoch 1] Train Loss: 6.2167 | Val Loss: 5.6730 | Score: 0.6484
epoch 12


                                                           

[Epoch 2] Train Loss: 5.9278 | Val Loss: 6.3565 | Score: 0.6695
  >>> Score Improved (0.6554 -> 0.6695). Saving model.
epoch 13


                                                           

[Epoch 3] Train Loss: 4.9630 | Val Loss: 7.0070 | Score: 0.6271
epoch 14


                                                           

[Epoch 4] Train Loss: 4.9901 | Val Loss: 5.2369 | Score: 0.7052
  >>> Score Improved (0.6695 -> 0.7052). Saving model.
epoch 15


                                                           

[Epoch 5] Train Loss: 4.5748 | Val Loss: 6.7638 | Score: 0.6381
epoch 16


                                                           

[Epoch 6] Train Loss: 4.1801 | Val Loss: 5.0528 | Score: 0.7145
  >>> Score Improved (0.7052 -> 0.7145). Saving model.
epoch 17


                                                           

[Epoch 7] Train Loss: 4.0963 | Val Loss: 5.3546 | Score: 0.6899
epoch 18


                                                           

[Epoch 8] Train Loss: 3.7324 | Val Loss: 5.2753 | Score: 0.7245
  >>> Score Improved (0.7145 -> 0.7245). Saving model.
epoch 19


                                                           

[Epoch 9] Train Loss: 3.7814 | Val Loss: 5.3581 | Score: 0.7236
epoch 20


                                                           

[Epoch 10] Train Loss: 3.4710 | Val Loss: 4.9394 | Score: 0.6981
epoch 21


                                                           

[Epoch 11] Train Loss: 3.3802 | Val Loss: 5.3486 | Score: 0.7148
epoch 22


                                                           

[Epoch 12] Train Loss: 3.0141 | Val Loss: 6.2726 | Score: 0.6782
epoch 23


                                                           

[Epoch 13] Train Loss: 3.0967 | Val Loss: 5.3283 | Score: 0.7219
epoch 24


                                                           

[Epoch 14] Train Loss: 2.7894 | Val Loss: 4.9534 | Score: 0.7340
  >>> Score Improved (0.7245 -> 0.7340). Saving model.
epoch 25


                                                           

[Epoch 15] Train Loss: 2.7108 | Val Loss: 5.1457 | Score: 0.7279
epoch 26


                                                           

[Epoch 16] Train Loss: 2.6208 | Val Loss: 5.0888 | Score: 0.7290
epoch 27


                                                           

[Epoch 17] Train Loss: 2.5248 | Val Loss: 5.0872 | Score: 0.7307
epoch 28


                                                           

[Epoch 18] Train Loss: 2.6057 | Val Loss: 5.0461 | Score: 0.7330
epoch 29


                                                           

[Epoch 19] Train Loss: 2.5249 | Val Loss: 5.0361 | Score: 0.7320
start training on fold 3
Using 2 GPUs (DataParallel)

=== Fold 3 | Phase 1: Backbone Frozen ===


                                                         

[Epoch 0] Train Loss: 22.5214 | Val Loss: 25.5136 | Score: -0.4429


                                                         

[Epoch 1] Train Loss: 14.7249 | Val Loss: 23.1457 | Score: -0.2781


                                                         

[Epoch 2] Train Loss: 12.0282 | Val Loss: 15.9783 | Score: 0.2485
  >>> Score Improved (0.0000 -> 0.2485). Saving model.


                                                         

[Epoch 3] Train Loss: 9.8968 | Val Loss: 13.9597 | Score: 0.3681
  >>> Score Improved (0.2485 -> 0.3681). Saving model.


                                                         

[Epoch 4] Train Loss: 8.1998 | Val Loss: 13.2416 | Score: 0.4123
  >>> Score Improved (0.3681 -> 0.4123). Saving model.


                                                         

[Epoch 5] Train Loss: 7.0456 | Val Loss: 12.7137 | Score: 0.4247
  >>> Score Improved (0.4123 -> 0.4247). Saving model.


                                                         

[Epoch 6] Train Loss: 6.6767 | Val Loss: 11.8411 | Score: 0.4765
  >>> Score Improved (0.4247 -> 0.4765). Saving model.


                                                         

[Epoch 7] Train Loss: 5.8490 | Val Loss: 11.3773 | Score: 0.5047
  >>> Score Improved (0.4765 -> 0.5047). Saving model.


                                                         

[Epoch 8] Train Loss: 5.5599 | Val Loss: 11.6047 | Score: 0.4924


                                                         

[Epoch 9] Train Loss: 5.4617 | Val Loss: 11.1916 | Score: 0.5183
  >>> Score Improved (0.5047 -> 0.5183). Saving model.

=== Fold 3 | Phase 2: Backbone Unfrozen ===
epoch 10


                                                         

[Epoch 0] Train Loss: 5.8367 | Val Loss: 10.9304 | Score: 0.5590
  >>> Score Improved (0.5183 -> 0.5590). Saving model.
epoch 11


                                                         

[Epoch 1] Train Loss: 5.5850 | Val Loss: 10.2258 | Score: 0.5816
  >>> Score Improved (0.5590 -> 0.5816). Saving model.
epoch 12


                                                         

[Epoch 2] Train Loss: 4.8586 | Val Loss: 10.3822 | Score: 0.5611
epoch 13


                                                         

[Epoch 3] Train Loss: 5.1080 | Val Loss: 9.8259 | Score: 0.6514
  >>> Score Improved (0.5816 -> 0.6514). Saving model.
epoch 14


                                                         

[Epoch 4] Train Loss: 4.6596 | Val Loss: 10.8396 | Score: 0.5449
epoch 15


                                                         

[Epoch 5] Train Loss: 4.5165 | Val Loss: 10.0783 | Score: 0.5689
epoch 16


                                                         

[Epoch 6] Train Loss: 3.9046 | Val Loss: 9.1675 | Score: 0.6483
epoch 17


                                                         

[Epoch 7] Train Loss: 3.2478 | Val Loss: 9.2273 | Score: 0.6216
epoch 18


                                                         

[Epoch 8] Train Loss: 2.8693 | Val Loss: 9.6539 | Score: 0.6219
epoch 19


                                                         

[Epoch 9] Train Loss: 2.8405 | Val Loss: 9.1493 | Score: 0.6321
epoch 20


                                                         

[Epoch 10] Train Loss: 2.7322 | Val Loss: 9.9953 | Score: 0.6430
epoch 21


                                                         

[Epoch 11] Train Loss: 2.9006 | Val Loss: 9.6560 | Score: 0.6244
epoch 22


                                                         

[Epoch 12] Train Loss: 2.6871 | Val Loss: 9.5457 | Score: 0.6162
epoch 23


                                                         

[Epoch 13] Train Loss: 2.5416 | Val Loss: 9.2226 | Score: 0.6336
epoch 24


                                                         

[Epoch 14] Train Loss: 2.3306 | Val Loss: 9.4558 | Score: 0.6234
epoch 25


                                                         

[Epoch 15] Train Loss: 2.1091 | Val Loss: 9.8705 | Score: 0.6306
epoch 26


                                                         

[Epoch 16] Train Loss: 2.0506 | Val Loss: 9.8049 | Score: 0.6342
epoch 27


                                                         

[Epoch 17] Train Loss: 2.0439 | Val Loss: 9.6142 | Score: 0.6274
epoch 28


                                                         

[Epoch 18] Train Loss: 2.0432 | Val Loss: 9.5799 | Score: 0.6260
epoch 29


                                                         

[Epoch 19] Train Loss: 2.0501 | Val Loss: 9.6305 | Score: 0.6273
[0.6766994394646929, 0.8166043147154585, 0.7340455502715365, 0.6514164163460323]


