## Imports

In [1]:
!pip install nltk
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np
import functools

from nltk.translate.bleu_score import sentence_bleu
from nltk.metrics.distance import jaccard_distance
from flax.training import train_state



## Actor-Critic Architecture

In [2]:
class ActorCritic(nn.Module):
    action_dim: int

    @nn.compact
    def __call__(self, state):
      x = nn.Dense(64)(state)
      x = nn.gelu(x)
      x = nn.Dense(64)(x)
      x = nn.gelu(x)

      logits = nn.Dense(self.action_dim)(x)
      value = nn.Dense(1)(x)

      return logits, value

## Sample Actions

In [3]:
def sample_actions(rng, logits, random = True):
    log_probs = jax.nn.log_softmax(logits)
    action = chosen_action = jax.random.categorical(rng, logits)
    if random:
      if jax.random.uniform(rng, shape=(1,), minval = 0, maxval = 1)[0] > 0.5:
        action = jax.random.randint(rng, shape=(1,), minval = 0, maxval = len(logits[0]))
      else:
        action = chosen_action
    else:
      action = chosen_action
    return action, log_probs

## Compute Advantages and Reward-To-Go

In [4]:
def compute_advantages(rewards, values, dones, last_value, gamma):
    returns = []
    discounted_reward = 0
    for i in reversed(range(len(rewards))):
        discounted_reward = rewards[i] + gamma * discounted_reward * dones[i]
        returns.insert(0, discounted_reward)
    returns = np.array(returns)
    advantages = returns - values
    return advantages, returns


## Create Train State

In [5]:
def create_train_state(rng, model, input_dim, action_dim, learning_rate):
    dummy_input = jnp.ones((1, input_dim))
    params = model.init(rng, dummy_input)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn = model.apply,
        params = params,
        tx = tx
    )

## Compute Loss

In [6]:
@functools.partial(jax.jit, static_argnums = (6, 7))
def train_step(
    state,
    states,
    actions,
    old_log_probs,
    advantages,
    returns,
    clip_eps,
    max_grad_norm
):
    def loss_fn(params):
        mean, values = state.apply_fn(params, states)
        log_probs = jax.nn.log_softmax(mean)

        ratio = jnp.exp(log_probs - old_log_probs)
        clipped_ratio = jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps)
        loss1 = ratio * advantages
        loss2 = clipped_ratio * advantages
        policy_loss = -jnp.mean(jnp.minimum(loss1, loss2))

        value_loss = jnp.mean((values.squeeze() - returns) ** 2)

        return policy_loss + value_loss

    grad_fn = jax.value_and_grad(loss_fn)
    _, grads = grad_fn(state.params)
    # grads, _ = optax.clip_by_global_norm(grads, max_grad_norm)
    new_state = state.apply_gradients(grads = grads)
    return new_state

## Update Function

In [7]:
def update_ppo(state, obs, batch_size, num_minibatches, clip_eps, max_grad_norm):
    indices = jnp.arange(len(obs["states"]))
    indices = jax.random.permutation(jax.random.PRNGKey(42), indices)

    for _ in range(num_minibatches):
        for i in range(0, len(indices), batch_size):
            mb_indices = indices[i: i + batch_size]
            mb_states = obs["states"][mb_indices]
            mb_actions = obs["actions"][mb_indices]
            mb_old_logprobs = obs["log_probs"][mb_indices]
            mb_advantages = obs["advantages"][mb_indices]
            mb_returns = obs["returns"][mb_indices]

            mb_advantages = (mb_advantages - jnp.mean(mb_advantages)) / jnp.std(mb_advantages) + 1e-8
            state = train_step(state, mb_states, mb_actions, mb_old_logprobs, mb_advantages, mb_returns, clip_eps, max_grad_norm)
    return state

## Collect Trajectories

In [8]:
def collect_trajectories(env, state, rng, steps_per_epoch, gamma):
    buffer = {
        "states": [],
        "actions": [],
        "rewards": [],
        "dones": [],
        "values": [],
        "log_probs": []
    }
    obs = env.reset()
    done = False
    episode_return = 0
    episode_length = 0
    for _ in range(steps_per_epoch):
        rng, actions_rng = jax.random.split(rng)
        mean, value = state.apply_fn(state.params, jnp.array([obs]))
        action, log_prob = sample_actions(actions_rng, mean, random = True)

        next_obs, reward, done, _ = env.step(np.array(action))

        buffer["states"].append(obs)
        buffer["actions"].append(action)
        buffer["rewards"].append(reward)
        buffer["dones"].append(done)
        buffer["values"].append(value[0, 0])
        buffer["log_probs"].append(log_prob.squeeze(0))

        episode_return += reward
        episode_length += 1

        if done:
            obs = env.reset()
            done = False
            episode_return = 0
            episode_length = 0
        else:
            obs = next_obs

    observations = {}
    for key, val in buffer.items():
        observations[key] = jnp.array(val)

    last_value = state.apply_fn(state.params, jnp.array([obs]))[1][0, 0]
    advantages, returns = compute_advantages(
        observations["rewards"],
        observations["values"],
        observations["dones"],
        last_value,
        gamma
    )
    observations["advantages"] = advantages[:, jnp.newaxis]
    observations["returns"] = returns[:, jnp.newaxis]
    return observations, rng


## Train

In [9]:
def train(
    env,
    seed,
    num_epochs,
    steps_per_epoch,
    batch_size,
    num_minibatches,
    gamma,
    clip_eps,
    learning_rate,
    max_grad_norm
):
    dummy_state = env.reset()
    input_dim = dummy_state.shape[0]
    action_dim = env.action_dim()

    rng = jax.random.PRNGKey(seed)
    rng, actor_rng = jax.random.split(rng)
    actor_critic = ActorCritic(action_dim)
    state = create_train_state(rng, actor_critic, input_dim, action_dim, learning_rate)

    for epoch in range(num_epochs):
        obs, rng = collect_trajectories(
            env, state, rng, steps_per_epoch, gamma
        )
        state = update_ppo(state, obs, batch_size, num_minibatches, clip_eps, max_grad_norm)

        if epoch % 10 == 0:
            eval_returns, sample = evaluate_policy(env, state, rng, 1)
            print(f"Eval return {eval_returns}, Sample {sample}")
    return state

## Evaluate

In [10]:
def evaluate_policy(env, state, rng, evals):
    returns = []
    samples = []
    for _ in range(evals):
        obs = env.reset()
        done = False
        episode_return = 0
        while not done:
            mean, _ = state.apply_fn(state.params, jnp.array([obs]))
            action = np.array(sample_actions(rng, mean, random = False)[0])
            obs, reward, done, _ = env.step(action)
            episode_return += reward
        sample = env.get_sample()
        returns.append(episode_return)
        samples.append(sample)
    return np.mean(np.array(returns)), samples[-1]


## Environment

In [11]:
class Environment:

    def __init__(self):
      self.pros = "Shall I compare thee to a summer's day ? , Thou art more lovely and more temperate"
      self.tokens = self.pros.split(" ")
      self.pro_length = 17
      self.words_to_actions = {
          "I": 1,
          "compare": 2,
          "thee": 3,
          "to": 4,
          "a": 5,
          "summer's": 6,
          "day": 7,
          "thou": 8,
          "art": 9,
          "more": 10,
          "lovely": 11,
          "and": 12,
          "temperate": 13,
          ",": 14,
          "?": 15,
          "Shall": 16
      }
      self.actions_to_words = {v:k for k, v in self.words_to_actions.items()}
      self.state = np.zeros((self.pro_length,))

    def action_dim(self):
      return len(self.words_to_actions)

    def get_sample(self):
      return self.current_sample

    def reset(self):
      self.state = np.zeros((self.pro_length,))
      self.counter = 0
      self.current_sample = ""
      return self.state

    def get_reward(self):
      prediction = set(self.current_sample)
      ground_truth = set("".join(self.tokens[:len(prediction)]))
      score = 1 - jaccard_distance(prediction, ground_truth)
      return score

    def step(self, action):
      action = action[0]
      word = self.actions_to_words[action + 1]
      self.state[self.counter] = action
      self.current_sample = self.current_sample + " " + word
      reward = self.get_reward()
      self.counter += 1
      if self.counter != self.pro_length:
        done = False
      else:
        done = True
      return self.state, reward, done, None


## Main Runner

In [12]:
if __name__ == "__main__":
    env = Environment()
    state = train(
        env,
        42,
        100,
        1024,
        64,
        64,
        0.99,
        0.2,
        3e-4,
        0.5
    )

Eval return 6.668478260869566, Sample  temperate temperate temperate temperate temperate and and and and and lovely and lovely lovely lovely lovely lovely
Eval return 6.464285714285717, Sample  summer's summer's and and and and and and and and and and and and and and and
Eval return 2.8333333333333317, Sample  I I I I I I I I I I I I I I I I I
Eval return 1.0708333333333333, Sample  day day and and and and and and and and and and and and and and and
Eval return 1.133333333333333, Sample  and and and and and and and and and and and and and and and and and
Eval return 4.77142857142857, Sample  lovely and and and and and and and and and and and and and and and and
Eval return 2.8333333333333317, Sample  I I I I I I I I I I I I I I I I I
Eval return 2.8333333333333317, Sample  I I I I I I I I I I I I I I I I I
Eval return 2.8333333333333317, Sample  I I I I I I I I I I I I I I I I I
Eval return 2.8333333333333317, Sample  I I I I I I I I I I I I I I I I I


## AI Generated Solution (for reference)

In [13]:
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
from flax.training.train_state import TrainState
import numpy as np
from typing import Any, Dict, List, Tuple
import random

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
key = jax.random.PRNGKey(SEED)

# Shakespeare poem to overfit
SHAKESPEARE_POEM = """
Shall I compare thee to a summer's day?
Thou art more lovely and more temperate:
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date;
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimm'd;
And every fair from fair sometime declines,
By chance or nature's changing course untrimm'd;
But thy eternal summer shall not fade,
Nor lose possession of that fair thou ow'st;
Nor shall death brag thou wander'st in his shade,
When in eternal lines to time thou grow'st:
So long as men can breathe or eyes can see,
So long lives this, and this gives life to thee.
"""

# Tokenization helpers
def create_vocabulary(text):
    # Create a simple character-level vocabulary
    chars = sorted(list(set(text)))
    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}
    vocab_size = len(chars)
    return char_to_idx, idx_to_char, vocab_size

def encode_text(text, char_to_idx):
    return [char_to_idx[ch] for ch in text]

def decode_tokens(tokens, idx_to_char):
    return ''.join([idx_to_char[idx] for idx in tokens])

# Create vocabulary from the poem
char_to_idx, idx_to_char, vocab_size = create_vocabulary(SHAKESPEARE_POEM)
encoded_poem = encode_text(SHAKESPEARE_POEM, char_to_idx)

# Hyperparameters
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
CONTEXT_LENGTH = 10
BATCH_SIZE = 16
LEARNING_RATE = 3e-4
PPO_EPOCHS = 4
NUM_STEPS = 100
CLIP_RATIO = 0.2
VALUE_COEF = 0.5
ENTROPY_COEF = 0.01
GAE_LAMBDA = 0.95
GAMMA = 0.99

# Define the Actor-Critic model with Flax
class ActorCritic(nn.Module):
    vocab_size: int
    emb_dim: int
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        # Shared embedding layer
        emb = nn.Embed(self.vocab_size, self.emb_dim)(x)

        # LSTM for sequential processing
        lstm_out = nn.RNN(nn.LSTMCell(self.hidden_dim))(emb)

        # Actor head (policy)
        logits = nn.Dense(self.vocab_size)(lstm_out)

        # Critic head (value function)
        value = nn.Dense(1)(lstm_out)

        return logits, value

# Initialize model and optimizer
def init_model(key, vocab_size):
    model = ActorCritic(vocab_size=vocab_size, emb_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM)

    # Create a sample input
    dummy_input = jnp.zeros((1, CONTEXT_LENGTH), dtype=jnp.int32)

    # Initialize parameters
    params = model.init(key, dummy_input)

    # Create optimizer
    tx = optax.adam(learning_rate=LEARNING_RATE)

    # Create train state
    return TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    )

# Generate data for training
def get_batch(encoded_text, batch_size, context_length):
    # Choose random starting indices
    max_idx = len(encoded_text) - context_length - 1
    indices = np.random.randint(0, max_idx, size=batch_size)

    # Create x and y batches
    x_batch = np.array([encoded_text[idx:idx+context_length] for idx in indices])
    y_batch = np.array([encoded_text[idx+1:idx+context_length+1] for idx in indices])

    return x_batch, y_batch

# PPO loss function
@functools.partial(jax.jit)
def ppo_loss(state, states, actions, rewards, old_log_probs, advantages):

    def loss_fn(params):
        # Forward pass with current parameters
        logits, values = state.apply_fn(params, states)

        # Calculate policy loss (PPO-Clip)
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        action_log_probs = jnp.take_along_axis(log_probs, actions[:, :, None], axis=-1).squeeze(-1)

        # Compute ratio and clipped ratio
        ratio = jnp.exp(action_log_probs - old_log_probs)
        clipped_ratio = jnp.clip(ratio, 1 - CLIP_RATIO, 1 + CLIP_RATIO)

        # Policy loss
        policy_loss = -jnp.mean(jnp.minimum(ratio * advantages, clipped_ratio * advantages))

        # Value loss
        value_loss = jnp.mean(jnp.square(values.squeeze(-1) - rewards))

        # Entropy for exploration
        entropy = jnp.mean(-jnp.sum(jnp.exp(log_probs) * log_probs, axis=-1))

        # Total loss
        total_loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

        return total_loss

    # Update model with PPO loss
    grad_fn = jax.value_and_grad(loss_fn, allow_int = True)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state


# Compute GAE (Generalized Advantage Estimation)
def compute_gae(rewards, values, next_values, dones, gamma=GAMMA, lam=GAE_LAMBDA):
    advantages = []
    gae = 0
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_value = next_values[t]
        else:
            next_value = values[t + 1]

        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages.insert(0, gae)

    advantages = jnp.array(advantages)
    returns = advantages + jnp.array(values)

    # Normalize advantages
    advantages = (advantages - jnp.mean(advantages)) / (jnp.std(advantages) + 1e-8)

    return advantages, returns

# Training function
def train_ppo():
    # Initialize model
    key, subkey = jax.random.split(jax.random.PRNGKey(SEED))
    state = init_model(subkey, vocab_size)

    for step in range(NUM_STEPS):
        # Sample batch
        states, next_states = get_batch(encoded_poem, BATCH_SIZE, CONTEXT_LENGTH)
        states = jnp.array(states)
        actions = jnp.array(next_states)

        # Get current policy and value predictions
        logits, values = state.apply_fn(state.params, states)

        # Sample actions and compute log probabilities
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        action_log_probs = jnp.take_along_axis(log_probs, actions[:, :, None], axis=-1).squeeze(-1)

        # Simple reward: match the target poem
        rewards = jnp.sum(jnp.equal(jnp.argmax(logits, axis=-1), actions), axis=-1) / CONTEXT_LENGTH

        # Compute advantages and returns
        dones = jnp.zeros_like(rewards)  # No episode termination in this task
        next_values = values  # Simple approximation
        advantages, returns = compute_gae(rewards, values.squeeze(-1), next_values.squeeze(-1), dones)

        # PPO update loop
        for _ in range(PPO_EPOCHS):
            # Update model with PPO loss
            state = ppo_loss(state, states, actions, returns, action_log_probs, advantages)

        # Print progress and generated text occasionally
        if step % 10 == 0:
            # Generate text from the model
            generated_text = generate_text(state, 100)
            print(f"Step {step}/{NUM_STEPS}, Generated text:")
            print(generated_text)
            print("-" * 40)

    return state

# Text generation function
def generate_text(state, length, temperature=0.8, prompt=None):
    if prompt is None:
        # Start with the first few characters from the poem
        prompt = SHAKESPEARE_POEM[:10]
        tokens = encode_text(prompt, char_to_idx)
    else:
        tokens = encode_text(prompt, char_to_idx)

    # Pad the context if needed
    if len(tokens) < CONTEXT_LENGTH:
        tokens = [0] * (CONTEXT_LENGTH - len(tokens)) + tokens
    elif len(tokens) > CONTEXT_LENGTH:
        tokens = tokens[-CONTEXT_LENGTH:]

    context = jnp.array([tokens])
    generated = list(tokens)

    for _ in range(length):
        # Get logits from the model
        logits, _ = state.apply_fn(state.params, context)

        # Apply temperature
        logits = logits / temperature

        # Sample from the distribution
        next_token = jax.random.categorical(jax.random.PRNGKey(random.randint(0, 10000)),
                                           logits[0, -1, :])

        # Append to generated text
        generated.append(int(next_token))

        # Update context
        context = jnp.roll(context, -1, axis=1)
        context = context.at[:, -1].set(int(next_token))

    # Decode and return
    return decode_tokens(generated, idx_to_char)

# Main execution
if __name__ == "__main__":
    print("Training PPO agent to overfit Shakespeare poem...")
    final_state = train_ppo()

    print("\nFinal generated text:")
    print(generate_text(final_state, 200, temperature=0.5))

    # Evaluate overfitting performance
    print("\nEvaluating overfitting performance...")
    states, next_states = get_batch(encoded_poem, 1, CONTEXT_LENGTH)
    logits, _ = final_state.apply_fn(final_state.params, jnp.array(states))
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(jnp.equal(predictions, jnp.array(next_states)))
    print(f"Character prediction accuracy: {accuracy * 100:.2f}%")

Training PPO agent to overfit Shakespeare poem...
Step 0/100, Generated text:

Shall I cshthkRR
BR' xRRtoW:sf :xNcu' tktb,pkxebuRtasne:?vgIsppmoIm,AAi.nA:ecaBwiIulnx
,:iy  mxhT,pW.tNfbpMb:
----------------------------------------
Step 10/100, Generated text:

Shall I cf?.bmwm;
.el'bRrgys:N?.:bh:omgRhNfrd;svBf?.v?Mmri.f. lN ,rd?mxnlNfxB,iSTB;;pnkn;Rug:I?ehxArNfyrg;R A
----------------------------------------
Step 20/100, Generated text:

Shall I cWMWTMBTMR.vrI?MsMT
RwWT;A:bMbfNbIvcT.xn'IW,AMtM?wMSwM?B.fmwfivvTh,wW?;xn.p?:rWuS?WRNwBNbwMMNd:s:?:;A
----------------------------------------
Step 30/100, Generated text:

Shall I cMTTTBTBBBmmg.S?emmyThwAi whcoWWBB?BdI;aesNiavxBBBRBSThn pdfBATTT?WIhfNRTlr:ITBkRTSBmmdpmyig;nbfyMBBA
----------------------------------------
Step 40/100, Generated text:

Shall I cTBTTTTBTTB::
ts:gAM?W?ABTW?BBT?B
WMAf BTBWSTB?TBvN.BWAlB?'TAWxBWBTNB
WBe;WNWWBBTB?TBN:onANWBxTxATTNx
----------------------------------------
Step 50/100, Generated text:

Shall I cBBBBATB