<a href="https://colab.research.google.com/github/leondotle/research/blob/main/Reinforcement_Learning_for_Optimal_Folding_Patterns.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required packages
!pip install jax jaxlib gymnax flax optax chex tqdm

In [None]:
# Install required packages
# !pip install jax jaxlib gymnax flax optax chex tqdm matplotlib

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training.train_state import TrainState
import gymnax
from gymnax.environments import environment, spaces
import numpy as np
from typing import Tuple, Dict, NamedTuple
import chex
from tqdm import tqdm
import matplotlib.pyplot as plt

print(f"JAX devices: {jax.devices()}")

# Simplified Environment Info Structure
class EnvInfo(NamedTuple):
    height: float
    area: float
    height_error: float
    area_error: float
    fold_angle: float
    crease_angle: float
    solved: bool

# Streamlined Origami Environment
class OrigamiEnv(environment.Environment):
    def __init__(self, target_height=1.0, target_area=15.0):
        super().__init__()
        self.target_height = target_height
        self.target_area = target_area

        # Grid parameters for origami geometry
        self.a, self.b, self.m, self.n = 1.0, 1.5, 8, 6
        i, j = jnp.arange(self.m + 1), jnp.arange(self.n + 1)
        self.I_grid, self.J_grid = jnp.meshgrid(i, j, indexing="ij")

    @property
    def default_state(self) -> chex.Array:
        return jnp.array([jnp.pi/4, jnp.pi/4], dtype=jnp.float32)

    def step_env(self, key: chex.PRNGKey, state: chex.Array, action: chex.Array) -> Tuple[chex.Array, float, bool, EnvInfo]:
        # Clip actions to valid range
        action = jnp.clip(action, 0.01, jnp.pi/2)

        # Compute geometry
        w, l, h = self._compute_geometry(action[0], action[1])
        area = w * l

        # Compute reward
        height_error = jnp.abs(h - self.target_height)
        area_error = jnp.abs(area - self.target_area)

        reward = (1.0 / (1.0 + 2.0 * height_error) +
                 1.0 / (1.0 + 0.1 * area_error) -
                 0.01 * jnp.sum(jnp.square(action - state)))

        done = jnp.logical_and(height_error < 0.2, area_error < 2.0)
        reward = jnp.where(done, reward + 3.0, reward)

        info = EnvInfo(
            height=h, area=area, height_error=height_error, area_error=area_error,
            fold_angle=action[0], crease_angle=action[1], solved=done
        )

        return action, reward, done, info

    def reset_env(self, key: chex.PRNGKey) -> Tuple[chex.Array, Dict]:
        noise = jax.random.uniform(key, (2,), minval=-0.1, maxval=0.1)
        state = jnp.clip(self.default_state + noise, 0.01, jnp.pi/2)
        w, l, h = self._compute_geometry(state[0], state[1])
        return state, {"initial_height": h, "initial_area": w * l}

    def _compute_geometry(self, fold_angle, crease_angle):
        sin_a, cos_a = jnp.sin(crease_angle), jnp.cos(crease_angle)
        sin_t = jnp.sin(fold_angle)

        gamma = jnp.arcsin(jnp.clip(sin_a * sin_t, -0.999, 0.999))

        x = self.I_grid * self.a * jnp.cos(gamma) + (self.J_grid % 2) * self.b * cos_a
        y = self.J_grid * self.b * sin_a
        z = (self.I_grid % 2) * self.a * sin_a * sin_t

        verts = jnp.stack([x.ravel(), y.ravel(), z.ravel()], axis=-1)
        return (jnp.maximum(jnp.ptp(verts[:, 0]), 0.001),
                jnp.maximum(jnp.ptp(verts[:, 1]), 0.001),
                jnp.maximum(jnp.ptp(verts[:, 2]), 0.0))

    @property
    def name(self) -> str: return "Origami-v1"
    @property
    def num_actions(self) -> int: return 2

    def action_space(self) -> spaces.Box:
        return spaces.Box(low=jnp.array([0.01, 0.01]), high=jnp.array([jnp.pi/2, jnp.pi/2]), shape=(2,))

    def observation_space(self) -> spaces.Box:
        return self.action_space()

# Simplified Actor-Critic Network
class ActorCritic(nn.Module):
    action_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)

        # Policy head
        mean = nn.sigmoid(nn.Dense(self.action_dim)(x)) * jnp.pi/2
        log_std = self.param('log_std', nn.initializers.constant(-1.0), (self.action_dim,))

        # Value head
        value = nn.Dense(1)(x).squeeze(-1)

        return mean, log_std, value

# Vectorized Environment Wrapper
class VectorizedEnv:
    def __init__(self, env, num_envs):
        self.env = env
        self.num_envs = num_envs
        self.reset_fn = jax.vmap(env.reset_env)
        self.step_fn = jax.vmap(env.step_env)

    def reset(self, key):
        keys = jax.random.split(key, self.num_envs)
        return self.reset_fn(keys)

    def step(self, key, states, actions):
        keys = jax.random.split(key, self.num_envs)
        return self.step_fn(keys, states, actions)

# PPO Helper Functions
def gaussian_log_prob(x, mean, log_std):
    """Calculate log probability of multivariate Gaussian with diagonal covariance."""
    std = jnp.exp(log_std)
    return -0.5 * jnp.sum(
        jnp.square((x - mean) / std) + 2 * log_std + jnp.log(2 * jnp.pi),
        axis=-1
    )

def ppo_clipped_loss(new_log_probs, old_log_probs, advantages, clip_param=0.2):
    """PPO clipped surrogate loss."""
    ratio = jnp.exp(new_log_probs - old_log_probs)
    clipped_ratio = jnp.clip(ratio, 1 - clip_param, 1 + clip_param)
    surrogate_loss = jnp.minimum(ratio * advantages, clipped_ratio * advantages)
    return -jnp.mean(surrogate_loss)

def compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95):
    """Compute Generalized Advantage Estimation."""
    advantages = []
    gae = 0.0

    for t in reversed(range(rewards.shape[0])):
        if t == rewards.shape[0] - 1:
            next_non_terminal = 1.0 - dones[t]
            next_value = next_values
        else:
            next_non_terminal = 1.0 - dones[t]
            next_value = values[t + 1]

        delta = rewards[t] + gamma * next_value * next_non_terminal - values[t]
        gae = delta + gamma * lam * next_non_terminal * gae
        advantages.insert(0, gae)

    advantages = jnp.array(advantages)
    returns = advantages + values

    # Normalize advantages
    advantages = (advantages - jnp.mean(advantages)) / (jnp.std(advantages) + 1e-8)

    return advantages, returns

# Streamlined PPO Training
def train_ppo(env, num_envs=32, total_steps=50000, lr=2e-4):
    # Initialize
    key = jax.random.PRNGKey(42)
    model = ActorCritic(action_dim=2)

    # Create train state
    dummy_obs = jnp.zeros((2,))
    params = model.init(key, dummy_obs)

    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(lr)
    )

    train_state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

    # Create vectorized environment
    venv = VectorizedEnv(env, num_envs)

    # Training parameters
    rollout_len = 64
    num_updates = total_steps // (num_envs * rollout_len)

    # Action sampling function
    @jax.jit
    def sample_actions(params, obs, key):
        mean, log_std, values = jax.vmap(model.apply, (None, 0))(params, obs)
        std = jnp.exp(log_std)

        actions = mean + jax.random.normal(key, mean.shape) * std
        actions = jnp.clip(actions, 0.01, jnp.pi/2)

        log_probs = gaussian_log_prob(actions, mean, log_std)

        return actions, log_probs, values

    # PPO loss function
    def ppo_loss(params, obs, actions, old_log_probs, advantages, returns):
        mean, log_std, values = jax.vmap(model.apply, (None, 0))(params, obs)

        # Calculate new log probabilities
        log_probs = gaussian_log_prob(actions, mean, log_std)

        # PPO clipped surrogate loss
        policy_loss = ppo_clipped_loss(log_probs, old_log_probs, advantages)

        # Value loss (MSE)
        value_loss = 0.5 * jnp.mean(jnp.square(values - returns))

        # Entropy bonus
        entropy = jnp.mean(jnp.sum(log_std + 0.5 * jnp.log(2 * jnp.pi * jnp.e), axis=-1))

        total_loss = policy_loss + value_loss - 0.01 * entropy

        return total_loss, {
            'policy_loss': policy_loss,
            'value_loss': value_loss,
            'entropy': entropy
        }

    # Update function
    @jax.jit
    def update(train_state, batch):
        grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
        (loss, info), grads = grad_fn(train_state.params, *batch)
        return train_state.apply_gradients(grads=grads), loss, info

    # Training loop
    key, reset_key = jax.random.split(key)
    obs, _ = venv.reset(reset_key)

    metrics = {'rewards': [], 'losses': [], 'heights': [], 'areas': []}

    for update_idx in tqdm(range(num_updates)):
        # Collect rollout
        rollout_obs, rollout_actions, rollout_rewards = [], [], []
        rollout_log_probs, rollout_values, rollout_dones = [], [], []
        rollout_heights, rollout_areas = [], []

        for step in range(rollout_len):
            key, action_key, step_key = jax.random.split(key, 3)

            actions, log_probs, values = sample_actions(train_state.params, obs, action_key)
            next_obs, rewards, dones, infos = venv.step(step_key, obs, actions)

            # Store rollout data
            rollout_obs.append(obs)
            rollout_actions.append(actions)
            rollout_rewards.append(rewards)
            rollout_log_probs.append(log_probs)
            rollout_values.append(values)
            rollout_dones.append(dones)
            rollout_heights.append(infos.height)
            rollout_areas.append(infos.area)

            obs = next_obs

            # Reset done environments
            if jnp.any(dones):
                key, reset_key = jax.random.split(key)
                reset_obs, _ = venv.reset(reset_key)
                obs = jnp.where(dones[:, None], reset_obs, obs)

        # Convert to arrays
        rollout_obs = jnp.array(rollout_obs)
        rollout_actions = jnp.array(rollout_actions)
        rollout_rewards = jnp.array(rollout_rewards)
        rollout_log_probs = jnp.array(rollout_log_probs)
        rollout_values = jnp.array(rollout_values)
        rollout_dones = jnp.array(rollout_dones)

        # Get final values for GAE
        _, _, next_values = jax.vmap(model.apply, (None, 0))(train_state.params, obs)

        # Compute advantages using GAE
        advantages, returns = compute_gae(rollout_rewards, rollout_values, next_values, rollout_dones)

        # Flatten for batch processing
        batch_size = rollout_len * num_envs
        batch_obs = rollout_obs.reshape(batch_size, -1)
        batch_actions = rollout_actions.reshape(batch_size, -1)
        batch_log_probs = rollout_log_probs.reshape(batch_size)
        batch_advantages = advantages.reshape(batch_size)
        batch_returns = returns.reshape(batch_size)

        # Multiple PPO updates
        for _ in range(4):
            train_state, loss, info = update(
                train_state,
                (batch_obs, batch_actions, batch_log_probs, batch_advantages, batch_returns)
            )

        # Record metrics
        metrics['rewards'].append(float(jnp.mean(rollout_rewards)))
        metrics['losses'].append(float(loss))
        metrics['heights'].append(float(jnp.mean(jnp.array(rollout_heights))))
        metrics['areas'].append(float(jnp.mean(jnp.array(rollout_areas))))

        if update_idx % 10 == 0:
            print(f"Update {update_idx}: Reward={metrics['rewards'][-1]:.3f}, "
                  f"Height={metrics['heights'][-1]:.3f}, Area={metrics['areas'][-1]:.3f}")

    return train_state, metrics

# Simplified testing function
def test_agent(env, train_state, num_episodes=5):
    key = jax.random.PRNGKey(0)
    best_reward = -float('inf')
    best_config = None

    for episode in range(num_episodes):
        key, reset_key = jax.random.split(key)
        obs, _ = env.reset_env(reset_key)
        episode_reward = 0
        step_count = 0

        while step_count < 50:  # Max steps per episode
            mean, _, _ = train_state.apply_fn(train_state.params, obs)
            action = jnp.clip(mean, 0.01, jnp.pi/2)

            key, step_key = jax.random.split(key)
            obs, reward, done, info = env.step_env(step_key, obs, action)
            episode_reward += reward
            step_count += 1

            if done:
                break

        if episode_reward > best_reward:
            best_reward = episode_reward
            best_config = {
                'fold_angle': float(info.fold_angle),
                'crease_angle': float(info.crease_angle),
                'height': float(info.height),
                'area': float(info.area),
                'reward': float(episode_reward)
            }

        print(f"Episode {episode+1}: Reward={episode_reward:.3f}, "
              f"Height={float(info.height):.3f}, Area={float(info.area):.3f}")

    return best_config

# Simplified plotting
def plot_metrics(metrics):
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    axes[0, 0].plot(metrics['rewards'])
    axes[0, 0].set_title('Rewards')
    axes[0, 0].set_xlabel('Update')
    axes[0, 0].set_ylabel('Average Reward')

    axes[0, 1].plot(metrics['losses'])
    axes[0, 1].set_title('Losses')
    axes[0, 1].set_xlabel('Update')
    axes[0, 1].set_ylabel('Loss')

    axes[1, 0].plot(metrics['heights'])
    axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Target')
    axes[1, 0].set_title('Heights')
    axes[1, 0].set_xlabel('Update')
    axes[1, 0].set_ylabel('Height')
    axes[1, 0].legend()

    axes[1, 1].plot(metrics['areas'])
    axes[1, 1].axhline(y=15.0, color='r', linestyle='--', label='Target')
    axes[1, 1].set_title('Areas')
    axes[1, 1].set_xlabel('Update')
    axes[1, 1].set_ylabel('Area')
    axes[1, 1].legend()

    plt.tight_layout()
    plt.show()

# Main execution
if __name__ == "__main__":
    # Create environment and train
    env = OrigamiEnv(target_height=1.0, target_area=15.0)
    print("Training agent...")
    train_state, metrics = train_ppo(env, num_envs=32, total_steps=25000)

    # Plot results
    plot_metrics(metrics)

    # Test agent
    print("\nTesting agent:")
    best_config = test_agent(env, train_state)

    print(f"\nBest configuration:")
    for key, value in best_config.items():
        print(f"{key}: {value:.4f}")