In [1]:
# ==============================================================================
# 04 - BARLOW TWINS PRE-TRAINING (25% DATA) - RESUME RUN
# ==============================================================================
# Purpose: To resume the Barlow Twins training from the last saved checkpoint.
# ==============================================================================

# --- 0. Install and Import ---
!pip install "zarr>=2.10.0" numcodecs -q

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T, torchvision.models as models
from pathlib import Path
import numpy as np, zarr, random, math, os, copy, zipfile, tempfile, gc
from tqdm.notebook import tqdm

print("Barlow Twins Pre-training Script Initialized (Resume Mode).")

# --- 1. Configuration ---
CONFIG = {
    'data_path': "/kaggle/input/01-data-preparation/data/ssl4eo-s12/train/S2RGB",
    'output_dir': "/kaggle/working/",
    # Path to the checkpoint from the PREVIOUS run's output
    'resume_from_checkpoint_path': "/kaggle/input/04-barlow-twins-training-25pct/barlow_checkpoint_25pct.pth",
    'epochs': 20,
    'batch_size': 128,
    'learning_rate': 1e-3,
    'weight_decay': 1.5e-6,
    'image_size': 224,
    'projection_dim': 2048,
    'hidden_dim': 4096,
    'lambda_param': 5e-3,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'early_stopping_patience': 7,
    'data_subset_percentage': 0.25,
    'checkpoint_filename': "barlow_checkpoint_25pct.pth", # The file to be saved in THIS run
    'best_model_filename': "barlow_best_encoder_25pct.pth"
}

# --- 2. Helper Classes & Functions ---
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience, self.verbose, self.delta = patience, verbose, delta; self.counter, self.best_score, self.early_stop, self.val_loss_min = 0, None, False, np.Inf
    def __call__(self, val_loss, model, path):
        if not math.isfinite(val_loss): print("Loss is not finite, stopping."); self.early_stop = True; return
        score = -val_loss
        if self.best_score is None: self.best_score = score; self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else: self.best_score = score; self.save_checkpoint(val_loss, model, path); self.counter = 0
    def save_checkpoint(self, val_loss, model, path):
        if self.verbose: print(f'Loss decreased ({self.val_loss_min:.2f}-->{val_loss:.2f}). Saving best model.')
        torch.save(model.encoder.state_dict(), path); self.val_loss_min = val_loss

def save_checkpoint(epoch, model, optimizer, scheduler, loss, config):
    state = {'epoch':epoch+1, 'model_state_dict':model.state_dict(), 'optimizer_state_dict':optimizer.state_dict(), 'scheduler_state_dict':scheduler.state_dict(), 'loss':loss}
    torch.save(state, os.path.join(config['output_dir'], config['checkpoint_filename']))

# MODIFIED: Intelligent checkpoint loading
def load_checkpoint(model, optimizer, scheduler, config):
    # Priority 1: Check for a checkpoint from a previous run's output
    resume_path = config.get('resume_from_checkpoint_path')
    if resume_path and os.path.exists(resume_path):
        print(f"Resuming from previous run's checkpoint: {resume_path}")
        path_to_load = resume_path
    # Priority 2: Check for a checkpoint in the current working directory
    elif os.path.exists(os.path.join(config['output_dir'], config['checkpoint_filename'])):
        path_to_load = os.path.join(config['output_dir'], config['checkpoint_filename'])
        print(f"Resuming from current session's checkpoint: {path_to_load}")
    else:
        print("No checkpoint found, starting from scratch.")
        return 0

    ckpt = torch.load(path_to_load, map_location=config['device'])
    model.load_state_dict(ckpt['model_state_dict']); optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scheduler.load_state_dict(ckpt['scheduler_state_dict']); start_epoch = ckpt['epoch']
    print(f"Successfully loaded. Resuming from epoch {start_epoch}.")
    return start_epoch

# --- 3. Dataset, Model, and Loss (No changes needed) ---
class SSL4EODataset(Dataset):
    def __init__(self, root_dir, subset_percentage=1.0):
        self.root_dir = Path(root_dir)
        all_files = sorted(list(self.root_dir.glob("*.zarr.zip")))
        num_files = int(len(all_files) * subset_percentage)
        self.zarr_files = random.sample(all_files, num_files) if subset_percentage < 1.0 else all_files
        self.images = self._preload_images()
        print(f"\nDataset Initialized: Using {subset_percentage*100:.0f}% of data -> {len(self.images):,} images pre-loaded.")
        self.transform = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.GaussianBlur(23, (0.1, 2.0)),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def _preload_images(self):
        images = []; print("Pre-loading images into RAM...")
        for fp in tqdm(self.zarr_files, desc="Loading files"):
            with tempfile.TemporaryDirectory() as td:
                with zipfile.ZipFile(str(fp), 'r') as zf: zf.extractall(td)
                za = zarr.open(td, mode='r')['bands'][:]; images.extend(za.reshape(-1, *za.shape[2:]))
        return images
    def __len__(self): return len(self.images)
    def __getitem__(self, idx):
        img_chw = self.images[idx]; img_hwc = np.transpose(img_chw, (1, 2, 0))
        return self.transform(img_hwc), self.transform(img_hwc)

class BarlowTwinsModel(nn.Module):
    def __init__(self, e, ed, pd, hd):
        super().__init__(); self.encoder = e
        self.projector = nn.Sequential(
            nn.Linear(ed, hd, bias=False), nn.BatchNorm1d(hd), nn.ReLU(inplace=True),
            nn.Linear(hd, hd, bias=False), nn.BatchNorm1d(hd), nn.ReLU(inplace=True),
            nn.Linear(hd, pd, bias=False))
    def forward(self, v1, v2):
        return self.projector(self.encoder(v1)), self.projector(self.encoder(v2))

def barlow_twins_loss_fn(z1, z2, lambda_param):
    bs, D = z1.shape
    z1_n=(z1-z1.mean(0))/(z1.std(0)+1e-5); z2_n=(z2-z2.mean(0))/(z2.std(0)+1e-5)
    c=(z1_n.T@z2_n)/bs
    on=torch.diagonal(c)-1; off=c.fill_diagonal_(0)
    return on.pow(2).sum()+lambda_param*off.pow(2).sum()

# --- 4. Main Training Execution ---
device = torch.device(CONFIG['device']); print(f"Using device: {device}")
dataset = SSL4EODataset(CONFIG['data_path'], CONFIG['data_subset_percentage'])
loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)

resnet=models.resnet18(weights=None); ed=resnet.fc.in_features; resnet.fc=nn.Identity()
model = BarlowTwinsModel(resnet, ed, CONFIG['projection_dim'], CONFIG['hidden_dim']).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(loader)*CONFIG['epochs'])
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))
early_stopper = EarlyStopping(patience=CONFIG['early_stopping_patience'], verbose=True)

start_epoch = load_checkpoint(model, optimizer, scheduler, CONFIG)

print(f"\n--- Resuming Barlow Twins Training on {CONFIG['data_subset_percentage']*100:.0f}% of Data ---")
for epoch in range(start_epoch, CONFIG['epochs']):
    model.train(); total_loss = 0.0
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    for v1, v2 in pbar:
        v1, v2 = v1.to(device), v2.to(device)
        with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):
            z1, z2 = model(v1, v2); loss = barlow_twins_loss_fn(z1, z2, CONFIG['lambda_param'])
        optimizer.zero_grad(); scaler.scale(loss).backward()
        scaler.step(optimizer); scaler.update(); scheduler.step()
        total_loss += loss.item(); pbar.set_postfix({'Loss':f"{loss.item():.2f}"})
    
    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1} Summary: Average Loss = {avg_loss:.2f}")
    save_checkpoint(epoch, model, optimizer, scheduler, avg_loss, CONFIG)
    early_stopper(avg_loss, model, os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename']))
    if early_stopper.early_stop: print("Early stopping triggered."); break

print("\n--- Training Finished ---")
torch.save(model.encoder.state_dict(), os.path.join(CONFIG['output_dir'], "barlow_final_encoder_25pct.pth"))

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/205.4 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m80.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.7/53.7 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hBarlow Twins Pre-training Script Initialized (Resume Mode).
Using device: cuda
Pre-loading images into RAM...


Loading files:   0%|          | 0/120 [00:00<?, ?it/s]


Dataset Initialized: Using 25% of data -> 30,720 images pre-loaded.


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))


Resuming from previous run's checkpoint: /kaggle/input/04-barlow-twins-training-25pct/barlow_checkpoint_25pct.pth
Successfully loaded. Resuming from epoch 19.

--- Resuming Barlow Twins Training on 25% of Data ---


Epoch 20/20:   0%|          | 0/240 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(device.type == 'cuda')):


Epoch 20 Summary: Average Loss = 864.63
Loss decreased (inf-->864.63). Saving best model.

--- Training Finished ---
