In [43]:
import numpy as np

In [44]:
# a simple vocab
vocab = ['apple', 'banana', 'orange', 'grape', 'kiwi']

In [46]:
def softmax(x):
  exp_x = np.exp(x - np.max(x)) # for numerical stability, so large numbers work fine. the prob distrubtion stays the same
  return exp_x / exp_x.sum()

In [47]:
def sample_word(logits):
  probs = softmax(logits)
  word_idx = np.random.choice(len(vocab), p=probs)
  return word_idx, probs

In [48]:
def reward(word_idx): # assume reward is word starting with vowel
  '''for demo, alignmnet target is words that start with a vowel. return 1 if true, 0 otherwise.'''
  word = vocab[word_idx]
  return 1.0 if word[0].lower() in 'aeiou' else 0.0

In [49]:
def group_label(word_idx):
  word = vocab[word_idx]
  return "vowel" if word[0].lower() in 'aeiou' else 'consonant'

In [50]:
# hyperparams
num_iterations = 50
num_episodes = 20
learning_rate = 0.1

In [53]:
np.random.seed(42)
logits = np.random.randn(len(vocab))

for iteration in range(1, num_iterations + 1):
    trajectories = []
    for _ in range(num_episodes):
        word_idx, probs = sample_word(logits)
        r = reward(word_idx)
        grp = group_label(word_idx)
        trajectories.append((word_idx, r, grp, probs))

    group_rewards = {}
    group_counts = {}
    for word_idx, r, grp, probs in trajectories:
        group_rewards.setdefault(grp, 0.0)
        group_counts.setdefault(grp, 0)
        group_rewards[grp] += r
        group_counts[grp] += 1
    group_avg = {grp: group_rewards[grp] / group_counts[grp] for grp in group_rewards if group_counts[grp] > 0}

    grad = np.zeros_like(logits)
    total_advantage = 0.0 
    for word_idx, r, grp, probs in trajectories:

        adv = r - group_avg[grp]
        total_advantage += adv

        grad_sample = -probs.copy() 
        grad_sample[word_idx] += 1.0

        grad += grad_sample * adv

    logits += learning_rate * grad

    avg_advantage = total_advantage / num_episodes
    probs = softmax(logits)
    print(f"Iteration {iteration:02d}:")
    print("  Updated probabilities:")
    for word, p in zip(vocab, probs):
        print(f"    {word:12s}: {p:.3f}")
    print(f"  Average group-relative advantage: {avg_advantage:.3f}\n")

Iteration 01:
  Updated probabilities:
    apple       : 0.168
    banana      : 0.089
    orange      : 0.195
    grape       : 0.468
    kiwi        : 0.081
  Average group-relative advantage: 0.000

Iteration 02:
  Updated probabilities:
    apple       : 0.168
    banana      : 0.089
    orange      : 0.195
    grape       : 0.468
    kiwi        : 0.081
  Average group-relative advantage: 0.000

Iteration 03:
  Updated probabilities:
    apple       : 0.168
    banana      : 0.089
    orange      : 0.195
    grape       : 0.468
    kiwi        : 0.081
  Average group-relative advantage: 0.000

Iteration 04:
  Updated probabilities:
    apple       : 0.168
    banana      : 0.089
    orange      : 0.195
    grape       : 0.468
    kiwi        : 0.081
  Average group-relative advantage: 0.000

Iteration 05:
  Updated probabilities:
    apple       : 0.168
    banana      : 0.089
    orange      : 0.195
    grape       : 0.468
    kiwi        : 0.081
  Average group-relative advantag