In [1]:
# ==============================================================================
# 03b - RESUME BYOL PRE-TRAINING (50% DATA) - v5.1
# ==============================================================================
# Purpose: To resume the previous BYOL training session from the last saved
#          checkpoint file generated by the run that timed out.
# ==============================================================================

# --- 0. Imports ---
!pip install "zarr>=2.10.0" numcodecs -q
import torch, torch.nn as nn, torch.nn.functional as F, torchvision.transforms as T, torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np, zarr, random, math, os, copy, zipfile, tempfile, gc
from tqdm.notebook import tqdm

# --- 1. Configuration ---
CONFIG = {
    'data_path': "/kaggle/input/01-data-preparation/data/ssl4eo-s12/train/S2RGB",
    'checkpoint_input_folder': '03-byol-training-50pct', # The name of the notebook version whose output you added
    'output_dir': "/kaggle/working/",
    'epochs': 50, 'batch_size': 128, 'learning_rate': 1e-3, 'weight_decay': 1.5e-6,
    'image_size': 224, 'projection_dim': 256, 'hidden_dim': 4096, 'base_tau': 0.996,
    'num_workers': 2, 'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'early_stopping_patience': 10, 'data_subset_percentage': 0.50,
    'checkpoint_filename': "byol_checkpoint_50pct.pth",
    'best_model_filename': "byol_best_encoder_50pct.pth"
}

# --- 2. Helper Classes & Functions ---
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience, self.verbose, self.delta = patience, verbose, delta
        self.counter, self.best_score, self.early_stop, self.val_loss_min = 0, None, False, np.Inf
    def __call__(self, val_loss, model, path):
        if not math.isfinite(val_loss): self.early_stop = True; print("Loss is not finite, stopping."); return
        score = -val_loss
        if self.best_score is None: self.best_score = score; self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1; print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else: self.best_score = score; self.save_checkpoint(val_loss, model, path); self.counter = 0
    def save_checkpoint(self, val_loss, model, path):
        print(f'Loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best model...'); torch.save(model.online_encoder.state_dict(), path); self.val_loss_min = val_loss

def save_checkpoint(epoch, model, optimizer, scheduler, loss, config):
    state = {'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss}
    torch.save(state, os.path.join(config['output_dir'], config['checkpoint_filename']))

def load_checkpoint(model, optimizer, scheduler, config):
    path = Path(f"/kaggle/input/{config['checkpoint_input_folder']}/{config['checkpoint_filename']}")
    start_epoch = 0
    if path.exists():
        print(f"✅ Resuming from checkpoint: {path}")
        ckpt = torch.load(path, map_location=config['device'])
        model.load_state_dict(ckpt['model_state_dict']); optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict']); start_epoch = ckpt['epoch']
        print(f"   -> Resuming training from the start of epoch {start_epoch + 1}")
    else: print(f"❌ Checkpoint not found at {path}. Cannot resume.")
    return start_epoch

# --- 3. Dataset, Model, and Loss ---
class SSL4EODataset(Dataset):
    def __init__(self, root_dir, subset_percentage=1.0):
        self.root_dir = Path(root_dir)
        all_files = sorted(list(self.root_dir.glob("*.zarr.zip")))
        num_files = int(len(all_files) * subset_percentage)
        self.zarr_files = random.sample(all_files, num_files) if subset_percentage < 1.0 else all_files
        self.images = self._preload_images(); self.total_images = len(self.images)
        print(f"\nDataset Initialized: Using {subset_percentage*100:.0f}% of data -> {len(self.images):,} images pre-loaded.")
        self.transform_t = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.GaussianBlur(23, (0.1, 2.0)),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        self.transform_t_prime = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.RandomSolarize(0.5, p=0.2),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def _preload_images(self):
        images = []; print("Pre-loading images into RAM...")
        for fp in tqdm(self.zarr_files, desc="Loading files"):
            with tempfile.TemporaryDirectory() as td:
                with zipfile.ZipFile(str(fp), 'r') as zf: zf.extractall(td)
                za = zarr.open(td, mode='r')['bands'][:]; images.extend(za.reshape(-1, *za.shape[2:]))
        return images
    def __len__(self): return self.total_images
    def __getitem__(self, idx):
        img_chw = self.images[idx]; img_hwc = np.transpose(img_chw, (1, 2, 0))
        return self.transform_t(img_hwc), self.transform_t_prime(img_hwc)

class MLP(nn.Module):
    def __init__(self, i, h, o): super().__init__(); self.net = nn.Sequential(nn.Linear(i, h), nn.BatchNorm1d(h), nn.ReLU(True), nn.Linear(h, o))
    def forward(self, x): return self.net(x)

class BYOLModel(nn.Module):
    def __init__(self, encoder, encoder_dim, p_dim, h_dim, b_tau):
        super().__init__(); self.base_tau = b_tau
        self.online_encoder=encoder; self.online_projector=MLP(encoder_dim,h_dim,p_dim); self.online_predictor=MLP(p_dim,h_dim,p_dim)
        self.target_encoder=copy.deepcopy(encoder); self.target_projector=copy.deepcopy(self.online_projector)
        for p in self.target_encoder.parameters(): p.requires_grad=False
        for p in self.target_projector.parameters(): p.requires_grad=False
    @torch.no_grad()
    def update_target_network(self, cs, ts):
        tau=1-(1-self.base_tau)*(math.cos(math.pi*cs/ts)+1)/2
        for o,t in zip(self.online_encoder.parameters(),self.target_encoder.parameters()): t.data.mul_(tau).add_(o.data,alpha=1-tau)
        for o,t in zip(self.online_projector.parameters(),self.target_projector.parameters()): t.data.mul_(tau).add_(o.data,alpha=1-tau)
    def forward(self, v1, v2):
        op1=self.online_predictor(self.online_projector(self.online_encoder(v1))); op2=self.online_predictor(self.online_projector(self.online_encoder(v2)))
        with torch.no_grad(): tp1=self.target_projector(self.target_encoder(v1)); tp2=self.target_projector(self.target_encoder(v2))
        return (op1, op2), (tp1.detach(), tp2.detach())

def byol_loss_fn(p, t):
    p1, p2=p; t1, t2=t; eps=1e-6
    p1_norm=F.normalize(p1,p=2,dim=-1,eps=eps); p2_norm=F.normalize(p2,p=2,dim=-1,eps=eps)
    t1_norm=F.normalize(t1.detach(),p=2,dim=-1,eps=eps); t2_norm=F.normalize(t2.detach(),p=2,dim=-1,eps=eps)
    loss1=2-2*(p1_norm*t2_norm).sum(dim=-1); loss2=2-2*(p2_norm*t1_norm).sum(dim=-1)
    return (loss1 + loss2).mean() * 0.5

# --- 4. Main Training Execution ---
if __name__ == '__main__':
    gc.collect(); torch.cuda.empty_cache()
    device = torch.device(CONFIG['device']); print(f"Using device: {device}")
    
    dataset = SSL4EODataset(CONFIG['data_path'], CONFIG['data_subset_percentage'])
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)

    print("Using ResNet-18 as the encoder model.")
    resnet=models.resnet18(weights=None); ed=resnet.fc.in_features; resnet.fc=nn.Identity()
    model = BYOLModel(resnet, ed, CONFIG['projection_dim'], CONFIG['hidden_dim'], CONFIG['base_tau']).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(loader)*CONFIG['epochs'])
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))
    early_stopper = EarlyStopping(patience=CONFIG['early_stopping_patience'], verbose=True)
    
    start_epoch = load_checkpoint(model, optimizer, scheduler, CONFIG)
    
    if start_epoch > 0 or "03-byol-training-50pct" not in CONFIG['checkpoint_input_folder']:
        print(f"\n--- Starting/Resuming BYOL Training from Epoch {start_epoch} ---")
        for epoch in range(start_epoch, CONFIG['epochs']):
            model.train(); total_loss = 0.0
            pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
            for v1, v2 in pbar:
                v1, v2 = v1.to(device), v2.to(device)
                with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
                    p, t = model(v1, v2); loss = byol_loss_fn(p, t)
                optimizer.zero_grad(set_to_none=True); scaler.scale(loss).backward()
                scaler.step(optimizer); scaler.update(); scheduler.step()
                model.update_target_network(epoch*len(loader)+pbar.n, len(loader)*CONFIG['epochs'])
                total_loss += loss.item(); pbar.set_postfix({'Loss':f"{loss.item():.4f}", 'LR':f"{optimizer.param_groups[0]['lr']:.6f}"})
            
            avg_loss = total_loss / len(loader)
            if not math.isfinite(avg_loss): print(f"Epoch {epoch+1} ended with non-finite loss: {avg_loss}. Stopping."); break
            
            print(f"Epoch {epoch+1} Summary: Average Loss = {avg_loss:.4f}")
            save_checkpoint(epoch, model, optimizer, scheduler, avg_loss, CONFIG)
            early_stopper(avg_loss, model, os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename']))
            if early_stopper.early_stop: print("Early stopping triggered."); break

        print("\n--- Training Finished ---")
        torch.save(model.online_encoder.state_dict(), os.path.join(CONFIG['output_dir'], "byol_final_encoder_50pct.pth"))
        print("Final and best models saved.")
    else:
        # This else block will run if load_checkpoint finds nothing and returns 0.
        print("Halting execution because a valid checkpoint to resume from was not found.")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/205.4 kB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/205.4 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/8.8 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m5.7/8.8 MB[0m [31m169.4 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m8.8/8.8 MB[0m [31m184.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m101.6 MB/s[0m eta [36m0:00:00[0m
[?25h

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.7/53.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h

Using device: cuda
Pre-loading images into RAM...


Loading files:   0%|          | 0/241 [00:00<?, ?it/s]


Dataset Initialized: Using 50% of data -> 61,696 images pre-loaded.
Using ResNet-18 as the encoder model.


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))


✅ Resuming from checkpoint: /kaggle/input/03-byol-training-50pct/byol_checkpoint_50pct.pth


   -> Resuming training from the start of epoch 17

--- Starting/Resuming BYOL Training from Epoch 16 ---


Epoch 17/50:   0%|          | 0/482 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):


Epoch 17 Summary: Average Loss = 0.1841


Loss decreased (inf --> 0.184122). Saving best model...


Epoch 18/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 18 Summary: Average Loss = 0.1807


Loss decreased (0.184122 --> 0.180744). Saving best model...


Epoch 19/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 19 Summary: Average Loss = 0.1812


EarlyStopping counter: 1 out of 10


Epoch 20/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 20 Summary: Average Loss = 0.1810


EarlyStopping counter: 2 out of 10


Epoch 21/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 21 Summary: Average Loss = 0.1814


EarlyStopping counter: 3 out of 10


Epoch 22/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 22 Summary: Average Loss = 0.1797


Loss decreased (0.180744 --> 0.179666). Saving best model...


Epoch 23/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 23 Summary: Average Loss = 0.1801


EarlyStopping counter: 1 out of 10


Epoch 24/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 24 Summary: Average Loss = 0.1806


EarlyStopping counter: 2 out of 10


Epoch 25/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 25 Summary: Average Loss = 0.1785


Loss decreased (0.179666 --> 0.178511). Saving best model...


Epoch 26/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 26 Summary: Average Loss = 0.1778


Loss decreased (0.178511 --> 0.177780). Saving best model...


Epoch 27/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 27 Summary: Average Loss = 0.1760


Loss decreased (0.177780 --> 0.176014). Saving best model...


Epoch 28/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 28 Summary: Average Loss = 0.1769


EarlyStopping counter: 1 out of 10


Epoch 29/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 29 Summary: Average Loss = 0.1771


EarlyStopping counter: 2 out of 10


Epoch 30/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 30 Summary: Average Loss = 0.1783


EarlyStopping counter: 3 out of 10


Epoch 31/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 31 Summary: Average Loss = 0.1794


EarlyStopping counter: 4 out of 10


Epoch 32/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 32 Summary: Average Loss = 0.1793


EarlyStopping counter: 5 out of 10


Epoch 33/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 33 Summary: Average Loss = 0.1834


EarlyStopping counter: 6 out of 10


Epoch 34/50:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 34 Summary: Average Loss = 0.1836


EarlyStopping counter: 7 out of 10


Epoch 35/50:   0%|          | 0/482 [00:00<?, ?it/s]