# üî¨ Lecture 10: MCUNet & TinyML - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/10_mcunet_tinyml/demo.ipynb)

## What You'll Learn
- TinyML constraints and challenges
- Memory-efficient inference
- MCUNet architecture design
- Peak memory optimization

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
print('Ready for TinyML!')

## Part 1: TinyML Constraints

Microcontrollers have extreme constraints:
- **Flash**: 256KB - 2MB (model weights)
- **SRAM**: 64KB - 512KB (activations)
- **Compute**: No GPU, limited CPU

In [None]:
# MCU specifications
mcus = {
    'STM32F4': {'flash_kb': 512, 'sram_kb': 128, 'mhz': 180, 'use': 'Wearables'},
    'STM32H7': {'flash_kb': 2048, 'sram_kb': 1024, 'mhz': 480, 'use': 'Industrial'},
    'nRF52840': {'flash_kb': 1024, 'sram_kb': 256, 'mhz': 64, 'use': 'IoT'},
    'ESP32': {'flash_kb': 4096, 'sram_kb': 520, 'mhz': 240, 'use': 'Smart Home'},
    'Cortex-M0': {'flash_kb': 256, 'sram_kb': 32, 'mhz': 48, 'use': 'Sensors'},
}

# Compare with typical ML models
models = {
    'MobileNetV2': {'params_mb': 14, 'peak_act_mb': 40},
    'ResNet-18': {'params_mb': 46, 'peak_act_mb': 100},
    'BERT-tiny': {'params_mb': 17, 'peak_act_mb': 50},
    'MCUNet': {'params_mb': 0.74, 'peak_act_mb': 0.39},
}

print('üìä MCU MEMORY CONSTRAINTS')
print('=' * 70)
print(f'{"MCU":<15} {"Flash":<12} {"SRAM":<12} {"MHz":<10} {"Use Case":<20}')
print('-' * 70)
for name, spec in mcus.items():
    print(f'{name:<15} {spec["flash_kb"]:>8} KB  {spec["sram_kb"]:>8} KB  {spec["mhz"]:>5}  {spec["use"]:<20}')

print(f'\nüìä TYPICAL MODEL REQUIREMENTS')
print('=' * 50)
print(f'{"Model":<15} {"Params (MB)":<15} {"Peak Memory (MB)":<20}')
print('-' * 50)
for name, req in models.items():
    print(f'{name:<15} {req["params_mb"]:>10.2f}    {req["peak_act_mb"]:>15.2f}')

print('\n‚ö†Ô∏è Most models are 100x too big for MCUs!')

In [None]:
# Visualize the gap
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Model sizes vs MCU Flash
model_names = list(models.keys())
model_params = [models[m]['params_mb'] * 1024 for m in model_names]  # Convert to KB
mcu_flash = [mcus[m]['flash_kb'] for m in mcus.keys()]

axes[0].bar(model_names, model_params, color='#ef4444', alpha=0.8, label='Model Size')
axes[0].axhline(y=np.mean(mcu_flash), color='#22c55e', linestyle='--', 
               linewidth=2, label=f'Avg MCU Flash ({np.mean(mcu_flash):.0f} KB)')
axes[0].set_ylabel('Size (KB)')
axes[0].set_title('Model Size vs MCU Flash')
axes[0].legend()
axes[0].set_yscale('log')

# Peak memory vs MCU SRAM
model_peak = [models[m]['peak_act_mb'] * 1024 for m in model_names]  # Convert to KB
mcu_sram = [mcus[m]['sram_kb'] for m in mcus.keys()]

axes[1].bar(model_names, model_peak, color='#ef4444', alpha=0.8, label='Peak Memory')
axes[1].axhline(y=np.mean(mcu_sram), color='#22c55e', linestyle='--', 
               linewidth=2, label=f'Avg MCU SRAM ({np.mean(mcu_sram):.0f} KB)')
axes[1].set_ylabel('Memory (KB)')
axes[1].set_title('Peak Activation Memory vs MCU SRAM')
axes[1].legend()
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

print('\nüí° MCUNet bridges this gap!')

## Part 2: Understanding Peak Memory

In [None]:
def calculate_layer_memory(layer, input_size, dtype_bytes=1):
    """
    Calculate memory for a single layer.
    
    Returns: (input_mem, output_mem, weight_mem)
    """
    batch, in_ch, h, w = input_size
    
    if isinstance(layer, nn.Conv2d):
        out_ch = layer.out_channels
        # Output size (assuming padding maintains size)
        out_h, out_w = h, w
        if layer.stride[0] > 1:
            out_h, out_w = h // layer.stride[0], w // layer.stride[1]
        
        input_mem = batch * in_ch * h * w * dtype_bytes
        output_mem = batch * out_ch * out_h * out_w * dtype_bytes
        weight_mem = layer.weight.numel() * dtype_bytes
        
        return input_mem, output_mem, weight_mem, (batch, out_ch, out_h, out_w)
    
    return 0, 0, 0, input_size

class TinyNet(nn.Module):
    """Example tiny network."""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.gap(x).flatten(1)
        return self.fc(x)

# Analyze memory usage
model = TinyNet()
input_size = (1, 3, 128, 128)

print('üìä LAYER-BY-LAYER MEMORY ANALYSIS')
print('=' * 70)
print(f'{"Layer":<15} {"Input (KB)":<15} {"Output (KB)":<15} {"Weights (KB)":<15}')
print('-' * 70)

current_size = input_size
peak_memory = 0
memory_timeline = []

for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        in_mem, out_mem, w_mem, new_size = calculate_layer_memory(layer, current_size)
        
        # Peak memory = input + output (need both during computation)
        layer_peak = (in_mem + out_mem) / 1024  # KB
        peak_memory = max(peak_memory, layer_peak)
        memory_timeline.append({'name': name, 'peak': layer_peak, 'in': in_mem/1024, 'out': out_mem/1024})
        
        print(f'{name:<15} {in_mem/1024:>12.2f}   {out_mem/1024:>12.2f}   {w_mem/1024:>12.2f}')
        current_size = new_size

print(f'\nüî∫ Peak activation memory: {peak_memory:.2f} KB')

In [None]:
# Visualize memory timeline
fig, ax = plt.subplots(figsize=(12, 6))

layers = [m['name'] for m in memory_timeline]
peaks = [m['peak'] for m in memory_timeline]
inputs = [m['in'] for m in memory_timeline]
outputs = [m['out'] for m in memory_timeline]

x = np.arange(len(layers))
width = 0.35

ax.bar(x - width/2, inputs, width, label='Input', color='#3b82f6')
ax.bar(x + width/2, outputs, width, label='Output', color='#22c55e')
ax.plot(x, peaks, 'ro-', linewidth=2, markersize=10, label='Peak (In+Out)')

ax.axhline(y=128, color='red', linestyle='--', label='STM32F4 SRAM (128KB)')

ax.set_xlabel('Layer')
ax.set_ylabel('Memory (KB)')
ax.set_title('üìä Memory Usage Through Network')
ax.set_xticks(x)
ax.set_xticklabels(layers)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print('\nüí° Peak memory occurs at early layers (large feature maps)!')
print('   MCUNet optimizes network width at each stage.')

## Part 3: MCUNet Design Principles

In [None]:
class MCUNetBlock(nn.Module):
    """
    MCUNet uses inverted residual blocks with optimized expansion ratios.
    Key: Different expansion ratios at different resolutions.
    """
    def __init__(self, in_ch, out_ch, stride=1, expand_ratio=3):
        super().__init__()
        hidden_ch = in_ch * expand_ratio
        
        self.use_residual = (stride == 1 and in_ch == out_ch)
        
        layers = []
        # Expand
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_ch, hidden_ch, 1, bias=False),
                nn.BatchNorm2d(hidden_ch),
                nn.ReLU6()
            ])
        
        # Depthwise
        layers.extend([
            nn.Conv2d(hidden_ch, hidden_ch, 3, stride, 1, groups=hidden_ch, bias=False),
            nn.BatchNorm2d(hidden_ch),
            nn.ReLU6()
        ])
        
        # Project
        layers.extend([
            nn.Conv2d(hidden_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch)
        ])
        
        self.conv = nn.Sequential(*layers)
    
    def forward(self, x):
        if self.use_residual:
            return x + self.conv(x)
        return self.conv(x)

class MCUNet(nn.Module):
    """
    MCUNet-style architecture optimized for MCU deployment.
    
    Key optimizations:
    1. Lower expansion ratios at high resolutions (save memory)
    2. Higher expansion ratios at low resolutions (more capacity)
    3. Aggressive downsampling early
    """
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Stem - reduce resolution quickly
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU6()
        )
        
        # Stage 1: High resolution, low expansion
        self.stage1 = nn.Sequential(
            MCUNetBlock(16, 16, stride=1, expand_ratio=1),
            MCUNetBlock(16, 24, stride=2, expand_ratio=2),
        )
        
        # Stage 2: Medium resolution, medium expansion
        self.stage2 = nn.Sequential(
            MCUNetBlock(24, 24, stride=1, expand_ratio=3),
            MCUNetBlock(24, 40, stride=2, expand_ratio=3),
        )
        
        # Stage 3: Low resolution, high expansion
        self.stage3 = nn.Sequential(
            MCUNetBlock(40, 40, stride=1, expand_ratio=4),
            MCUNetBlock(40, 80, stride=2, expand_ratio=4),
        )
        
        # Head
        self.head = nn.Sequential(
            nn.Conv2d(80, 160, 1, bias=False),
            nn.BatchNorm2d(160),
            nn.ReLU6(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.fc = nn.Linear(160, num_classes)
    
    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.head(x).flatten(1)
        return self.fc(x)

# Analyze MCUNet
mcunet = MCUNet()

# Count parameters
params = sum(p.numel() for p in mcunet.parameters())
params_kb = params * 1 / 1024  # INT8

print('üìä MCUNET ARCHITECTURE')
print('=' * 50)
print(f'Parameters: {params:,}')
print(f'Model size (INT8): {params_kb:.2f} KB')
print(f'\nFits in {params_kb:.0f} KB Flash! ‚úÖ')

In [None]:
# Compare memory usage: Regular vs MCUNet design
def estimate_peak_memory(model, input_size, dtype_bytes=1):
    """Estimate peak memory through the network."""
    x = torch.randn(input_size)
    peak = 0
    
    # Hook to track activations
    activations = []
    
    def hook(module, input, output):
        activations.append(output.numel() * dtype_bytes / 1024)  # KB
    
    hooks = []
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            hooks.append(module.register_forward_hook(hook))
    
    with torch.no_grad():
        _ = model(x)
    
    # Remove hooks
    for h in hooks:
        h.remove()
    
    # Peak is max of consecutive sums
    for i in range(len(activations) - 1):
        peak = max(peak, activations[i] + activations[i+1])
    
    return peak, activations

# Compare
print('üìä PEAK MEMORY COMPARISON')
print('=' * 50)

input_size = (1, 3, 128, 128)

peak_tiny, acts_tiny = estimate_peak_memory(TinyNet(), input_size)
peak_mcu, acts_mcu = estimate_peak_memory(mcunet, input_size)

print(f'TinyNet (naive):     {peak_tiny:.2f} KB')
print(f'MCUNet (optimized):  {peak_mcu:.2f} KB')
print(f'Reduction:           {peak_tiny/peak_mcu:.1f}x')

# Visualize
fig, ax = plt.subplots(figsize=(12, 5))

ax.plot(acts_tiny, 'o-', label='TinyNet', color='#ef4444', linewidth=2)
ax.plot(acts_mcu, 's-', label='MCUNet', color='#22c55e', linewidth=2)
ax.axhline(y=128, color='gray', linestyle='--', label='128KB SRAM limit')

ax.set_xlabel('Layer')
ax.set_ylabel('Activation Memory (KB)')
ax.set_title('üìä Activation Memory Through Network')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Part 4: TinyNAS - Joint Network-Memory Optimization

In [None]:
def tinynas_search(target_flash_kb, target_sram_kb, num_trials=20):
    """
    Simplified TinyNAS: Search for architecture that fits MCU constraints.
    
    Search over:
    - Network width (channels)
    - Expansion ratios
    - Number of blocks
    """
    best_config = None
    best_acc = 0
    
    results = []
    
    for trial in range(num_trials):
        # Sample architecture config
        base_width = np.random.choice([8, 12, 16, 20, 24])
        expand_ratios = [np.random.choice([1, 2, 3, 4]) for _ in range(3)]
        num_blocks = [np.random.choice([1, 2, 3]) for _ in range(3)]
        
        # Estimate size
        params = base_width * 100 + sum(expand_ratios) * 50  # Simplified estimate
        peak_mem = base_width * 10 + sum(expand_ratios) * 5  # Simplified estimate
        
        # Check constraints
        if params <= target_flash_kb and peak_mem <= target_sram_kb:
            # Simulate accuracy (in real NAS, train and evaluate)
            acc = 50 + params / 10 + np.random.randn() * 2
            
            results.append({
                'config': {'width': base_width, 'expand': expand_ratios, 'blocks': num_blocks},
                'params': params,
                'peak_mem': peak_mem,
                'acc': acc
            })
            
            if acc > best_acc:
                best_acc = acc
                best_config = results[-1]
    
    return best_config, results

print('üîç TINYNAS SEARCH')
print('=' * 50)
print('Target: Flash ‚â§ 256KB, SRAM ‚â§ 128KB')

best, all_results = tinynas_search(target_flash_kb=256, target_sram_kb=128, num_trials=50)

print(f'\nüèÜ Best Architecture Found:')
print(f'   Config: {best["config"]}')
print(f'   Model size: {best["params"]:.0f} KB')
print(f'   Peak memory: {best["peak_mem"]:.0f} KB')
print(f'   Accuracy: {best["acc"]:.1f}%')

# Visualize search results
fig, ax = plt.subplots(figsize=(10, 6))

params = [r['params'] for r in all_results]
peak_mems = [r['peak_mem'] for r in all_results]
accs = [r['acc'] for r in all_results]

scatter = ax.scatter(params, peak_mems, c=accs, cmap='RdYlGn', s=100, alpha=0.7)
ax.scatter(best['params'], best['peak_mem'], c='red', s=300, marker='*', label='Best')

# Constraint box
ax.axvline(x=256, color='red', linestyle='--', alpha=0.5)
ax.axhline(y=128, color='red', linestyle='--', alpha=0.5)
ax.fill_between([0, 256], [0, 0], [128, 128], alpha=0.1, color='green', label='Feasible')

ax.set_xlabel('Model Size (KB)')
ax.set_ylabel('Peak Memory (KB)')
ax.set_title('üìä TinyNAS Search Space')
plt.colorbar(scatter, label='Accuracy (%)')
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
print('üéØ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. MCUs have KB-level memory (1000x less than GPUs)')
print('\n2. Peak memory = max(input + output) across layers')
print('\n3. Early layers (high resolution) need low expansion')
print('\n4. MCUNet optimizes width/expansion per resolution')
print('\n5. TinyNAS jointly optimizes network and memory')
print('\n6. Result: ImageNet models in 1MB Flash + 256KB SRAM!')
print('\n' + '=' * 60)
print('\nüìö Next: Efficient Transformers!')