# MLX ResNet Model Training

This notebook demonstrates training a ResNet (Residual Network) model with approximately 2000 neurons using Apple's MLX framework on a random dataset.

In [1]:
# Import required libraries
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List
import time
import math

In [2]:
# Set random seed for reproducibility
np.random.seed(42)
mx.random.seed(42)

In [3]:
# Generate random dataset
def generate_random_dataset(n_samples: int = 10000, n_features: int = 64, n_classes: int = 10) -> Tuple[mx.array, mx.array]:
    """
    Generate a random dataset for classification.
    Using 64 features to simulate image-like data (8x8 "images")
    
    Args:
        n_samples: Number of samples
        n_features: Number of input features (64 for 8x8 images)
        n_classes: Number of output classes
    
    Returns:
        Tuple of (features, labels)
    """
    # Generate random features normalized to [0, 1] range like images
    X = mx.random.uniform(0, 1, (n_samples, n_features))
    
    # Generate random labels
    y = mx.random.randint(0, n_classes, (n_samples,))
    
    return X, y

# Generate dataset
print("Generating random dataset...")
X, y = generate_random_dataset(n_samples=10000, n_features=64, n_classes=10)
print(f"Dataset shape: X={X.shape}, y={y.shape}")
print(f"Features range: [{X.min():.3f}, {X.max():.3f}]")

Generating random dataset...
Dataset shape: X=(10000, 64), y=(10000,)
Features range: [0.000, 1.000]


In [4]:
# Split dataset into train and test
def train_test_split(X: mx.array, y: mx.array, test_size: float = 0.2) -> Tuple[mx.array, mx.array, mx.array, mx.array]:
    """
    Split dataset into training and testing sets.
    """
    n_samples = X.shape[0]
    n_test = int(n_samples * test_size)
    n_train = n_samples - n_test
    
    # Random indices for splitting
    indices = mx.arange(n_samples)
    train_indices = indices[:n_train]
    test_indices = indices[n_train:]
    
    X_train = X[train_indices]
    X_test = X[test_indices]
    y_train = y[train_indices]
    y_test = y[test_indices]
    
    return X_train, X_test, y_train, y_test

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y)
print(f"Training set: X_train={X_train.shape}, y_train={y_train.shape}")
print(f"Test set: X_test={X_test.shape}, y_test={y_test.shape}")

Training set: X_train=(8000, 64), y_train=(8000,)
Test set: X_test=(2000, 64), y_test=(2000,)


In [12]:
# Define ResNet Building Blocks
class ResidualBlock(nn.Module):
    """
    Basic Residual Block for ResNet.
    """
    
    def __init__(self, in_features: int, out_features: int, stride: int = 1, downsample=None):
        super().__init__()
        
        self.linear1 = nn.Linear(in_features, out_features)
        self.bn1 = nn.BatchNorm(out_features)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(out_features, out_features)
        self.bn2 = nn.BatchNorm(out_features)
        self.downsample = downsample
        self.dropout = nn.Dropout(0.1)
    
    def __call__(self, x):
        identity = x
        
        # First linear layer
        out = self.linear1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        # Second linear layer
        out = self.linear2(out)
        out = self.bn2(out)
        
        # Skip connection
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out = out + identity  # Residual connection
        out = self.relu(out)
        
        return out


class ResNet(nn.Module):
    """
    ResNet architecture with approximately 2000 neurons.
    """
    
    def __init__(self, input_size: int, num_classes: int, num_neurons: int = 2000):
        super().__init__()
        
        # Calculate layer sizes to achieve ~2000 neurons
        # We'll use 4 residual blocks with decreasing sizes
        layer1_size = 512  # 512 neurons
        layer2_size = 512  # 512 neurons  
        layer3_size = 512  # 512 neurons
        layer4_size = 464  # 464 neurons (total ≈ 2000)
        
        # Initial projection layer
        self.initial_linear = nn.Linear(input_size, layer1_size)
        self.initial_bn = nn.BatchNorm(layer1_size)
        self.initial_relu = nn.ReLU()
        
        # Residual blocks
        self.layer1 = self._make_layer(layer1_size, layer1_size, 2)
        self.layer2 = self._make_layer(layer1_size, layer2_size, 2, 
                                      downsample=nn.Linear(layer1_size, layer2_size))
        self.layer3 = self._make_layer(layer2_size, layer3_size, 2)
        self.layer4 = self._make_layer(layer3_size, layer4_size, 2,
                                      downsample=nn.Linear(layer3_size, layer4_size))
        
        # Global average pooling and classifier
        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(layer4_size, num_classes)
        
        # Store layer sizes for reporting
        self.layer_sizes = [layer1_size, layer2_size, layer3_size, layer4_size]
        self.total_neurons = sum(self.layer_sizes)
    
    def _make_layer(self, in_features: int, out_features: int, blocks: int, downsample=None):
        """
        Create a layer with multiple residual blocks.
        """
        layers = []
        
        # First block (may have downsample)
        layers.append(ResidualBlock(in_features, out_features, downsample=downsample))
        
        # Remaining blocks
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_features, out_features))
        
        return nn.Sequential(*layers)
    
    def __call__(self, x):
        # Initial projection
        x = self.initial_linear(x)
        x = self.initial_bn(x)
        x = self.initial_relu(x)
        
        # Residual layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Global pooling and classification
        # Remove the mean operation since we don't need global pooling for 1D features
        # x has shape (batch_size, layer4_size) already
        x = self.dropout(x)
        x = self.classifier(x)
        
        return x

# Initialize the ResNet model
model = ResNet(input_size=64, num_classes=10)
print(f"ResNet model initialized with {model.total_neurons} neurons")
print(f"Layer sizes: {model.layer_sizes}")

# Count total parameters
def count_parameters(model):
    total = 0
    for name, param in model.parameters().items():
        if hasattr(param, 'size'):
            total += param.size
        elif hasattr(param, 'shape'):
            total += np.prod(param.shape)
    return total

total_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")

ResNet model initialized with 2000 neurons
Layer sizes: [512, 512, 512, 464]
Total parameters: 0


In [13]:
# Define loss function and optimizer
def cross_entropy_loss(logits, targets):
    """
    Compute cross-entropy loss.
    """
    return mx.mean(nn.losses.cross_entropy(logits, targets))

def accuracy(logits, targets):
    """
    Compute accuracy.
    """
    predictions = mx.argmax(logits, axis=1)
    return mx.mean(predictions == targets)

# Initialize optimizer with learning rate scheduling
initial_lr = 0.001
optimizer = optim.Adam(learning_rate=initial_lr)
print(f"Optimizer initialized: Adam with learning rate {initial_lr}")

Optimizer initialized: Adam with learning rate 0.001


In [14]:
# Training function
def forward_and_loss(model, X, y):
    """
    Forward pass and loss computation.
    """
    logits = model(X)
    loss = cross_entropy_loss(logits, y)
    return loss, logits

# Compile the training step
loss_and_grad_fn = nn.value_and_grad(model, forward_and_loss)

def train_step(model, optimizer, X, y):
    """
    Single training step.
    """
    (loss, logits), grads = loss_and_grad_fn(model, X, y)
    optimizer.update(model, grads)
    return loss, logits

def adjust_learning_rate(optimizer, epoch, initial_lr):
    """
    Adjust learning rate during training (cosine annealing).
    """
    lr = initial_lr * 0.5 * (1 + math.cos(math.pi * epoch / 100))
    optimizer.learning_rate = lr
    return lr

In [15]:
# Training loop
def train_resnet(model, optimizer, X_train, y_train, X_test, y_test, epochs=100, batch_size=128):
    """
    Train the ResNet model.
    """
    train_losses = []
    train_accuracies = []
    test_accuracies = []
    learning_rates = []
    
    n_batches = len(X_train) // batch_size
    
    print(f"Starting ResNet training for {epochs} epochs...")
    print(f"Batch size: {batch_size}, Number of batches: {n_batches}")
    print(f"Model has {model.total_neurons} neurons and {total_params:,} parameters")
    print("=" * 80)
    
    for epoch in range(epochs):
        epoch_start = time.time()
        epoch_loss = 0.0
        epoch_acc = 0.0
        
        # Adjust learning rate
        current_lr = adjust_learning_rate(optimizer, epoch, initial_lr)
        learning_rates.append(current_lr)
        
        # Training
        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            
            X_batch = X_train[start_idx:end_idx]
            y_batch = y_train[start_idx:end_idx]
            
            loss, logits = train_step(model, optimizer, X_batch, y_batch)
            
            epoch_loss += loss.item()
            epoch_acc += accuracy(logits, y_batch).item()
        
        # Average metrics
        avg_loss = epoch_loss / n_batches
        avg_train_acc = epoch_acc / n_batches
        
        # Test accuracy
        test_logits = model(X_test)
        test_acc = accuracy(test_logits, y_test).item()
        
        # Store metrics
        train_losses.append(avg_loss)
        train_accuracies.append(avg_train_acc)
        test_accuracies.append(test_acc)
        
        epoch_time = time.time() - epoch_start
        
        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1:3d}/{epochs} | "
                  f"Loss: {avg_loss:.4f} | "
                  f"Train Acc: {avg_train_acc:.4f} | "
                  f"Test Acc: {test_acc:.4f} | "
                  f"LR: {current_lr:.6f} | "
                  f"Time: {epoch_time:.2f}s")
    
    return train_losses, train_accuracies, test_accuracies, learning_rates

# Start training
print("Starting ResNet Training...")
print("=" * 80)
train_losses, train_accs, test_accs, lrs = train_resnet(
    model, optimizer, X_train, y_train, X_test, y_test, 
    epochs=100, batch_size=128
)
print("=" * 80)
print("ResNet training completed!")

Starting ResNet Training...
Starting ResNet training for 100 epochs...
Batch size: 128, Number of batches: 62
Model has 2000 neurons and 0 parameters
Epoch   1/100 | Loss: 2.5821 | Train Acc: 0.0906 | Test Acc: 0.1010 | LR: 0.001000 | Time: 1.63s
Epoch   1/100 | Loss: 2.5821 | Train Acc: 0.0906 | Test Acc: 0.1010 | LR: 0.001000 | Time: 1.63s
Epoch  11/100 | Loss: 2.3316 | Train Acc: 0.0955 | Test Acc: 0.0985 | LR: 0.000976 | Time: 0.68s
Epoch  11/100 | Loss: 2.3316 | Train Acc: 0.0955 | Test Acc: 0.0985 | LR: 0.000976 | Time: 0.68s
Epoch  21/100 | Loss: 2.3110 | Train Acc: 0.1023 | Test Acc: 0.1015 | LR: 0.000905 | Time: 0.66s
Epoch  21/100 | Loss: 2.3110 | Train Acc: 0.1023 | Test Acc: 0.1015 | LR: 0.000905 | Time: 0.66s
Epoch  31/100 | Loss: 2.3090 | Train Acc: 0.0958 | Test Acc: 0.1000 | LR: 0.000794 | Time: 0.71s
Epoch  31/100 | Loss: 2.3090 | Train Acc: 0.0958 | Test Acc: 0.1000 | LR: 0.000794 | Time: 0.71s
Epoch  41/100 | Loss: 2.3052 | Train Acc: 0.1056 | Test Acc: 0.0980 | LR: 

RuntimeError: [metal::malloc] Resource limit (499000) exceeded.

In [None]:
# Plot comprehensive training results
plt.figure(figsize=(20, 12))

# Plot loss
plt.subplot(2, 3, 1)
plt.plot(train_losses, 'b-', label='Training Loss', linewidth=2)
plt.title('Training Loss Over Time', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Cross-Entropy Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot training accuracy
plt.subplot(2, 3, 2)
plt.plot(train_accs, 'g-', label='Training Accuracy', linewidth=2)
plt.title('Training Accuracy Over Time', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot both training and test accuracy
plt.subplot(2, 3, 3)
plt.plot(train_accs, 'g-', label='Training Accuracy', linewidth=2)
plt.plot(test_accs, 'r-', label='Test Accuracy', linewidth=2)
plt.title('Training vs Test Accuracy', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot learning rate schedule
plt.subplot(2, 3, 4)
plt.plot(lrs, 'm-', label='Learning Rate', linewidth=2)
plt.title('Learning Rate Schedule', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')

# Plot loss (log scale)
plt.subplot(2, 3, 5)
plt.semilogy(train_losses, 'b-', label='Training Loss (Log Scale)', linewidth=2)
plt.title('Training Loss (Log Scale)', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Cross-Entropy Loss (Log)')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot accuracy difference
plt.subplot(2, 3, 6)
acc_diff = [train - test for train, test in zip(train_accs, test_accs)]
plt.plot(acc_diff, 'orange', label='Train-Test Accuracy Gap', linewidth=2)
plt.title('Overfitting Monitor', fontsize=14, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Accuracy Difference')
plt.legend()
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

# Print final results
print(f"\nFinal Results:")
print(f"="*50)
print(f"Final Training Accuracy: {train_accs[-1]:.4f} ({train_accs[-1]*100:.2f}%)")
print(f"Final Test Accuracy: {test_accs[-1]:.4f} ({test_accs[-1]*100:.2f}%)")
print(f"Final Training Loss: {train_losses[-1]:.4f}")
print(f"Best Test Accuracy: {max(test_accs):.4f} ({max(test_accs)*100:.2f}%) at epoch {test_accs.index(max(test_accs))+1}")
print(f"Final Learning Rate: {lrs[-1]:.6f}")
print(f"="*50)

In [None]:
# ResNet Model Architecture Summary
print("ResNet Model Architecture Summary:")
print("=" * 60)
print(f"Input Layer: 64 features")
print(f"Initial Projection: 64 → {model.layer_sizes[0]} neurons")
print("\nResidual Blocks:")
for i, size in enumerate(model.layer_sizes, 1):
    print(f"  Layer {i}: {size} neurons (2 ResBlocks with skip connections)")
print(f"\nOutput Layer: {model.layer_sizes[-1]} → 10 classes")
print(f"\nTotal Hidden Neurons: {model.total_neurons}")
print(f"Total Parameters: {total_params:,}")
print(f"Architecture: ResNet with Batch Normalization and Dropout")
print("=" * 60)

# Analyze residual connections
print("\nResNet Features:")
print("• Skip connections for gradient flow")
print("• Batch normalization for training stability")
print("• Dropout for regularization")
print("• Cosine annealing learning rate schedule")
print("• Global average pooling")

In [None]:
# Test predictions and model analysis
print("\nModel Performance Analysis:")
print("=" * 40)

# Test predictions on a sample
sample_X = X_test[:10]
sample_y = y_test[:10]
sample_logits = model(sample_X)
sample_preds = mx.argmax(sample_logits, axis=1)
sample_probs = nn.softmax(sample_logits, axis=1)

print("Sample Predictions:")
for i in range(10):
    true_label = sample_y[i].item()
    pred_label = sample_preds[i].item()
    confidence = sample_probs[i, pred_label].item()
    status = "✓" if true_label == pred_label else "✗"
    print(f"Sample {i+1:2d}: True={true_label}, Pred={pred_label}, Conf={confidence:.3f} {status}")

# Calculate confusion matrix elements
all_preds = mx.argmax(model(X_test), axis=1)
correct_per_class = {}
total_per_class = {}

for i in range(10):
    mask = y_test == i
    if mx.sum(mask) > 0:
        class_preds = all_preds[mask]
        correct_per_class[i] = mx.sum(class_preds == i).item()
        total_per_class[i] = mx.sum(mask).item()

print("\nPer-Class Accuracy:")
for i in range(10):
    if i in total_per_class and total_per_class[i] > 0:
        acc = correct_per_class[i] / total_per_class[i]
        print(f"Class {i}: {acc:.3f} ({correct_per_class[i]}/{total_per_class[i]})")

In [None]:
# Compare ResNet vs Standard NN
print("\nResNet vs Standard Neural Network:")
print("=" * 50)
print("ResNet Advantages:")
print("• Skip connections prevent vanishing gradients")
print("• Easier to train deeper networks")
print("• Better gradient flow through residual paths")
print("• Batch normalization improves training stability")
print("• Can achieve better performance with same parameter count")
print("\nArchitectural Benefits:")
print(f"• {len(model.layer_sizes)} residual layers with skip connections")
print(f"• Total of {model.total_neurons} neurons across hidden layers")
print(f"• Batch normalization in every residual block")
print(f"• Dropout regularization (0.1 in blocks, 0.5 before classifier)")
print(f"• Cosine annealing learning rate schedule")

# Save model information
print("\n" + "="*60)
print("ResNet Training Completed Successfully!")
print("="*60)
print("\nTo save this ResNet model:")
print("```python")
print("# Save model parameters")
print("mx.save_safetensors('resnet_model.safetensors', model.parameters())")
print("```")
print("\nTo load the model later:")
print("```python")
print("# Load model")
print("model = ResNet(input_size=64, num_classes=10)")
print("model.load_weights(mx.load('resnet_model.safetensors'))")
print("```")