In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import random
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torchvision.transforms.functional as TF
import torchvision.transforms as T
from dataset import Sentinel2InpaintingDataset


In [2]:
data_root='../s2a.tar/s2a'

# In train.py, use target_size that's divisible by 32:
dataset = Sentinel2InpaintingDataset(
    root_dir=data_root,
    mask_type='random',
    limit_samples= 4000,
    target_size=(256, 256),  # or (256, 256), (384, 384), etc.
    format="satlas"
)

  0%|          | 0/4640 [00:00<?, ?it/s]

 14%|█▍        | 664/4640 [00:02<00:13, 288.89it/s]

Found 4000 samples with all 12 bands





In [3]:
trainLoader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class InfoNCELoss(nn.Module):
    """
    InfoNCE contrastive loss for SASSL.
    Maximizes agreement between student views and teacher view of same image,
    while minimizing agreement with other images in the batch (negatives).
    """
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, teacher_output, student_outputs):
        """
        Args:
            teacher_output: (batch_size, embed_dim) - teacher embeddings
            student_outputs: list of (batch_size, embed_dim) - student embeddings
        
        Returns:
            loss: scalar tensor
        """
        batch_size = teacher_output.shape[0]
        
        # Normalize embeddings (L2 normalization)
        teacher_output = F.normalize(teacher_output, dim=-1, p=2)
        student_outputs = [F.normalize(s, dim=-1, p=2) for s in student_outputs]
        
        total_loss = 0
        
        # Compute loss for each student view
        for student_output in student_outputs:
            # Compute similarity matrix: (batch_size, batch_size)
            # logits[i, j] = cosine similarity between student_i and teacher_j
            logits = torch.matmul(student_output, teacher_output.T) / self.temperature
            
            # Labels: diagonal elements are positives (same image index)
            # Off-diagonal elements are negatives (different images)
            labels = torch.arange(batch_size, device=logits.device)
            
            # Cross-entropy loss
            # This maximizes logits[i, i] and minimizes logits[i, j] for i ≠ j
            loss = F.cross_entropy(logits, labels)
            total_loss += loss
        
        # Average over all student views (4 local + 1 spectral = 5 views)
        return total_loss / len(student_outputs)


In [5]:
import torch
import torch.optim as optim
from tqdm import tqdm
from swinSASSL import SwinSASSL
import time

# Hyperparameters
epochs = 100
learning_rate = 1e-4
momentum_teacher = 0.996
warmup_epochs = 5
weight_decay = 0.04

# Initialize model and loss
model = SwinSASSL(
    random_crop_size=(64, 64),
    drop_probability=0.4,
    swin_in_channels=9
)

criterion = InfoNCELoss(temperature=0.07)

# Optimizer (only student parameters)
optimizer = optim.AdamW(
    model.student.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=epochs,
    eta_min=1e-6
)

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total, trainable = count_parameters(model)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")
print(f"Non-trainable parameters: {total - trainable:,}")
print(f"Percentage trainable: {100 * trainable / total:.2f}%")

# More detailed breakdown
print("\n--- Student Model ---")
student_total, student_trainable = count_parameters(model.student)
print(f"Student total: {student_total:,}")
print(f"Student trainable: {student_trainable:,}")

print("\n--- Teacher Model ---")
teacher_total, teacher_trainable = count_parameters(model.teacher)
print(f"Teacher total: {teacher_total:,}")
print(f"Teacher trainable: {teacher_trainable:,}")

# Training loop
model.student.train()
model.teacher.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"

for epoch in range(epochs):
    epoch_loss = 0.0
    num_batches = 0
    
    pbar = tqdm(trainLoader, desc=f'Epoch {epoch+1}/{epochs}')
    
    for batch in pbar:
        images = batch['c9'].to(device, non_blocking=True)
        
        # Skip if batch size < 2 (need negatives)
        if images.shape[0] < 2:
            continue
        
        # Forward pass (images moved to device inside model)
        teacher_outputs, student_outputs = model(images)
        # print(teacher_outputs.shape)
        # print(student_outputs[0].shape)
        

        # Compute contrastive loss
        loss = criterion(teacher_outputs, student_outputs)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (prevent exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.student.parameters(), max_norm=1.0)
        
        # Update student
        optimizer.step()
        
        # Update teacher with EMA
        # Momentum warmup: gradually increase from 0 to target momentum
        if epoch < warmup_epochs:
            m = momentum_teacher * (epoch / warmup_epochs)
        else:
            m = momentum_teacher
        
        model.update_teacher(momentum=m)
        
        # Track metrics
        epoch_loss += loss.item()
        num_batches += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{epoch_loss/num_batches:.4f}',
            'lr': f'{optimizer.param_groups[0]["lr"]:.6f}',
            'momentum': f'{m:.4f}'
        })
    
    # Calculate average epoch loss
    avg_loss = epoch_loss / num_batches
    
    # Step scheduler
    scheduler.step()
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{epochs} Summary:")
    print(f"  Average Loss: {avg_loss:.4f}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    print("-" * 60)
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 4 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'student_state_dict': model.student.state_dict(),
            'teacher_state_dict': model.teacher.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
        }
        torch.save(checkpoint, f'sassl_checkpoint_epoch_{epoch+1}.pth')
        print(f"✓ Checkpoint saved: sassl_checkpoint_epoch_{epoch+1}.pth\n")

# Save final model
final_checkpoint = {
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'student_state_dict': model.student.state_dict(),
    'teacher_state_dict': model.teacher.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'final_loss': avg_loss,
}
torch.save(final_checkpoint, 'sassl_final_model.pth')
print("\n" + "="*60)
print("Training completed!")
print(f"Final model saved: sassl_final_model.pth")
print("="*60)

Total parameters: 177,476,608
Trainable parameters: 88,738,304
Non-trainable parameters: 88,738,304
Percentage trainable: 50.00%

--- Student Model ---
Student total: 87,943,136
Student trainable: 87,943,136

--- Teacher Model ---
Teacher total: 87,943,136
Teacher trainable: 0


Epoch 1/100: 100%|██████████| 2000/2000 [27:33<00:00,  1.21it/s, loss=0.6911, avg_loss=0.7601, lr=0.000100, momentum=0.0000]



Epoch 1/100 Summary:
  Average Loss: 0.7601
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 2/100: 100%|██████████| 2000/2000 [27:20<00:00,  1.22it/s, loss=0.9797, avg_loss=0.8080, lr=0.000100, momentum=0.1992]



Epoch 2/100 Summary:
  Average Loss: 0.8080
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 3/100: 100%|██████████| 2000/2000 [27:32<00:00,  1.21it/s, loss=2.7423, avg_loss=0.8384, lr=0.000100, momentum=0.3984]



Epoch 3/100 Summary:
  Average Loss: 0.8384
  Learning Rate: 0.000100
------------------------------------------------------------


Epoch 4/100: 100%|██████████| 2000/2000 [27:23<00:00,  1.22it/s, loss=0.3928, avg_loss=0.9189, lr=0.000100, momentum=0.5976]



Epoch 4/100 Summary:
  Average Loss: 0.9189
  Learning Rate: 0.000100
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_4.pth



Epoch 5/100: 100%|██████████| 2000/2000 [27:27<00:00,  1.21it/s, loss=0.6898, avg_loss=0.8256, lr=0.000100, momentum=0.7968]



Epoch 5/100 Summary:
  Average Loss: 0.8256
  Learning Rate: 0.000099
------------------------------------------------------------


Epoch 6/100: 100%|██████████| 2000/2000 [27:24<00:00,  1.22it/s, loss=0.1350, avg_loss=0.6724, lr=0.000099, momentum=0.9960]



Epoch 6/100 Summary:
  Average Loss: 0.6724
  Learning Rate: 0.000099
------------------------------------------------------------


Epoch 7/100: 100%|██████████| 2000/2000 [27:22<00:00,  1.22it/s, loss=0.2968, avg_loss=0.7095, lr=0.000099, momentum=0.9960]



Epoch 7/100 Summary:
  Average Loss: 0.7095
  Learning Rate: 0.000099
------------------------------------------------------------


Epoch 8/100: 100%|██████████| 2000/2000 [27:21<00:00,  1.22it/s, loss=2.6646, avg_loss=0.7596, lr=0.000099, momentum=0.9960]



Epoch 8/100 Summary:
  Average Loss: 0.7596
  Learning Rate: 0.000098
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_8.pth



Epoch 9/100: 100%|██████████| 2000/2000 [27:22<00:00,  1.22it/s, loss=0.0120, avg_loss=0.7073, lr=0.000098, momentum=0.9960]



Epoch 9/100 Summary:
  Average Loss: 0.7073
  Learning Rate: 0.000098
------------------------------------------------------------


Epoch 10/100: 100%|██████████| 2000/2000 [27:28<00:00,  1.21it/s, loss=0.7083, avg_loss=0.6272, lr=0.000098, momentum=0.9960]



Epoch 10/100 Summary:
  Average Loss: 0.6272
  Learning Rate: 0.000098
------------------------------------------------------------


Epoch 11/100: 100%|██████████| 2000/2000 [27:16<00:00,  1.22it/s, loss=0.1377, avg_loss=0.5851, lr=0.000098, momentum=0.9960]



Epoch 11/100 Summary:
  Average Loss: 0.5851
  Learning Rate: 0.000097
------------------------------------------------------------


Epoch 12/100: 100%|██████████| 2000/2000 [27:30<00:00,  1.21it/s, loss=0.6318, avg_loss=0.5898, lr=0.000097, momentum=0.9960]



Epoch 12/100 Summary:
  Average Loss: 0.5898
  Learning Rate: 0.000097
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_12.pth



Epoch 13/100: 100%|██████████| 2000/2000 [27:30<00:00,  1.21it/s, loss=0.7144, avg_loss=0.6637, lr=0.000097, momentum=0.9960]



Epoch 13/100 Summary:
  Average Loss: 0.6637
  Learning Rate: 0.000096
------------------------------------------------------------


Epoch 14/100: 100%|██████████| 2000/2000 [27:29<00:00,  1.21it/s, loss=0.7828, avg_loss=0.6440, lr=0.000096, momentum=0.9960]



Epoch 14/100 Summary:
  Average Loss: 0.6440
  Learning Rate: 0.000095
------------------------------------------------------------


Epoch 15/100: 100%|██████████| 2000/2000 [27:27<00:00,  1.21it/s, loss=0.6932, avg_loss=0.6631, lr=0.000095, momentum=0.9960]



Epoch 15/100 Summary:
  Average Loss: 0.6631
  Learning Rate: 0.000095
------------------------------------------------------------


Epoch 16/100: 100%|██████████| 2000/2000 [27:26<00:00,  1.21it/s, loss=0.5514, avg_loss=0.6888, lr=0.000095, momentum=0.9960]



Epoch 16/100 Summary:
  Average Loss: 0.6888
  Learning Rate: 0.000094
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_16.pth



Epoch 17/100: 100%|██████████| 2000/2000 [27:23<00:00,  1.22it/s, loss=0.2182, avg_loss=0.7680, lr=0.000094, momentum=0.9960]



Epoch 17/100 Summary:
  Average Loss: 0.7680
  Learning Rate: 0.000093
------------------------------------------------------------


Epoch 18/100: 100%|██████████| 2000/2000 [1:26:47<00:00,  2.60s/it, loss=0.6977, avg_loss=0.7647, lr=0.000093, momentum=0.9960]  



Epoch 18/100 Summary:
  Average Loss: 0.7647
  Learning Rate: 0.000092
------------------------------------------------------------


Epoch 19/100: 100%|██████████| 2000/2000 [35:54<00:00,  1.08s/it, loss=0.3059, avg_loss=0.6802, lr=0.000092, momentum=0.9960] 



Epoch 19/100 Summary:
  Average Loss: 0.6802
  Learning Rate: 0.000091
------------------------------------------------------------


Epoch 20/100: 100%|██████████| 2000/2000 [46:11<00:00,  1.39s/it, loss=0.9080, avg_loss=0.6465, lr=0.000091, momentum=0.9960] 



Epoch 20/100 Summary:
  Average Loss: 0.6465
  Learning Rate: 0.000091
------------------------------------------------------------
✓ Checkpoint saved: sassl_checkpoint_epoch_20.pth



Epoch 21/100:   0%|          | 0/2000 [00:06<?, ?it/s]


KeyboardInterrupt: 