### Training Resnet 50

In [1]:
# Imports & Setup
import os
import time
import h5py
import numpy as np
from glob import iglob
from PIL import Image
from tqdm import tqdm, trange

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torchvision.models import resnet50
from torch.utils.checkpoint import checkpoint
import torch_optimizer as optim_extra
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
from torchvision import transforms
from torchvision.transforms import InterpolationMode

# Print the PyTorch version to verify environment setup
print("PyTorch version:", torch.__version__)

2025-06-23 08:53:23.200522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750668803.221341   19581 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750668803.227707   19581 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750668803.244076   19581 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750668803.244097   19581 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750668803.244099   19581 computation_placer.cc:177] computation placer alr

PyTorch version: 2.7.1+cu126


In [2]:
# Dataset class for loading TIFF images
class TiffRGBDataset(Dataset):
    """Images in TIFF Format"""
    def __init__(self, root_dir, max_samples=None, transform=None):
        # Expand user path and store transform
        root_dir = os.path.expanduser(root_dir)
        self.transform = transform
        # Create a pattern to find all .tif/.TIF files recursively
        pattern = os.path.join(root_dir, '**', '*.[tT][iI][fF]')
        # Count total files matching the pattern
        total = sum(1 for _ in iglob(pattern, recursive=True))
        # Optionally limit number of samples
        self.samples = sorted(iglob(pattern, recursive=True))[:max_samples] if max_samples else sorted(iglob(pattern, recursive=True))
        print(f"[DEBUG] {len(self.samples)} TIFF images loaded")  # Debug print

    def __len__(self):
        # Return number of samples
        return len(self.samples)

    def __getitem__(self, idx):
        # Open image, convert to RGB, and apply transform if provided
        img = Image.open(self.samples[idx]).convert('RGB')
        return self.transform(img) if self.transform else img


In [3]:
# Define transformation to resize images and convert to tensor
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize all images to 224x224
    transforms.ToTensor(),           # Convert PIL Image to PyTorch tensor
])

# Initialize the TIFF dataset with up to 100k samples
dataset = TiffRGBDataset(
    root_dir="data/s2_rgb/0k_251k_uint8_jpeg_tif/rgb",
    max_samples=100000,
    transform=transform
)

# Create DataLoader for batching and loading data
loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

# Initialize accumulators for mean and standard deviation calculation
sum_ = torch.zeros(3)    # Sum of pixel values per channel
sum_sq = torch.zeros(3)  # Sum of squared pixel values per channel
cnt = 0                  # Total number of pixels processed

# Iterate through the dataset in batches
for batch in tqdm(loader, desc="Computing mean and std"):  # Progress bar for monitoring
    batch = batch.to(torch.float32)  # Ensure data is float32 for numerical stability
    b, c, h, w = batch.shape          # Batch size, channels, height, width
    # Accumulate sum of pixels and sum of squared pixels
    sum_ += batch.sum(dim=[0, 2, 3])         # Sum across batch, height, and width dims
    sum_sq += (batch ** 2).sum(dim=[0, 2, 3])  # Sum of squares for each channel
    cnt += b * h * w                           # Increment pixel count

# Compute mean and standard deviation for each channel
mean = sum_ / cnt
std = torch.sqrt(sum_sq / cnt - mean ** 2)

# Print the computed statistics
print(f"Mean: {mean.tolist()}")
print(f"Std:  {std.tolist()}")

KeyboardInterrupt: 

In [3]:
# Augmentation in for Barlow Twins
#Mean for subset RGB tensor([0.3601, 0.3573, 0.3337]) 
#STD for subset RGB tensor([0.2403, 0.2317, 0.2347])
#Mean: for 100k image set [0.4824180603027344, 0.4808058738708496, 0.47794070839881897]
# Std: for 100k image set [0.19021621346473694, 0.16879530251026154, 0.14623168110847473]


# Aus Deiner Ausgabe kopiert
mean_list = [0.4824180603027344, 0.4808058738708496, 0.47794070839881897]
std_list  = [0.19021621346473694, 0.16879530251026154, 0.14623168110847473]

# In Torch-Tensoren umwandeln
mean = torch.tensor(mean_list, dtype=torch.float32)
std  = torch.tensor(std_list,  dtype=torch.float32)

base = [
    T.RandomResizedCrop(224, scale=(0.08,1.0), interpolation=InterpolationMode.BICUBIC),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomApply([T.ColorJitter(0.8,0.8,0.8,0.2)], p=0.8),
    T.RandomGrayscale(p=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=int(0.1*224)|1)], p=0.5),
]

transform_1 = T.Compose(base + [
    T.ToTensor(),
    T.Normalize(mean, std),
])
transform_2 = T.Compose(base + [
    T.RandomSolarize(threshold=0.5, p=0.2),
    T.ToTensor(),
    T.Normalize(mean, std),
])

# TwoCropTransform for Barlow Twins
class TwoCropTransformBT:
    def __init__(self, t1, t2):
        self.t1 = t1
        self.t2 = t2
    def __call__(self, img):
        return self.t1(img), self.t2(img)

# Augmentation-Schritte visualisieren
def log_augment_steps(img: Image.Image, writer: SummaryWriter, base_transforms: list, step: int = 0):
    """
    Zeigt jeweils das Ergebnis nach jedem Basisschritt.
    """
    ops = [
        ("01_ResizeCrop", base_transforms[0]),
        ("02_HFlip",       base_transforms[1]),
        ("03_ColorJitter", base_transforms[2].transforms[0] if isinstance(base_transforms[2], T.RandomApply) else base_transforms[2]),
        ("04_Gray",        base_transforms[3]),
        ("05_GaussianBlur", base_transforms[4].transforms[0] if isinstance(base_transforms[4], T.RandomApply) else base_transforms[4]),
    ]
    x = img
    for name, op in ops:
        x = op(x)
        t = T.ToTensor()(x)
        writer.add_image(f"Augment/{name}", torchvision.utils.make_grid(t.unsqueeze(0), normalize=True), step)

In [4]:
# Funktion zum Erstellen des HDF5-Datensatzes

def prepare_h5(root_dir, out_path, max_samples, log_dir):
    writer = SummaryWriter(log_dir=os.path.join(log_dir, "prep"))
    os.makedirs(os.path.dirname(out_path), exist_ok=True)


    dataset = TiffRGBDataset(
        root_dir=root_dir,
        transform=TwoCropTransformBT(transform_1, transform_2),
        max_samples=max_samples
    )
    
    writer.add_text("Dataset/Info",
                    f"Root: {root_dir}\nSamples: {len(dataset)}\nTransforms: {base}",
                    global_step=0)

    # Original-Sample loggen
    sample_orig = Image.open(dataset.samples[min(1, len(dataset)-1)]).convert("RGB")
    orig_t = T.ToTensor()(sample_orig)
    writer.add_image("Dataset/OriginalSample",
                     torchvision.utils.make_grid(orig_t.unsqueeze(0), normalize=True),
                     global_step=0)

    log_augment_steps(sample_orig, writer, base, step=0)

    # Histogramm der Crop-Skalen
    scales = []
    for _ in range(100):
        _, _, h, _ = T.RandomResizedCrop.get_params(sample_orig, scale=(0.8,1.0), ratio=(1,1))
        scales.append(h/224)
    writer.add_histogram("Augment/ScaleDist", torch.tensor(scales), global_step=0)

    # HDF5 schreiben
    N = len(dataset)
    with h5py.File(out_path, "w") as f:
        d1 = f.create_dataset("view1", (N, 3, 224, 224), dtype="uint8")
        d2 = f.create_dataset("view2", (N, 3, 224, 224), dtype="uint8")
        
        for i in trange(N, desc="Schreibe HDF5"):
            x1, x2 = dataset[i]
            d1[i] = (x1.mul(255).byte().numpy())  # Skaliere zu [0, 255] und speichere als Byte
            d2[i] = (x2.mul(255).byte().numpy())  # Für beide Ansichten (view1 und view2)

    writer.close()
    print(f"HDF5 gespeichert: {out_path}")


In [5]:
# Modell, Loss-Funktion & Trainings-Loop

class HDF5Dataset(Dataset):
    def __init__(self, path):
        self.f = h5py.File(path, "r")
        self.v1, self.v2 = self.f["view1"], self.f["view2"]
    def __len__(self): return self.v1.shape[0]
    def __getitem__(self, idx):
        # Gibt das Bildpaar als Tensor zurück
        i1 = torch.from_numpy(self.v1[idx].astype(np.float32) / 255.)  # Normalisieren auf [0,1]
        i2 = torch.from_numpy(self.v2[idx].astype(np.float32) / 255.)       
        return i1, i2

class BarlowTwinsModel(nn.Module):
    def __init__(self, proj_dim=2048, hidden_dim=8192):
        super().__init__()
        # 1) Pretrained ResNet-50 als Encoder
        self.backbone = resnet50(pretrained=True)
        feat_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        # 2) Projector mit 3 Schichten
        self.projector = nn.Sequential(
            nn.Linear(feat_dim,   hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, proj_dim,   bias=False),
            nn.BatchNorm1d(proj_dim),
        )

    def forward(self, x1, x2):
        return self.projector(self.backbone(x1)), self.projector(self.backbone(x2))

def off_diagonal(x):
    n,_ = x.shape
    return x.flatten()[:-1].view(n-1,n+1)[:,1:].flatten()

def barlow_twins_loss(z1, z2, lambda_offdiag=5e-3):
    # z1, z2: [B, D]
    B, D = z1.size()
    # 1) Standardisierung
    z1 = (z1 - z1.mean(0)) / z1.std(0)
    z2 = (z2 - z2.mean(0)) / z2.std(0)
    # 2) Korrelationsmatrix
    C = (z1.T @ z2) / B   # [D, D]
    # 3) Loss
    diag_loss    = torch.sum((torch.diagonal(C) - 1) ** 2)
    offdiag_loss = torch.sum(C**2) - torch.sum(torch.diagonal(C)**2)
    return diag_loss + lambda_offdiag * offdiag_loss
    

def train(h5_path, log_dir, total_epochs=100, batch_size=32, accum_steps=4, lr=5e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device:", device)
    
    # Dataset laden
    ds = HDF5Dataset(h5_path)
    print(f"Anzahl der Samples im Dataset: {len(ds)}")  # Debugging-Ausgabe

    # Überprüfen, ob DataLoader korrekt initialisiert wurde
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True,
                        num_workers=8, pin_memory=True,
                        persistent_workers=True, prefetch_factor=2, drop_last=True)
    
    # Wenn der DataLoader leer ist, die Funktion verlassen
    if len(loader) == 0:
        print("Der DataLoader enthält keine Daten. Training kann nicht fortgesetzt werden.")
        return

    model = BarlowTwinsModel().to(device)

    # LARS-Optimizer
    opt = optim_extra.LARS(
        model.parameters(),
        lr=lr,       # Basis-LR skaliert mit Batch-Size / 256
        weight_decay=1e-6,
        momentum=0.9,
    )
    
    # Warmup und Cosine Learning Rate Scheduler
    warmup_epochs = 10
    
    sched = torch.optim.lr_scheduler.SequentialLR(
        opt,
        schedulers=[
            LinearLR(opt, start_factor=0.01, total_iters=warmup_epochs),
            CosineAnnealingLR(opt, T_max=total_epochs-warmup_epochs)
        ],
        milestones=[warmup_epochs]
    )

    scaler = GradScaler()
    writer = SummaryWriter(log_dir=os.path.join(log_dir, "train"))
    
    # Checkpoint-Verzeichnis anlegen
    ckpt_dir = os.path.join(log_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    for epoch in range(total_epochs):
        t0, loss_acc = time.time(), 0
        model.train()
        opt.zero_grad()
        for batch_idx, (x1, x2) in enumerate(loader):
            x1, x2 = x1.to(device), x2.to(device)
            with autocast():
                z1, z2 = model(x1, x2)
                loss = barlow_twins_loss(z1, z2) / accum_steps
            scaler.scale(loss).backward()
            loss_acc += loss.item() * accum_steps

            if (batch_idx + 1) % accum_steps == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()

            step = epoch * len(loader) + batch_idx
            writer.add_scalar("Loss/train_batch", loss.item() * accum_steps, step)
            writer.add_scalar("LR", opt.param_groups[0]['lr'], step)

        sched.step()
        epoch_loss = loss_acc / len(loader) if len(loader) > 0 else 0
        writer.add_scalar("Loss/train", epoch_loss, epoch)
        writer.add_scalar("Time/epoch", time.time() - t0, epoch)

        if (epoch + 1) % 10 == 0 or (epoch + 1) == total_epochs:
            ckpt_path = os.path.join(ckpt_dir, f"barlow_epoch{epoch+1:03d}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': epoch_loss,
            }, ckpt_path)
            print(f"[CHECKPOINT] {ckpt_path} gespeichert")

    writer.close()


In [13]:
# Notebook-Parameter for subset

# data preprocessing and saving in HDF
mode = "prepare"            
root_dir    = "data/ssl4eo-s12-subset/rgb"
h5_path     = "data/ssl4eo-s12-subset_res50/augmented_dataset_subset1.h5"
logdir      = "runs/ssl4eo-s12-subset_res50"
max_samples = 100
epochs      = 100


In [26]:
# Notebook-Parameter for subset

# training
mode = "train"            
root_dir    = "data/ssl4eo-s12-subset/rgb"
h5_path     = "data/ssl4eo-s12-subset_res50/augmented_dataset_subset1.h5"
logdir      = "runs/ssl4eo-s12-subset_res50"
max_samples = 100
epochs= 100

In [39]:
# Notebook-Parameter for total RGB

# data preprocessing and saving in HDFmode

mode = "prepare"            
root_dir    = "data/s2_rgb/0k_251k_uint8_jpeg_tif/rgb"
h5_path     = "data/s2_rgb/augmented_dataset_100000_res50.h5"
logdir      = "runs/barlow_twins_ssl4eo_rgb_100000_res50"
max_samples = 100000
epochs      = 100

In [6]:
# Notebook-Parameter for total RGB

# training

mode = "train"            
root_dir    = "data/s2_rgb/0k_251k_uint8_jpeg_tif/rgb"
h5_path     = "data/s2_rgb/augmented_dataset_100000_res50.h5"
logdir      = "runs/barlow_twins_ssl4eo_rgb_100000_res50"
max_samples = 100000
epochs      = 100

In [4]:
# Ausführung

if mode == "prepare":
    prepare_h5(root_dir, h5_path, max_samples, logdir)
elif mode == "train":
    train(h5_path, logdir, total_epochs=epochs)
else:
    raise ValueError(f"Unbekannter Modus: {mode}")


device: cuda
Anzahl der Samples im Dataset: 100000


NameError: name 'resnet50' is not defined