# üîç Lecture 7: Neural Architecture Search (Part 1) - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/07_neural_architecture_search_1/demo.ipynb)

## What You'll Learn
- NAS fundamentals: Search space, strategy, and evaluation
- DARTS: Differentiable architecture search
- Supernet training with architecture weights
- Deriving final architecture from trained supernet

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
print('Ready for Neural Architecture Search!')

## Part 1: The NAS Problem

**Goal**: Automatically find the best neural network architecture.

**Three components**:
1. **Search Space**: What architectures can we explore?
2. **Search Strategy**: How do we explore efficiently?
3. **Evaluation Strategy**: How do we measure architecture quality?

In [None]:
# Visualize the search space
operations = {
    'conv3x3': 'Standard 3√ó3 convolution',
    'conv5x5': 'Standard 5√ó5 convolution',
    'sep_conv3x3': 'Depthwise separable 3√ó3',
    'dil_conv3x3': 'Dilated 3√ó3 convolution',
    'max_pool': 'Max pooling 3√ó3',
    'avg_pool': 'Average pooling 3√ó3',
    'skip': 'Skip connection (identity)',
    'zero': 'No connection',
}

print('üîç TYPICAL NAS SEARCH SPACE')
print('=' * 60)
print('\nOperations available at each edge:')
for op, desc in operations.items():
    print(f'  ‚Ä¢ {op:15} - {desc}')

# Calculate search space size
num_ops = len(operations)
num_edges = 14  # Typical DARTS cell has 14 edges
search_space_size = num_ops ** num_edges

print(f'\nüìä Search Space Size:')
print(f'   Operations: {num_ops}')
print(f'   Edges per cell: {num_edges}')
print(f'   Total architectures: {num_ops}^{num_edges} = {search_space_size:,}')
print(f'\n‚ö†Ô∏è Exhaustive search is impossible!')

## Part 2: DARTS - Differentiable Architecture Search

**Key Idea**: Make architecture choice differentiable!

Instead of discrete choice, use weighted sum of all operations:

$$\bar{o}(x) = \sum_i \frac{\exp(\alpha_i)}{\sum_j \exp(\alpha_j)} \cdot o_i(x)$$

Where $\alpha_i$ are learnable architecture parameters.

In [None]:
class MixedOperation(nn.Module):
    """
    Mixed operation: Weighted sum of candidate operations.
    Architecture weights (alpha) are learned jointly with model weights.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # Candidate operations
        self.ops = nn.ModuleList([
            nn.Sequential(  # conv3x3
                nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ),
            nn.Sequential(  # conv5x5
                nn.Conv2d(in_channels, out_channels, 5, padding=2, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ),
            nn.Sequential(  # sep_conv3x3
                nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=False),
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ),
            nn.MaxPool2d(3, stride=1, padding=1),  # max_pool
            nn.Identity(),  # skip connection
        ])
        
        self.op_names = ['conv3x3', 'conv5x5', 'sep_conv3x3', 'max_pool', 'skip']
        
        # Architecture weights (learnable)
        self.alpha = nn.Parameter(torch.zeros(len(self.ops)))
    
    def forward(self, x):
        # Softmax over architecture weights
        weights = F.softmax(self.alpha, dim=0)
        
        # Weighted sum of all operations
        out = sum(w * op(x) for w, op in zip(weights, self.ops))
        return out
    
    def get_selected_op(self):
        """Return the operation with highest weight."""
        idx = self.alpha.argmax().item()
        return self.op_names[idx], F.softmax(self.alpha, dim=0)[idx].item()

# Demo
mixed_op = MixedOperation(16, 16)
x = torch.randn(1, 16, 32, 32)
out = mixed_op(x)

print('üìä MIXED OPERATION')
print('=' * 50)
print(f'Input shape: {x.shape}')
print(f'Output shape: {out.shape}')
print(f'\nArchitecture weights (before training):')
weights = F.softmax(mixed_op.alpha, dim=0)
for name, w in zip(mixed_op.op_names, weights):
    print(f'  {name:15}: {w.item():.3f}')
print(f'\nüí° Initially uniform - will change during training!')

## Part 3: Building a DARTS Supernet

In [None]:
class DARTSCell(nn.Module):
    """A cell with multiple mixed operations."""
    def __init__(self, channels):
        super().__init__()
        
        # Each cell has multiple edges, each with a mixed operation
        self.edges = nn.ModuleList([
            MixedOperation(channels, channels),
            MixedOperation(channels, channels),
            MixedOperation(channels, channels),
        ])
    
    def forward(self, x):
        # Simple sequential for demo (real DARTS has DAG structure)
        for edge in self.edges:
            x = x + edge(x)  # Residual connection
        return x

class DARTSSupernet(nn.Module):
    """Supernet that contains all possible architectures."""
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        
        # Searchable cells
        self.cells = nn.ModuleList([
            DARTSCell(16),
            DARTSCell(16),
        ])
        
        # Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, num_classes)
    
    def forward(self, x):
        x = self.stem(x)
        for cell in self.cells:
            x = cell(x)
        x = self.gap(x).flatten(1)
        return self.fc(x)
    
    def get_architecture_params(self):
        """Return all architecture parameters (alphas)."""
        arch_params = []
        for cell in self.cells:
            for edge in cell.edges:
                arch_params.append(edge.alpha)
        return arch_params
    
    def get_weight_params(self):
        """Return all weight parameters (excluding alphas)."""
        weight_params = []
        for name, param in self.named_parameters():
            if 'alpha' not in name:
                weight_params.append(param)
        return weight_params
    
    def print_architecture(self):
        """Print the current architecture."""
        print('\nüèóÔ∏è Current Architecture:')
        for i, cell in enumerate(self.cells):
            print(f'\nCell {i}:')
            for j, edge in enumerate(cell.edges):
                op, prob = edge.get_selected_op()
                print(f'  Edge {j}: {op} ({prob:.2%})')

# Create supernet
supernet = DARTSSupernet()
print('üìä DARTS SUPERNET')
print('=' * 50)

# Count parameters
arch_params = sum(p.numel() for p in supernet.get_architecture_params())
weight_params = sum(p.numel() for p in supernet.get_weight_params())

print(f'Architecture parameters: {arch_params}')
print(f'Weight parameters: {weight_params:,}')

supernet.print_architecture()

## Part 4: Bi-level Optimization

In [None]:
# Create synthetic data
def create_data(n_samples=500, img_size=32):
    X = torch.randn(n_samples, 3, img_size, img_size)
    y = torch.randint(0, 10, (n_samples,))
    return X, y

X_train, y_train = create_data(500)
X_val, y_val = create_data(200)

def train_darts(supernet, X_train, y_train, X_val, y_val, epochs=20):
    """
    DARTS bi-level optimization:
    1. Update weights w on training data
    2. Update architecture Œ± on validation data
    """
    # Two optimizers
    weight_optimizer = torch.optim.SGD(supernet.get_weight_params(), lr=0.01, momentum=0.9)
    arch_optimizer = torch.optim.Adam(supernet.get_architecture_params(), lr=0.001)
    
    criterion = nn.CrossEntropyLoss()
    
    history = {'train_loss': [], 'val_loss': [], 'arch_entropy': []}
    
    for epoch in range(epochs):
        supernet.train()
        
        # Step 1: Update weights on training data
        weight_optimizer.zero_grad()
        train_output = supernet(X_train)
        train_loss = criterion(train_output, y_train)
        train_loss.backward()
        weight_optimizer.step()
        
        # Step 2: Update architecture on validation data
        arch_optimizer.zero_grad()
        val_output = supernet(X_val)
        val_loss = criterion(val_output, y_val)
        val_loss.backward()
        arch_optimizer.step()
        
        # Track architecture entropy (how decisive the choices are)
        entropy = 0
        for cell in supernet.cells:
            for edge in cell.edges:
                probs = F.softmax(edge.alpha, dim=0)
                entropy -= (probs * (probs + 1e-8).log()).sum().item()
        
        history['train_loss'].append(train_loss.item())
        history['val_loss'].append(val_loss.item())
        history['arch_entropy'].append(entropy)
        
        if (epoch + 1) % 5 == 0:
            print(f'Epoch {epoch+1}: Train Loss={train_loss.item():.3f}, '
                  f'Val Loss={val_loss.item():.3f}, Entropy={entropy:.3f}')
    
    return history

print('üîÑ TRAINING DARTS SUPERNET')
print('=' * 50)
history = train_darts(supernet, X_train, y_train, X_val, y_val, epochs=30)

In [None]:
# Visualize training
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history['train_loss'], label='Train', color='#3b82f6')
axes[0].plot(history['val_loss'], label='Val', color='#ef4444')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss During Search')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['arch_entropy'], color='#22c55e')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Architecture Entropy')
axes[1].set_title('Architecture Becoming More Decisive')
axes[1].grid(True, alpha=0.3)

# Show final architecture weights
all_weights = []
labels = []
for i, cell in enumerate(supernet.cells):
    for j, edge in enumerate(cell.edges):
        probs = F.softmax(edge.alpha, dim=0).detach().numpy()
        all_weights.append(probs)
        labels.append(f'C{i}E{j}')

all_weights = np.array(all_weights)
im = axes[2].imshow(all_weights.T, cmap='YlOrRd', aspect='auto')
axes[2].set_xlabel('Edge')
axes[2].set_ylabel('Operation')
axes[2].set_xticks(range(len(labels)))
axes[2].set_xticklabels(labels)
axes[2].set_yticks(range(5))
axes[2].set_yticklabels(supernet.cells[0].edges[0].op_names)
axes[2].set_title('Architecture Weights (Brighter = Higher)')
plt.colorbar(im, ax=axes[2])

plt.tight_layout()
plt.show()

# Print final architecture
supernet.print_architecture()

## Part 5: Deriving Final Architecture

In [None]:
def derive_architecture(supernet):
    """
    Derive discrete architecture from trained supernet.
    Select the operation with highest weight at each edge.
    """
    architecture = []
    
    for i, cell in enumerate(supernet.cells):
        cell_arch = []
        for j, edge in enumerate(cell.edges):
            op, prob = edge.get_selected_op()
            cell_arch.append({
                'operation': op,
                'probability': prob
            })
        architecture.append(cell_arch)
    
    return architecture

# Derive architecture
final_arch = derive_architecture(supernet)

print('üèÜ FINAL DISCOVERED ARCHITECTURE')
print('=' * 50)

for i, cell in enumerate(final_arch):
    print(f'\nCell {i}:')
    for j, edge in enumerate(cell):
        print(f'  Edge {j}: {edge["operation"]} (confidence: {edge["probability"]:.1%})')

# Create discrete model from discovered architecture
class DiscoveredNet(nn.Module):
    """Discrete model based on discovered architecture."""
    def __init__(self, architecture, channels=16, num_classes=10):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        
        # Build cells based on discovered architecture
        self.cells = nn.ModuleList()
        for cell_arch in architecture:
            cell_ops = []
            for edge in cell_arch:
                op = self._make_op(edge['operation'], channels)
                cell_ops.append(op)
            self.cells.append(nn.ModuleList(cell_ops))
        
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(channels, num_classes)
    
    def _make_op(self, name, channels):
        if name == 'conv3x3':
            return nn.Sequential(
                nn.Conv2d(channels, channels, 3, padding=1, bias=False),
                nn.BatchNorm2d(channels), nn.ReLU())
        elif name == 'conv5x5':
            return nn.Sequential(
                nn.Conv2d(channels, channels, 5, padding=2, bias=False),
                nn.BatchNorm2d(channels), nn.ReLU())
        elif name == 'sep_conv3x3':
            return nn.Sequential(
                nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False),
                nn.Conv2d(channels, channels, 1, bias=False),
                nn.BatchNorm2d(channels), nn.ReLU())
        elif name == 'max_pool':
            return nn.MaxPool2d(3, stride=1, padding=1)
        else:  # skip
            return nn.Identity()
    
    def forward(self, x):
        x = self.stem(x)
        for cell_ops in self.cells:
            for op in cell_ops:
                x = x + op(x)
        x = self.gap(x).flatten(1)
        return self.fc(x)

# Create and compare
discovered_net = DiscoveredNet(final_arch)

supernet_params = sum(p.numel() for p in supernet.parameters())
discovered_params = sum(p.numel() for p in discovered_net.parameters())

print(f'\nüìä MODEL COMPARISON')
print(f'Supernet parameters: {supernet_params:,}')
print(f'Discovered net parameters: {discovered_params:,}')
print(f'Reduction: {supernet_params/discovered_params:.1f}x smaller!')

In [None]:
print('üéØ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. NAS automates architecture design')
print('\n2. Search space can have billions of architectures')
print('\n3. DARTS makes search differentiable using softmax')
print('\n4. Bi-level optimization: weights on train, arch on val')
print('\n5. Architecture weights converge to discrete choices')
print('\n6. Final model is much smaller than supernet')
print('\n7. Discovered architectures often beat hand-designed ones!')
print('\n' + '=' * 60)
print('\nüìö Next: Hardware-Aware NAS!')