In [50]:
# ==============================================================================
# 03 - BYOL PRE-TRAINING (v4.1 - NUMERICAL STABILITY FIX)
# ==============================================================================
# Purpose: To train the BYOL model with enhanced numerical stability controls
#          to prevent loss explosion.
# Changes:
#   - Added epsilon to F.normalize in the loss function.
#   - Switched autocast to use bfloat16 for a wider dynamic range.
#   - Automatically deletes old checkpoints on new runs to ensure a fresh start.
# ==============================================================================

# --- 0. Install and Import ---
!pip install "zarr>=2.10.0" numcodecs -q

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T, torchvision.models as models
from pathlib import Path
import numpy as np, zarr, random, math, os, copy, zipfile, tempfile, gc
from tqdm.notebook import tqdm

print("Final BYOL Pre-training Script Initialized (Numerical Stability Fix).")

# --- 1. Configuration ---
CONFIG = {
    'data_path': "/kaggle/input/01-data-preparation/data/ssl4eo-s12/train/S2RGB",
    'output_dir': "/kaggle/working/",
    'epochs': 30,
    'batch_size': 128,
    'learning_rate': 1e-4,
    'weight_decay': 1.5e-6,
    'image_size': 224,
    'projection_dim': 256,
    'hidden_dim': 4096,
    'base_tau': 0.996,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'early_stopping_patience': 7,
    'data_subset_percentage': 0.10,
    'checkpoint_filename': "byol_checkpoint_stable.pth",
    'best_model_filename': "byol_best_encoder_stable.pth"
}

# --- 2. Helper Classes & Functions ---
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience, self.verbose, self.delta = patience, verbose, delta
        self.counter, self.best_score, self.early_stop, self.val_loss_min = 0, None, False, np.Inf
    def __call__(self, val_loss, model, path):
        if not math.isfinite(val_loss):
            print("Loss is not finite, triggering early stop."); self.early_stop = True; return
        score = -val_loss
        if self.best_score is None: self.best_score = score; self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else: self.best_score = score; self.save_checkpoint(val_loss, model, path); self.counter = 0
    def save_checkpoint(self, val_loss, model, path):
        if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best model...')
        torch.save(model.online_encoder.state_dict(), path); self.val_loss_min = val_loss

def save_checkpoint(epoch, model, optimizer, scheduler, loss, config):
    checkpoint_path = os.path.join(config['output_dir'], config['checkpoint_filename'])
    state = {'epoch': epoch + 1, 'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss}
    torch.save(state, checkpoint_path)

def load_checkpoint(model, optimizer, scheduler, config):
    checkpoint_path = os.path.join(config['output_dir'], config['checkpoint_filename'])
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        print(f"Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=config['device'])
        model.load_state_dict(checkpoint['model_state_dict']); optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict']); start_epoch = checkpoint['epoch']
        print(f"Resumed from epoch {start_epoch}, with previous loss {checkpoint['loss']:.4f}")
    else: print("No checkpoint found, starting from scratch.")
    return start_epoch

# --- 3. Dataset, Model, and Loss ---
class SSL4EODataset(Dataset):
    def __init__(self, root_dir, subset_percentage=1.0):
        self.root_dir = Path(root_dir)
        all_zarr_files = sorted(list(self.root_dir.glob("*.zarr.zip")))
        num_files_to_use = int(len(all_zarr_files) * subset_percentage)
        self.zarr_files = random.sample(all_zarr_files, num_files_to_use) if subset_percentage < 1.0 else all_zarr_files
        self.images = self._preload_images(); self.total_images = len(self.images)
        print(f"\nDataset Initialized: Using {subset_percentage*100:.0f}% of data.")
        print(f"-> Pre-loaded {self.total_images:,} images into memory.")
        self.transform_t = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.GaussianBlur(23, (0.1, 2.0)),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        self.transform_t_prime = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.RandomSolarize(0.5, p=0.2),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def _preload_images(self):
        images = []; print("Pre-loading images into RAM... This may take a few minutes.")
        for file_path in tqdm(self.zarr_files, desc="Loading files"):
            with tempfile.TemporaryDirectory() as temp_dir:
                with zipfile.ZipFile(str(file_path), 'r') as zf: zf.extractall(temp_dir)
                zarr_group = zarr.open(temp_dir, mode='r')
                numpy_array = zarr_group['bands'][:]; images.extend(numpy_array.reshape(-1, *numpy_array.shape[2:]))
        return images
    def __len__(self): return self.total_images
    def __getitem__(self, idx):
        image_chw = self.images[idx]; image_hwc = np.transpose(image_chw, (1, 2, 0))
        return self.transform_t(image_hwc), self.transform_t_prime(image_hwc)

class MLP(nn.Module):
    def __init__(self, i, h, o): super().__init__(); self.net = nn.Sequential(nn.Linear(i, h), nn.BatchNorm1d(h), nn.ReLU(True), nn.Linear(h, o))
    def forward(self, x): return self.net(x)

class BYOLModel(nn.Module):
    def __init__(self, encoder, encoder_dim, p_dim, h_dim, b_tau):
        super().__init__(); self.base_tau = b_tau
        self.online_encoder=encoder; self.online_projector=MLP(encoder_dim,h_dim,p_dim); self.online_predictor=MLP(p_dim,h_dim,p_dim)
        self.target_encoder=copy.deepcopy(encoder); self.target_projector=copy.deepcopy(self.online_projector)
        for p in self.target_encoder.parameters(): p.requires_grad=False
        for p in self.target_projector.parameters(): p.requires_grad=False
    @torch.no_grad()
    def update_target_network(self, cs, ts):
        tau=1-(1-self.base_tau)*(math.cos(math.pi*cs/ts)+1)/2
        for o,t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): t.data.mul_(tau).add_(o.data,alpha=1-tau)
        for o,t in zip(self.online_projector.parameters(), self.target_projector.parameters()): t.data.mul_(tau).add_(o.data,alpha=1-tau)
    def forward(self, v1, v2):
        op1=self.online_predictor(self.online_projector(self.online_encoder(v1))); op2=self.online_predictor(self.online_projector(self.online_encoder(v2)))
        with torch.no_grad(): tp1=self.target_projector(self.target_encoder(v1)); tp2=self.target_projector(self.target_encoder(v2))
        return (op1, op2), (tp1.detach(), tp2.detach())

# MODIFIED: Stabilized Loss Function
def byol_loss_fn(p, t):
    p1, p2 = p; t1, t2 = t
    eps = 1e-6
    p1_norm = F.normalize(p1, p=2, dim=-1, eps=eps)
    p2_norm = F.normalize(p2, p=2, dim=-1, eps=eps)
    t1_norm = F.normalize(t1.detach(), p=2, dim=-1, eps=eps)
    t2_norm = F.normalize(t2.detach(), p=2, dim=-1, eps=eps)
    loss1 = 2 - 2 * (p1_norm * t2_norm).sum(dim=-1)
    loss2 = 2 - 2 * (p2_norm * t1_norm).sum(dim=-1)
    return (loss1 + loss2).mean() * 0.5


# --- 4. Main Training Execution ---
if __name__ == '__main__':
    gc.collect(); torch.cuda.empty_cache()
    device = torch.device(CONFIG['device']); print(f"Using device: {device}")
    
    dataset = SSL4EODataset(root_dir=CONFIG['data_path'], subset_percentage=CONFIG.get('data_subset_percentage', 1.0))
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                        num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
    
    print("Using ResNet-18 as the encoder model.")
    resnet = models.resnet18(weights=None); encoder_output_dim = resnet.fc.in_features; resnet.fc = nn.Identity()
    
    model = BYOLModel(
        encoder=resnet, encoder_dim=encoder_output_dim,
        p_dim=CONFIG['projection_dim'], h_dim=CONFIG['hidden_dim'], b_tau=CONFIG['base_tau']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    
    total_steps = len(loader) * CONFIG['epochs']
    warmup_steps = len(loader) * 1
    def lr_lambda(current_step):
        if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    scaler = torch.cuda.amp.GradScaler(enabled=(CONFIG['device'] == 'cuda'))
    early_stopper = EarlyStopping(patience=CONFIG['early_stopping_patience'], verbose=True)
    
    checkpoint_path = os.path.join(CONFIG['output_dir'], CONFIG['checkpoint_filename'])
    if os.path.exists(checkpoint_path):
        print(f"Deleting old checkpoint at {checkpoint_path} to start fresh."); os.remove(checkpoint_path)
    start_epoch = 0
    
    print(f"\n--- Starting Training with Stability Controls on {CONFIG['data_subset_percentage']*100:.0f}% of Data ---")
    for epoch in range(start_epoch, CONFIG['epochs']):
        model.train(); total_loss = 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
        for i, (view1, view2) in enumerate(pbar):
            view1, view2 = view1.to(device), view2.to(device)
            
            # Use bfloat16 for better numerical stability with AMP
            with torch.cuda.amp.autocast(enabled=(CONFIG['device'] == 'cuda'), dtype=torch.bfloat16):
                predictions, targets = model(view1, view2)
                loss = byol_loss_fn(predictions, targets)
            
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            scheduler.step()
            current_step = epoch * len(loader) + i
            model.update_target_network(current_step, total_steps)
            
            total_loss += loss.item()
            pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'LR': f"{optimizer.param_groups[0]['lr']:.6f}"})

        avg_loss = total_loss / len(loader)
        if not math.isfinite(avg_loss):
            print(f"Epoch {epoch+1} ended with non-finite loss: {avg_loss}. Stopping training."); break
            
        print(f"Epoch {epoch+1} Summary: Average Loss = {avg_loss:.4f}")
        save_checkpoint(epoch, model, optimizer, scheduler, avg_loss, CONFIG)
        early_stopper(avg_loss, model, os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename']))
        if early_stopper.early_stop:
            print("Early stopping triggered due to no improvement in loss."); break

    print("\n--- Training Finished ---")
    final_encoder_path = os.path.join(CONFIG['output_dir'], "byol_final_encoder.pth")
    torch.save(model.online_encoder.state_dict(), final_encoder_path)
    print(f"Final online encoder saved to {final_encoder_path}")
    print(f"Best encoder (lowest loss) saved to {os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename'])}")

Final BYOL Pre-training Script Initialized (Numerical Stability Fix).
Using device: cuda
Pre-loading images into RAM... This may take a few minutes.


Loading files:   0%|          | 0/48 [00:00<?, ?it/s]


Dataset Initialized: Using 10% of data.
-> Pre-loaded 12,288 images into memory.
Using ResNet-18 as the encoder model.
Deleting old checkpoint at /kaggle/working/byol_checkpoint_stable.pth to start fresh.

--- Starting Training with Stability Controls on 10% of Data ---


  scaler = torch.cuda.amp.GradScaler(enabled=(CONFIG['device'] == 'cuda'))


Epoch 1/30:   0%|          | 0/96 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(CONFIG['device'] == 'cuda'), dtype=torch.bfloat16):


Epoch 1 Summary: Average Loss = 0.9380
Validation loss decreased (inf --> 0.937999). Saving best model...


Epoch 2/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 2 Summary: Average Loss = 0.6688
Validation loss decreased (0.937999 --> 0.668840). Saving best model...


Epoch 3/30:   0%|          | 0/96 [00:00<?, ?it/s]

KeyboardInterrupt: 