In [1]:
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F

from model.Models import Transformer, Transformer2
from model.Optim import CosineWithRestarts
from model.Batch import create_masks
from utils.utils import MyTokenizer, MyMasker
from utils.data import TextDataset
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
# Loading data
bs=128
dataset = TextDataset()
train_size = int(0.99*len(dataset))
test_size = len(dataset)-train_size

print(train_size, test_size)

225027 2273


In [3]:
masker = MyMasker()
tokenizer = MyTokenizer(32)

train_dataset, val_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

In [4]:
trainloader = DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True, num_workers=0)
valloader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, num_workers=0)

In [5]:
# Loading Tranformer model from scratch
max_len = 32
model = Transformer(src_vocab=28, d_model=128, max_seq_len=max_len, N=12, heads=8, dropout=0.1)
model.to('cuda')
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [6]:
masker = MyMasker()
tokenizer = MyTokenizer(max_len)

optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

In [7]:
def train_model(model, bs, epochs, printevery):

    print("training model...")
    start = time.time()
    if torch.cuda.is_available():
        print('gpu detected!')
    else:
        print('no gpu detected')
        return 0

    model.train()
    for epoch in range(epochs):

        total_loss = 0

        for i, trg in enumerate(trainloader):

            # src = batch.src.transpose(0,1)
            # trg = batch.trg.transpose(0,1)
            # trg_input = trg[:, :-1]
            # src_mask, _ = create_masks(src, trg_input) # need to edit

            # test to check if overfit

            # src is the incomplete word
            perc=None
            src = masker.mask(trg, perc)  # e.g. [m_zh__n, _s, _w_eso_e]
            src = tokenizer.encode(src)  # e.g. [[], [], []]
            
            # trg is the complete word
            trg = tokenizer.encode(trg)

            # our src_mask is the same as trg_mask = mask
            mask, _ = create_masks(src)  # e.g. [[1, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 0]]

            # Converting to cuda
            if torch.cuda.is_available():
                src = src.to('cuda')
                mask = mask.to('cuda')
                trg = trg.to('cuda')
            
            model.train()
            # preds = model(src, mask)
            preds = model(src)
            # ys = trg[:, 1:].contiguous().view(-1)
            # y = mask.squeeze(1)
            
            # 

            optim.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), trg.contiguous().view(-1), ignore_index=0)
            loss.backward()
            optim.step()

            total_loss += loss.item()

            # print(i+1)
            if (i + 1) % printevery == 0:
                p = int(100 * (i + 1) / len(trainloader.dataset) * bs)
                avg_loss = total_loss / printevery
                print("\r   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_loss), end='')
                total_loss = 0

            
            if (i+1) % 10 == 0:
                torch.save(model.state_dict(), f'./weights/model_automask_weights_{datetime.today().strftime("%m%d%Y")}')
                pass
                
        total_val_loss = 0
        sims = 0
        for i, val in enumerate(valloader):
            perc=None
            src = masker.mask(val, perc)  # e.g. [m_zh__n, _s, _w_eso_e]
            src = tokenizer.encode(src)  # e.g. [[], [], []]
            
            # trg is the complete word
            val = tokenizer.encode(val)
            
            # our src_mask is the same as trg_mask = mask
            mask, _ = create_masks(src)  # e.g. [[1, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 0]]
            
            # Converting to cuda
            if torch.cuda.is_available():
                src = src.to('cuda')
                mask = mask.to('cuda')
                val = val.to('cuda')
            
            model.eval()
            preds = model(src)
            
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), val.contiguous().view(-1), ignore_index=0)
            
            total_val_loss += loss.item()
            sims += 1
            if (i + 1) % printevery == 0:
                p = int(100 * (i + 1) / len(valloader.dataset) * bs)
                avg_val_loss = total_val_loss / sims
                print("\r   %dm: epoch %d [%s%s]  %d%%  loss = %.3f" %((time.time() - start)//60, epoch + 1, "".join('#'*(p//5)), "".join(' '*(20-(p//5))), p, avg_val_loss), end='')
            
        print("\r   %dm: epoch %d [%s%s]  %d%%  loss = %.3f\nepoch %d complete, val loss = %.03f" %\
        ((time.time() - start)//60, epoch + 1, "".join('#'*(100//5)), "".join(' '*(20-(100//5))), 100, avg_loss, epoch + 1, avg_val_loss))

In [8]:
# train_model(model, bs=bs, epochs=25, printevery=1)

In [9]:
start = ord('a')
alphabets = {'_': 27}
ids = {27:'_', 0:''}
for i in range(26):
    ch = chr(start)
    alphabets[ch] = i+1
    ids[i+1] = ch
    start += 1

In [10]:
from agent import Agent
from model.Models import PGN


pgn = PGN(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
pgn.transformer.load_state_dict(torch.load('./weights/model_weights_03202024'))

'''
pgn = PGN(src_vocab=28, d_model=32, max_seq_len=32, N=2, heads=4, dropout=0.1)
pgn.transformer.load_state_dict(torch.load('./weights/model_weights_lite_1'))
'''

if torch.cuda.is_available():
    pgn.to('cuda')

pgn.eval()

PGN(
  (transformer): Transformer(
    (encoder): Encoder(
      (embed): Embedder(
        (embed): Embedding(28, 128, padding_idx=0)
      )
      (pe): PositionalEncoder(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layers): ModuleList(
        (0-11): 12 x EncoderLayer(
          (norm_1): Norm()
          (norm_2): Norm()
          (attn): MultiHeadAttention(
            (q_linear): Linear(in_features=128, out_features=128, bias=True)
            (v_linear): Linear(in_features=128, out_features=128, bias=True)
            (k_linear): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (out): Linear(in_features=128, out_features=128, bias=True)
          )
          (ff): FeedForward(
            (linear_1): Linear(in_features=128, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear_2): Linear(in_features=2048, out_features=128, bias=True)
        

In [11]:
def simulate_trajectories(envs, policy, horizon, device):
    n = envs.num_envs

    # Initializing simulation matrices for the given batched episode
    log_probs = torch.zeros((horizon, n), dtype=torch.float32).to(device)
    rewards = torch.zeros((horizon, n), dtype=torch.float32).to(device)
    dones = torch.ones((horizon, n), dtype=bool).to(device)

    obs, _ = envs.reset()
    done = np.zeros((n,), dtype=bool)  # e.g. [False, False, False]
    T = None

    for t in range(horizon):
        obs = torch.tensor(np.float32(obs)).to(device)

        action, log_prob = policy.get_action(obs)

        log_probs[t] = log_prob
        dones[t] = torch.tensor(done).to(device)

        obs, reward, terminated, truncated, info = envs.step(action.cpu().detach().numpy())
        done = done | (np.array(terminated) | np.array(truncated))

        # Modify rewards to NOT consider data points after `done`
        reward = reward * ~done
        rewards[t] = torch.tensor(reward).to(device)

        if done.all():
            T = t
            break

    cum_discounted_rewards = discount_cumsum(rewards, dones, gamma=0.99, normalize=False, device=device)
    mean_episode_return = torch.sum(cum_discounted_rewards, axis=0) / torch.sum(~dones, axis=0)

    traj_info = {
        'log_probs': log_probs[:T],
        'rewards': rewards[:T],
        'dones': dones[:T],
    }

    return traj_info, torch.sum(rewards, axis=0), mean_episode_return

In [12]:
x = torch.rand((3, 28, 32))

In [13]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from sklearn.utils import shuffle

class HangmanEnv(gym.Env):
    def __init__(self, dataloader, max_seq_len=32, init_counter=0):
        super(HangmanEnv, self).__init__()

        self.dataset = shuffle(dataloader.dataset)
        self.counter = init_counter
        self.max_seq_len = max_seq_len
        self.action_space = spaces.Discrete(28)  # 26 possible actions (a-z) + '' + '_'
        self.observation_space = spaces.Box(low=0, high=27, shape=(self.max_seq_len,), dtype=int)

        self.hidden_word = None
        self.word_length = None
        self._reset_attributes()
    
    def _reset_attributes(self):
        self.guessed_letters = set()
        self.remaining_attempts = 6  # Maximum attempts
        self.current_state = np.zeros(self.max_seq_len, dtype=int)  # Initial state
        self.game_over = False

    def reset(self, *, seed=0, options=None):
        self.hidden_word = self.dataset[self.counter]
        self.word_length = len(self.hidden_word)
        self._reset_attributes()
        
        # Increment reset counter
        self.counter += 1

        current_word = ''.join([char if char in self.guessed_letters else '_' for char in self.hidden_word])
        self.current_state = self.word2state(current_word)
        return self.current_state, {'word': current_word, 'hidden_word': self.hidden_word, 'guessed_letters': self.guessed_letters}

    def generate_random_word(self):
        # Replace this with your logic for generating random words
        word_list = self.dataset
        idx = self.counter % len(word_list)
        self.counter += 1
        return word_list[idx]

    def step(self, action):
        if action in self.guessed_letters:
            print("You have already guessed that letter.")
        else:
            self.guessed_letters.add(action)
            if action in self.hidden_word:
                reward = 0
            else:
                reward = 0
                self.remaining_attempts -= 1

        if set(self.hidden_word) <= self.guessed_letters or self.remaining_attempts == 0:
            reward = 1 if set(self.hidden_word) <= self.guessed_letters else 0
            self.game_over = True

        current_word = ''.join([char if char in self.guessed_letters else '_' for char in self.hidden_word])
        self.current_state = self.word2state(current_word)
        return self.current_state, reward, self.game_over, self.game_over, {'word': current_word, 'hidden_word': self.hidden_word, 'guessed_letters': self.guessed_letters}

    def word2state(self, word):
        state = [27 if char == '_' else ord(char) - ord('a') + 1 for char in word]
        while len(state) < self.max_seq_len:
            state.append(0)
        return state

In [14]:
def get_valid_actions(guessed_letters):
    
    valid_actions = torch.ones((len(guessed_letters), 28)).to('cuda')
    valid_actions[:,  0] = 0.
    valid_actions[:, -1] = 0.
    
    for i, s in enumerate(guessed_letters):
        for char in s:
            idx = ord(char) - ord('a') + 1
            valid_actions[i, idx] = 0.
    
    return valid_actions

In [426]:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions.categorical import Categorical


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

device = 'cuda'

class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)

    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.0).to(device))
        return -p_log_p.sum(-1)


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        
        
        # pretrained model outputs raw logits of `expected` word from supervised learning
        self.pretrainedLLM = Transformer(src_vocab=28, d_model=128, max_seq_len=32, N=12, heads=8, dropout=0.1)
        self.pretrainedLLM.load_state_dict(torch.load('./weights/model_weights_epoch750_04042024'))
        
        # helper function
        self.softmax = nn.Softmax(dim=-1)
        
        # Flatten output logits from transformer and feed into feed-forward NN
        self.critic = nn.Sequential(
            layer_init(nn.Linear(32*28, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(32*28, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 28), std=0.01),
        )

    def get_value(self, x):
        with torch.no_grad():
            mask = (x != 0).unsqueeze(-2)
            logits = self.pretrainedLLM(x)
            probs = nn.functional.softmax(logits, dim=-1)
            probs_masked = probs * torch.stack((mask.squeeze(-2),) * probs.shape[-1], dim=-1)
            
        return self.critic(probs_masked.view((probs_masked.shape[0], -1)))

    def get_action_and_value(self, x, valid_actions, action=None):
        with torch.no_grad():
            mask = (x != 0).unsqueeze(-2)
            logits = self.pretrainedLLM(x)
            probs = nn.functional.softmax(logits, dim=-1)
            probs_masked = probs * torch.stack((mask.squeeze(-2),) * probs.shape[-1], dim=-1)
            
#             print(logits)
#             print(probs)

#             q = (1-probs+1e-9).log()
#             q = torch.matmul(1.*mask, q)
#             q = q.squeeze(1)
# #             print('before', torch.isnan(q).any().item())
# #             print(q)
#             q = torch.mul(q, valid_actions)
# #             q = q / torch.sum(q)
# #             print(valid_actions*1)
# #             print('after', torch.isnan(q).any().item())
# #             print(q)
# #             print('----------------------')
#             action_new = torch.argmin(q, dim=-1)
            
#             print(probs)
#             print(torch.isnan((1-probs).log()).any().item())
#             print(torch.min(q, dim=-1))
        
            probs = torch.matmul(1.*mask, probs)  # effectively adds the probs row-wise for each action / character
            probs = probs.squeeze(1)
            probs = probs / torch.sum(probs)

#         x = self.pretrainedLLM(x)
#         temp = nn.functional.softmax(x, dim=-1)
#         print(temp)
#         x = x.view((x.shape[0], -1))
        
#         logits = self.actor(x)
#         probs = CategoricalMasked(logits=logits, masks=valid_actions)
#         if action is None:
#             action = probs.sample()
#         logprob = probs.log_prob(action)
#         entropy = probs.entropy()

            fprobs = torch.mul(probs, valid_actions)
            fprobs = fprobs / torch.sum(fprobs)

            # Choose max probable actions as int form
            action = torch.argmax(fprobs, dim=-1)
            logprob = torch.max(fprobs, dim=-1)[0].log()
            entropy = -torch.where(valid_actions.type(torch.BoolTensor).to(device), fprobs*fprobs.log(), torch.tensor(0.0).to(device))
            entropy = entropy.sum(-1)

#         print(action.shape, logprob.shape, entropy.shape, self.critic(x).shape)
#         print(action.shape, logprob.shape, entropy.shape, self.critic(logits.view((logits.shape[0], -1))).shape)
#         print(action)
        
        return action, logprob, entropy, self.critic(probs_masked.view((probs_masked.shape[0], -1)))

In [448]:
def mini_sim(agent, envs, optimizer):
    
    n = envs.num_envs
    horizon = 32
    batch_size = horizon*n
    minibatch_size = batch_size // 4
    num_updates = 10000000 // batch_size
    
    device = 'cuda'
    
    # Initializing agent
#     agent = Agent(envs).to(device)
#     optimizer = torch.optim.Adam(agent.parameters(), lr=2.5e-4, eps=1e-5)
    
    # Initializing simulation matrices for the given batched episode
    observations = torch.zeros((horizon, n, *envs.single_observation_space.shape), dtype=int).to(device)
    actions = torch.zeros((horizon, n)).to(device)
    action_masks = torch.zeros((horizon, n) + (envs.single_action_space.n,)).to(device)
    values = torch.ones((horizon, n), dtype=torch.float).to(device)
    logprobs = torch.zeros((horizon, n), dtype=torch.float32).to(device)
    rewards = torch.zeros((horizon, n), dtype=torch.float32).to(device)
    dones = torch.ones((horizon, n), dtype=bool).to(device)
    
    
    # TRY NOT TO MODIFY: start the game
    global_t = 0
    start_time = time.time()
    next_obs, info = envs.reset()
    next_obs = torch.tensor(next_obs).to(device)
    next_done = torch.zeros((n,)).to(device)
    
    valid_actions = get_valid_actions(info['guessed_letters'])
    
    done = False
    cr = 0.
    
    wins = 0
    total_games = 0
#     for update in range(1, num_updates + 1):
    while True:
        
        for t in range(horizon):
            global_t += 1 * n
            observations[t] = next_obs
            dones[t] = next_done
            action_masks[t] = valid_actions
            
#             # Get action probs
#             probs = pgn(next_obs)
            
#             # Zero out invalid actions
#             b_probs = torch.mul(probs, valid_actions)
#             b_probs = b_probs / torch.sum(b_probs)
        
#             # Choose max probable actions as int form
#             action_ints = torch.argmax(b_probs, dim=-1)
            
             # ALGO LOGIC: action logic
            with torch.no_grad():
                action_ints, logprob, _, value = agent.get_action_and_value(next_obs, valid_actions)
                values[t] = value.flatten()
            actions[t] = action_ints
            logprobs[t] = logprob
        
            # Convert to action_int to action_str guesses
            action_strs =  [chr(idx-1 + ord('a')) for idx in action_ints]
            
            
            # Take step in the envs
            next_obs, reward, terminated, truncated, info = envs.step(action_strs)
            done = (terminated | truncated)
            
            # log data
            rewards[t] = torch.tensor(reward).to(device).view(-1)
            next_obs = torch.tensor(next_obs).to(device)
            next_done = torch.tensor(done).to(device)
            
            valid_actions = get_valid_actions(info['guessed_letters'])
        
        wins += rewards.sum()
        total_games += dones.sum()
        
        win_rate = wins / total_games
        
        mean_time_per_game = (time.time() - start_time) / total_games 
        
        print('\r  wins : %d \t total games : %d \t win rate : %.03f%% \t time_per_game : %.03f ms \t average reward : %.03f ' %(wins, total_games, 100*win_rate, 1000*mean_time_per_game, 0), end='')
        
#         '''
        gamma = 0.99
        gae_lambda = 0.95
        clip_coef = 0.1
        ent_coef = 0.01
        vf_coef = 0.5
        max_grad_norm = 0.5
        
        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            if True: # if args.gae:
                advantages = torch.zeros_like(rewards).to(device)
                lastgaelam = 0
                for t in reversed(range(horizon)):
                    if t == horizon - 1:
                        nextnonterminal = 1.0 - 1.*next_done
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - 1.*dones[t + 1]
                        nextvalues = values[t + 1]
                    delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                    advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(device)
                for t in reversed(range(horizon)):
                    if t == horizon - 1:
                        nextnonterminal = 1.0 - next_done
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = rewards[t] + gamma * nextnonterminal * next_return
                advantages = returns - values
        
        # flatten the batch
        b_obs = observations.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        b_action_masks = action_masks.reshape((-1, action_masks.shape[-1]))
        
        
        # Optimizing the policy and value network
        b_inds = np.arange(batch_size)
        clipfracs = []
        for epoch in range(10):
            np.random.shuffle(b_inds)
            for start in range(0, batch_size, minibatch_size):
                end = start + minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(
                    b_obs[mb_inds],
                    b_action_masks[mb_inds],
                    b_actions.long()[mb_inds].T,
                )
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if True: # if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if True:  # if clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -clip_coef,
                        clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss + vf_coef * v_loss #- ent_coef * entropy_loss + v_loss * vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

            if False: # if args.target_kl is not None:
                if approx_kl > target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
#         '''

        
    envs.close()
    return cr

In [449]:
trainloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=False, num_workers=0)
valloader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False, num_workers=0)

In [450]:
# envs = gym.vector.SyncVectorEnv(
#         [lambda: HangmanEnv(trainloader) for i in range(trainloader.batch_size)]
#     )

In [None]:
# from env.hangman import HangmanEnv
import gymnasium as gym

# agent = Agent(envs).to(device)
# optimizer = torch.optim.Adam(agent.parameters(), lr=2.5e-5, eps=1e-5)

# def test_pgn(dataloader):
dataloader = valloader
    
envs = gym.vector.SyncVectorEnv(
    [lambda: HangmanEnv(dataloader) for i in range(dataloader.batch_size)]
)

agent = Agent(envs).to(device)
optimizer = torch.optim.Adam(agent.parameters(), lr=2.5e-4, eps=1e-5)


agent.train()

cr = mini_sim(agent, envs, optimizer)

#         print('\r  wins : %d \t total games : %d \t win rate : %.03f%% \t reward : %.03f \t average reward : %.03f ' %(wins, total_games, 100*win_rate, cr, avg_reward), end='')

  wins : 3880 	 total games : 6036 	 win rate : 64.281% 	 time_per_game : 23.784 ms 	 average reward : 0.000 

In [None]:
t_ = time.time()
test_pgn(trainloader)
print("\n", time.time() - t_)

In [None]:
envs.single_action_space.n

In [89]:
def make_env(idx):
    def thunk():
        env = HangmanEnv(dataloader=valloader, init_counter=0)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env
    return thunk

In [90]:
envs = gym.vector.SyncVectorEnv(
        [lambda: HangmanEnv(valloader) for i in range(3)]
    )

In [108]:
envs.action_space.sample()

array([ 6, 17, 15], dtype=int64)

In [100]:
# Example usage:
obs, info = envs.reset()
print(torch.tensor(obs))
print(get_valid_actions(info['guessed_letters']))
done = False
while not done:
#     action = ['a', 'b', 'b']
    action = input("Enter a letter to guess: ")
    obs, reward, terminated, truncated, info = envs.step(action)
    print(info['guessed_letters'])
    print(get_valid_actions(info['guessed_letters']))
    print(info['guessed_letters'].shape)
#     env.render()

tensor([[27, 27, 27, 27, 27, 27, 27, 27, 27,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [27, 27, 27, 27, 27, 27, 27, 27, 27,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],
       dtype=torch.int32)
tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]], device='cuda:0')


KeyboardInterrupt: Interrupted by user

In [80]:
def mini_sim2(sample):
    env = HangmanEnv(sample[0])
    n = len(sample[0])
    state = masker.mask(sample, 1)
    sample_mask, _ = create_masks(tokenizer.encode(sample))
    mask = sample_mask.to('cuda')
    y = sample_mask.squeeze(1).to('cuda')
    y_float = torch.where(y, 1., 0.)
    
    left = torch.ones((1, 28)).to('cuda')
    left[0,  0] = 0.
    left[0, -1] = 0.
    
    P = nn.Softmax(dim=-1)
    
    done = False
    
    cr = 0

    while not done:
        
        # print(state)
        
        state = tokenizer.encode(state)
        state = state.to('cuda')
        
        # q_probs = score / torch.sum(score)
        
        probs = pgn(state, mask)
        
        b_probs = torch.mul(probs, left)
        b_probs = b_probs / torch.sum(b_probs)
        b = torch.distributions.Categorical(probs=b_probs)

        action = b.sample()
        
        # using a greedy approach
        guess_id = torch.argmax(b_probs).item()
        
        # guess_id = action.item()
        guess = ids[guess_id]
        
        next_state, r, done, _ = env.step(guess)
        
        state = [''.join(next_state)]
#         print(state) #, guess, r, next_state)
        
        left[0, guess_id] = 0.
        
        cr += r
        # print(guess, cr)
    
    return cr

In [28]:
pgn.transformer.encoder.embed

Embedder(
  (embed): Embedding(28, 128, padding_idx=0)
)

In [81]:
# from env.hangman import Hangman, HangmanEnv

def test_pgn2(valloader):
    
    wins = 0
    reward = 0
    total_games = 0
    pgn.eval()
    for i, state in enumerate(valloader):
        
        if total_games > 10: return
        
        cr = mini_sim2(state)
        if cr > - 6:
            wins += 1
            # print(state)
        total_games += 1
        reward += cr
        
        avg_reward = reward / total_games
        win_rate = wins / total_games
        print('\r  wins : %d \t total games : %d \t win rate : %.03f%% \t reward : %.03f \t average reward : %.03f ' %(wins, total_games, 100*win_rate, cr, avg_reward), end='')

In [83]:
t_ = time.time()
test_pgn2(valloader)
print("\n", time.time() - t_)

  wins : 6 	 total games : 11 	 win rate : 54.545% 	 reward : -3.000 	 average reward : -5.364 
 1.878509521484375
