In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim import AdamW
from copy import deepcopy

class EncoderNetwork(nn.Module):
    """Base encoder network to generate representations."""
    def __init__(self, feature_dim=256):
        super().__init__()
        # Simple ConvNet backbone
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        
        # Projection MLP
        self.projector = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, feature_dim),
        )
        
        # Prediction MLP (only used in student)
        self.predictor = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, feature_dim),
        )
        
    def forward(self, x):
        # Get backbone features
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        
        # Project features
        z = self.projector(x)
        z = F.normalize(z, dim=-1)  # L2 normalize
        
        # Predict (for student)
        p = self.predictor(z)
        p = F.normalize(p, dim=-1)  # L2 normalize
        
        return z, p

class TeacherStudentSSL:
    """Teacher-Student framework for self-supervised learning."""
    def __init__(self, feature_dim=256, ema_decay=0.99):
        self.student = EncoderNetwork(feature_dim)
        # Teacher is a moving average of the student
        self.teacher = deepcopy(self.student)
        self.teacher.predictor = None  # Teacher doesn't need predictor
        
        # Disable gradients for teacher
        for param in self.teacher.parameters():
            param.requires_grad = False
            
        self.ema_decay = ema_decay
        
        # Define augmentations
        self.augment = T.Compose([
            T.RandomResizedCrop(32),
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            T.GaussianBlur(kernel_size=3),
        ])
    
    @torch.no_grad()
    def update_teacher(self):
        """Update teacher weights as EMA of student weights."""
        for teacher_param, student_param in zip(self.teacher.parameters(), 
                                              self.student.parameters()):
            teacher_param.data = (self.ema_decay * teacher_param.data + 
                                (1 - self.ema_decay) * student_param.data)
    
    def training_step(self, images):
        """Perform one training step."""
        # Generate two random augmentations
        view1 = torch.stack([self.augment(img) for img in images])
        view2 = torch.stack([self.augment(img) for img in images])
        
        # Student forward passes
        student_z1, student_p1 = self.student(view1)
        student_z2, student_p2 = self.student(view2)
        
        # Teacher forward passes (no gradients)
        with torch.no_grad():
            teacher_z1, _ = self.teacher(view1)
            teacher_z2, _ = self.teacher(view2)
        
        # Compute loss
        loss1 = F.mse_loss(student_p1, teacher_z2.detach())
        loss2 = F.mse_loss(student_p2, teacher_z1.detach())
        loss = (loss1 + loss2) * 0.5
        
        return loss

def train_ssl(model, train_loader, epochs=100, lr=1e-3):
    """Training loop."""
    optimizer = AdamW(model.student.parameters(), lr=lr)
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (images, _) in enumerate(train_loader):
            # Training step
            loss = model.training_step(images)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update teacher
            model.update_teacher()
            
            total_loss += loss.item()
            
        # Print epoch stats
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

# Example usage:
"""
# Initialize model
ssl_model = TeacherStudentSSL(feature_dim=256, ema_decay=0.99)

# Train
train_ssl(ssl_model, train_loader, epochs=100)

# After training, use teacher for downstream tasks
trained_encoder = ssl_model.teacher
"""

'\n# Initialize model\nssl_model = TeacherStudentSSL(feature_dim=256, ema_decay=0.99)\n\n# Train\ntrain_ssl(ssl_model, train_loader, epochs=100)\n\n# After training, use teacher for downstream tasks\ntrained_encoder = ssl_model.teacher\n'