In [1]:
# ==============================================================================
# 04 - BARLOW TWINS PRE-TRAINING ON SSL4EO-S12
# ==============================================================================
# Purpose: To train the Barlow Twins model under the exact same conditions
#          as the BYOL model (ResNet-18, 10% data) for a fair comparison.
# ==============================================================================

# --- 0. Install and Import ---
!pip install "zarr>=2.10.0" numcodecs -q

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T, torchvision.models as models
from pathlib import Path
import numpy as np, zarr, random, math, os, copy, zipfile, tempfile, gc
from tqdm.notebook import tqdm

print("Barlow Twins Pre-training Script Initialized.")

# --- 1. Configuration ---
CONFIG = {
    'data_path': "/kaggle/input/01-data-preparation/data/ssl4eo-s12/train/S2RGB",
    'output_dir': "/kaggle/working/",
    'epochs': 30,
    'batch_size': 128,
    'learning_rate': 1e-3, # Barlow Twins often benefits from a slightly higher LR than BYOL
    'weight_decay': 1.5e-6,
    'image_size': 224,
    'projection_dim': 2048, # Barlow Twins paper suggests larger projection dims
    'hidden_dim': 4096,
    'lambda_param': 5e-3, # Lambda for the loss function, from the paper
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'early_stopping_patience': 7,
    'data_subset_percentage': 0.10,
    'checkpoint_filename': "barlow_checkpoint_10pct.pth",
    'best_model_filename': "barlow_best_encoder_10pct.pth"
}

# --- 2. Helper Classes & Functions (EarlyStopping, Checkpoints) ---
# (These helper functions are identical to the BYOL script and can be reused)
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience, self.verbose, self.delta = patience, verbose, delta
        self.counter, self.best_score, self.early_stop, self.val_loss_min = 0, None, False, np.Inf
    def __call__(self, val_loss, model, path):
        if not math.isfinite(val_loss): print("Loss is not finite, stopping early."); self.early_stop = True; return
        score = -val_loss
        if self.best_score is None: self.best_score = score; self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else: self.best_score = score; self.save_checkpoint(val_loss, model, path); self.counter = 0
    def save_checkpoint(self, val_loss, model, path):
        if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best model...')
        torch.save(model.encoder.state_dict(), path); self.val_loss_min = val_loss

def save_checkpoint(epoch, model, optimizer, scheduler, loss, config):
    state = {'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss}
    torch.save(state, os.path.join(config['output_dir'], config['checkpoint_filename']))

def load_checkpoint(model, optimizer, scheduler, config):
    path = os.path.join(config['output_dir'], config['checkpoint_filename'])
    start_epoch = 0
    if os.path.exists(path):
        print(f"Resuming from checkpoint: {path}")
        ckpt = torch.load(path, map_location=config['device'])
        model.load_state_dict(ckpt['model_state_dict']); optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict']); start_epoch = ckpt['epoch']
        print(f"Resumed from epoch {start_epoch}, with previous loss {ckpt['loss']:.4f}")
    else: print("No checkpoint found, starting from scratch.")
    return start_epoch

# --- 3. Dataset, Model, and Loss ---
class SSL4EODataset(Dataset):
    def __init__(self, root_dir, subset_percentage=1.0):
        self.root_dir = Path(root_dir)
        all_zarr_files = sorted(list(self.root_dir.glob("*.zarr.zip")))
        num_files_to_use = int(len(all_zarr_files) * subset_percentage)
        self.zarr_files = random.sample(all_zarr_files, num_files_to_use) if subset_percentage < 1.0 else all_zarr_files
        self.images = self._preload_images(); self.total_images = len(self.images)
        print(f"\nDataset Initialized: Using {subset_percentage*100:.0f}% of data.")
        print(f"-> Pre-loaded {self.total_images:,} images into memory.")
        # Barlow Twins uses a symmetric augmentation pipeline
        self.transform = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.GaussianBlur(23, (0.1, 2.0)),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def _preload_images(self):
        images = []; print("Pre-loading images into RAM... This may take a few minutes.")
        for fp in tqdm(self.zarr_files, desc="Loading files"):
            with tempfile.TemporaryDirectory() as td:
                with zipfile.ZipFile(str(fp), 'r') as zf: zf.extractall(td)
                za = zarr.open(td, mode='r')['bands'][:]; images.extend(za.reshape(-1, *za.shape[2:]))
        return images
    def __len__(self): return self.total_images
    def __getitem__(self, idx):
        image_chw = self.images[idx]; image_hwc = np.transpose(image_chw, (1, 2, 0))
        view1 = self.transform(image_hwc); view2 = self.transform(image_hwc)
        return view1, view2

class BarlowTwinsModel(nn.Module):
    def __init__(self, encoder, encoder_dim, projection_dim, hidden_dim):
        super().__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim)
        )
    def forward(self, v1, v2):
        z1 = self.projector(self.encoder(v1))
        z2 = self.projector(self.encoder(v2))
        return z1, z2

def barlow_twins_loss_fn(z1, z2, lambda_param):
    batch_size, D = z1.shape
    # Normalize the representations
    z1_norm = (z1 - z1.mean(dim=0)) / (z1.std(dim=0) + 1e-5)
    z2_norm = (z2 - z2.mean(dim=0)) / (z2.std(dim=0) + 1e-5)
    # Cross-correlation matrix
    c = (z1_norm.T @ z2_norm) / batch_size
    # Loss calculation
    on_diag = torch.diagonal(c) - 1
    off_diag = c.fill_diagonal_(0)
    loss = (on_diag**2).sum() + lambda_param * (off_diag**2).sum()
    return loss

# --- 4. Main Training Execution ---
if __name__ == '__main__':
    gc.collect(); torch.cuda.empty_cache()
    device = torch.device(CONFIG['device']); print(f"Using device: {device}")
    
    dataset = SSL4EODataset(root_dir=CONFIG['data_path'], subset_percentage=CONFIG.get('data_subset_percentage', 1.0))
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                        num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
    
    print("Using ResNet-18 as the encoder model.")
    resnet = models.resnet18(weights=None); encoder_output_dim = resnet.fc.in_features; resnet.fc = nn.Identity()
    
    model = BarlowTwinsModel(
        encoder=resnet, encoder_dim=encoder_output_dim,
        projection_dim=CONFIG['projection_dim'], hidden_dim=CONFIG['hidden_dim']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(loader)*CONFIG['epochs'])
    
    early_stopper = EarlyStopping(patience=CONFIG['early_stopping_patience'], verbose=True)
    start_epoch = load_checkpoint(model, optimizer, scheduler, CONFIG)
    
    print(f"\n--- Starting Barlow Twins Training on {CONFIG['data_subset_percentage']*100:.0f}% of Data ---")
    for epoch in range(start_epoch, CONFIG['epochs']):
        model.train(); total_loss = 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
        for view1, view2 in pbar:
            view1, view2 = view1.to(device), view2.to(device)
            z1, z2 = model(view1, view2)
            loss = barlow_twins_loss_fn(z1, z2, lambda_param=CONFIG['lambda_param'])
            
            optimizer.zero_grad(); loss.backward(); optimizer.step(); scheduler.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'LR': f"{optimizer.param_groups[0]['lr']:.6f}"})

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch+1} Summary: Average Loss = {avg_loss:.4f}")
        
        save_checkpoint(epoch, model, optimizer, scheduler, avg_loss, CONFIG)
        early_stopper(avg_loss, model, os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename']))
        if early_stopper.early_stop:
            print("Early stopping triggered."); break

    print("\n--- Training Finished ---")
    final_encoder_path = os.path.join(CONFIG['output_dir'], "barlow_final_encoder.pth")
    torch.save(model.encoder.state_dict(), final_encoder_path)
    print(f"Final encoder saved to {final_encoder_path}")
    print(f"Best encoder (lowest loss) saved to {os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename'])}")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/205.4 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m80.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.7/53.7 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hBarlow Twins Pre-training Script Initialized.
Using device: cuda
Pre-loading images into RAM... This may take a few minutes.


Loading files:   0%|          | 0/48 [00:00<?, ?it/s]


Dataset Initialized: Using 10% of data.
-> Pre-loaded 12,288 images into memory.
Using ResNet-18 as the encoder model.
No checkpoint found, starting from scratch.

--- Starting Barlow Twins Training on 10% of Data ---


Epoch 1/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 1 Summary: Average Loss = 1711.7849
Validation loss decreased (inf --> 1711.784856). Saving best model...


Epoch 2/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 2 Summary: Average Loss = 1578.3484
Validation loss decreased (1711.784856 --> 1578.348357). Saving best model...


Epoch 3/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 3 Summary: Average Loss = 1494.0904
Validation loss decreased (1578.348357 --> 1494.090424). Saving best model...


Epoch 4/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 4 Summary: Average Loss = 1485.7309
Validation loss decreased (1494.090424 --> 1485.730896). Saving best model...


Epoch 5/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 5 Summary: Average Loss = 1469.8250
Validation loss decreased (1485.730896 --> 1469.824993). Saving best model...


Epoch 6/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 6 Summary: Average Loss = 1453.7621
Validation loss decreased (1469.824993 --> 1453.762060). Saving best model...


Epoch 7/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 7 Summary: Average Loss = 1402.5623
Validation loss decreased (1453.762060 --> 1402.562266). Saving best model...


Epoch 8/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 8 Summary: Average Loss = 1374.0013
Validation loss decreased (1402.562266 --> 1374.001284). Saving best model...


Epoch 9/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 9 Summary: Average Loss = 1367.7070
Validation loss decreased (1374.001284 --> 1367.706961). Saving best model...


Epoch 10/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 10 Summary: Average Loss = 1372.5364
EarlyStopping counter: 1 out of 7


Epoch 11/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 11 Summary: Average Loss = 1368.9609
EarlyStopping counter: 2 out of 7


Epoch 12/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 12 Summary: Average Loss = 1362.2470
Validation loss decreased (1367.706961 --> 1362.246980). Saving best model...


Epoch 13/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 13 Summary: Average Loss = 1363.2346
EarlyStopping counter: 1 out of 7


Epoch 14/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 14 Summary: Average Loss = 1350.2438
Validation loss decreased (1362.246980 --> 1350.243837). Saving best model...


Epoch 15/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 15 Summary: Average Loss = 1309.2655
Validation loss decreased (1350.243837 --> 1309.265507). Saving best model...


Epoch 16/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 16 Summary: Average Loss = 1292.5524
Validation loss decreased (1309.265507 --> 1292.552408). Saving best model...


Epoch 17/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 17 Summary: Average Loss = 1264.4738
Validation loss decreased (1292.552408 --> 1264.473836). Saving best model...


Epoch 18/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 18 Summary: Average Loss = 1249.7422
Validation loss decreased (1264.473836 --> 1249.742236). Saving best model...


Epoch 19/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 19 Summary: Average Loss = 1224.1528
Validation loss decreased (1249.742236 --> 1224.152788). Saving best model...


Epoch 20/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 20 Summary: Average Loss = 1212.3201
Validation loss decreased (1224.152788 --> 1212.320133). Saving best model...


Epoch 21/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 21 Summary: Average Loss = 1196.7811
Validation loss decreased (1212.320133 --> 1196.781138). Saving best model...


Epoch 22/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 22 Summary: Average Loss = 1192.2698
Validation loss decreased (1196.781138 --> 1192.269826). Saving best model...


Epoch 23/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 23 Summary: Average Loss = 1185.5230
Validation loss decreased (1192.269826 --> 1185.523020). Saving best model...


Epoch 24/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 24 Summary: Average Loss = 1183.5818
Validation loss decreased (1185.523020 --> 1183.581792). Saving best model...


Epoch 25/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 25 Summary: Average Loss = 1176.7248
Validation loss decreased (1183.581792 --> 1176.724815). Saving best model...


Epoch 26/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 26 Summary: Average Loss = 1169.8698
Validation loss decreased (1176.724815 --> 1169.869798). Saving best model...


Epoch 27/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 27 Summary: Average Loss = 1160.4320
Validation loss decreased (1169.869798 --> 1160.431989). Saving best model...


Epoch 28/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 28 Summary: Average Loss = 1161.9384
EarlyStopping counter: 1 out of 7


Epoch 29/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 29 Summary: Average Loss = 1167.0406
EarlyStopping counter: 2 out of 7


Epoch 30/30:   0%|          | 0/96 [00:00<?, ?it/s]

Epoch 30 Summary: Average Loss = 1164.8102
EarlyStopping counter: 3 out of 7

--- Training Finished ---
Final encoder saved to /kaggle/working/barlow_final_encoder.pth
Best encoder (lowest loss) saved to /kaggle/working/barlow_best_encoder_10pct.pth
