In [86]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

from Wordle import WordleEnv
from models.DQN import DQN
from models.ActorCritic import Actor, Critic
from ReplayMemory import ReplayMemory
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

# Initialize the environment (input subset size if necessary)
size = None
env = WordleEnv(subset_size=size) 

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


## DQN

In [89]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.99
EPS_END = 0.05
EPS_DECAY = 150000
TAU = 0.005
LR = 1e-4

n_actions = env.action_size
state = env.reset()
n_observations = len(state)

policy_net = DQN(env.state_size, env.action_size).to(device)
target_net = DQN(env.state_size, env.action_size).to(device)

target_net.load_state_dict(policy_net.state_dict())

# Optimizer initialization
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


steps_done = 0
# Epsilon greedy action selection
# Gradually decrease epsilon
# If epsilon is greater than the random sample, take random action
# Otherwise, take the action that gives the most Q value.
def select_action(state, available_actions, action_size):
    global steps_done
    sample = random.random()
    global eps_threshold
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # Create a mask tensor for previously chosen actions
            mask = torch.full((1, action_size), -float('inf'), device=device)
            for idx in available_actions:
                mask[0, idx] = 0

            # Add the mask to the DQN output and select the maximum value
            masked_output = policy_net(state) + mask
            return masked_output.max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.choice(available_actions)]], device=device, dtype=torch.long)


In [90]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [None]:
num_episodes = 300000
average_reward = 0
for episode in range(num_episodes):
    # Initialize the environment and get it's state
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state, env.available_actions, env.action_size)
        observation, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        
        if done:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            average_reward += reward/1000
            if(episode % 1000 == 0):
                print(f"Episode: {episode}/{num_episodes}, Attempts: {env.attempts}, Reward: {reward[0]}")
                average_reward = 0
            break

print('Complete')



In [92]:
total_attempts = 0
correct_guesses = 0

no_test_trials = 1000

eps_threshold = 1e-11 #For only exploration

for episode in range(no_test_trials):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    for t in count():
        action = select_action(state, env.available_actions, env.action_size)
        observation, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        
        if done:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        state = next_state
        total_attempts += 1

        if done:
            if(reward[0] == 10):
                correct_guesses += 1
            break

success_rate = correct_guesses / (no_test_trials)
average_attempts = total_attempts / (no_test_trials)

print(f"Trials: {no_test_trials}, Success rate: {success_rate:.2f}, Average number of attempts: {average_attempts:.2f}")

Trials: 1000, Success rate: 0.66, Average number of attempts: 5.09


In [94]:
total_attempts = 0
correct_guesses = 0

no_test_trials = 1000

eps_threshold = 1e-11 #For only exploration

for episode in range(no_test_trials):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    
    action = torch.tensor([[345]], device=device, dtype=torch.long) #Salet start.
    for t in count():
        observation, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        
        if done:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        state = next_state
        total_attempts += 1

        if done:
            if(reward[0] == 10):
                correct_guesses += 1
            break
        action = select_action(state, env.available_actions, env.action_size)

success_rate = correct_guesses / (no_test_trials)
average_attempts = total_attempts / (no_test_trials)

print(f"Trials with SALET start: {no_test_trials}, Success rate: {success_rate:.2f}, Average number of attempts: {average_attempts:.2f}")

Trials with SALET start: 1000, Success rate: 0.71, Average number of attempts: 4.95


In [131]:
state = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

for t in count():
    action = select_action(state, env.available_actions, env.action_size)
    observation, reward, done, _ = env.step(action.item())
    reward = torch.tensor([reward], device=device)
    env.render()
    if done:
        next_state = None
    else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
    if done:
        break

    state = next_state

Current guess: SLOSH
Target word: THORN
Attempts left: 5
Current guess: BOOZY
Target word: THORN
Attempts left: 4
Current guess: ATOLL
Target word: THORN
Attempts left: 3
Current guess: ERODE
Target word: THORN
Attempts left: 2
Current guess: THORN
Target word: THORN
Attempts left: 2


In [142]:
state = env.reset()
env.target_word = 'FRONT'
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
action = torch.tensor([[344]], device=device, dtype=torch.long) #Salet start.

for t in count():
    observation, reward, done, _ = env.step(action.item())
    reward = torch.tensor([reward], device=device)
    env.render()
    if done:
        next_state = None
    else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
    if done:
        break
    action = select_action(state, env.available_actions, env.action_size)
    state = next_state

Current guess: SALET
Target word: FRONT
Attempts left: 5
Current guess: POINT
Target word: FRONT
Attempts left: 4
Current guess: CHANT
Target word: FRONT
Attempts left: 3
Current guess: FRONT
Target word: FRONT
Attempts left: 3


In [158]:
state = env.reset()
env.target_word = 'CHOKE'
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
action = torch.tensor([[344]], device=device, dtype=torch.long) #Salet start.

for t in count():
    observation, reward, done, _ = env.step(action.item())
    reward = torch.tensor([reward], device=device)
    env.render()
    if done:
        next_state = None
    else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
    if done:
        break
    action = select_action(state, env.available_actions, env.action_size)
    state = next_state

Current guess: SALET
Target word: CHOKE
Attempts left: 5
Current guess: KEBAB
Target word: CHOKE
Attempts left: 4
Current guess: TRAWL
Target word: CHOKE
Attempts left: 3
Current guess: ACUTE
Target word: CHOKE
Attempts left: 2
Current guess: DODGE
Target word: CHOKE
Attempts left: 1
Current guess: ELOPE
Target word: CHOKE
Attempts left: 0


## Actor Critic

In [None]:
actor = Actor(env.state_size, env.action_size).to(device)
critic = Critic(env.state_size, env.action_size).to(device)

def select_action_actor(state, available_actions, action_size):
    with torch.no_grad():
        # Create a mask tensor for previously chosen actions
        mask = torch.full((1, action_size), -float('inf'), device=device)
        for idx in available_actions:
            mask[0, idx] = 0

        # Add the mask to the actor output and sample from the distribution
        actor_output = actor(state).probs
        masked_output = actor_output + mask
        masked_distribution = Categorical(logits=masked_output)
        return masked_distribution.sample().view(1, 1)

In [None]:
average_reward = 0
num_episodes = 10000

optimizer_actor = optim.Adam(actor.parameters(), lr=0.001)
optimizer_critic = optim.Adam(critic.parameters(), lr=0.001)

for episode in range(num_episodes):
    state = env.reset()
    log_probs = []
    values = []
    rewards = []
    masks = []
    entropy = 0
    env.reset()
    state = torch.FloatTensor(state).to(device)

    for i in count():
        dist, value = actor(state), critic(state)

        action = select_action_actor(state, env.available_actions, env.action_size)
        next_state, reward, done, _ = env.step(action.item())

        next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)

        log_prob = dist.log_prob(action).unsqueeze(1)
        entropy += dist.entropy().mean()

        log_probs.append(log_prob)
        if value.dim() > 1:
            values.append(value.squeeze(0))
        rewards.append(torch.tensor([reward], dtype=torch.float, device=device))
        masks.append(torch.tensor([1-done], dtype=torch.float, device=device))

        state = next_state

        if done:
            average_reward += reward/1000
            if(episode % 1000 == 0):
                print(f"Episode: {episode}/{num_episodes}, Attempts: {env.attempts}, Reward: {reward}")
                average_reward = 0
            break

    next_value = critic(next_state)

    log_probs = torch.cat(log_probs)
    returns = torch.tensor(rewards).sum()
    if values:
        values = torch.cat(values, dim=0).unsqueeze(1)
    else:
        # If values is empty, initialize it with a dummy tensor to avoid errors
        values = torch.zeros(1, 1, device=device, requires_grad=True)

    advantage = returns - values

    actor_loss = -(log_probs * advantage.detach()).mean()
    critic_loss = advantage.pow(2).mean()

    optimizer_actor.zero_grad()
    optimizer_critic.zero_grad()
    actor_loss.backward()
    critic_loss.backward()
    optimizer_actor.step()
    optimizer_critic.step()

In [None]:
total_attempts = 0
correct_guesses = 0

no_test_trials = 1000

eps_threshold = 1e-11 #For only exploration

for episode in range(no_test_trials):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    for t in count():
        action = select_action_actor(state, env.available_actions, env.action_size)
        observation, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        
        if done:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        state = next_state
        total_attempts += 1

        if done:
            if(reward[0] == 10):
                correct_guesses += 1
            break

success_rate = correct_guesses / (no_test_trials)
average_attempts = total_attempts / (no_test_trials)

print(f"Trials: {no_test_trials}, Success rate: {success_rate:.2f}, Average number of attempts: {average_attempts:.2f}")

In [None]:
total_attempts = 0
correct_guesses = 0

no_test_trials = 1000

eps_threshold = 1e-11 #For only exploration

for episode in range(no_test_trials):
    state = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    
    action = torch.tensor([[345]], device=device, dtype=torch.long) #Salet start.
    for t in count():
        observation, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        
        if done:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        state = next_state
        total_attempts += 1

        if done:
            if(reward[0] == 10):
                correct_guesses += 1
            break
        action = select_action_actor(state, env.available_actions, env.action_size)

success_rate = correct_guesses / (no_test_trials)
average_attempts = total_attempts / (no_test_trials)

print(f"Trials with SALET start: {no_test_trials}, Success rate: {success_rate:.2f}, Average number of attempts: {average_attempts:.2f}")

In [None]:
state = env.reset()
env.target_word = 'OTHER'
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
action = torch.tensor([[344]], device=device, dtype=torch.long) #Salet start.

for t in count():
    observation, reward, done, _ = env.step(action.item())
    reward = torch.tensor([reward], device=device)
    env.render()
    if done:
        next_state = None
    else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
    if done:
        break
    action = select_action_actor(state, env.available_actions, env.action_size)
    state = next_state