# Subclassing Models in Deep Learning

## Introduction

Subclassing is an advanced and flexible approach to building neural networks. Unlike sequential models which are limited to linear layer stacking, subclassing allows for creating custom model architectures with complex topologies, branching paths, and sophisticated control flow.

## Mathematical Foundation

Unlike sequential models where data flows linearly, subclassed models can implement arbitrary functions $f_\theta: X \rightarrow Y$ mapping inputs to outputs.

For a model with multiple branches and paths, we can represent the computation as a directed graph where each node is a layer or operation:

- For a model with branching paths, the output at a given node can be expressed as:
$$h_i = f_i\left(\{h_j : j \in \text{parents}(i)\}\right)$$

- For layers with multiple inputs (e.g., concatenation or addition), we might have:
$$h_i = g_i\left(h_{j_1}, h_{j_2}, \ldots, h_{j_k}\right)$$

- For skip connections (as in ResNets), we could have:
$$h_i = \sigma(W_i h_{i-1} + h_{i-2} + b_i)$$

This flexibility allows for implementing complex architectures like ResNets, Inception models, or custom research architectures.

## Visual Representation

Subclassing allows for complex model architectures with branching paths and skip connections:

```mermaid
flowchart TB
    Input(("Input")) --> Conv1["Conv1"]
    Conv1 --> Conv2A["Conv2A"]
    Conv1 --> Conv2B["Conv2B"]
    Conv2A --> Conv3A["Conv3A"]
    Conv2B --> Conv3B["Conv3B"]
    Conv1 --> |"Skip Connection"| Add
    Conv3A --> Add{{"+"}}
    Conv3B --> Add
    Add --> Output(("Output"))
    style Input fill:#f5f5f5,stroke:#333,stroke-width:1px,color:black
    style Output fill:#f5f5f5,stroke:#333,stroke-width:1px,color:black
    style Add fill:#ffcc99,stroke:#333,stroke-width:1px,color:black
    style Conv1 fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
    style Conv2A fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
    style Conv2B fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
    style Conv3A fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
    style Conv3B fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
```

A more detailed diagram showing a residual block in a subclassed model:

```mermaid
flowchart LR
    Input(("x")) --> Conv1["Conv2D"] --> BN1["BatchNorm"] --> ReLU1["ReLU"]
    ReLU1 --> Conv2["Conv2D"] --> BN2["BatchNorm"]
    Input --> |"Identity"| Add{{"+"}}
    BN2 --> Add --> ReLU2["ReLU"] --> Output(("output"))
    style Input fill:#f5f5f5,stroke:#333,stroke-width:1px,color:black
    style Output fill:#f5f5f5,stroke:#333,stroke-width:1px,color:black
    style Conv1 fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
    style Conv2 fill:#bbdefb,stroke:#333,stroke-width:1px,color:black
    style BN1 fill:#dcedc8,stroke:#333,stroke-width:1px,color:black
    style BN2 fill:#dcedc8,stroke:#333,stroke-width:1px,color:black
    style ReLU1 fill:#ffe0b2,stroke:#333,stroke-width:1px,color:black
    style ReLU2 fill:#ffe0b2,stroke:#333,stroke-width:1px,color:black
    style Add fill:#ffcc99,stroke:#333,stroke-width:1px,color:black
```

In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

## Implementation with PyTorch

PyTorch's object-oriented design makes subclassing natural through the extension of the `nn.Module` class:

In [61]:
class BasicNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
        super(BasicNet, self).__init__()
        
        # Define layers as class attributes
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Keep configuration for later reference
        self.config = {
            "input_size": input_size,
            "hidden_size": hidden_size,
            "output_size": output_size,
            "dropout_rate": dropout_rate
        }
    
    def forward(self, x):
        # Define the forward pass with custom logic
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    # Custom methods can be added
    def get_config(self):
        return self.config

# Create the model
model = BasicNet(input_size=784, hidden_size=128, output_size=10)
print(model)

BasicNet(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)


## Advanced Subclassing Patterns

### Multi-Input and Multi-Output Networks

In [62]:
class MIMONet(nn.Module):
    """Multi-Input Multi-Output Network"""
    def __init__(self):
        super(MIMONet, self).__init__()
        
        # Image processing branch
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.image_fc = nn.Linear(64 * 7 * 7, 128)
        
        # Metadata processing branch
        self.meta_fc1 = nn.Linear(10, 64)
        self.meta_fc2 = nn.Linear(64, 32)
        
        # Combined processing
        self.combined_fc = nn.Linear(128 + 32, 128)
        
        # Multiple output heads
        self.class_output = nn.Linear(128, 10)  # Classification head
        self.reg_output = nn.Linear(128, 1)     # Regression head
    
    def forward(self, image, metadata):
        # Process image
        img = F.relu(self.conv1(image))
        img = self.pool(img)
        img = F.relu(self.conv2(img))
        img = self.pool(img)
        img = img.view(img.size(0), -1)  # Flatten
        img = F.relu(self.image_fc(img))
        
        # Process metadata
        meta = F.relu(self.meta_fc1(metadata))
        meta = F.relu(self.meta_fc2(meta))
        
        # Combine features
        combined = torch.cat([img, meta], dim=1)
        combined = F.relu(self.combined_fc(combined))
        
        # Generate outputs
        class_output = self.class_output(combined)
        reg_output = self.reg_output(combined)
        
        return class_output, reg_output

# To demonstrate the model shape:
mimo_model = MIMONet()
print(mimo_model)

MIMONet(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (image_fc): Linear(in_features=3136, out_features=128, bias=True)
  (meta_fc1): Linear(in_features=10, out_features=64, bias=True)
  (meta_fc2): Linear(in_features=64, out_features=32, bias=True)
  (combined_fc): Linear(in_features=160, out_features=128, bias=True)
  (class_output): Linear(in_features=128, out_features=10, bias=True)
  (reg_output): Linear(in_features=128, out_features=1, bias=True)
)


### Implementing a ResNet Block

ResNets use skip connections that are easily implemented with subclassing:

In [63]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        
        # Main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection (skip connection)
        self.shortcut = nn.Sequential()
        # If dimensions change, we need to adjust the shortcut path
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        # Store input for skip connection
        shortcut = x
        
        # Main path
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # Add skip connection
        out += self.shortcut(shortcut)
        out = F.relu(out)
        
        return out

# Build a simple ResNet
class SimpleResNet(nn.Module):
    def __init__(self, num_blocks=2, num_classes=10):
        super(SimpleResNet, self).__init__()
        
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Stacking residual blocks
        self.layer1 = self._make_layer(64, num_blocks, stride=1)
        self.layer2 = self._make_layer(128, num_blocks, stride=2)
        
        self.fc = nn.Linear(128 * 7 * 7, num_classes)
    
    def _make_layer(self, out_channels, num_blocks, stride):
        # Create a layer with multiple residual blocks
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(ResidualBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)  # Flatten
        out = self.fc(out)
        return out

resnet = SimpleResNet()
print("ResNet architecture:")
print(resnet)

ResNet architecture:
SimpleResNet(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64,

## Training a Subclassed Model

Below is an example of training a subclassed model on the MNIST dataset:

In [64]:
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )
        
        # Add metrics tracking
        self.train_accuracy = []
        self.val_accuracy = []
        self.losses = []
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)  # Flatten the features
        x = self.classifier(x)
        return x
    
    # Custom method for feature extraction
    def extract_features(self, x):
        with torch.no_grad():
            x = self.features(x)
            return x.view(x.size(0), -1)  # Return flattened features

# Create the model and define training components
model = MNISTClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function with progress tracking
def train_model(model, train_loader, val_loader=None, epochs=3):
    train_losses = []
    train_accs = []
    val_accs = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(train_loader):
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if i % 100 == 99:  # Print every 100 mini-batches
                print(f'Epoch {epoch+1}, Batch {i+1}: Loss = {running_loss/100:.4f}, '
                      f'Accuracy = {100*correct/total:.2f}%')
                running_loss = 0.0
        
        # Save epoch statistics
        train_acc = 100 * correct / total
        train_accs.append(train_acc)
        
        # Validation
        if val_loader:
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for inputs, labels in val_loader:
                    outputs = model(inputs)
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = 100 * correct / total
            val_accs.append(val_acc)
            print(f'Epoch {epoch+1}: Validation Accuracy = {val_acc:.2f}%')
    
    # Store metrics in the model for later use
    model.train_accuracy = train_accs
    if val_loader:
        model.val_accuracy = val_accs
    
    return train_accs, val_accs

# Setup data loaders (commented out as it requires the MNIST dataset)
"""
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

# Train the model
train_accs, val_accs = train_model(model, train_loader, test_loader, epochs=3)
"""

print("Model ready for training with MNIST dataset.")

Model ready for training with MNIST dataset.


## Inference and Evaluation

Subclassed models can implement custom inference methods:

In [65]:
class InferenceModel(nn.Module):
    def __init__(self, base_model):
        super(InferenceModel, self).__init__()
        self.base_model = base_model
        # Turn off gradient tracking for inference
        for param in self.base_model.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        return self.base_model(x)
    
    def predict(self, x, return_confidence=False):
        """Custom prediction method with confidence scores"""
        self.eval()  # Set to evaluation mode
        with torch.no_grad():
            logits = self(x)
            probabilities = F.softmax(logits, dim=1)
            confidences, predictions = torch.max(probabilities, dim=1)
            
            if return_confidence:
                return predictions, confidences
            return predictions
    
    def predict_batch(self, dataloader):
        """Predict on a full batch of data"""
        all_predictions = []
        all_labels = []
        
        self.eval()
        with torch.no_grad():
            for inputs, labels in dataloader:
                predictions = self.predict(inputs)
                all_predictions.append(predictions)
                all_labels.append(labels)
                
        return torch.cat(all_predictions), torch.cat(all_labels)
    
    def evaluate(self, dataloader):
        """Evaluate model performance"""
        predictions, labels = self.predict_batch(dataloader)
        correct = (predictions == labels).sum().item()
        total = labels.size(0)
        accuracy = 100 * correct / total
        
        return {
            'accuracy': accuracy,
            'correct': correct,
            'total': total
        }

# Create an inference model wrapper around our classifier
inference_model = InferenceModel(model)
print("Inference model ready for evaluation.")

# Example of synthetic data for demonstration
dummy_input = torch.randn(5, 1, 28, 28)  # 5 MNIST-like images
predictions, confidences = inference_model.predict(dummy_input, return_confidence=True)
print(f"\nPredictions: {predictions}")
print(f"Confidence scores: {confidences}")

Inference model ready for evaluation.

Predictions: tensor([3, 3, 3, 3, 3])
Confidence scores: tensor([0.1214, 0.1248, 0.1190, 0.1214, 0.1272])


## Advantages and Limitations of Subclassing

### Advantages
- **Maximum Flexibility**: Can implement any architecture or computational graph
- **Code Reusability**: Can create class hierarchies and reuse components
- **Dynamic Behavior**: Can implement conditional computation paths
- **Complex Topologies**: Supports skip connections, branches, and custom operations
- **Custom Methods**: Can add domain-specific utility functions to models

### Limitations
- **Complexity**: More complex implementation than sequential models
- **Debugging Challenges**: Harder to trace through custom forward logic
- **Serialization Issues**: Custom methods may require special handling during saving/loading
- **Potentially Slower**: May not benefit from certain optimizations available to static graphs

## Conclusion

Subclassing provides the ultimate flexibility in designing neural network architectures. It allows for:

1. Implementation of custom architectures with complex topologies
2. Addition of domain-specific methods and behaviors to models
3. Creation of models with multiple inputs and outputs
4. Development of sophisticated architectures like ResNets, Inception models, or custom research designs

While it requires more code and understanding compared to sequential models, the flexibility and power it offers make it the preferred approach for advanced deep learning architectures and research.

## Exercises

1. **Basic**: Implement a simple feed-forward neural network using subclassing and compare it to an equivalent sequential model.

2. **Intermediate**: Modify the ResidualBlock to create a "DenseBlock" that concatenates the input with the output instead of adding them.

3. **Advanced**: Implement a U-Net architecture for image segmentation using subclassing, with encoder and decoder paths and skip connections.

4. **Research**: Design and implement a custom attention mechanism within a subclassed model for sequence-to-sequence tasks.