In [None]:
import numpy as np
from functools import partial


import pandas as pd
import matplotlib.pyplot as plt

In [3]:
# matches = re.match('var La=\[(?:"([a-z]+)",?)+\]', wordle_js)
def extract_wordlist(name):
    raw_list = re.findall(fr'{name}=\[([a-z",]+)\]', wordle_js)[0]
    return raw_list.replace('"', '').split(',')

solutions_str = extract_wordlist('La')
guesses_str   = extract_wordlist('Ta') + solutions_str

In [4]:
def encode(word):
    return np.array([ord(c) - ord('a') for c in word])

def decode(word_vector):
    return ''.join(chr(ord('a') + c) for c in word_vector)

solutions = np.stack([encode(word) for word in solutions_str])
guesses   = np.stack([encode(word) for word in guesses_str])

In [5]:
@jax.jit
def score_guess(guess, solution):
    green  = (guess == solution)
    
    # Effectively mask out green numbers
    solution = solution + jnp.where(green, 100, 0)
    guess    = guess + jnp.where(green, 1000, 0)
        
    def is_yellow(i):
        return (jnp.cumsum(guess == guess[i])[i] <= jnp.sum(solution == guess[i]))
    
    yellow = jax.vmap(is_yellow)(jnp.arange(5))
    return jnp.where(green,  2,
           jnp.where(yellow, 1,
                             0))


def fmt_score(score):
    return ''.join('✕~✓'[i] for i in score)

## Test cases
def test(guess, solution, expected_score):
    score = score_guess(encode(guess), encode(solution))
    if np.all(score == np.asarray(expected_score)):
        print(f'Guess {guess} for {solution} scored correctly.')
    else:
        print(f'Guess {guess} for {solution} scored '
              f'{fmt_score(score)}, expected {fmt_score(expected_score)}!')

test('acccb', 'bccca', [1,2,2,2,1])
test('aaaab', 'baaaa', [1,2,2,2,1])
test('abcda', 'abcde', [2,2,2,2,0])
test('cbada', 'abcde', [1,2,1,2,0])
test('cbada', 'abcde', [1,2,1,2,0])
test('zymic', 'could', [0,0,0,0,1])

Guess acccb for bccca scored correctly.
Guess aaaab for baaaa scored correctly.
Guess abcda for abcde scored correctly.
Guess cbada for abcde scored correctly.
Guess cbada for abcde scored correctly.
Guess zymic for could scored correctly.


In [6]:
def build_functions():
    def encoder(guess, score):
        letter_emb = hk.Embed(26, 112, name='letter_emb')(guess)
        guess_emb  = hk.Embed(26,  16, name='guess_emb')(score)
        x = jnp.concatenate([letter_emb, guess_emb], axis=-1)
        x = rearrange(x, 'batch letter dim -> batch (letter dim)')
        return hk.Sequential([
            hk.Linear(1024, name='enc1'), jax.nn.relu,
            hk.Linear(1024, name='enc2'), jax.nn.relu,
            hk.Linear(1024, name='enc3'),
        ])(x)
        
    def actor(current_information):
        features = hk.Sequential([
            hk.Linear(1024), jax.nn.relu,
            hk.Linear(1024), jax.nn.relu,
        ])(current_information)
        guess_logits    = hk.Linear(len(guesses))(features)
        expected_reward = hk.Linear(len(guesses))(features)
        return guess_logits, expected_reward
        
    def init(guess, score):
        info = encoder(guess, score)
        actions = actor(info)
        return actions

    return init, (encoder, actor)

In [7]:
def unpack_without_apply_rng(multitransformed):
    funs = multitransformed.apply
    def apply_without_rng(fun):
        def inner(params, *args, **kwargs):
            return fun(params, None, *args, **kwargs)
        return inner
    
    return jax.tree_map(apply_without_rng, funs)

net = hk.multi_transform(build_functions)
rng = jax.random.PRNGKey(42)

guess = jnp.ones([7, 5], dtype=jnp.uint8)
score = jnp.ones([7, 5], dtype=jnp.uint8)
params = net.init(rng, guess, score)
encoder, actor = unpack_without_apply_rng(net)
opt = optax.adam(1e-1)
opt_state = opt.init(params)

jax.tree_map(lambda x: x.shape, params)

FlatMap({
  'enc1': FlatMap({'b': (1024,), 'w': (640, 1024)}),
  'enc2': FlatMap({'b': (1024,), 'w': (1024, 1024)}),
  'enc3': FlatMap({'b': (1024,), 'w': (1024, 1024)}),
  'guess_emb': FlatMap({'embeddings': (26, 16)}),
  'letter_emb': FlatMap({'embeddings': (26, 112)}),
  'linear': FlatMap({'b': (1024,), 'w': (1024, 1024)}),
  'linear_1': FlatMap({'b': (1024,), 'w': (1024, 1024)}),
  'linear_2': FlatMap({'b': (12972,), 'w': (1024, 12972)}),
  'linear_3': FlatMap({'b': (12972,), 'w': (1024, 12972)}),
})

In [8]:
@partial(jax.vmap, in_axes=[-1, None], out_axes=-1)
@partial(jax.vmap, in_axes=[None, 0], out_axes=0)
def select_guess(guesses, idx):
    return guesses[idx]

@jax.jit
def training_step(key, solutions, opt_state, params):
    def calculate_loss(θ):
        B = solutions.shape[0]
        information = jnp.zeros([B, 1024])
        
        evaluations = []
        scores = []
        expected_rewards = []
        
        for i in range(6):
            guess_logits, evaluation = actor(θ, information)
            guess_idx = jax.random.categorical(key, guess_logits)
            guess = select_guess(guesses, guess_idx)
            evaluations.append(jnp.take_along_axis(evaluation, guess_idx.reshape(-1, 1), axis=-1))
            
            expected_rewards.append(jnp.sum(jax.nn.softmax(guess_logits, axis=-1) *
                                      jax.lax.stop_gradient(evaluation), axis=-1))
            
            score = jax.vmap(score_guess)(guess, solutions)
            scores.append(score)
            information = information + encoder(θ, guess, score)

        scores = jnp.stack(scores, axis=1)
        solved = jnp.all(scores == 2, axis=-1)
        reward = scores.sum(axis=[1, 2]) + 100 * solved.sum(axis=1)
        critic_loss = sum(jnp.mean(jnp.abs(ev - reward)) for ev in evaluations)
        expected_reward = jnp.mean(jnp.stack(expected_rewards))
        actor_loss = -expected_reward
        
        return critic_loss + actor_loss, dict(
            critic_loss=critic_loss,
            actor_loss=actor_loss,
            expected_reward=expected_reward,
            actual_reward=jnp.mean(reward),
            pct_solved=jnp.mean(solved),
        )
    
    grads, metrics = jax.grad(calculate_loss, has_aux=True)(params)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return metrics, opt_state, params

In [9]:
batch_size = 64
training_key = jax.random.PRNGKey(0)

def train_epoch(key, opt_state, params):
    key, subkey = jax.random.split(key)
    batch_data = jax.random.permutation(key, solutions, axis=-1)
    N = batch_data.shape[0]
    batch_data = batch_data[:(N//batch_size)*batch_size]
    all_metrics = []
    for batch in rearrange(batch_data, '(N B) ... -> N B ...', B=64):
        metrics, opt_state, params = training_step(training_key, batch, opt_state, params)
        all_metrics.append(jax.tree_map(float, metrics))
    return all_metrics, opt_state, params

In [10]:
epochs = []
for epoch in range(200):
    metrics, opt_state, params = train_epoch(jax.random.PRNGKey(0), opt_state, params)
    epoch_metrics = pd.DataFrame(metrics).mean()
    print(epoch_metrics['actual_reward'], epoch_metrics['pct_solved'])
    epochs.append(epoch_metrics)

8.414496527777779 0.0
9.379340277777779 0.0
9.07595486111111 0.0
7.794270833333333 0.0
7.794270833333333 0.0
7.794270833333333 0.0
8.24826388888889 0.0
8.28689236111111 0.0
8.26953125 0.0
9.58029513888889 0.0
9.6953125 0.0
9.6953125 0.0
9.6953125 0.0
9.6953125 0.0
9.807725694444445 0.0
10.207899305555555 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0
10.1796875 0.0


In [None]:
epoch_metrics = pd.concat([pd.DataFrame(epoch).mean() for epoch in epochs], axis=1).T
plt.plot(epoch_metrics.actor_loss)

In [None]:
df = pd.DataFrame(metrics)
plt.plot(df['actual_reward'])

In [None]:
epoch_metrics