In [1]:
# ==============================================================================
# 04b - RESUME BARLOW TWINS TRAINING (ON NEW ACCOUNT)
# ==============================================================================
# Purpose: To resume the Barlow Twins training from a checkpoint generated
#          by a different Kaggle account's notebook version.
# ==============================================================================

# --- 0. Imports ---
!pip install "zarr>=2.10.0" numcodecs -q
import torch, torch.nn as nn, torch.nn.functional as F, torchvision.transforms as T, torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np, zarr, random, math, os, copy, zipfile, tempfile, gc
from tqdm.notebook import tqdm

# --- 1. Configuration ---
CONFIG = {
    'data_path': "/kaggle/input/01-data-preparation/data/ssl4eo-s12/train/S2RGB",
    'checkpoint_input_folder': '04-barlow-twins-50pct', 
    'output_dir': "/kaggle/working/",
    'epochs': 30,
    'batch_size': 128,
    'learning_rate': 1e-3,
    'weight_decay': 1.5e-6,
    'image_size': 224,
    'projection_dim': 2048,
    'hidden_dim': 4096,
    'lambda_param': 5e-3,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'early_stopping_patience': 10,
    'data_subset_percentage': 0.50,
    'checkpoint_filename': "barlow_checkpoint_50pct.pth",
    'best_model_filename': "barlow_best_encoder_50pct.pth"
}

# --- 2. Helper Classes & Functions ---
class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0):
        self.patience, self.verbose, self.delta = patience, verbose, delta
        self.counter, self.best_score, self.early_stop, self.val_loss_min = 0, None, False, np.Inf
    def __call__(self, val_loss, model, path):
        if not math.isfinite(val_loss): print("Loss is not finite, stopping."); self.early_stop = True; return
        score = -val_loss
        if self.best_score is None: self.best_score = score; self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else: self.best_score = score; self.save_checkpoint(val_loss, model, path); self.counter = 0
    def save_checkpoint(self, val_loss, model, path):
        if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving best model...')
        torch.save(model.encoder.state_dict(), path); self.val_loss_min = val_loss

def save_checkpoint(epoch, model, optimizer, scheduler, loss, config):
    state = {'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': loss}
    torch.save(state, os.path.join(config['output_dir'], config['checkpoint_filename']))

def load_checkpoint(model, optimizer, scheduler, config):
    path = Path(f"/kaggle/input/{config['checkpoint_input_folder']}/{config['checkpoint_filename']}")
    start_epoch = 0
    if path.exists():
        print(f"✅ Resuming from checkpoint: {path}")
        ckpt = torch.load(path, map_location=config['device'])
        model.load_state_dict(ckpt['model_state_dict']); optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict']); start_epoch = ckpt['epoch']
        print(f"   -> Resuming training from the start of epoch {start_epoch}")
    else: print(f"❌ Checkpoint not found at {path}. Cannot resume.")
    return start_epoch

# --- 3. Dataset, Model, and Loss ---
class SSL4EODataset(Dataset):
    def __init__(self, root_dir, subset_percentage=1.0):
        self.root_dir = Path(root_dir)
        all_zarr_files = sorted(list(self.root_dir.glob("*.zarr.zip")))
        num_files_to_use = int(len(all_zarr_files) * subset_percentage)
        self.zarr_files = random.sample(all_zarr_files, num_files_to_use) if subset_percentage < 1.0 else all_zarr_files
        self.images = self._preload_images(); self.total_images = len(self.images)
        print(f"\nDataset Initialized: Using {subset_percentage*100:.0f}% of data. -> Pre-loaded {self.total_images:,} images.")
        self.transform = T.Compose([
            T.ToPILImage(), T.RandomResizedCrop(CONFIG['image_size'], antialias=True), T.RandomHorizontalFlip(),
            T.ColorJitter(0.4, 0.4, 0.2, 0.1), T.RandomGrayscale(p=0.2), T.GaussianBlur(23, (0.1, 2.0)),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def _preload_images(self):
        images = []; print("Pre-loading images into RAM...")
        for fp in tqdm(self.zarr_files, desc="Loading files"):
            with tempfile.TemporaryDirectory() as td:
                with zipfile.ZipFile(str(fp), 'r') as zf: zf.extractall(td)
                za = zarr.open(td, mode='r')['bands'][:]; images.extend(za.reshape(-1, *za.shape[2:]))
        return images
    def __len__(self): return self.total_images
    def __getitem__(self, idx):
        image_chw = self.images[idx]; image_hwc = np.transpose(image_chw, (1, 2, 0))
        view1 = self.transform(image_hwc); view2 = self.transform(image_hwc)
        return view1, view2

class BarlowTwinsModel(nn.Module):
    def __init__(self, encoder, encoder_dim, projection_dim, hidden_dim):
        super().__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True),
            nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(True),
            nn.Linear(hidden_dim, projection_dim))
    def forward(self, v1, v2):
        z1 = self.projector(self.encoder(v1)); z2 = self.projector(self.encoder(v2))
        return z1, z2

def barlow_twins_loss_fn(z1, z2, lambda_param):
    batch_size = z1.shape[0]
    z1_norm = (z1 - z1.mean(dim=0)) / (z1.std(dim=0) + 1e-5)
    z2_norm = (z2 - z2.mean(dim=0)) / (z2.std(dim=0) + 1e-5)
    c = (z1_norm.T @ z2_norm) / batch_size
    on_diag_loss = ((torch.diagonal(c) - 1)**2).sum()
    off_diag = c.fill_diagonal_(0)
    off_diag_loss = (off_diag**2).sum()
    return on_diag_loss + lambda_param * off_diag_loss

# --- 4. Main Training Execution ---
if __name__ == '__main__':
    gc.collect(); torch.cuda.empty_cache()
    device = torch.device(CONFIG['device']); print(f"Using device: {device}")
    
    dataset = SSL4EODataset(CONFIG['data_path'], CONFIG['data_subset_percentage'])
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                        num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
    
    print("Using ResNet-18 as the encoder model.")
    resnet = models.resnet18(weights=None); encoder_output_dim = resnet.fc.in_features; resnet.fc = nn.Identity()
    
    model = BarlowTwinsModel(
        encoder=resnet, encoder_dim=encoder_output_dim,
        projection_dim=CONFIG['projection_dim'], hidden_dim=CONFIG['hidden_dim']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(loader)*CONFIG['epochs'])
    
    early_stopper = EarlyStopping(patience=CONFIG['early_stopping_patience'], verbose=True)
    
    # This will load the state from the previous 12-hour run
    start_epoch = load_checkpoint(model, optimizer, scheduler, CONFIG)
    
    if start_epoch > 0:
        print(f"\n--- Resuming Barlow Twins Training from Epoch {start_epoch} ---")
        # The training loop starts from the next epoch
        for epoch in range(start_epoch, CONFIG['epochs']):
            model.train(); total_loss = 0.0
            pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
            for view1, view2 in pbar:
                view1, view2 = view1.to(device), view2.to(device)
                z1, z2 = model(view1, view2)
                loss = barlow_twins_loss_fn(z1, z2, lambda_param=CONFIG['lambda_param'])
                
                optimizer.zero_grad(); loss.backward(); optimizer.step(); scheduler.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'Loss': f"{loss.item():.2f}", 'LR': f"{optimizer.param_groups[0]['lr']:.6f}"})

            avg_loss = total_loss / len(loader)
            print(f"Epoch {epoch+1} Summary: Average Loss = {avg_loss:.2f}")
            
            save_checkpoint(epoch, model, optimizer, scheduler, avg_loss, CONFIG)
            early_stopper(avg_loss, model, os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename']))
            if early_stopper.early_stop:
                print("Early stopping triggered."); break

        print("\n--- Training Finished ---")
        final_encoder_path = os.path.join(CONFIG['output_dir'], "barlow_final_encoder_50pct.pth")
        torch.save(model.encoder.state_dict(), final_encoder_path)
        print(f"Final encoder saved to {final_encoder_path}")
        print(f"Best encoder (lowest loss) saved to {os.path.join(CONFIG['output_dir'], CONFIG['best_model_filename'])}")
        
    else:
        print("Could not find a valid checkpoint to resume from. Halting execution.")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/205.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m204.8/205.4 kB[0m [31m7.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/205.4 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/8.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/8.8 MB[0m [31m58.5 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m8.4/8.8 MB[0m [31m132.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m8.4/8.8 MB[0m [31m132.6 MB/s[0m eta [36m0:00:01[0m

[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m8.4/8.8 MB[0m [31m132.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m8.8/8.8 MB[0m [31m58.0 MB/s[0m eta [36m0:00:01[0m

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m44.8 MB/s[0m eta [36m0:00:00[0m
[?25h

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.7/53.7 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h

Using device: cuda
Pre-loading images into RAM...


Loading files:   0%|          | 0/241 [00:00<?, ?it/s]


Dataset Initialized: Using 50% of data. -> Pre-loaded 61,696 images.
Using ResNet-18 as the encoder model.


✅ Resuming from checkpoint: /kaggle/input/04-barlow-twins-50pct/barlow_checkpoint_50pct.pth


   -> Resuming training from the start of epoch 8

--- Resuming Barlow Twins Training from Epoch 8 ---


Epoch 9/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 9 Summary: Average Loss = 663.90


Validation loss decreased (inf --> 663.899121). Saving best model...


Epoch 10/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 10 Summary: Average Loss = 619.46


Validation loss decreased (663.899121 --> 619.456833). Saving best model...


Epoch 11/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 11 Summary: Average Loss = 583.80


Validation loss decreased (619.456833 --> 583.804136). Saving best model...


Epoch 12/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 12 Summary: Average Loss = 557.73


Validation loss decreased (583.804136 --> 557.726988). Saving best model...


Epoch 13/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 13 Summary: Average Loss = 535.54


Validation loss decreased (557.726988 --> 535.536615). Saving best model...


Epoch 14/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 14 Summary: Average Loss = 515.72


Validation loss decreased (535.536615 --> 515.722460). Saving best model...


Epoch 15/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 15 Summary: Average Loss = 499.11


Validation loss decreased (515.722460 --> 499.110033). Saving best model...


Epoch 16/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 16 Summary: Average Loss = 484.36


Validation loss decreased (499.110033 --> 484.361633). Saving best model...


Epoch 17/30:   0%|          | 0/482 [00:00<?, ?it/s]

Epoch 17 Summary: Average Loss = 470.28


Validation loss decreased (484.361633 --> 470.281585). Saving best model...


Epoch 18/30:   0%|          | 0/482 [00:00<?, ?it/s]