# Recipe 8: Uncertainty-Aware ENAS

This notebook implements uncertainty-aware ENAS, which considers the uncertainty in architecture evaluation.

## Imports and Setup

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define constants
NUM_OPS = 5
NUM_NODES = 4

## Uncertainty-Aware Controller

In [2]:
class UncertaintyAwareController(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.lstm = nn.LSTMCell(NUM_OPS, hidden_size)
        self.actor = nn.Linear(hidden_size, NUM_OPS)
        self.critic_mean = nn.Linear(hidden_size, 1)
        self.critic_std = nn.Linear(hidden_size, 1)

    def forward(self, num_cells):
        h, c = torch.zeros(1, self.lstm.hidden_size), torch.zeros(1, self.lstm.hidden_size)
        actions = []
        log_probs = []
        value_dists = []
        for _ in range(num_cells):
            cell_actions = []
            for _ in range(NUM_NODES * 3):
                x = torch.zeros(1, NUM_OPS)
                h, c = self.lstm(x, (h, c))
                logits = self.actor(h)
                probs = torch.softmax(logits, dim=-1)
                action = torch.multinomial(probs, 1).item()
                cell_actions.append(action)
                log_probs.append(torch.log(probs[0, action]))
                mean = self.critic_mean(h)
                std = torch.exp(self.critic_std(h))
                value_dists.append(dist.Normal(mean, std))
            actions.append(cell_actions)
        return actions, torch.stack(log_probs), value_dists

class Cell(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.nodes = nn.ModuleList([MixedOp(C) for _ in range(NUM_NODES)])

    def forward(self, x, actions):
        states = [x]
        for node, (op1, op2, combine) in zip(self.nodes, actions):
            s1 = node(states[op1], combine[0])
            s2 = node(states[op2], combine[1])
            states.append(s1 + s2)
        return states[-1]

class Network(nn.Module):
    def __init__(self, C, num_classes, num_cells):
        super().__init__()
        self.stem = nn.Conv2d(3, C, 3, padding=1)
        self.cells = nn.ModuleList([Cell(C) for _ in range(num_cells)])
        self.classifier = nn.Linear(C, num_classes)

    def forward(self, x, actions):
        x = self.stem(x)
        for cell, cell_action in zip(self.cells, actions):
            x = cell(x, cell_action)
        x = x.mean([2, 3])
        return self.classifier(x)

## Uncertainty-Aware Training Function

In [3]:
def train_uncertainty_aware_enas(network, controller, train_data, val_data, num_epochs):
    network_optim = optim.Adam(network.parameters(), lr=0.01)
    controller_optim = optim.Adam(controller.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        # Train shared parameters
        network.train()
        for x, y in train_data:
            actions, _, _ = controller(len(network.cells))
            network_optim.zero_grad()
            loss = nn.CrossEntropyLoss()(network(x, actions), y)
            loss.backward()
            network_optim.step()

        # Evaluate architectures
        network.eval()
        rewards = []
        log_probs_list = []
        value_dists_list = []
        for _ in range(10):  # Sample 10 architectures
            actions, log_probs, value_dists = controller(len(network.cells))
            with torch.no_grad():
                acc = evaluate(network, actions, val_data)
            rewards.append(acc)
            log_probs_list.append(log_probs)
            value_dists_list.append(value_dists)

        # Update controller
        rewards = torch.tensor(rewards)
        log_probs = torch.cat(log_probs_list)
        
        controller_optim.zero_grad()
        actor_loss = 0
        critic_loss = 0
        for r, lp, vd in zip(rewards, log_probs_list, value_dists_list):
            advantage = r - vd[0].mean
            actor_loss -= lp * advantage.detach()
            critic_loss -= vd[0].log_prob(r)
        
        loss = actor_loss.mean() + 0.5 * critic_loss.mean()
        loss.backward()
        controller_optim.step()

        print(f"Epoch {epoch}, Avg Reward: {rewards.mean().item():.4f}")

def evaluate(model, actions, data):
    correct = 0
    total = 0
    for x, y in data:
        with torch.no_grad():
            outputs = model(x, actions)
            _, predicted = outputs.max(1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    return correct / total

## Usage Example

In [None]:
# Assuming you have your data loaded as train_data and val_data
network = Network(16, 10, 8)  # 16 channels, 10 classes, 8 cells
controller = UncertaintyAwareController(100)
# Uncomment the following line to train
# train_uncertainty_aware_enas(network, controller, train_data, val_data, num_epochs=50)