In [1]:
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F

from model.Models import Transformer, Transformer2
from model.Optim import CosineWithRestarts
from model.Batch import create_masks
from utils.utils import MyTokenizer, MyMasker
from utils.data import TextDataset
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
# Loading data
bs=128
dataset = TextDataset()
train_size = int(0.99*len(dataset))
test_size = len(dataset)-train_size

print(train_size, test_size)

225027 2273


In [3]:
train_dataset, val_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

In [4]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from sklearn.utils import shuffle

class HangmanEnv(gym.Env):
    def __init__(self, dataloader, max_seq_len=32, init_counter=0):
        super(HangmanEnv, self).__init__()

        self.dataset = shuffle(dataloader.dataset)
        self.counter = init_counter
        self.max_seq_len = max_seq_len
        self.action_space = spaces.Discrete(28)  # 26 possible actions (a-z) + '' + '_'
        self.observation_space = spaces.Box(low=0, high=27, shape=(self.max_seq_len,), dtype=int)

        self.hidden_word = None
        self.word_length = None
        self._reset_attributes()
    
    def _reset_attributes(self):
        self.guessed_letters = set()
        self.remaining_attempts = 6  # Maximum attempts
        self.current_state = np.zeros(self.max_seq_len, dtype=int)  # Initial state
        self.game_over = False

    def reset(self, *, seed=0, options=None):
        self.hidden_word = self.dataset[self.counter % len(self.dataset)]
        self.word_length = len(self.hidden_word)
        self._reset_attributes()
        
        # Increment reset counter
        self.counter += 1

        current_word = ''.join([char if char in self.guessed_letters else '_' for char in self.hidden_word])
        self.current_state = self.word2state(current_word)
        return self.current_state, {'word': current_word, 'hidden_word': self.hidden_word, 'guessed_letters': self.guessed_letters}

    def generate_random_word(self):
        # Replace this with your logic for generating random words
        word_list = self.dataset
        idx = self.counter % len(word_list)
        self.counter += 1
        return word_list[idx]

    def step(self, action):
        if action in self.guessed_letters:
            print("You have already guessed that letter.")
        else:
            self.guessed_letters.add(action)
            if action in self.hidden_word:
                reward = 0
            else:
                reward = 0
                self.remaining_attempts -= 1

        if set(self.hidden_word) <= self.guessed_letters or self.remaining_attempts == 0:
            reward = 1 if set(self.hidden_word) <= self.guessed_letters else 0
            self.game_over = True

        current_word = ''.join([char if char in self.guessed_letters else '_' for char in self.hidden_word])
        self.current_state = self.word2state(current_word)
        return self.current_state, reward, self.game_over, self.game_over, {'word': current_word, 'hidden_word': self.hidden_word, 'guessed_letters': self.guessed_letters}

    def word2state(self, word):
        state = [27 if char == '_' else ord(char) - ord('a') + 1 for char in word]
        while len(state) < self.max_seq_len:
            state.append(0)
        return state

In [5]:
def get_valid_actions(guessed_letters):
    
    valid_actions = torch.ones((len(guessed_letters), 28)).to('cuda')
    valid_actions[:,  0] = 0.
    valid_actions[:, -1] = 0.
    
    for i, s in enumerate(guessed_letters):
        for char in s:
            idx = ord(char) - ord('a') + 1
            valid_actions[i, idx] = 0.
    
    return valid_actions

In [6]:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions.categorical import Categorical


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

device = 'cuda'

class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)

    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.0).to(device))
        return -p_log_p.sum(-1)


class Agent(nn.Module):
    def __init__(self, envs, temperature=1.):
        super(Agent, self).__init__()
        
        
        # pretrained model outputs raw logits of `expected` word from supervised learning
        self.pretrainedLLM = Transformer(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
        self.pretrainedLLM.load_state_dict(torch.load('./weights/model_weights_epoch750_04042024'))
        
        # helper function
        self.softmax = nn.Softmax(dim=-1)
        
        # Flatten output logits from transformer and feed into feed-forward NN
        self.critic = nn.Sequential(
            layer_init(nn.Linear(32*28, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(32*28, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 28), std=0.01),
        )
        
        # temperature
        self.temperature = temperature

    def get_value(self, x):
#         with torch.no_grad():
        mask = (x != 0).unsqueeze(-2)
        logits = self.pretrainedLLM(x)
        probs = nn.functional.softmax(logits, dim=-1)
        probs_masked = probs * torch.stack((mask.squeeze(-2),) * probs.shape[-1], dim=-1)
        probs_masked_flatten = probs_masked.view((probs_masked.shape[0], -1))
            
        return self.critic(probs_masked_flatten.detach())

    def get_action_and_value(self, x, valid_actions, action=None):
#         with torch.no_grad():
        mask = (x != 0).unsqueeze(-2)
        logits = self.pretrainedLLM(x) / self.temperature
        probs = nn.functional.softmax(logits, dim=-1)
        probs_masked = probs * torch.stack((mask.squeeze(-2),) * probs.shape[-1], dim=-1)
        probs_masked_flatten = probs_masked.view((probs_masked.shape[0], -1))

        probs = torch.matmul(1.*mask, probs)  # effectively adds the probs row-wise for each action / character
        probs = probs.squeeze(1)
        probs = probs / torch.sum(probs)

        fprobs = torch.mul(probs, valid_actions)
        fprobs = fprobs / torch.sum(fprobs)

        # Choose max probable actions as int form
        action = torch.argmax(fprobs, dim=-1)
        logprob = torch.max(fprobs, dim=-1)[0].log()
        
        return action, logprob, fprobs * 0., self.critic(probs_masked_flatten.detach()) * 0.

In [7]:
def mini_sim(agent, envs, optimizer, train=False):
    
    n = envs.num_envs
    horizon = 32
    batch_size = horizon*n
    minibatch_size = batch_size // 4
    num_updates = 10000000 // batch_size
    
    device = 'cuda'
    
    # Initializing simulation matrices for the given batched episode
    observations = torch.zeros((horizon, n, *envs.single_observation_space.shape), dtype=int).to(device)
    actions = torch.zeros((horizon, n)).to(device)
    logprobs = torch.zeros((horizon, n), dtype=torch.float32).to(device)
    rewards = torch.zeros((horizon, n), dtype=torch.float32).to(device)
    dones = torch.ones((horizon, n), dtype=bool).to(device)
    
    
    # TRY NOT TO MODIFY: start the game
    global_t = 0
    start_time = time.time()
    next_obs, info = envs.reset()
    next_obs = torch.tensor(next_obs).to(device)
    next_done = torch.zeros((n,)).to(device)
    
    valid_actions = get_valid_actions(info['guessed_letters'])
    
    done = False
    cr = 0.
    
    wins = 0
    total_games = 0
    while True:
        
        for t in range(horizon):
            global_t += 1 * n
            observations[t] = next_obs
            dones[t] = next_done

            action_ints, logprob, _, _ = agent.get_action_and_value(next_obs, valid_actions)
            actions[t] = action_ints
            logprobs[t] = logprob
        
            # Convert to action_int to action_str guesses
            action_strs =  [chr(idx-1 + ord('a')) for idx in action_ints]
            
            
            # Take step in the envs
            next_obs, reward, terminated, truncated, info = envs.step(action_strs)
            done = (terminated | truncated)
            
            # log data
            rewards[t] = torch.tensor(reward).to(device).view(-1)
            next_obs = torch.tensor(next_obs).to(device)
            next_done = torch.tensor(done).to(device)
            
            valid_actions = get_valid_actions(info['guessed_letters'])
        
        wins += rewards.sum()
        total_games += dones.sum()
        
        win_rate = wins / total_games
        curr_win_rate = rewards.sum() / dones.sum()
        
        mean_time_per_game = (time.time() - start_time) / total_games 
        
        print('\r  wins : %d \t total games : %d \t win rate : %.03f%% \t curr win rate : %.03f%% \t time_per_game : %.03f ms' \
              %(wins, total_games, 100*win_rate, 100*curr_win_rate, 1000*mean_time_per_game), end='')
        
        if train:
            gamma = 1.
            gae_lambda = 0.95
            clip_coef = 0.1
            ent_coef = 0.01
            vf_coef = 0.5
            max_grad_norm = 0.5

            # bootstrap value if not done
            with torch.no_grad():
                next_value = agent.get_value(next_obs).reshape(1, -1)
                if False: # if args.gae:
                    advantages = torch.zeros_like(rewards).to(device)
                    lastgaelam = 0
                    for t in reversed(range(horizon)):
                        if t == horizon - 1:
                            nextnonterminal = 1.0 - 1.*next_done
                            nextvalues = next_value
                        else:
                            nextnonterminal = 1.0 - 1.*dones[t + 1]
                            nextvalues = values[t + 1]
                        delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                        advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
                    returns = advantages + values
                else:
                    returns = torch.zeros_like(rewards).to(device)
                    for t in reversed(range(horizon)):
                        if t == horizon - 1:
                            nextnonterminal = 1.0 - 1.*next_done
                            next_return = 0. # next_value
                        else:
                            nextnonterminal = 1.0 - 1.*dones[t + 1]
                            next_return = returns[t + 1]
                        returns[t] = rewards[t] + gamma * nextnonterminal * next_return
                    advantages = returns #- values
          
            optimizer.zero_grad()
            loss = (logprobs * returns).mean()
            loss.backward()
            optimizer.step()
            
        # Initializing simulation matrices for the given batched episode
        observations = torch.zeros((horizon, n, *envs.single_observation_space.shape), dtype=int).to(device)
        actions = torch.zeros((horizon, n)).to(device)
        logprobs = torch.zeros((horizon, n), dtype=torch.float32).to(device)
        rewards = torch.zeros((horizon, n), dtype=torch.float32).to(device)
        dones = torch.ones((horizon, n), dtype=bool).to(device)
            

        
    envs.close()
    return cr

In [8]:
trainloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=False, num_workers=0)
valloader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, num_workers=0)

In [None]:
# from env.hangman import HangmanEnv
import gymnasium as gym

# agent = Agent(envs).to(device)
# optimizer = torch.optim.Adam(agent.parameters(), lr=2.5e-5, eps=1e-5)

# def test_pgn(dataloader):
dataloader = trainloader
    
envs = gym.vector.SyncVectorEnv(
    [lambda: HangmanEnv(dataloader) for i in range(dataloader.batch_size)]
)

agent = Agent(envs, temperature=1.).to(device)
optimizer = torch.optim.Adam(agent.parameters(), lr=1e-5)


agent.train()

cr = mini_sim(agent, envs, optimizer, train=True)

#         print('\r  wins : %d \t total games : %d \t win rate : %.03f%% \t reward : %.03f \t average reward : %.03f ' %(wins, total_games, 100*win_rate, cr, avg_reward), end='')

  wins : 7339 	 total games : 11079 	 win rate : 66.242% 	 curr win rate : 78.000% 	 time_per_game : 72.974 ms