In [1]:
# ==============================================================================
# 04c - RESUME BARLOW TWINS TRAINING (PART 3 - FINAL PATH CORRECTION)
# ==============================================================================
# Purpose: To resume the Barlow Twins training from the correct checkpoint path.
# Change: Updated 'checkpoint_input_folder' to match the actual Kaggle path.
# ==============================================================================

# --- 0. Imports ---
!pip install "zarr>=2.10.0" numcodecs -q
import torch, torch.nn as nn, torch.nn.functional as F, torchvision.transforms as T, torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np, zarr, random, math, os, copy, zipfile, tempfile, gc
from tqdm.notebook import tqdm

# --- 1. Configuration ---
CONFIG = {
    # Path to the original data from the preparation notebook
    'data_input_folder': '01-data-preparation',
    
    # MODIFIED: Use the EXACT folder name as listed by 'ls -R'
    'checkpoint_input_folder': '04-barlow-twins-training-50pct', 
    
    # Where to save outputs for THIS run
    'output_dir': "/kaggle/working/",
    
    # Training parameters must match the previous run to resume correctly
    'epochs': 30, # The final target number of epochs
    'batch_size': 128,
    'learning_rate': 1e-3,
    'weight_decay': 1.5e-6,
    'image_size': 224,
    'projection_dim': 2048,
    'hidden_dim': 4096,
    'lambda_param': 5e-3,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'early_stopping_patience': 10,
    'data_subset_percentage': 0.50,
    
    # Filenames must match exactly
    'checkpoint_filename': "barlow_checkpoint_50pct.pth",
    'best_model_filename': "barlow_best_encoder_50pct.pth"
}

# --- 2. Helper Classes & Functions ---
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience, self.verbose, self.delta = patience, verbose, delta
        self.counter, self.best_score, self.early_stop, self.val_loss_min = 0, None, False, np.Inf
    def __call__(self, val_loss, model, path):
        if not math.isfinite(val_loss): print("Loss is not finite, stopping."); self.early_stop = True; return
        score = -val_loss
        if self.best_score is None: self.best_score = score; self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else: self.best_score = score; self.save_checkpoint(val_loss, model, path); self.counter = 0
    def save_checkpoint(self, val_loss, model, path):
        if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best model...')
        torch.save(model.encoder.state_dict(), path); self.val_loss_min = val_loss

def save_checkpoint(epoch, model, optimizer, scheduler, loss, config):
    state = {'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss}
    torch.save(state, os.path.join(config['output_dir'], config['checkpoint_filename']))

def load_checkpoint(model, optimizer, scheduler, config):
    path = Path(f"/kaggle/input/{config['checkpoint_input_folder']}/{config['checkpoint_filename']}")
    start_epoch = 0
    if path.exists():
        print(f"✅ Resuming from checkpoint: {path}")
        ckpt = torch.load(path, map_location=config['device'])
        model.load_state_dict(ckpt['model_state_dict']); optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict']); start_epoch = ckpt['epoch']
        print(f"   -> Resuming training from the start of epoch {start_epoch}")
    else: print(f"❌ Checkpoint not found at {path}. Starting a fresh run.")
    return start_epoch

# --- 3. Dataset, Model, and Loss ---
class SSL4EODataset(Dataset):
    def __init__(self, root_dir, subset_percentage=1.0):
        self.root_dir = Path(root_dir)
        all_zarr_files = sorted(list(self.root_dir.glob("*.zarr.zip")))
        num_files_to_use = int(len(all_zarr_files) * subset_percentage)
        random.seed(42) # for reproducibility of the subset
        self.zarr_files = random.sample(all_zarr_files, num_files_to_use) if subset_percentage < 1.0 else all_zarr_files
        self.images = self._preload_images(); self.total_images = len(self.images)
        print(f"\nDataset Initialized: Using {subset_percentage*100:.0f}% of data. -> Pre-loaded {self.total_images:,} images.")
        self.transform = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.GaussianBlur(23, (0.1, 2.0)),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def _preload_images(self):
        images = []; print("Pre-loading images into RAM...")
        for fp in tqdm(self.zarr_files, desc="Loading files"):
            with tempfile.TemporaryDirectory() as td:
                with zipfile.ZipFile(str(fp), 'r') as zf: zf.extractall(td)
                za = zarr.open(td, mode='r')['bands'][:]; images.extend(za.reshape(-1, *za.shape[2:]))
        return images
    def __len__(self): return self.total_images
    def __getitem__(self, idx):
        image_chw = self.images[idx]; image_hwc = np.transpose(image_chw, (1, 2, 0))
        view1 = self.transform(image_hwc); view2 = self.transform(image_hwc)
        return view1, view2

class BarlowTwinsModel(nn.Module):
    def __init__(self, encoder, encoder_dim, projection_dim, hidden_dim):
        super().__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True),
            nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True),
            nn.Linear(hidden_dim, projection_dim))
    def forward(self, v1, v2):
        z1 = self.projector(self.encoder(v1)); z2 = self.projector(self.encoder(v2))
        return z1, z2

def barlow_twins_loss_fn(z1, z2, lambda_param):
    batch_size = z1.shape[0]
    z1_norm = (z1 - z1.mean(dim=0)) / (z1.std(dim=0) + 1e-5)
    z2_norm = (z2 - z2.mean(dim=0)) / (z2.std(dim=0) + 1e-5)
    c = (z1_norm.T @ z2_norm) / batch_size
    on_diag_loss = ((torch.diagonal(c) - 1)**2).sum()
    off_diag = c.fill_diagonal_(0)
    off_diag_loss = (off_diag**2).sum()
    return on_diag_loss + lambda_param * off_diag_loss

# --- 4. Main Training Execution ---
if __name__ == '__main__':
    gc.collect(); torch.cuda.empty_cache()
    device = torch.device(CONFIG['device']); print(f"Using device: {device}")
    
    data_full_path = Path(f"/kaggle/input/{CONFIG['data_input_folder']}/data/ssl4eo-s12/train/S2RGB")
    dataset = SSL4EODataset(data_full_path, CONFIG['data_subset_percentage'])
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                        num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
    
    print("Using ResNet-18 as the encoder model.")
    resnet = models.resnet18(weights=None); encoder_output_dim = resnet.fc.in_features; resnet.fc = nn.Identity()
    
    model = BarlowTwinsModel(
        encoder=resnet, encoder_dim=encoder_output_dim,
        projection_dim=CONFIG['projection_dim'], hidden_dim=CONFIG['hidden_dim']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(loader)*CONFIG['epochs'])
    
    early_stopper = EarlyStopping(patience=CONFIG['early_stopping_patience'], verbose=True)
    
    # This will load the state from the previous 12-hour run
    start_epoch = load_checkpoint(model, optimizer, scheduler, CONFIG)
    
    if start_epoch > 0:
        print(f"\n--- Resuming Barlow Twins Training from Epoch {start_epoch} ---")
        for epoch in range(start_epoch, CONFIG['epochs']):
            model.train(); total_loss = 0.0
            pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
            for view1, view2 in pbar:
                view1, view2 = view1.to(device), view2.to(device)
                z1, z2 = model(view1, view2)
                loss = barlow_twins_loss_fn(z1, z2, lambda_param=CONFIG['lambda_param'])
                
                optimizer.zero_grad(); loss.backward(); optimizer.step(); scheduler.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'Loss': f"{loss.item():.2f}", 'LR': f"{optimizer.param_groups[0]['lr']:.6f}"})

            avg_loss = total_loss / len(loader)
            print(f"Epoch {epoch+1} Summary: Average Loss = {avg_loss:.2f}")
            
            save_checkpoint(epoch, model, optimizer, scheduler, avg_loss, CONFIG)
            early_stopper(avg_loss, model, os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename']))
            if early_stopper.early_stop:
                print("Early stopping triggered."); break

        print("\n--- Training Finished ---")
        torch.save(model.encoder.state_dict(), os.path.join(CONFIG['output_dir'], "barlow_final_encoder_50pct.pth"))
        
    else:
        print("Could not find a valid checkpoint to resume from. Halting execution.")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/205.4 kB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m204.8/205.4 kB[0m [31m8.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/205.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/8.8 MB[0m [31m?[0m eta [36m-:--:--[0m

[2K   [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/8.8 MB[0m [31m64.7 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m4.1/8.8 MB[0m [31m64.3 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m6.4/8.8 MB[0m [31m63.5 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m8.8/8.8 MB[0m [31m74.2 MB/s[0m eta [36m0:00:01[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m54.5 MB/s[0m eta [36m0:00:00[0m
[?25h

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.7/53.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h

Using device: cuda
Pre-loading images into RAM...


Loading files:   0%|          | 0/241 [00:00<?, ?it/s]


Dataset Initialized: Using 50% of data. -> Pre-loaded 61,696 images.
Using ResNet-18 as the encoder model.


✅ Resuming from checkpoint: /kaggle/input/04-barlow-twins-training-50pct/barlow_checkpoint_50pct.pth


   -> Resuming training from the start of epoch 17

--- Resuming Barlow Twins Training from Epoch 17 ---


Epoch 18/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 18 Summary: Average Loss = 457.67


Validation loss decreased (inf --> 457.667400). Saving best model...


Epoch 19/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 19 Summary: Average Loss = 446.41


Validation loss decreased (457.667400 --> 446.412952). Saving best model...


Epoch 20/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 20 Summary: Average Loss = 436.04


Validation loss decreased (446.412952 --> 436.035797). Saving best model...


Epoch 21/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 21 Summary: Average Loss = 427.15


Validation loss decreased (436.035797 --> 427.153399). Saving best model...


Epoch 22/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 22 Summary: Average Loss = 419.74


Validation loss decreased (427.153399 --> 419.742676). Saving best model...


Epoch 23/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 23 Summary: Average Loss = 413.37


Validation loss decreased (419.742676 --> 413.367056). Saving best model...


Epoch 24/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 24 Summary: Average Loss = 407.93


Validation loss decreased (413.367056 --> 407.925351). Saving best model...


Epoch 25/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 25 Summary: Average Loss = 403.44


Validation loss decreased (407.925351 --> 403.442369). Saving best model...


Epoch 26/30:   0%|          | 0/482 [00:00<?, ?it/s]