In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from copy import deepcopy

class SimpleConvNet(nn.Module):
    """Simple ConvNet for MNIST."""
    def __init__(self):
        super().__init__()
        # Input: 1x28x28
        self.conv1 = nn.Conv2d(1, 32, 3, 1, padding=1)  # Output: 32x28x28
        self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1)  # Output: 64x28x28
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        
        # After max pooling, size will be 64x14x14
        # Calculate flattened size: 64 * 14 * 14 = 12544
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # First conv block
        x = self.conv1(x)  # 32x28x28
        x = F.relu(x)
        x = self.conv2(x)  # 64x28x28
        x = F.relu(x)
        x = F.max_pool2d(x, 2)  # 64x14x14
        x = self.dropout1(x)
        
        # Flatten: 64*14*14 = 12544
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

class MeanTeacher:
    def __init__(self, ema_decay=0.999):
        # Create student and teacher models
        self.student = SimpleConvNet()
        self.teacher = deepcopy(self.student)
        self.ema_decay = ema_decay

        # Freeze teacher parameters
        for param in self.teacher.parameters():
            param.detach_()

    @torch.no_grad()
    def update_teacher(self):
        """Update teacher weights as exponential moving average of student weights."""
        for teacher_param, student_param in zip(self.teacher.parameters(), 
                                              self.student.parameters()):
            teacher_param.data = (self.ema_decay * teacher_param.data + 
                                (1 - self.ema_decay) * student_param.data)

    def consistency_loss(self, student_logits, teacher_logits):
        """Calculate consistency loss between student and teacher predictions."""
        return F.mse_loss(F.softmax(student_logits, dim=1), 
                         F.softmax(teacher_logits, dim=1))

def prepare_mnist_data(num_labeled=1000, batch_size=128):
    """Prepare MNIST datasets with both labeled and unlabeled data."""
    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Load full training set
    full_dataset = MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = MNIST('./data', train=False, transform=transform)
    
    # Create artificial labeled/unlabeled split
    labeled_indices = torch.randperm(len(full_dataset))[:num_labeled]
    unlabeled_indices = torch.randperm(len(full_dataset))[num_labeled:]
    
    labeled_dataset = torch.utils.data.Subset(full_dataset, labeled_indices)
    unlabeled_dataset = torch.utils.data.Subset(full_dataset, unlabeled_indices)
    
    # Create data loaders
    labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size, 
                              shuffle=True, num_workers=2)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, 
                                 shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                           shuffle=False, num_workers=2)
    
    return labeled_loader, unlabeled_loader, test_loader

def train_mean_teacher(
    model,
    labeled_loader,
    unlabeled_loader,
    epochs=50,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    """Training loop for Mean Teacher."""
    print(f"Using device: {device}")
    model.student.to(device)
    model.teacher.to(device)
    
    optimizer = torch.optim.SGD(model.student.parameters(), lr=0.01, momentum=0.9)
    
    # Training loop
    for epoch in range(epochs):
        model.student.train()
        model.teacher.eval()
        
        total_loss = 0
        supervised_loss = 0
        consistency_loss = 0
        correct = 0
        total = 0
        
        # Get an iterator for unlabeled data
        unlabeled_iter = iter(unlabeled_loader)
        
        # Train with both labeled and unlabeled data
        for batch_idx, (labeled_imgs, labels) in enumerate(labeled_loader):
            # Get unlabeled batch
            try:
                unlabeled_imgs, _ = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                unlabeled_imgs, _ = next(unlabeled_iter)
            
            labeled_imgs, labels = labeled_imgs.to(device), labels.to(device)
            unlabeled_imgs = unlabeled_imgs.to(device)
            
            # Forward passes
            student_labeled_logits = model.student(labeled_imgs)
            student_unlabeled_logits = model.student(unlabeled_imgs)
            
            with torch.no_grad():
                teacher_unlabeled_logits = model.teacher(unlabeled_imgs)
            
            # Calculate losses
            sup_loss = F.cross_entropy(student_labeled_logits, labels)
            cons_loss = model.consistency_loss(student_unlabeled_logits, 
                                            teacher_unlabeled_logits)
            
            # Ramp up consistency weight
            consistency_weight = min(epoch / 10, 1.0) * 0.5
            loss = sup_loss + consistency_weight * cons_loss
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update teacher model
            model.update_teacher()
            
            # Track metrics
            total_loss += loss.item()
            supervised_loss += sup_loss.item()
            consistency_loss += cons_loss.item()
            
            pred = student_labeled_logits.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
            total += labels.size(0)
            
            if batch_idx % 100 == 0:
                print(f'Batch [{batch_idx}/{len(labeled_loader)}], Loss: {loss.item():.4f}')
        
        # Print epoch stats
        print(f'\nEpoch {epoch+1}/{epochs}:')
        print(f'Average Loss: {total_loss/len(labeled_loader):.4f}')
        print(f'Supervised Loss: {supervised_loss/len(labeled_loader):.4f}')
        print(f'Consistency Loss: {consistency_loss/len(labeled_loader):.4f}')
        print(f'Accuracy: {100.*correct/total:.2f}%\n')

# Usage example:
if __name__ == "__main__":
    # Prepare data
    labeled_loader, unlabeled_loader, test_loader = prepare_mnist_data(num_labeled=1000)

    # Create and train model
    mean_teacher = MeanTeacher()
    train_mean_teacher(mean_teacher, labeled_loader, unlabeled_loader, epochs=50)

Using device: cpu




Batch [0/8], Loss: 2.2942

Epoch 1/50:
Average Loss: 2.2246
Supervised Loss: 2.2246
Consistency Loss: 0.0003
Accuracy: 21.50%

Batch [0/8], Loss: 1.9979

Epoch 2/50:
Average Loss: 1.7015
Supervised Loss: 1.7011
Consistency Loss: 0.0065
Accuracy: 50.50%

Batch [0/8], Loss: 1.1588

Epoch 3/50:
Average Loss: 0.9490
Supervised Loss: 0.9451
Consistency Loss: 0.0395
Accuracy: 66.80%

Batch [0/8], Loss: 0.7444

Epoch 4/50:
Average Loss: 0.6961
Supervised Loss: 0.6877
Consistency Loss: 0.0564
Accuracy: 75.50%

Batch [0/8], Loss: 0.6984

Epoch 5/50:
Average Loss: 0.5948
Supervised Loss: 0.5831
Consistency Loss: 0.0585
Accuracy: 80.80%

Batch [0/8], Loss: 0.5710

Epoch 6/50:
Average Loss: 0.5516
Supervised Loss: 0.5358
Consistency Loss: 0.0632
Accuracy: 82.40%

Batch [0/8], Loss: 0.5162

Epoch 7/50:
Average Loss: 0.4968
Supervised Loss: 0.4788
Consistency Loss: 0.0600
Accuracy: 84.90%

Batch [0/8], Loss: 0.4293

Epoch 8/50:
Average Loss: 0.4631
Supervised Loss: 0.4411
Consistency Loss: 0.0629
Ac