In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import random
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torchvision.transforms.functional as TF
import torchvision.transforms as T


from dataset import Sentinel2InpaintingDataset


In [2]:
data_root='D:/s2a.tar/s2a'

# In train.py, use target_size that's divisible by 32:
dataset = Sentinel2InpaintingDataset(
    root_dir=data_root,
    mask_type='random',
    augment=True,
    limit_samples= 2000,
    target_size=(256, 256)  # or (256, 256), (384, 384), etc.
)

  8%|▊         | 393/4640 [00:16<02:58, 23.77it/s]

Found 2000 samples with all 12 bands





In [3]:
trainLoader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class InfoNCELoss(nn.Module):
    """
    InfoNCE contrastive loss for SASSL.
    Maximizes agreement between student views and teacher view of same image,
    while minimizing agreement with other images in the batch (negatives).
    """
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, teacher_output, student_outputs):
        """
        Args:
            teacher_output: (batch_size, embed_dim) - teacher embeddings
            student_outputs: list of (batch_size, embed_dim) - student embeddings
        
        Returns:
            loss: scalar tensor
        """
        batch_size = teacher_output.shape[0]
        
        # Normalize embeddings (L2 normalization)
        teacher_output = F.normalize(teacher_output, dim=-1, p=2)
        student_outputs = [F.normalize(s, dim=-1, p=2) for s in student_outputs]
        
        total_loss = 0
        
        # Compute loss for each student view
        for student_output in student_outputs:
            # Compute similarity matrix: (batch_size, batch_size)
            # logits[i, j] = cosine similarity between student_i and teacher_j
            logits = torch.matmul(student_output, teacher_output.T) / self.temperature
            
            # Labels: diagonal elements are positives (same image index)
            # Off-diagonal elements are negatives (different images)
            labels = torch.arange(batch_size, device=logits.device)
            
            # Cross-entropy loss
            # This maximizes logits[i, i] and minimizes logits[i, j] for i ≠ j
            loss = F.cross_entropy(logits, labels)
            total_loss += loss
        
        # Average over all student views (4 local + 1 spectral = 5 views)
        return total_loss / len(student_outputs)


In [None]:
import torch
import torch.optim as optim
from tqdm import tqdm
from SASSL import SASSL

# Hyperparameters
epochs = 100
learning_rate = 1e-4
momentum_teacher = 0.996
warmup_epochs = 10
weight_decay = 0.04

# Initialize model and loss
model = SASSL(
    random_crop_size=(64, 64),
    drop_probability=0.4
)

criterion = InfoNCELoss(temperature=0.07)

# Optimizer (only student parameters)
optimizer = optim.AdamW(
    model.student.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=epochs,
    eta_min=1e-6
)

# Training loop
model.student.train()
model.teacher.eval()

for epoch in range(epochs):
    epoch_loss = 0.0
    num_batches = 0
    
    pbar = tqdm(trainLoader, desc=f'Epoch {epoch+1}/{epochs}')
    
    for batch in pbar:
        images = batch['original']
        
        # Skip if batch size < 2 (need negatives)
        if images.shape[0] < 2:
            continue
        
        # Forward pass (images moved to device inside model)
        teacher_outputs, student_outputs = model(images)
        
        # Compute contrastive loss
        loss = criterion(teacher_outputs, student_outputs)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (prevent exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.student.parameters(), max_norm=1.0)
        
        # Update student
        optimizer.step()
        
        # Update teacher with EMA
        # Momentum warmup: gradually increase from 0 to target momentum
        if epoch < warmup_epochs:
            m = momentum_teacher * (epoch / warmup_epochs)
        else:
            m = momentum_teacher
        
        model.update_teacher(momentum=m)
        
        # Track metrics
        epoch_loss += loss.item()
        num_batches += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{epoch_loss/num_batches:.4f}',
            'lr': f'{optimizer.param_groups[0]["lr"]:.6f}',
            'momentum': f'{m:.4f}'
        })
    
    # Calculate average epoch loss
    avg_loss = epoch_loss / num_batches
    
    # Step scheduler
    scheduler.step()
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{epochs} Summary:")
    print(f"  Average Loss: {avg_loss:.4f}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    print("-" * 60)
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'student_state_dict': model.student.state_dict(),
            'teacher_state_dict': model.teacher.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
        }
        torch.save(checkpoint, f'sassl_checkpoint_epoch_{epoch+1}.pth')
        print(f"✓ Checkpoint saved: sassl_checkpoint_epoch_{epoch+1}.pth\n")

# Save final model
final_checkpoint = {
    'epoch': epochs,
    'student_state_dict': model.student.state_dict(),
    'teacher_state_dict': model.teacher.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'final_loss': avg_loss,
}
torch.save(final_checkpoint, 'sassl_final_model.pth')
print("\n" + "="*60)
print("Training completed!")
print(f"Final model saved: sassl_final_model.pth")
print("="*60)






Epoch 1/100: 100%|██████████| 500/500 [05:04<00:00,  1.64it/s, loss=0.5265, avg_loss=0.9851, lr=0.000100, momentum=0.0000]



Epoch 1/100 Summary:
  Average Loss: 0.9851
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 2/100: 100%|██████████| 500/500 [05:39<00:00,  1.47it/s, loss=0.6161, avg_loss=1.0359, lr=0.000100, momentum=0.0996]



Epoch 2/100 Summary:
  Average Loss: 1.0359
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 3/100: 100%|██████████| 500/500 [04:31<00:00,  1.84it/s, loss=0.0726, avg_loss=1.1671, lr=0.000100, momentum=0.1992]



Epoch 3/100 Summary:
  Average Loss: 1.1671
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 4/100: 100%|██████████| 500/500 [03:40<00:00,  2.26it/s, loss=1.1699, avg_loss=1.0196, lr=0.000100, momentum=0.2988]



Epoch 4/100 Summary:
  Average Loss: 1.0196
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 5/100: 100%|██████████| 500/500 [04:12<00:00,  1.98it/s, loss=0.7360, avg_loss=1.0596, lr=0.000100, momentum=0.3984]



Epoch 5/100 Summary:
  Average Loss: 1.0596
  Learning Rate: 0.000099
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_5.pth



Epoch 6/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=1.0952, avg_loss=0.9484, lr=0.000099, momentum=0.4980]



Epoch 6/100 Summary:
  Average Loss: 0.9484
  Learning Rate: 0.000099
------------------------------------------------------------


Epoch 7/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=1.8369, avg_loss=1.0261, lr=0.000099, momentum=0.5976]



Epoch 7/100 Summary:
  Average Loss: 1.0261
  Learning Rate: 0.000099
------------------------------------------------------------


Epoch 8/100: 100%|██████████| 500/500 [03:39<00:00,  2.27it/s, loss=0.5041, avg_loss=1.1223, lr=0.000099, momentum=0.6972]



Epoch 8/100 Summary:
  Average Loss: 1.1223
  Learning Rate: 0.000098
------------------------------------------------------------


Epoch 9/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.8513, avg_loss=1.0215, lr=0.000098, momentum=0.7968]



Epoch 9/100 Summary:
  Average Loss: 1.0215
  Learning Rate: 0.000098
------------------------------------------------------------


Epoch 10/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=1.7109, avg_loss=0.9365, lr=0.000098, momentum=0.8964]



Epoch 10/100 Summary:
  Average Loss: 0.9365
  Learning Rate: 0.000098
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_10.pth



Epoch 11/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=1.1302, avg_loss=0.8984, lr=0.000098, momentum=0.9960]



Epoch 11/100 Summary:
  Average Loss: 0.8984
  Learning Rate: 0.000097
------------------------------------------------------------


Epoch 12/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.3523, avg_loss=0.8012, lr=0.000097, momentum=0.9960]



Epoch 12/100 Summary:
  Average Loss: 0.8012
  Learning Rate: 0.000097
------------------------------------------------------------


Epoch 13/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.4647, avg_loss=0.8010, lr=0.000097, momentum=0.9960]



Epoch 13/100 Summary:
  Average Loss: 0.8010
  Learning Rate: 0.000096
------------------------------------------------------------


Epoch 14/100: 100%|██████████| 500/500 [03:38<00:00,  2.28it/s, loss=0.1007, avg_loss=0.7648, lr=0.000096, momentum=0.9960]



Epoch 14/100 Summary:
  Average Loss: 0.7648
  Learning Rate: 0.000095
------------------------------------------------------------


Epoch 15/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.2018, avg_loss=0.7854, lr=0.000095, momentum=0.9960]



Epoch 15/100 Summary:
  Average Loss: 0.7854
  Learning Rate: 0.000095
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_15.pth



Epoch 16/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.4587, avg_loss=0.7835, lr=0.000095, momentum=0.9960]



Epoch 16/100 Summary:
  Average Loss: 0.7835
  Learning Rate: 0.000094
------------------------------------------------------------


Epoch 17/100: 100%|██████████| 500/500 [03:39<00:00,  2.27it/s, loss=0.5400, avg_loss=0.8400, lr=0.000094, momentum=0.9960]



Epoch 17/100 Summary:
  Average Loss: 0.8400
  Learning Rate: 0.000093
------------------------------------------------------------


Epoch 18/100: 100%|██████████| 500/500 [03:38<00:00,  2.28it/s, loss=0.4434, avg_loss=0.7793, lr=0.000093, momentum=0.9960]



Epoch 18/100 Summary:
  Average Loss: 0.7793
  Learning Rate: 0.000092
------------------------------------------------------------


Epoch 19/100: 100%|██████████| 500/500 [03:37<00:00,  2.29it/s, loss=1.9335, avg_loss=0.8229, lr=0.000092, momentum=0.9960]



Epoch 19/100 Summary:
  Average Loss: 0.8229
  Learning Rate: 0.000091
------------------------------------------------------------


Epoch 20/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.5864, avg_loss=0.7317, lr=0.000091, momentum=0.9960]



Epoch 20/100 Summary:
  Average Loss: 0.7317
  Learning Rate: 0.000091
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_20.pth



Epoch 21/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.0654, avg_loss=0.7294, lr=0.000091, momentum=0.9960]



Epoch 21/100 Summary:
  Average Loss: 0.7294
  Learning Rate: 0.000090
------------------------------------------------------------


Epoch 22/100: 100%|██████████| 500/500 [05:51<00:00,  1.42it/s, loss=1.0319, avg_loss=0.7812, lr=0.000090, momentum=0.9960]



Epoch 22/100 Summary:
  Average Loss: 0.7812
  Learning Rate: 0.000089
------------------------------------------------------------


Epoch 23/100: 100%|██████████| 500/500 [05:52<00:00,  1.42it/s, loss=0.6178, avg_loss=0.7841, lr=0.000089, momentum=0.9960]



Epoch 23/100 Summary:
  Average Loss: 0.7841
  Learning Rate: 0.000088
------------------------------------------------------------


Epoch 24/100: 100%|██████████| 500/500 [05:53<00:00,  1.42it/s, loss=0.1048, avg_loss=0.8155, lr=0.000088, momentum=0.9960]



Epoch 24/100 Summary:
  Average Loss: 0.8155
  Learning Rate: 0.000087
------------------------------------------------------------


Epoch 25/100: 100%|██████████| 500/500 [05:52<00:00,  1.42it/s, loss=0.8889, avg_loss=0.7003, lr=0.000087, momentum=0.9960]



Epoch 25/100 Summary:
  Average Loss: 0.7003
  Learning Rate: 0.000086
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_25.pth



Epoch 26/100: 100%|██████████| 500/500 [04:24<00:00,  1.89it/s, loss=0.9422, avg_loss=0.7375, lr=0.000086, momentum=0.9960]



Epoch 26/100 Summary:
  Average Loss: 0.7375
  Learning Rate: 0.000084
------------------------------------------------------------


Epoch 27/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.9932, avg_loss=0.7738, lr=0.000084, momentum=0.9960]



Epoch 27/100 Summary:
  Average Loss: 0.7738
  Learning Rate: 0.000083
------------------------------------------------------------


Epoch 28/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.8230, avg_loss=0.7333, lr=0.000083, momentum=0.9960]



Epoch 28/100 Summary:
  Average Loss: 0.7333
  Learning Rate: 0.000082
------------------------------------------------------------


Epoch 29/100: 100%|██████████| 500/500 [03:39<00:00,  2.27it/s, loss=1.2308, avg_loss=0.7861, lr=0.000082, momentum=0.9960]



Epoch 29/100 Summary:
  Average Loss: 0.7861
  Learning Rate: 0.000081
------------------------------------------------------------


Epoch 30/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.4546, avg_loss=0.7672, lr=0.000081, momentum=0.9960]



Epoch 30/100 Summary:
  Average Loss: 0.7672
  Learning Rate: 0.000080
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_30.pth



Epoch 31/100: 100%|██████████| 500/500 [03:46<00:00,  2.21it/s, loss=0.7770, avg_loss=0.7495, lr=0.000080, momentum=0.9960]



Epoch 31/100 Summary:
  Average Loss: 0.7495
  Learning Rate: 0.000078
------------------------------------------------------------


Epoch 32/100: 100%|██████████| 500/500 [03:39<00:00,  2.27it/s, loss=0.9932, avg_loss=0.8536, lr=0.000078, momentum=0.9960]



Epoch 32/100 Summary:
  Average Loss: 0.8536
  Learning Rate: 0.000077
------------------------------------------------------------


Epoch 33/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.1523, avg_loss=0.8868, lr=0.000077, momentum=0.9960]



Epoch 33/100 Summary:
  Average Loss: 0.8868
  Learning Rate: 0.000076
------------------------------------------------------------


Epoch 34/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.3037, avg_loss=0.8427, lr=0.000076, momentum=0.9960]



Epoch 34/100 Summary:
  Average Loss: 0.8427
  Learning Rate: 0.000074
------------------------------------------------------------


Epoch 35/100: 100%|██████████| 500/500 [03:41<00:00,  2.26it/s, loss=0.3147, avg_loss=0.7772, lr=0.000074, momentum=0.9960]



Epoch 35/100 Summary:
  Average Loss: 0.7772
  Learning Rate: 0.000073
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_35.pth



Epoch 36/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.4771, avg_loss=0.8184, lr=0.000073, momentum=0.9960]



Epoch 36/100 Summary:
  Average Loss: 0.8184
  Learning Rate: 0.000072
------------------------------------------------------------


Epoch 37/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=1.2174, avg_loss=0.8192, lr=0.000072, momentum=0.9960]



Epoch 37/100 Summary:
  Average Loss: 0.8192
  Learning Rate: 0.000070
------------------------------------------------------------


Epoch 38/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.4849, avg_loss=0.8289, lr=0.000070, momentum=0.9960]



Epoch 38/100 Summary:
  Average Loss: 0.8289
  Learning Rate: 0.000069
------------------------------------------------------------


Epoch 39/100: 100%|██████████| 500/500 [03:38<00:00,  2.28it/s, loss=0.7620, avg_loss=0.8461, lr=0.000069, momentum=0.9960]



Epoch 39/100 Summary:
  Average Loss: 0.8461
  Learning Rate: 0.000067
------------------------------------------------------------


Epoch 40/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.8175, avg_loss=0.7954, lr=0.000067, momentum=0.9960]



Epoch 40/100 Summary:
  Average Loss: 0.7954
  Learning Rate: 0.000066
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_40.pth



Epoch 41/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=0.9935, avg_loss=0.8094, lr=0.000066, momentum=0.9960]



Epoch 41/100 Summary:
  Average Loss: 0.8094
  Learning Rate: 0.000064
------------------------------------------------------------


Epoch 42/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=1.1372, avg_loss=0.8302, lr=0.000064, momentum=0.9960]



Epoch 42/100 Summary:
  Average Loss: 0.8302
  Learning Rate: 0.000063
------------------------------------------------------------


Epoch 43/100: 100%|██████████| 500/500 [03:40<00:00,  2.26it/s, loss=0.5902, avg_loss=0.7836, lr=0.000063, momentum=0.9960]



Epoch 43/100 Summary:
  Average Loss: 0.7836
  Learning Rate: 0.000061
------------------------------------------------------------


Epoch 44/100: 100%|██████████| 500/500 [03:39<00:00,  2.28it/s, loss=2.6511, avg_loss=0.7600, lr=0.000061, momentum=0.9960]



Epoch 44/100 Summary:
  Average Loss: 0.7600
  Learning Rate: 0.000060
------------------------------------------------------------


Epoch 45/100: 100%|██████████| 500/500 [03:38<00:00,  2.28it/s, loss=1.3035, avg_loss=0.8132, lr=0.000060, momentum=0.9960]



Epoch 45/100 Summary:
  Average Loss: 0.8132
  Learning Rate: 0.000058
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_45.pth



Epoch 46/100: 100%|██████████| 500/500 [03:41<00:00,  2.26it/s, loss=1.2622, avg_loss=0.8246, lr=0.000058, momentum=0.9960]



Epoch 46/100 Summary:
  Average Loss: 0.8246
  Learning Rate: 0.000057
------------------------------------------------------------


Epoch 47/100: 100%|██████████| 500/500 [03:41<00:00,  2.26it/s, loss=1.3815, avg_loss=0.8609, lr=0.000057, momentum=0.9960]



Epoch 47/100 Summary:
  Average Loss: 0.8609
  Learning Rate: 0.000055
------------------------------------------------------------


Epoch 48/100: 100%|██████████| 500/500 [03:38<00:00,  2.29it/s, loss=0.7658, avg_loss=0.8175, lr=0.000055, momentum=0.9960]



Epoch 48/100 Summary:
  Average Loss: 0.8175
  Learning Rate: 0.000054
------------------------------------------------------------


Epoch 49/100: 100%|██████████| 500/500 [03:38<00:00,  2.28it/s, loss=0.6416, avg_loss=0.8329, lr=0.000054, momentum=0.9960]



Epoch 49/100 Summary:
  Average Loss: 0.8329
  Learning Rate: 0.000052
------------------------------------------------------------


Epoch 50/100: 100%|██████████| 500/500 [03:39<00:00,  2.27it/s, loss=2.2180, avg_loss=0.8024, lr=0.000052, momentum=0.9960]



Epoch 50/100 Summary:
  Average Loss: 0.8024
  Learning Rate: 0.000051
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_50.pth



Epoch 51/100:  30%|███       | 151/500 [01:56<04:28,  1.30it/s, loss=1.6141, avg_loss=0.8824, lr=0.000051, momentum=0.9960]


KeyboardInterrupt: 

: 