#### Continual Learning Model - Dynamic Layer Expansion - Episodic Replay with Random Sampling - Memory Partitioning & Data Augmentation - Class Incremental - MNIST

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import random
from collections import defaultdict
import torch.nn.functional as F
import copy
import time
import psutil

In [2]:
start_time = time.time()
process = psutil.Process()
start_memory = process.memory_info().rss / 1024 ** 2

In [3]:
device = torch.device('mps')
print(f"Using device: {device}")

Using device: mps


In [4]:
# Dataset preparation function
def prepare_dataset():
    # Load MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST(
        root='/Users/guptaabhinav/Documents/All Documentation/Education/Postgrad - USC - Coursework/USC Semester-3/Deep Learning/Project/data', train=True, download=False, transform=transform)
    test_dataset = torchvision.datasets.MNIST(
        root='/Users/guptaabhinav/Documents/All Documentation/Education/Postgrad - USC - Coursework/USC Semester-3/Deep Learning/Project/data', train=False, download=False, transform=transform)
    
    # Split dataset into 5 tasks (0-1, 2-3, 4-5, 6-7, 8-9)
    tasks = []
    for i in range(0, 10, 2):
        train_idx = torch.where((train_dataset.targets == i) | 
                              (train_dataset.targets == i+1))[0]
        test_idx = torch.where((test_dataset.targets == i) | 
                              (test_dataset.targets == i+1))[0]
        
        train_subset = Subset(train_dataset, train_idx)
        test_subset = Subset(test_dataset, test_idx)
        tasks.append((train_subset, test_subset))
    
    return tasks

In [5]:
class DynamicBlock(nn.Module):
    def __init__(self, in_channels, initial_channels, kernel_size=3, padding=1, expansion_factor=0.5):
        super().__init__()
        self.in_channels = in_channels
        self.current_channels = initial_channels
        self.expansion_factor = expansion_factor
        self.kernel_size = kernel_size
        self.padding = padding
        
        # Main convolution layer
        self.conv = nn.Conv2d(in_channels, initial_channels, kernel_size=kernel_size, padding=padding)
        self.bn = nn.BatchNorm2d(initial_channels)
        self.expansion_layers = nn.ModuleList()
        
    def expand(self, task_id):
        """Add new channels for new task"""
        expansion_size = int(self.current_channels * self.expansion_factor)
        new_conv = nn.Conv2d(
            self.in_channels, 
            expansion_size, 
            kernel_size=self.kernel_size, 
            padding=self.padding
        ).to(self.conv.weight.device)
        new_bn = nn.BatchNorm2d(expansion_size).to(self.conv.weight.device)
        
        # Initialize with similar statistics
        nn.init.kaiming_normal_(new_conv.weight)
        with torch.no_grad():
            new_conv.weight.data *= 0.1
        
        self.expansion_layers.append((new_conv, new_bn, task_id))
        self.current_channels += expansion_size
        return self.current_channels
        
    def forward(self, x, task_id=None):
        main_out = F.relu(self.bn(self.conv(x)))
        
        expansion_outputs = []
        for conv, bn, layer_task_id in self.expansion_layers:
            if task_id is None or layer_task_id <= task_id:
                exp_out = F.relu(bn(conv(x)))
                expansion_outputs.append(exp_out)
        
        if expansion_outputs:
            return torch.cat([main_out] + expansion_outputs, dim=1)
        return main_out

In [6]:
class DynamicIncrementalCNN(nn.Module):
    def __init__(self, num_tasks=5, classes_per_task=2, expansion_threshold=0.5):
        super().__init__()
        self.num_tasks = num_tasks
        self.classes_per_task = classes_per_task
        self.total_classes = num_tasks * classes_per_task
        self.expansion_threshold = expansion_threshold
        self.task_performance = defaultdict(list)
        
        # Dynamic feature extractor matching your architecture
        self.features = nn.ModuleList([
            DynamicBlock(1, 32),
            nn.MaxPool2d(2, 2),
            DynamicBlock(32, 64),
            nn.MaxPool2d(2, 2)
        ])
        
        # Dynamic classifier
        self.current_hidden_size = 512
        self.classifier = nn.ModuleList([
            nn.Linear(64 * 7 * 7, self.current_hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.current_hidden_size, self.total_classes)
        ])
        
    def should_expand(self, task_id, current_performance):
        """Determine if network should expand based on performance"""
        if not self.task_performance[task_id]:
            return False
        
        avg_prev_performance = np.mean(self.task_performance[task_id][-5:])
        return current_performance < avg_prev_performance * self.expansion_threshold
    
    def expand_network(self, task_id):
        """Expand network capacity for new task"""
        # Expand convolutional blocks
        new_channels = 0
        for layer in self.features:
            if isinstance(layer, DynamicBlock):
                new_channels = layer.expand(task_id)
        
        # Update classifier input size if needed
        if new_channels:
            old_fc = self.classifier[0]
            new_in_features = new_channels * 7 * 7
            new_hidden_size = int(self.current_hidden_size * 1.2)  # Grow hidden layer
            
            # Create new classifier layers
            new_fc1 = nn.Linear(new_in_features, new_hidden_size).to(old_fc.weight.device)
            new_fc2 = nn.Linear(new_hidden_size, self.total_classes).to(self.classifier[3].weight.device)
            
            # Initialize and copy weights
            with torch.no_grad():
                # Copy weights for the first layer
                new_fc1.weight.data[:old_fc.out_features, :old_fc.in_features] = old_fc.weight.data
                new_fc1.bias.data[:old_fc.out_features] = old_fc.bias.data
                
                # Copy weights for the output layer
                new_fc2.weight.data[:, :self.classifier[3].in_features] = self.classifier[3].weight.data
                new_fc2.bias.data = self.classifier[3].bias.data
            
            # Update classifier
            self.classifier[0] = new_fc1
            self.classifier[3] = new_fc2
            self.current_hidden_size = new_hidden_size
    
    def forward(self, x, task_id=None):
        # Forward through dynamic feature extractor
        for layer in self.features:
            if isinstance(layer, DynamicBlock):
                x = layer(x, task_id)
            else:
                x = layer(x)
        
        # Flatten and forward through classifier
        x = x.view(x.size(0), -1)
        x = self.classifier[0](x)
        x = self.classifier[1](x)
        x = self.classifier[2](x)
        x = self.classifier[3](x)
        
        return x
    
    def update_task_performance(self, task_id, performance):
        """Track task performance for expansion decisions"""
        self.task_performance[task_id].append(performance)
        if len(self.task_performance[task_id]) > 10:
            self.task_performance[task_id].pop(0)

In [7]:
class AugmentedMemoryItem:
    def __init__(self, image, label, task_id):
        self.image = image
        self.label = label
        self.task_id = task_id

class HybridEpisodicBuffer:
    def __init__(self, samples_per_class=50, n_classes=10, device='mps'):
        self.samples_per_class = samples_per_class
        self.n_classes = n_classes
        self.device = device
        self.memory = defaultdict(list)
        
        # Define augmentation transforms
        self.augmentation_transforms = transforms.Compose([
            transforms.RandomAffine(
                degrees=20,
                translate=(0.15, 0.15),
                scale=(0.85, 1.15),
                fill=0
            ),
            transforms.RandomRotation(
                degrees=20,
                fill=0
            ),
            transforms.RandomPerspective(
                distortion_scale=0.1,
                p=0.3,
                fill=0
            ),
            transforms.GaussianBlur(
                kernel_size=3,
                sigma=(0.1, 0.2)
            ),
        ])
    
    def add_sample(self, image, label, task_id):
        """Add a new sample to the memory buffer"""
        if len(self.memory[label.item()]) >= self.samples_per_class:
            # Remove oldest sample if buffer is full for this class
            self.memory[label.item()].pop(0)
        
        # Store the sample
        self.memory[label.item()].append(
            AugmentedMemoryItem(image.cpu(), label.cpu(), task_id)
        )
    
    def get_memory_samples(self, n_samples=32, augment=True):
        """Get random samples from memory with augmentation"""
        if len(self.memory) == 0:
            return None, None, None
        
        # Randomly select classes that are in memory
        available_classes = list(self.memory.keys())
        samples_per_class = n_samples // len(available_classes)
        remaining_samples = n_samples % len(available_classes)
        
        memory_images = []
        memory_labels = []
        memory_task_ids = []
        
        # Distribute samples across classes
        for class_idx in available_classes:
            class_samples = random.sample(
                self.memory[class_idx],
                min(samples_per_class + (1 if remaining_samples > 0 else 0),
                    len(self.memory[class_idx]))
            )
            if remaining_samples > 0:
                remaining_samples -= 1
            
            for sample in class_samples:
                image = sample.image
                if augment:
                    # Apply augmentation with 50% probability
                    if random.random() < 0.5:
                        image = self.augment_sample(image)
                
                memory_images.append(image)
                memory_labels.append(sample.label)
                memory_task_ids.append(sample.task_id)
        
        # Convert to tensors
        memory_images = torch.stack(memory_images).to(self.device)
        memory_labels = torch.tensor(memory_labels).to(self.device)
        memory_task_ids = torch.tensor(memory_task_ids).to(self.device)
        
        return memory_images, memory_labels, memory_task_ids
    
    def augment_sample(self, image):
        """Apply augmentation to a single image"""
        # Convert to PIL image for transforms
        if isinstance(image, torch.Tensor):
            image = transforms.ToPILImage()(image.squeeze())
        augmented = self.augmentation_transforms(image)
        # Convert back to tensor
        return transforms.ToTensor()(augmented)

class DynamicContinualLearningModel(DynamicIncrementalCNN):
    def __init__(self, num_tasks=5, classes_per_task=2, memory_samples_per_class=50):
        super().__init__(num_tasks, classes_per_task)
        self.episodic_memory = HybridEpisodicBuffer(
            samples_per_class=memory_samples_per_class,
            n_classes=num_tasks * classes_per_task
        )

In [8]:
def train_task_with_dynamic_expansion(model, train_loader, task_id, device, 
                                    epochs=7, replay_batch_size=32, augment_ratio=2):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            task_targets = targets % 2 + (task_id * 2)
            
            # Generate augmented samples
            augmented_inputs = []
            augmented_targets = []
            
            for _ in range(augment_ratio - 1):
                aug_inputs = torch.stack([
                    model.episodic_memory.augment_sample(img) 
                    for img in inputs
                ]).to(device)
                augmented_inputs.append(aug_inputs)
                augmented_targets.append(task_targets)
            
            # Get replay samples
            memory_inputs, memory_labels, _ = model.episodic_memory.get_memory_samples(
                n_samples=replay_batch_size,
                augment=True
            )
            
            # Combine all samples
            batch_inputs = [inputs] + augmented_inputs
            batch_targets = [task_targets] + augmented_targets
            
            if memory_inputs is not None:
                batch_inputs.append(memory_inputs)
                batch_targets.append(memory_labels)
            
            combined_inputs = torch.cat(batch_inputs)
            combined_targets = torch.cat(batch_targets)
            
            # Training step
            optimizer.zero_grad()
            outputs = model(combined_inputs, task_id)
            loss = criterion(outputs, combined_targets)
            loss.backward()
            optimizer.step()
            
            # Store samples in memory
            for img, lbl in zip(inputs, targets):
                model.episodic_memory.add_sample(img, lbl, task_id)
            
            running_loss += loss.item()
            
            # Calculate accuracy
            task_outputs = outputs[:len(inputs)]
            _, predicted = task_outputs.max(1)
            total += targets.size(0)
            correct += (predicted == task_targets).sum().item()
        
        epoch_acc = 100 * correct / total
        model.update_task_performance(task_id, epoch_acc)
        
        # Check if network should expand
        if model.should_expand(task_id, epoch_acc):
            print(f"\nExpanding network for task {task_id + 1} at epoch {epoch + 1}")
            model.expand_network(task_id)
        
        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}, '
              f'Accuracy: {epoch_acc:.2f}%')

In [9]:
tasks = prepare_dataset()

# Initialize the model
model = DynamicContinualLearningModel(
    num_tasks=5, 
    classes_per_task=2, 
    memory_samples_per_class=100
).to(device)

# Train with dynamic expansion
for task_id, (train_subset, _) in enumerate(tasks):
    print(f'\nTraining Task {task_id + 1}')
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    train_task_with_dynamic_expansion(model, train_loader, task_id, device)
    # train_task_with_enhanced_expansion(model, train_loader, task_id, device)


Training Task 1
Epoch 1, Loss: 0.096, Accuracy: 99.21%
Epoch 2, Loss: 0.029, Accuracy: 99.83%
Epoch 3, Loss: 0.022, Accuracy: 99.92%
Epoch 4, Loss: 0.017, Accuracy: 99.94%
Epoch 5, Loss: 0.017, Accuracy: 99.91%
Epoch 6, Loss: 0.018, Accuracy: 99.90%
Epoch 7, Loss: 0.012, Accuracy: 99.99%

Training Task 2


KeyboardInterrupt: 

In [None]:
def evaluate_task_agnostic(model, tasks, device, verbose=True):
    """
    Evaluate model performance without any knowledge of task boundaries or class associations.
    """
    model.eval()
    all_predictions = []
    all_targets = []
    task_predictions = defaultdict(list)
    task_targets = defaultdict(list)
    
    with torch.no_grad():
        for task_id, (_, test_subset) in enumerate(tasks):
            test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)
            
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                
                # Store predictions and targets for both overall and per-task analysis
                predictions = predicted.cpu().numpy()
                targets_np = targets.cpu().numpy()
                
                all_predictions.extend(predictions)
                all_targets.extend(targets_np)
                task_predictions[task_id].extend(predictions)
                task_targets[task_id].extend(targets_np)
    
    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)
    
    # Calculate metrics
    metrics = np.mean(all_predictions == all_targets) * 100
    
    print("\nTask-Agnostic Evaluation Results")
    print("=" * 50)
    print(f"Overall Accuracy: {metrics:.2f}%")

evaluate_task_agnostic(model, tasks, device)


Task-Agnostic Evaluation Results
Overall Accuracy: 95.40%


In [None]:
end_time = time.time()
print(f"Training Time: {end_time - start_time:.2f} seconds")
end_memory = process.memory_info().rss / 1024 ** 2  # Memory in MB
print(f"Memory Usage: {end_memory - start_memory:.2f} MB")