In [1]:
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

from model.Models import Transformer
from model.Optim import CosineWithRestarts
from model.Batch import create_masks
from utils.utils import MyTokenizer, MyMasker
from utils.data import TextDataset
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
# Loading data
bs=128
dataset = TextDataset()
train_size = int(0.99*len(dataset))
test_size = len(dataset)-train_size

print(train_size, test_size)

225027 2273


In [3]:
masker = MyMasker()
tokenizer = MyTokenizer(32)

train_dataset, val_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

### Implementing custom Hangman gym-based env
* Follows the gym protocol.
* Is vectorized and can support multithreading for parallel computation.

In [4]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from sklearn.utils import shuffle

class HangmanEnv(gym.Env):
    def __init__(self, dataloader, max_seq_len=32, init_counter=0):
        super(HangmanEnv, self).__init__()

        self.dataset = shuffle(dataloader.dataset)
        self.counter = init_counter
        self.max_seq_len = max_seq_len
        self.action_space = spaces.Discrete(28)  # 26 possible actions (a-z) + '' + '_'
        self.observation_space = spaces.Box(low=0, high=27, shape=(self.max_seq_len,), dtype=int)

        self.hidden_word = None
        self.word_length = None
        self._reset_attributes()
    
    def _reset_attributes(self):
        self.guessed_letters = set()
        self.remaining_attempts = 6  # Maximum attempts
        self.current_state = np.zeros(self.max_seq_len, dtype=int)  # Initial state
        self.game_over = False

    def reset(self, *, seed=0, options=None):
        self.hidden_word = self.dataset[self.counter % len(self.dataset)]
        self.word_length = len(self.hidden_word)
        self._reset_attributes()
        
        # Increment reset counter
        self.counter += 1

        current_word = ''.join([char if char in self.guessed_letters else '_' for char in self.hidden_word])
        self.current_state = self.word2state(current_word)
        return self.current_state, \
            {
            'word': current_word, 
            'hidden_word': self.hidden_word, 
            'guessed_letters': self.guessed_letters,
            'remaining_attempts': self.remaining_attempts,
        }

    def generate_random_word(self):
        # Replace this with your logic for generating random words
        word_list = self.dataset
        idx = self.counter % len(word_list)
        self.counter += 1
        return word_list[idx]

    def step(self, action):
        if action in self.guessed_letters:
            print("You have already guessed that letter.")
        else:
            self.guessed_letters.add(action)
            if action in self.hidden_word:
                reward = 0
            else:
                reward = 0
                self.remaining_attempts -= 1

        if set(self.hidden_word) <= self.guessed_letters or self.remaining_attempts == 0:
            reward = 1 if set(self.hidden_word) <= self.guessed_letters else 0
            self.game_over = True

        current_word = ''.join([char if char in self.guessed_letters else '_' for char in self.hidden_word])
        self.current_state = self.word2state(current_word)
        return self.current_state, reward, self.game_over, self.game_over, \
            {
            'word': current_word, 
            'hidden_word': self.hidden_word, 
            'guessed_letters': self.guessed_letters,
            'remaining_attempts': self.remaining_attempts,
        }

    def word2state(self, word):
        state = [27 if char == '_' else ord(char) - ord('a') + 1 for char in word]
        while len(state) < self.max_seq_len:
            state.append(0)
        return state

In [5]:
def get_valid_actions(guessed_letters):
    
    valid_actions = torch.ones((len(guessed_letters), 28)).to('cuda')
    valid_actions[:,  0] = 0.
    valid_actions[:, -1] = 0.
    
    for i, s in enumerate(guessed_letters):
        for char in s:
            idx = ord(char) - ord('a') + 1
            valid_actions[i, idx] = 0.
    
    return valid_actions

In [6]:

# Define the Replay Buffer
class ReplayBuffer:
    def __init__(self, envs, buffer_size):
        
        self.envs = envs
#         self.memory = deque(maxlen=buffer_size)
#         self.batch_size = batch_size
        self.buffer_size = buffer_size
#         self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
#         self.seed = random.seed(seed)
        
        # pointer
        self.t_ = 0
        
        # Initializing simulation matrices for the given batched episode
        n = envs.num_envs
        self.states = torch.zeros((buffer_size, n, *envs.single_observation_space.shape), dtype=int).to(device)
        self.numlives = torch.zeros((buffer_size, n), dtype=int).to(device)
        self.actions = torch.zeros((buffer_size, n), dtype=int).to(device)
        self.rewards = torch.zeros((buffer_size, n), dtype=torch.float32).to(device)
        self.next_states = torch.zeros((buffer_size, n, *envs.single_observation_space.shape), dtype=int).to(device)
        self.next_valid_actions = torch.zeros((buffer_size, n, envs.single_action_space.n), dtype=int).to(device)
        self.next_numlives = torch.zeros((buffer_size, n), dtype=int).to(device)
        self.dones = torch.ones((buffer_size, n), dtype=bool).to(device)
    
    def add(self, state, action, reward, next_state, next_valid_action, numlives, next_numlives, done):
        idx = self.t_ % self.buffer_size
        
        self.states[idx] = state
        self.actions[idx] = action
        self.rewards[idx] = torch.tensor(reward).to(device)
        self.next_states[idx] = next_state
        self.next_valid_actions[idx] = next_valid_action
        self.numlives[idx] = torch.tensor(numlives).to(device)
        self.next_numlives[idx] = torch.tensor(next_numlives).to(device)
        self.dones[idx] = torch.tensor(done).to(device)
        self.t_ += 1
    
    def sample(self):
        
        # Flatten the simulation matrices
        b_states = self.states.reshape((-1,) + self.envs.single_observation_space.shape)
        b_actions = self.actions.reshape(-1)
        b_rewards = self.rewards.reshape(-1)
        b_next_states = self.next_states.reshape((-1,) + self.envs.single_observation_space.shape)
        b_next_valid_actions = self.next_valid_actions.reshape((-1, self.envs.single_action_space.n))
        b_numlives = self.numlives.reshape(-1)
        b_next_numlives = self.next_numlives.reshape(-1)
        b_dones = self.dones.reshape(-1)
        
        return (b_states, b_actions, b_rewards, b_next_states, b_next_valid_actions, b_numlives, b_next_numlives, b_dones)
    
    def __len__(self):
        return len(self.memory)

In [7]:

class BehaviourPolicy(nn.Module):
    def __init__(self, envs, temperature=1.):
        super(BehaviourPolicy, self).__init__()
        
        self.memory = ReplayBuffer(envs, buffer_size=32)
        
        # pretrained model outputs raw logits of `expected` word from supervised learning
        self.pretrainedLLM = Transformer(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
        self.pretrainedLLM.load_state_dict(torch.load('./weights/model_weights_epoch750_04042024'))
        
        # helper function
        self.softmax = nn.Softmax(dim=-1)
        
        # temperature
        self.temperature = temperature
        
        self.update_every = 1
        self.t_step = 0
    
    def step(self, state, action, reward, next_state, next_valid_action, numlives, next_numlives, done):    
        self.memory.add(state, action, reward, next_state, next_valid_action, numlives, next_numlives, done)

        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            pass

    def act(self, x, valid_actions):
        mask = (x != 0).unsqueeze(-2)
        logits = self.pretrainedLLM(x) / self.temperature
        probs = nn.functional.softmax(logits, dim=-1)

        probs = torch.matmul(1.*mask, probs)  # effectively adds the probs row-wise for each action / character
        probs = probs.squeeze(1)
        probs = probs / torch.sum(probs)

        fprobs = torch.mul(probs, valid_actions)
        fprobs = fprobs / torch.sum(fprobs)        

#         dist = Categorical(probs=fprobs)
#         action = dist.sample()
#         logprob = dist.log_prob(action)
        
        action = torch.argmax(fprobs, dim=-1)
        
        return action, None

In [8]:
import numpy as np
from torch.distributions.categorical import Categorical
import random
from collections import namedtuple, deque

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.uniform_(layer.weight)
#     torch.nn.init.constant_(layer.bias, bias_const)
    return layer

device = 'cuda'

class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)

    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.0).to(device))
        return -p_log_p.sum(-1)
    

# Define the Q-network
class QNetwork(nn.Module):
    def __init__(self, envs):
        super(QNetwork, self).__init__()
        
        # pretrained model outputs raw logits of `expected` word from supervised learning
        self.transformer = Transformer(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
        
        self.pretrained = Transformer(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
        self.pretrained.load_state_dict(torch.load('./weights/model_weights_epoch750_04042024'))
        
        self.pretrained.eval()
        
        # Flatten output logits from transformer and feed into a LINEAR LAYER
        self.output = nn.Sequential(
            layer_init(nn.Linear(32*28, 28, bias=False)),
            nn.Sigmoid(),
        )
        
        self.scalar = nn.Parameter(torch.tensor(0.))
        self.matrix = nn.Parameter(torch.ones((6, 32, 28)))
        
    def forward(self, x, stages):
        mask = (x != 0).unsqueeze(-2)
        with torch.no_grad():
            h = self.pretrained(x)
            h = nn.functional.softmax(h, dim=-1)
#         x = nn.functional.softmax(self.transformer(x), dim=-1)    
        x = h
        x_masked = x * torch.stack((mask.squeeze(-2),) * x.shape[-1], dim=-1)
        
#         mean = x_masked.mean(dim=-2)
        
#         x_masked_flatten = x_masked.view((x_masked.shape[0], -1))
        
#         mu = nn.functional.sigmoid(self.scalar)
        
#         print(self.output(x_masked_flatten).shape)
#         print(mean.shape)

        return (x_masked * self.matrix[stages-1, ...]).sum(dim=-2)
    


class DQNAgent(nn.Module):
    def __init__(self, envs):
        super(DQNAgent, self).__init__()
        
        self.gamma = 1.00
        self.tau = 1e-1
        self.lr = 1e-2
        
        self.qnetwork_local = QNetwork(envs).to(device)
        self.qnetwork_target = QNetwork(envs).to(device)
        
        self.optimizer = torch.optim.SGD(self.qnetwork_local.parameters(), lr=self.lr)
        
        self.memory = ReplayBuffer(envs, buffer_size=32)
        
        self.buffer_size = int(1e5)
        self.batch_size = 64
        
        self.update_every = 1
        self.t_step = 0
        
    
    def step(self, state, action, reward, next_state, next_valid_action, numlives, next_numlives, done):    
        self.memory.add(state, action, reward, next_state, next_valid_action, numlives, next_numlives, done)

        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            pass
#             if len(self.memory) > self.batch_size:
#             experiences = self.memory.sample()
#             self.learn(experiences)
                
    
    def act(self, state, valid_actions, numlives, eps=0.):
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state, numlives)
        self.qnetwork_local.train()
        actions = torch.argmax(action_values * valid_actions, dim=-1)
        return actions, None
        
    
    def learn(self, experiences):
        states, actions, rewards, next_states, next_valid_actions, numlives, next_numlives, dones = experiences
        
        Q_targets_next = (self.qnetwork_target(next_states, next_numlives).detach() * next_valid_actions).max(1)[0]
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - 1.*dones))
        Q_targets = Q_targets.unsqueeze(-1)
        
        Q_expected = self.qnetwork_local(states, numlives).gather(1, actions.unsqueeze(-1))
        
        loss = F.mse_loss(Q_expected, Q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
        
        return loss
        
    
    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
            


In [9]:
trainloader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=False, num_workers=0)
valloader = DataLoader(dataset=val_dataset, batch_size=128, shuffle=False, num_workers=0)

In [10]:
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize the environment
dataloader = valloader
envs = gym.vector.SyncVectorEnv(
    [lambda: HangmanEnv(dataloader) for i in range(dataloader.batch_size)]
)

In [11]:
# Initialize the agent
# agent = DQNAgent(envs).to(device)
agent = BehaviourPolicy(envs, temperature=1.).to(device)

# agent.qnetwork_local.load_state_dict(torch.load('online-linear-DQN-checkpoint.pth'))

agent.eval()
# behaviourPolicy.eval()

BehaviourPolicy(
  (pretrainedLLM): Transformer(
    (encoder): Encoder(
      (embed): Embedder(
        (embed): Embedding(28, 128, padding_idx=0)
      )
      (pe): PositionalEncoder(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layers): ModuleList(
        (0-11): 12 x EncoderLayer(
          (norm_1): Norm()
          (norm_2): Norm()
          (attn): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=True)
            (v_linear): Linear(in_features=128, out_features=128, bias=True)
            (k_linear): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (out): Linear(in_features=128, out_features=128, bias=True)
          )
          (ff): FeedForward(
            (linear_1): Linear(in_features=128, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear_2): Linear(in_features=2048, out_features=128, bias=

In [12]:
from collections import deque

# Create a deque with a maximum length of 20
max_size = 10
losses = deque(maxlen=max_size)

def dqn(n_episodes=1000):
#     scores = []
#     scores_window = deque(maxlen=100)
    
    wins = total_games = 0
    start_time = time.time()
    
    i_episode = 0
    while True:
        
        state, info = envs.reset()
        state = torch.tensor(state).to(device)
        
        for t in range(agent.memory.buffer_size):
            valid_actions = get_valid_actions(info['guessed_letters'])
            numlives = info['remaining_attempts']
            
            if np.random.rand() < 1.:
                action_ints, _ = agent.act(state, valid_actions)
                
            else:  # follow random policy for exploration
                action_ints = []
                for v in valid_actions:
                    v_ = v.nonzero().flatten()
                    idx = torch.randint(0, v_.numel(), (1,)).item()
                    action_ints.append(v_[idx])
                action_ints = torch.tensor(action_ints).to(device)
            
            # Convert to action_int to action_str guesses
            action_strs =  [chr(idx-1 + ord('a')) for idx in action_ints]
            
            # Take step in the envs
            next_state, reward, terminated, truncated, info = envs.step(action_strs)
            next_state = torch.tensor(next_state).to(device)
            done = (terminated | truncated)
            
            agent.step(state, action_ints, reward, next_state, \
                       get_valid_actions(info['guessed_letters']), numlives, info['remaining_attempts'], done)
            
            state = next_state
        
        experiences = agent.memory.sample()
        loss = torch.tensor(0)
        losses.append(loss.item())
        
        wins += agent.memory.rewards.sum()
        total_games += agent.memory.dones.sum()
        
        win_rate = wins / total_games
        curr_win_rate = agent.memory.rewards.sum() / agent.memory.dones.sum()
        
        mean_time_per_game = (time.time() - start_time) / total_games
        
        print('''\r  loss = %.03f \t wins : %d \t total games : %d \t win rate : %.03f%% \t curr win rate : %.03f%% \t time_per_game : %.03f ms''' \
              %(sum(losses) / len(losses), wins, total_games, 100*win_rate, 100*curr_win_rate, 1000*mean_time_per_game), end='', flush=True)
        
        i_episode += 1
        
        if (i_episode + 1) % 100 == 0:
            pass
        

In [13]:
# Train the agent
scores = dqn()

# Plot the scores

  loss = 0.000 	 wins : 482152 	 total games : 730610 	 win rate : 65.993% 	 curr win rate : 66.964% 	 time_per_game : 4.515 ms


KeyboardInterrupt



In [None]:
agent = saved_agent