### Training Resnet 18 for Barlow Twins

## 1. Imports

In [None]:
# --- Standard Libraries ---
import os
import time
import h5py  # For reading and writing HDF5 files
import numpy as np
from glob import iglob  # For finding files with glob patterns
from PIL import Image  # For image manipulation
from tqdm import tqdm, trange  # For progress bars

# --- PyTorch Core Libraries ---
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn  # For CUDA optimizations
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter  # For logging to TensorBoard
from torch.cuda.amp import GradScaler, autocast  # For mixed-precision training
from torchvision.models import resnet18
from torch.utils.checkpoint import checkpoint  # For gradient checkpointing

# --- Additional Libraries ---
import torch_optimizer as optim_extra  # For extra optimizers like LARS
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
from torchvision.transforms import InterpolationMode

# Print the PyTorch version to verify the environment setup
print(f"PyTorch version: {torch.__version__}")

## 2. Configuration

In [None]:
# --- Configuration ---
# Choose one of the following configurations by uncommenting the desired block.

# --- Option 1: Prepare dataset ---
# mode = "prepare"
# root_dir = "data/s2_rgb/0k_251k_uint8_jpeg_tif/rgb"
# h5_path = "data/s2_rgb/augmented_dataset_100000_res18.h5"
# logdir = "runs/barlow_twins_ssl4eo_rgb_100000_res18"
# max_samples = 100000
# epochs = 100

# --- Option 2: Train on dataset (CURRENTLY ACTIVE) ---
mode = "train"
root_dir = "data/s2_rgb/0k_251k_uint8_jpeg_tif/rgb" # Directory containing the dataset
h5_path = "data/s2_rgb/augmented_dataset_100000_res18.h5" # Path to the HDF5 file containing the dataset
logdir = "runs/barlow_twins_ssl4eo_rgb_100000_res18" # Directory for TensorBoard logs
max_samples = 100000
epochs = 100

## 3. Dataset Definitions

In [None]:
class TiffRGBDataset(Dataset):
    """A PyTorch Dataset for loading TIFF images and converting them to RGB."""

    def __init__(self, root_dir, max_samples=None, transform=None):
        """
        Args:
            root_dir (str): The root directory containing the TIFF images.
            max_samples (int, optional): Maximum number of samples to load. Defaults to None (load all).
            transform (callable, optional): A function/transform to apply to the images. Defaults to None.
        """
        # Expand user path (e.g., '~') and store the transform
        root_dir = os.path.expanduser(root_dir)
        self.transform = transform

        # Create a pattern to find all .tif/.TIF files recursively
        pattern = os.path.join(root_dir, '**', '*.[tT][iI][fF]')

        # Find all samples, sort them, and optionally limit the number of samples
        all_samples = sorted(iglob(pattern, recursive=True))
        self.samples = all_samples[:max_samples] if max_samples else all_samples

        print(f"[DEBUG] Found {len(all_samples)} total images. Loading {len(self.samples)} samples.")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Fetches the image at the given index, converts it to RGB, and applies the transform."""
        # Open the image file, ensure it is in RGB format
        img = Image.open(self.samples[idx]).convert('RGB')

        # Apply the transform if it is provided
        if self.transform:
            return self.transform(img)
        return img

In [None]:
# --- HDF5 Dataset ---
class HDF5Dataset(Dataset):
    """A PyTorch Dataset for loading data from an HDF5 file."""
    def __init__(self, path):
        self.f = h5py.File(path, "r")
        self.v1, self.v2 = self.f["view1"], self.f["view2"]

    def __len__(self):
        return self.v1.shape[0]

    def __getitem__(self, idx):
        # Load the image pair and normalize to [0, 1]
        i1 = torch.from_numpy(self.v1[idx].astype(np.float32) / 255.)
        i2 = torch.from_numpy(self.v2[idx].astype(np.float32) / 255.)
        return i1, i2

## 4. Data Preparation & Augmentation

In [None]:
# This cell calculates the mean and standard deviation of the dataset.
# These statistics are useful for normalization.

# Define a simple transformation to resize images and convert them to a tensor
transform = T.Compose([
    T.Resize((224, 224)),  # Ensure all images have the same size
    T.ToTensor(),           # Convert PIL Image to a PyTorch tensor
]) 

# Initialize the dataset with a large number of samples for accurate statistics
dataset = TiffRGBDataset(
    root_dir="data/s2_rgb/0k_251k_uint8_jpeg_tif/rgb",
    max_samples=100000,  # Using 100k images for calculation
    transform=transform
)

# Create a DataLoader to iterate over the dataset in batches
loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,  # No need to shuffle for calculating statistics
    num_workers=0,  # Set to 0 for this calculation to avoid multi-processing issues
    pin_memory=True
)

# Initialize tensors to accumulate the sum and sum of squares of pixel values
sum_ = torch.zeros(3)    # Sum of pixel values for each of the 3 (RGB) channels
sum_sq = torch.zeros(3)  # Sum of squared pixel values for each channel
num_pixels = 0           # Total number of pixels processed

# Iterate through the dataset with a progress bar
for batch in tqdm(loader, desc="Computing mean and std"):
    # Ensure the batch is in float32 for precise calculations
    batch = batch.to(torch.float32)
    b, c, h, w = batch.shape  # Batch size, channels, height, width

    # Update accumulators
    sum_ += batch.sum(dim=[0, 2, 3])         # Sum pixel values across batch, height, and width
    sum_sq += (batch ** 2).sum(dim=[0, 2, 3])  # Sum squared pixel values
    num_pixels += b * h * w                    # Increment the total pixel count

# Calculate the mean and standard deviation
mean = sum_ / num_pixels
std = torch.sqrt(sum_sq / num_pixels - mean ** 2)

# Print the final results
print(f"Computed Mean: {mean.tolist()}")
print(f"Computed Std:  {std.tolist()}")

In [None]:
# This cell defines the data augmentations for Barlow Twins.

# --- Augmentation Setup ---
# The following mean and std values were pre-calculated on a 100k image subset.
mean_list = [0.4824, 0.4808, 0.4779]
std_list  = [0.1902, 0.1688, 0.1462]

# Convert lists to PyTorch tensors
mean = torch.tensor(mean_list, dtype=torch.float32)
std  = torch.tensor(std_list,  dtype=torch.float32)

# Define the base set of transformations applied to both views
base_transforms = [
    T.RandomResizedCrop(224, scale=(0.08, 1.0), interpolation=InterpolationMode.BICUBIC),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomApply([T.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=23)], p=0.5), # kernel_size must be odd
]

# Create the first transformation pipeline
transform_1 = T.Compose(base_transforms + [
    T.ToTensor(),
    T.Normalize(mean=mean, std=std),
])

# Create the second transformation pipeline, adding RandomSolarize
transform_2 = T.Compose(base_transforms + [
    T.RandomSolarize(threshold=0.5, p=0.2),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std),
])

# --- Two-Crop Transform for Barlow Twins ---
class TwoCropTransformBT:
    """A transform that creates two different augmented views of the same image."""
    def __init__(self, t1, t2):
        self.t1 = t1
        self.t2 = t2

    def __call__(self, img):
        """Applies the two transforms and returns a pair of augmented images."""
        return self.t1(img), self.t2(img)

# --- Visualization Function ---
def log_augment_steps(img: Image.Image, writer: SummaryWriter, base_transforms: list, step: int = 0):
    """Logs the result of each base augmentation step to TensorBoard."""
    # Define the operations with descriptive names
    ops = [
        ("01_ResizeCrop", base_transforms[0]),
        ("02_HFlip", base_transforms[1]),
        ("03_ColorJitter", base_transforms[2].transforms[0] if isinstance(base_transforms[2], T.RandomApply) else base_transforms[2]),
        ("04_Gray", base_transforms[3]),
        ("05_GaussianBlur", base_transforms[4].transforms[0] if isinstance(base_transforms[4], T.RandomApply) else base_transforms[4]),
    ]

    x = img
    # Apply each operation sequentially and log the result
    for name, op in ops:
        x = op(x)
        # Convert the image to a tensor to be logged
        t = T.ToTensor()(x)
        writer.add_image(f"Augment/{name}", torchvision.utils.make_grid(t.unsqueeze(0), normalize=True), step)

In [None]:
def prepare_h5(root_dir, out_path, max_samples, log_dir):
    """
    Preprocesses the dataset by applying augmentations and saving the results
    into an HDF5 file for faster loading during training.
    """
    # Create a SummaryWriter for TensorBoard logging
    writer = SummaryWriter(log_dir=os.path.join(log_dir, "prep"))
    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    # Initialize the dataset with the two-crop transform
    dataset = TiffRGBDataset(
        root_dir=root_dir,
        transform=TwoCropTransformBT(transform_1, transform_2),
        max_samples=max_samples
    )

    # Log dataset information to TensorBoard
    writer.add_text("Dataset/Info",
                    f"Root: {root_dir}\nSamples: {len(dataset)}\nTransforms: {base_transforms}",
                    global_step=0)

    # Log an original sample image for reference
    sample_orig = Image.open(dataset.samples[min(1, len(dataset)-1)]).convert("RGB")
    orig_t = T.ToTensor()(sample_orig)
    writer.add_image("Dataset/OriginalSample",
                     torchvision.utils.make_grid(orig_t.unsqueeze(0), normalize=True),
                     global_step=0)

    # Log the augmentation steps for visualization
    log_augment_steps(sample_orig, writer, base_transforms, step=0)

    # Log a histogram of the crop scales to understand the distribution
    scales = []
    for _ in range(100):
        _, _, h, _ = T.RandomResizedCrop.get_params(sample_orig, scale=(0.8, 1.0), ratio=(1, 1))
        scales.append(h / 224)
    writer.add_histogram("Augment/ScaleDist", torch.tensor(scales), global_step=0)

    # --- Write to HDF5 file ---
    N = len(dataset)
    with h5py.File(out_path, "w") as f:
        # Create datasets for the two augmented views
        d1 = f.create_dataset("view1", (N, 3, 224, 224), dtype="uint8")
        d2 = f.create_dataset("view2", (N, 3, 224, 224), dtype="uint8")

        # Iterate through the dataset and save the augmented views
        for i in trange(N, desc="Writing to HDF5"):
            x1, x2 = dataset[i]
            # Scale to [0, 255] and save as byte to save space
            d1[i] = (x1.mul(255).byte().numpy())
            d2[i] = (x2.mul(255).byte().numpy())

    writer.close()
    print(f"HDF5 file saved to: {out_path}")

## 5. Model and Loss Function

In [None]:
# --- Barlow Twins Model ---
class BarlowTwinsModel(nn.Module):
    """The Barlow Twins model architecture."""
    def __init__(self, proj_dim=2048, hidden_dim=8192):
        super().__init__()
        # 1) Backbone: Pretrained ResNet-18
        self.backbone = resnet18(pretrained=True)
        feat_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()  # Remove the final classification layer

        # 2) Projector: A 3-layer MLP
        self.projector = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, proj_dim, bias=False),
            nn.BatchNorm1d(proj_dim),
        )

    def forward(self, x1, x2):
        """Passes both augmented views through the backbone and projector."""
        z1 = self.projector(self.backbone(x1))
        z2 = self.projector(self.backbone(x2))
        return z1, z2

# --- Loss Function ---
def off_diagonal(x):
    """Returns the off-diagonal elements of a matrix."""
    n, _ = x.shape
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def barlow_twins_loss(z1, z2, lambda_offdiag=5e-3):
    """Calculates the Barlow Twins loss."""
    B, D = z1.size()  # Batch size, Dimensionality

    # 1) Normalize the representations
    z1_norm = (z1 - z1.mean(0)) / z1.std(0)
    z2_norm = (z2 - z2.mean(0)) / z2.std(0)

    # 2) Compute the cross-correlation matrix
    C = (z1_norm.T @ z2_norm) / B

    # 3) Calculate the loss
    on_diag = torch.diagonal(C)
    invariance_loss = ((on_diag - 1)**2).sum()
    redundancy_loss = (off_diagonal(C)**2).sum()
    return invariance_loss + lambda_offdiag * redundancy_loss

## 6. Training

In [None]:
# --- Training Function ---
def train(h5_path, log_dir, total_epochs=100, batch_size=32, accum_steps=4, lr=5e-5):
    """The main training loop."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load the HDF5 dataset
    ds = HDF5Dataset(h5_path)
    print(f"Number of samples in dataset: {len(ds)}")

    # Create the DataLoader
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True,
                        num_workers=8, pin_memory=True,
                        persistent_workers=True, prefetch_factor=2, drop_last=True)

    if len(loader) == 0:
        print("DataLoader is empty. Cannot proceed with training.")
        return

    # Initialize the model and move it to the device
    model = BarlowTwinsModel().to(device)

    # LARS Optimizer
    optimizer = optim_extra.LARS(
        model.parameters(),
        lr=lr,  # Base LR is scaled by batch_size / 256 in the paper
        weight_decay=1e-6,
        momentum=0.9,
    )

    # Learning rate scheduler with warmup
    warmup_epochs = 10
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.01, total_iters=warmup_epochs),
            CosineAnnealingLR(optimizer, T_max=total_epochs - warmup_epochs)
        ],
        milestones=[warmup_epochs]
    )

    # GradScaler for mixed-precision training
    scaler = torch.cuda.amp.GradScaler()
    writer = SummaryWriter(log_dir=os.path.join(log_dir, "train"))

    # Create checkpoint directory
    ckpt_dir = os.path.join(log_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    # --- Training Loop ---
    for epoch in range(total_epochs):
        start_time = time.time()
        loss_accumulator = 0
        model.train()
        optimizer.zero_grad()

        for batch_idx, (x1, x2) in enumerate(loader):
            x1, x2 = x1.to(device), x2.to(device)

            # Forward pass with automatic mixed precision
            with torch.amp.autocast():
                z1, z2 = model(x1, x2)
                loss = barlow_twins_loss(z1, z2) / accum_steps

            # Backward pass
            scaler.scale(loss).backward()
            loss_accumulator += loss.item() * accum_steps

            # Gradient accumulation
            if (batch_idx + 1) % accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            # Log batch-level stats
            step = epoch * len(loader) + batch_idx
            writer.add_scalar("Loss/train_batch", loss.item() * accum_steps, step)
            writer.add_scalar("LR", optimizer.param_groups[0]['lr'], step)

        # --- End of Epoch ---
        scheduler.step()
        epoch_loss = loss_accumulator / len(loader) if len(loader) > 0 else 0
        writer.add_scalar("Loss/train_epoch", epoch_loss, epoch)
        writer.add_scalar("Time/epoch", time.time() - start_time, epoch)
        print(f"Epoch {epoch+1}/{total_epochs}, Loss: {epoch_loss:.4f}")

        # Save checkpoint
        if (epoch + 1) % 10 == 0 or (epoch + 1) == total_epochs:
            ckpt_path = os.path.join(ckpt_dir, f"barlow_epoch_{epoch+1:03d}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
            }, ckpt_path)
            print(f"[CHECKPOINT] Saved to {ckpt_path}")

    writer.close()

## 7. Execution

In [None]:
# --- Execution ---
# This cell runs the appropriate function based on the 'mode' variable
# defined in the configuration cells above.

if mode == "prepare":
    # If in 'prepare' mode, run the HDF5 creation process
    prepare_h5(root_dir, h5_path, max_samples, logdir)
elif mode == "train":
    # If in 'train' mode, start the training process
    train(h5_path, logdir, total_epochs=epochs)
else:
    # Raise an error for an unknown mode
    raise ValueError(f"Unknown mode: {mode}")