In [1]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class ReplayMemory():
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class Checkpointer():
    def __init__(self, dir, model_file, optim_file):
        self.dir = dir
        self.model_file = os.path.join(self.dir, model_file)
        self.optim_file = os.path.join(self.dir, optim_file)
        self.best_loss = None

    def checkpoint(self, loss, model, optim) -> None:
        if self.best_loss is None or loss < self.best_loss:
            self.best_loss = loss
            model.eval()
            torch.save(model.state_dict(), self.model_file)
            torch.save(optim.state_dict(), self.optim_file)
            model.train()


class EarlyStopper():
    def __init__(self, min_delta=0, patience=0, start_episode=0):
        self.min_delta = min_delta
        self.patience = patience
        self.start_episode = start_episode
        self.episode_last_improved = 0
        self.best_loss = None
        self.stopped = False

    def stop(self, loss, episode_n) -> bool:
        if self.best_loss is None or loss <= self.loss - self.min_delta:
            self.loss = loss
            self.episode_last_improved = episode_n
        elif (
            episode_n > start_episode + self.patience
            and episode_n > sef.episode_last_improved + self.patience
        ):
            self.stopped = True
        return self.stopped

In [13]:
import os
import sys
parent_dir = '/'.join(os.getcwd().split('/')[:-1])
sys.path.append(parent_dir)

from wordle import Wordle
from wordle import load_vocab
from wordle import ALPH_LEN

from itertools import pairwise
from numpy import prod

from __future__ import annotations
from typing import Iterable
from typing import Union
from numpy import ndarray
from numpy import int64

training_history = dict[str, tuple[float]]

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'is_terminal'))

ConvPoolConfig = namedtuple(
    'ConvPoolConfig',
    (
        'conv_out',
        'conv_kernel_size',
        'pool_kernel_size',
        'pool_padding',
        'dropout'
    )
)

FCConfig = namedtuple('FCConfig', ('f_out', 'dropout'))


class DQN(nn.Module):
    n_char_channels = 1 + ALPH_LEN
    n_hint_channels = Wordle.n_hint_states
    channels_in = n_char_channels + n_hint_channels
    def __init__(
        self,
        target_len: int,
        n_actions: int,
        conv_pool_configs: Iterable[ConvPoolConfig],
        fc_configs: list[FCConfig]
    ) -> None:
        super(DQN, self).__init__()
        self.target_len = target_len
        self.input_size = DQN.get_input_size(self.target_len)

        self.n_actions = n_actions
        if self.n_actions != fc_configs[-1].f_out:
            raise ValueError(f'fc_configs[-1].f_out ({fc_configs[-1].f_out}) should equal n_actions ({n_actions})')

        self.conv_pool_configs = conv_pool_configs
        self.fc_configs = fc_configs

        self.conv_pool_layers = DQN.make_conv_pool_layers(self.conv_pool_configs)
        self.fcin = DQN.get_fcin(self.input_size, self.conv_pool_layers)
        self.fc_layers = DQN.make_fc_layers(self.n_actions, self.fc_configs)

    def forward(self, x):
        x = self.one_hot_state(x)
        x = torch.cat(x)
        x = self.conv_pool_layers(x)
        x = x.view(-1, self.fcin)  # -1 for batch size inference
        x = self.fc_layers(x)
        return x

    @staticmethod
    def make_conv_pool_layers(conv_pool_configs: Iterable[ConvPoolConfig]) -> nn.Sequential:
        first_config = conv_pool_configs[0]
        conv_pool_layers = [
                nn.Conv2d(DQN.channels_in, first_config.conv_out, kernel_size=first_config.conv_kernel_size, padding='same'),
                nn.ReLU()
        ]
        if first_config.pool_kernel_size:
            conv_pool_layers.append(
                nn.MaxPool2d(kernel_size=first_config.pool_kernel_size, stride=1, padding=first_config.pool_padding)
            )
        if first_config.dropout:
            conv_pool_layers.append(nn.Dropout2d(first_config.dropout))

        for config1, config2 in pairwise(conv_pool_configs):
            conv_pool_layers.extend((
                nn.Conv2d(config1.conv_out, config2.conv_out, kernel_size=config2.conv_kernel_size, padding='same'),
                nn.ReLU()
            ))
            if config2.pool_kernel_size:
                conv_pool_layers.append(
                    nn.MaxPool2d(kernel_size=config2.pool_kernel_size, stride=1, padding=config2.pool_padding)
                )
            if config2.dropout:
                conv_pool_layers.append(nn.Dropout2d(config2.dropout))
        return nn.Sequential(*conv_pool_layers)

    @staticmethod
    def get_input_size(target_len):
        return (1, DQN.channels_in, Wordle.max_attempts, target_len)

    @staticmethod
    def get_fcin(input_size: tuple[int], conv_pool_layers: nn.Sequential) -> int:
        test_tensor = torch.zeros(input_size)
        with torch.no_grad():
            test_out_size = conv_pool_layers(test_tensor).size()
        return prod(test_out_size)

    @staticmethod
    def make_fc_layers(n_actions, fc_configs: list[FCConfig]) -> nn.Sequential:
        first_config = fc_configs[0]

        fc_layers = [nn.Linear(fcin, first_config.f_out)]
        if len(fc_configs) > 1:
            fc_layers.append(nn.ReLU())
        if first_config.dropout:
            fc_layers.append(nn.Dropout(first_config.dropout))

        for config1, config2 in pairwise(fc_configs):
            fc_layers.append(nn.Linear(config1.f_out, config2.f_out))
            if config2.f_out != n_actions:
                fc_layers.append(nn.ReLU())
            if config2.dropout:
                fc_layers.append(nn.Dropout(config2.dropout))
                
        return nn.Sequential(*fc_layers)

    def one_hot_state(
            self,
            state: ndarray[int64]
        ) -> tuple:
        one_hot_chars = torch.zeros((DQN.n_char_channels, Wordle.max_attempts, self.target_len))
        one_hot_hints = torch.zeros((DQN.n_hint_channels, Wordle.max_attempts, self.target_len))
        reached_attempts_made = False
        for attempt_i in range(Wordle.max_attempts):
            for pos_i in range(self.target_len):
                space_tuple = (attempt_i, pos_i)
                char, hint = state[:, *space_tuple]
                if char == Wordle.initial_empty:
                    reached_attempts_made
                    break
                one_hot_chars[char, *space_tuple] = 1.
                one_hot_hints[hint, *space_tuple] = 1.
            if reached_attempts_made:
                break
        return (one_hot_chars, one_hot_hints)

    def train_model(self, trainer: DQNTrainer) -> training_history:
        target_net = DQN(
            self.target_len,
            self.n_actions,
            self.conv_pool_configs,
            self.fc_configs
        )
        target_net.load_state_dict(self.state_dict())
        trainer.set_target_net(target_net)
        return trainer.train(self)

    def get_reward(self, state, status, guess_n) -> float:
        turn_value = 10.
        reward = -turn_value
        reward += sum(
            1./self.target_len for hint in state[Wordle.hint_channel, guess_n - 1]
            if hint == Wordle.correct
        )
        reward += Wordle.max_attempts * turn_value * int(status == Wordle.won)
        return reward / self.target_len


class DQNTrainer():
    def __init__(
        self,
        optimizer,
        memory_size: int = 10_000,
        n_episodes: int = 10_000,
        checkpointer: Checkpointer = None,
        stopper: EarlyStopper = None,
        device: torch.device = None,
        batch_size: int = 128,
        discount: float = 0.99,
        eps_start: float = 0.9,
        eps_end: float = 0.05,
        eps_decay: int = 1000,
        update_rate: float = 1e-4,
    ) -> None:
        self.optimizer = optimizer
        self.memory = ReplayMemory(memory_size)
        self.n_episodes = n_episodes
        self.batch_size = batch_size
        self.discount = discount
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.update_rate = update_rate
        self.checkpointer = checkpointer
        self.stopper = stopper
        self.device = device

        self.step_n = 0
        self.batch_n = 0

        self.target_net = None # set in DQN.train w/ DQNTrainer.set_target_net

    def set_target_net(self, target_net: DQN) -> None:
        self.target_net = target_net
        self.target_net.to(self.device)

    def select_action(self, policy_net: DQN, state: torch.Tensor) -> torch.Tensor:
        self.step_n += 1
        eps_threshold = (
            (self.eps_start - self.eps_end)
            * math.exp(-1. * self.steps_n / self.eps_decay)
            + self.eps_end 
        )
        if random.random() > eps_threshold:
            with torch.no_grad():
                return policy_net(state).max(1).indices.view(1, 1)
        else:
            return torch.tensor(
                [[random.randint(0, policy_net.n_actions - 1)]],
                device=self.device,
                dtype=torch.long
            )

    def train(self, policy_net: DQN, n_episodes: int = 0) -> training_history:
        if n_episodes == 0:
            n_episodes == self.n_episodes

        episode_rewards = []
        losses = []
        for episode_i in range(n_episodes):
            #print(f'{episode_i = }')
            if self.stopper.stopped:
                print(f'Early stopping after {episode_i} episodes')
                break
            wordle = Wordle(all_words)
            #print(f'{wordle.target = }')
            #print(f'{wordle.target_len = }')
            state = wordle.state
            state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            #next_state = state
            status = Wordle.ongoing
            episode_reward = 0
            while status == Wordle.ongoing:
                #print(f'state size in loop: {state.size()}')
                #print(f'state in loop: {state}')
                action = select_action(policy_net, state)
                invalid_guess = False
                try:
                    next_state = wordle.guess(all_words[action])
                    status = wordle.check_state()
                except ValueError:
                    # "Should" maybe wordle.attempts_made += 1, but not used in reward anyway
                    next_state = None
                    status = Wordle.lost
                    invalid_guess = True
                reward = get_reward(
                    next_state,
                    status,
                    wordle.attempts_made,
                    target_len
                )
                episode_reward += reward
                if not invalid_guess:
                    next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)
                reward = torch.tensor([reward], device=device)
        
                # Store the transition in memory
                self.memory.push(state, action, next_state, reward)
                state = next_state
        
                # Perform optimization on the policy network, checkpoint, early stopping
                loss = self.optimize_model()
                if loss:
                    losses.append(loss)
                    batch_n += 1
                    self.checkpointer.checkpoint(loss, policy_net, optimizer)
                if self.stopper.stop(loss, episode_i):
                    break
        
                # Soft update of the target network's weights
                # θ′ ← τ θ + (1 −τ )θ′
                policy_net_state_dict = policy_net.state_dict()
                target_net_state_dict = self.target_net.state_dict()
                for key in policy_net_state_dict:
                    target_net_state_dict[key] = (
                        policy_net_state_dict[key] * self.update_rate
                        + target_net_state_dict[key] * (1 - self.update_rate)
                    )
                self.target_net.load_state_dict(target_net_state_dict)
        
                if status != Wordle.ongoing:
                    episode_rewards.append(episode_reward)
                    #if episode_i % 50 == 0:
                    plot_training(losses, batch_n, episode_rewards, episode_i)

        return {
            'episode_rewards': episode_rewards,
            'losses': losses
        }

    def optimize_model(self) -> Union[None, float]:
        """ Return Huber loss """
        if len(memory) < self.batch_size:
            return None
        transitions = self.memory.sample(self.batch_size)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043)
        batch = Transition(*zip(*transitions))
    
        # Compute a mask of non-terminal states and concatenate the batch elements
        non_terminal_mask = torch.logical_not(torch.tensor(batch.is_terminal, device=device, dtype=torch.bool))
        non_terminal_next_states = [s for s, nt in zip(batch.next_state, non_terminal_mask) if nt]
        #if not non_terminal_nest_states:
        #    return None
        non_terminal_next_states = torch.cat(non_terminal_nest_states)
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
    
        # Compute Q(s_t, a)
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)
    
        # Compute V(s_{t+1}) for all next states.
        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        with torch.no_grad():
            next_state_values[non_terminal_mask] = target_net(non_terminal_next_states).max(1).values
        # Compute the expected Q values
        expected_state_action_values = torch.where(
            non_terminal_mask,
            reward_batch + next_state_values * self.discount,
            reward_batch
        )
    
        # Compute Huber loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        # In-place gradient clipping
        torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
        self.optimizer.step()
    
        return float(loss)

    @staticmethod
    def plot_training(losses: list[float], batch_n: int, episode_rewards: list[float], episode_n: int):
        plt.figure(0, figsize=(20, 8))
        plt.clf()
        plt.title(f'Episode {episode_n}')
    
        if batch_n:
            plt.delaxes()
            ax1 = plt.subplot(121)
            ax2 = plt.subplot(122)
            
            ax1.set_xlabel('Batch Number')
            ax1.set_ylabel('Loss')
            ax1.scatter(tuple(range(batch_n)), losses)
    
            ax2.set_xlabel('Episode')
            ax2.set_ylabel('Total Reward')
            ax2.scatter(tuple(range(episode_n + 1)), episode_rewards)
    
            plt.tight_layout()
    
        else:
            plt.xlabel('Episode')
            plt.ylabel('Total Reward')
            plt.scatter(tuple(range(episode_n + 1)), episode_rewards)
    
        if is_ipython:
            display.display(plt.gcf())
            display.clear_output(wait=True)

In [14]:
TARGET_LEN = 5
ALL_WORDS = load_vocab(f'../corncob_caps_5.txt')
N_ACTIONS = len(ALL_WORDS)

conv_pool_configs = [
    # ConvPoolConfig(conv_out, conv_kernel_size, pool_kernel_size, pool_stride, dropout)
    ConvPoolConfig(DQN.channels_in, 1,           0, 0,                          0.05),
    ConvPoolConfig(DQN.channels_in * 4, 3,       (1, 3), (0, 1),                0.1),
    ConvPoolConfig(DQN.channels_in * 8, 3,       (1, 2), (0, 1),                0.1)
]
fcin = DQN.get_fcin(
    DQN.get_input_size(TARGET_LEN),
    DQN.make_conv_pool_layers(conv_pool_configs)
)
fc_configs = [
    FCConfig(2 ** 12, 0.5),
    FCConfig(N_ACTIONS, 0.)
]

wordle_net = DQN(TARGET_LEN, N_ACTIONS, conv_pool_configs, fc_configs).to(device)

lr = 1e-5
optimizer = optim.AdamW(wordle_net.parameters(), lr=lr, amsgrad=True)
checkpointer = Checkpointer('../checkpoints', 'model_5c.pt', 'optim_5c.pt')
checkpointer.checkpoint(None, wordle_net, optimizer)
stopper = EarlyStopper(
    min_delta=0.0,
    patience=50,
    start_episode=200
)

trainer = DQNTrainer(
    optimizer,
    checkpointer=checkpointer,
    stopper=stopper,
    device=device
)
wordle_net.train_model(trainer)
    
plt.ioff();

In [None]:
game = Wordle(all_words)
print(f'{game.target = }')
print(f'{game.target_len = }')
state = game.state
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
with torch.no_grad():
    action = policy_net(state).max(1).indices.view(1, 1)
    guess = all_words[action]
print(f'guessing: {all_words[action]}')
state = game.guess(guess)

In [None]:
npn = DQN(n_actions)
npn.load_state_dict(torch.load('../checkpoints/model.pt'))