# Recipe 3: ENAS with Policy Gradient (PPO)

This notebook implements ENAS using Proximal Policy Optimization (PPO) for the controller update.

## Imports and Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import Network

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define constants
NUM_OPS = 5
NUM_NODES = 4

## PPO Controller Implementation

In [None]:
class PPOController(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.lstm = nn.LSTMCell(NUM_OPS, hidden_size)
        self.actor = nn.Linear(hidden_size, NUM_OPS)
        self.critic = nn.Linear(hidden_size, 1)

    def forward(self, num_cells):
        h, c = torch.zeros(1, self.lstm.hidden_size), torch.zeros(1, self.lstm.hidden_size)
        actions = []
        log_probs = []
        values = []
        for _ in range(num_cells):
            cell_actions = []
            for _ in range(NUM_NODES * 3):
                x = torch.zeros(1, NUM_OPS)
                h, c = self.lstm(x, (h, c))
                logits = self.actor(h)
                value = self.critic(h)
                probs = torch.softmax(logits, dim=-1)
                action = torch.multinomial(probs, 1).item()
                cell_actions.append(action)
                log_probs.append(torch.log(probs[0, action]))
                values.append(value)
            actions.append(cell_actions)
        return actions, torch.stack(log_probs), torch.cat(values)

## PPO Update Function

In [None]:
def ppo_update(ppo_controller, old_log_probs, old_values, rewards, actions, epsilon=0.2):
    new_actions, new_log_probs, new_values = ppo_controller(len(actions))
    
    ratios = torch.exp(new_log_probs - old_log_probs)
    advantages = rewards - old_values.detach()
    
    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1 - epsilon, 1 + epsilon) * advantages
    
    actor_loss = -torch.min(surr1, surr2).mean()
    critic_loss = nn.MSELoss()(new_values, rewards)
    
    loss = actor_loss + 0.5 * critic_loss
    return loss

## Training Function

In [None]:
def train_enas_ppo(network, ppo_controller, train_data, val_data, num_epochs):
    network_optim = optim.Adam(network.parameters(), lr=0.01)
    ppo_optim = optim.Adam(ppo_controller.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        # Train shared parameters
        network.train()
        for x, y in train_data:
            actions, _, _ = ppo_controller(len(network.cells))
            network_optim.zero_grad()
            loss = nn.CrossEntropyLoss()(network(x, actions), y)
            loss.backward()
            network_optim.step()

        # Evaluate architectures
        network.eval()
        rewards = []
        old_log_probs_list = []
        old_values_list = []
        actions_list = []
        for _ in range(10):  # Sample 10 architectures
            actions, log_probs, values = ppo_controller(len(network.cells))
            with torch.no_grad():
                acc = evaluate(network, actions, val_data)
            rewards.append(acc)
            old_log_probs_list.append(log_probs)
            old_values_list.append(values)
            actions_list.append(actions)

        # Update controller using PPO
        rewards = torch.tensor(rewards)
        old_log_probs = torch.cat(old_log_probs_list)
        old_values = torch.cat(old_values_list)
        
        for _ in range(5):  # PPO update steps
            ppo_optim.zero_grad()
            loss = ppo_update(ppo_controller, old_log_probs, old_values, rewards, actions_list)
            loss.backward()
            ppo_optim.step()

        print(f"Epoch {epoch}, Avg Reward: {rewards.mean().item():.4f}")

def evaluate(model, actions, data):
    correct = 0
    total = 0
    for x, y in data:
        with torch.no_grad():
            outputs = model(x, actions)
            _, predicted = outputs.max(1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    return correct / total

## Usage Example

In [None]:
# Assuming you have your data loaded as train_data and val_data
network = Network(16, 10, 8)  # 16 channels, 10 classes, 8 cells
ppo_controller = PPOController(100)
# Uncomment the following line to train
# train_enas_ppo(network, ppo_controller, train_data, val_data, num_epochs=50)