In [None]:
# =============================================================================
# 1. IMPORT LIBRARIES
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from tqdm import tqdm
import math

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

In [None]:
# =============================================================================
# 2. SETUP AND CONFIGURATION
# =============================================================================
# Hyperparameters
NUM_EPOCHS = 100 # SSL requires more epochs than supervised learning
BATCH_SIZE = 1024 # Crucial for contrastive learning to have a large batch size
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
TEMPERATURE = 0.5 # Temperature for the NT-Xent loss
PROJECTION_DIM = 128 # Dimension of the projected features

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Note: A BATCH_SIZE of 256 might be tight on a 4GB GPU. If you get a CUDA out-of-memory
# error, try reducing it to 128.

In [None]:
# =============================================================================
# 3. DATA AUGMENTATION AND LOADING
# =============================================================================
# Define the strong augmentations for SimCLR
# We apply these transforms twice to get two correlated views of the same image.
class SimCLRAugmentation:
    def __init__(self, size):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        # Return two different augmented views of the same image
        return self.transform(x), self.transform(x)

# Load the EuroSAT dataset and apply the augmentations
# The dataset will be downloaded to a 'Data' directory.
train_dataset = torchvision.datasets.EuroSAT(
    root='./Data',
    download=True,
    transform=SimCLRAugmentation(size=64) # EuroSAT images are 64x64
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True # Drop the last incomplete batch
)

In [None]:
# =============================================================================
# 4. MODEL ARCHITECTURE (ENCODER + PROJECTOR)
# =============================================================================
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super().__init__()
        self.encoder = base_encoder
        self.encoder.fc = nn.Identity() # Replace the classifier layer

        # Projector head
        self.projector = nn.Sequential(
            nn.Linear(512, 512, bias=False), # ResNet-18 has 512 output features
            nn.ReLU(),
            nn.Linear(512, projection_dim, bias=False)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return h, z

# Initialize the model
base_encoder = resnet18()
model = SimCLR(base_encoder, projection_dim=PROJECTION_DIM).to(device)



In [None]:
# =============================================================================
# 5. LOSS FUNCTION AND OPTIMIZER
# =============================================================================
# NT-Xent Loss Function
def nt_xent_loss(z1, z2, temperature):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    # Concatenate the projections for calculating similarity matrix
    z = torch.cat([z1, z2], dim=0)
    
    # Calculate cosine similarity
    sim_matrix = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
    
    # Get the positive samples (the other view of the same image)
    # The matrix is 2N x 2N, positives are at (i, i+N) and (i+N, i)
    sim_ij = torch.diag(sim_matrix, BATCH_SIZE)
    sim_ji = torch.diag(sim_matrix, -BATCH_SIZE)
    positives = torch.cat([sim_ij, sim_ji], dim=0)
    
    # Mask to remove self-similarity
    mask = (~torch.eye(2 * BATCH_SIZE, 2 * BATCH_SIZE, dtype=bool)).float().to(device)
    
    # Denominator: sum of similarities with all other samples
    numerator = torch.exp(positives / temperature)
    denominator = mask * torch.exp(sim_matrix / temperature)
    
    loss = -torch.log(numerator / torch.sum(denominator, dim=1))
    return torch.mean(loss)

# LARS Optimizer is often recommended for SimCLR, but AdamW works well too.
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*NUM_EPOCHS, eta_min=0, last_epoch=-1)


In [None]:
# =============================================================================
# 6. PRE-TRAINING LOOP
# =============================================================================
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for (view1, view2), _ in pbar: # We don't need the labels here
        view1, view2 = view1.to(device), view2.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass through the model
        _, z1 = model(view1) # h1 is the representation, z1 is the projection
        _, z2 = model(view2)
        
        # Calculate loss
        loss = nt_xent_loss(z1, z2, TEMPERATURE)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Average Loss: {avg_loss:.4f}")

print("Finished Pre-training!")

In [None]:
# =============================================================================
# 7. SAVE THE ENCODER (BACKBONE)
# =============================================================================
# After pre-training, we only need the encoder part.
torch.save(model.encoder.state_dict(), 'simclr_encoder_eurosat.pth')
print("Encoder saved to simclr_encoder_eurosat.pth")

In [None]:
# =============================================================================
# 1. IMPORT LIBRARIES
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from tqdm import tqdm

# =============================================================================
# 2. SETUP AND CONFIGURATION
# =============================================================================
# Hyperparameters for linear evaluation
NUM_EPOCHS = 50
BATCH_SIZE = 128
LEARNING_RATE = 0.01

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# =============================================================================
# 3. DATA LOADING (Supervised)
# =============================================================================
# Use standard, simpler augmentations for evaluation
eval_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

full_dataset = torchvision.datasets.EuroSAT(root='./data', download=True, transform=eval_transforms)
class_names = full_dataset.classes
NUM_CLASSES = len(class_names)
print(f"Number of classes: {NUM_CLASSES}")

# Split dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# =============================================================================
# 4. MODEL FOR LINEAR EVALUATION
# =============================================================================
# Load the pre-trained encoder
encoder = resnet18()
encoder.fc = nn.Identity() # We don't need the original classifier
encoder.load_state_dict(torch.load('simclr_encoder_eurosat.pth'))
print("Pre-trained encoder loaded successfully.")

# Freeze all the parameters in the encoder
for param in encoder.parameters():
    param.requires_grad = False

# Create the full model with a new linear classifier
class LinearClassifier(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        # The new classifier layer is the only part that will be trained
        self.classifier = nn.Linear(512, num_classes) # ResNet-18 output is 512

    def forward(self, x):
        features = self.encoder(x)
        return self.classifier(features)

model = LinearClassifier(encoder, NUM_CLASSES).to(device)

# =============================================================================
# 5. LOSS, OPTIMIZER, AND TRAINING
# =============================================================================
criterion = nn.CrossEntropyLoss()
# We only pass the parameters of the classifier to the optimizer
optimizer = optim.Adam(model.classifier.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    # Training
    model.train()
    # The encoder is in eval mode to disable batch norm updates
    model.encoder.eval()
    
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {running_loss/len(train_loader):.4f}, Val Accuracy: {accuracy:.2f}%")

print("Finished Linear Evaluation.")