In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler

In [2]:
import os
import matplotlib.pyplot as plt
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
%matplotlib inline

In [3]:
BATCH_SIZE = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("high")

In [4]:
START_TOKN = "<"
END_TOKN = ">"
MISS_TOKN = "*"
PAD_TOKN = "#"

In [5]:
def read(file):
    f = open(file, 'r', encoding='utf-8')
    data = f.read().splitlines()
    f.close()
    return data

train_data = read('train_data.txt')
val_data = read('val_data.txt')
len(train_data), len(val_data)

(204570, 22730)

In [6]:
MAX_WORD_LEN = max([len(w) for w in train_data]) + 1 # extra padding
BLOCK_SIZE = MAX_WORD_LEN + 2
MAX_WORD_LEN, BLOCK_SIZE

(30, 32)

In [7]:
CHARS = sorted(list(set("".join(train_data))) + [START_TOKN, END_TOKN, MISS_TOKN, PAD_TOKN])
"".join(CHARS)

'#*<>abcdefghijklmnopqrstuvwxyz'

In [8]:
stoi = {ch: i for i, ch in enumerate(CHARS)}
itos = {i: ch for i, ch in enumerate(CHARS)}

text = START_TOKN + "hello" + MISS_TOKN + "world" + PAD_TOKN + END_TOKN
tokns = [stoi[ch] for ch in text]
"".join([itos[i] for i in tokns])

'<hello*world#>'

In [9]:
class Attention(nn.Module):
    def __init__(self, emb=32, heads=4):
        super().__init__()
        self.heads = heads
        self.key = nn.Linear(emb, emb, bias=False)
        self.qry = nn.Linear(emb, emb)
        self.val = nn.Linear(emb, emb)
        self.proj = nn.Linear(emb, emb)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, z, attn_mask=None, rel_pos_emb=None):
        B, Tx, C = x.shape                                                # (B, Tx, C)
        Tz = z.shape[1]                                                   # (B, Tz, C)
        H = self.heads
        q = self.qry(x).view(B, Tx, H, C // H).transpose(1, 2)            # (B, H, Tx, C//H)
        k = self.key(z).view(B, Tz, H, C // H).transpose(1, 2)            # (B, H, Tz, C//H)
        v = self.val(z).view(B, Tz, H, C // H).transpose(1, 2)            # (B, H, Tz, C//H)
        att = (q @ k.transpose(-1, -2)) * (1.0 / np.sqrt(k.size(-1)))     # (B, H, Tx, Tz)
        if rel_pos_emb is not None:
            att += rel_pos_emb                                            # (B, H, Tx, Tz)
        if attn_mask is not None:
            att = att.masked_fill(attn_mask == 1, -1e9)                   # (B, H, Tx, Tz)
        att = self.dropout(F.softmax(att, dim=-1))                        # (B, H, Tx, Tz)
        y = (att @ v).transpose(1, 2)                                     # (B, Tx, H, C//H)
        y = y.contiguous().view(B, Tx, C)                                 # (B, Tx, C)
        y = self.dropout(self.proj(y))                                    # (B, Tx, C)
        return y

In [10]:
class FeedForward(nn.Module):
    def __init__(self, emb=32, mlp_emb=128):
        super().__init__()
        self.lin = nn.Linear(emb, mlp_emb)
        self.proj = nn.Linear(mlp_emb, emb)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):                                                    # (B, T, C)
        x = self.lin(x)                                                      # (B, T, M)
        x = F.gelu(x)                                                        # (B, T, M)
        y = self.proj(x)                                                     # (B, T, C)
        y = self.dropout(y)                                                  # (B, T, C)
        return y

In [11]:
class RelativePositionBias(nn.Module):
    def __init__(self, heads=4, block_size=BLOCK_SIZE):
        super().__init__()
        self.block_size = block_size
        self.bias = nn.Embedding(2*block_size+1, heads)
        
    def forward(self, Tx, Tz):
        pos = torch.arange(Tx, device=DEVICE)
        rel_pos = pos[:, None] - pos[None, :]                                # (Tx, Tz)
        rel_pos = rel_pos + self.block_size                                  # (Tx, Tz)
        bias = self.bias(rel_pos)                                            # (Tx, Tz, H)
        return bias.permute(2, 0, 1)                                         # (H, Tx, Tz)

In [12]:
class Block(nn.Module):
    def __init__(self, emb=32, heads=4, mlp_emb=128, x_attn=False, rel_pos=False):
        super().__init__()
        self.self_norm = nn.LayerNorm(emb)
        self.self_attn = Attention(emb, heads)
        self.self_res_scale = nn.Parameter(torch.tensor(0.1))
        self.feed_fwd_norm = nn.LayerNorm(emb)
        self.feed_fwd = FeedForward(emb, mlp_emb)
        self.feed_fwd_res_scale = nn.Parameter(torch.tensor(0.1))
        self.x_attn = x_attn
        if self.x_attn:
            self.cross_norm = nn.LayerNorm(emb)
            self.cross_attn = Attention(emb, heads)
            self.cross_res_scale = nn.Parameter(torch.tensor(0.1))
        self.rel_pos = rel_pos
        if self.rel_pos:
            self.rel_pos_bias = RelativePositionBias(heads) 
            
    def forward(self, x, z=None, self_mask=None, x_mask=None):
        nx = self.self_norm(x)                                              # (B, Tx, C)
        rel_pos_emb = None
        if self.rel_pos:
            rel_pos_emb = self.rel_pos_bias(nx.size(1), nx.size(1))         # (H, Tx, Tx)
            rel_pos_emb = rel_pos_emb.unsqueeze(0)                          # (1, H, Tx, Tx)
        self_attn = self.self_attn(nx, nx, self_mask, rel_pos_emb)          # (B, Tx, C)
        attn = x + self.self_res_scale * self_attn                          # (B, Tx, C)
        if self.x_attn:
            nx = self.cross_norm(attn)                                      # (B, Tx, C)
            x_attn = self.cross_attn(nx, z, x_mask)                         # (B, Tx, C)
            attn = attn + self.cross_res_scale * x_attn                     # (B, Tx, C)
        nx = self.feed_fwd_norm(attn)                                       # (B, Tx, C)
        ff = attn + self.feed_fwd_res_scale * self.feed_fwd(nx)             # (B, Tx, C)
        return ff

In [13]:
class Transformer(nn.Module):
    def __init__(self, emb=32, heads=4, mlp_emb=128, decoder_layers=4, encoder_layers=2):
        super().__init__()
        self.char_embed = nn.Embedding(len(CHARS), emb)
        self.guess_embed = nn.Linear(len(CHARS), emb)
        
        self.guess_blocks = nn.Sequential(*[Block(emb, heads, mlp_emb) for _ in range(encoder_layers)])
        self.decoder_blocks = nn.ModuleList([Block(emb, heads, mlp_emb, True, True) for _ in range(decoder_layers)])
        
        self.decoder_norm = nn.LayerNorm(emb)
        self.decoder = nn.Linear(emb, len(CHARS), bias=False)
        self.decoder.weight = self.char_embed.weight


    def forward(self, x, g):
        B, T = x.shape                                                        # (B, T)
        char_emb = self.char_embed(x)                                         # (B, T, C)
        g_emb = self.guess_embed(g).unsqueeze(1).expand(-1, T, -1)            # (B, T, C)

        pos_mask = (x == stoi[PAD_TOKN])                                      # (B, T)
        
        attn_mask = pos_mask.unsqueeze(1) | pos_mask.unsqueeze(2)             # (B, T, T)
        attn_mask = attn_mask.unsqueeze(1)                                    # (B, 1, T, T)
        
        qry_mask = pos_mask.unsqueeze(1).unsqueeze(2)                         # (B, 1, 1, T)
        qry_mask = qry_mask.expand(-1, -1, T, -1)                             # (B, 1, T, T)
        
        g_z = self.guess_blocks(g_emb)                                        # (B, T, C)
        out = char_emb                                                        # (B, T, C)
        for block in self.decoder_blocks:
            out = block(out, g_z, self_mask=attn_mask, x_mask=qry_mask)       # (B, T, C)  
        
        out = self.decoder_norm(out)                                          # (B, T, C)
        out = self.decoder(out)                                               # (B, T, V)
        out = out.transpose(-1, -2)                                           # (B, V, T)
        decoder_mask = (x == stoi[MISS_TOKN])                                 # (B, T)
        num_masked = decoder_mask.sum(dim=1, keepdim=True)                    # (B, 1)
        masked_out = out.masked_fill(~decoder_mask.unsqueeze(1), 0.0)         # (B, V, T)
        logits = masked_out.sum(dim=-1) / num_masked                          # (B, V)        
        return logits

In [14]:
@torch.no_grad()
def predict(model, word, guessed):
    model.eval()

    word =  word.replace(".", MISS_TOKN)
    x_ch = (START_TOKN + word + END_TOKN + PAD_TOKN*BLOCK_SIZE)[:BLOCK_SIZE]
    x = torch.tensor([stoi[ch] for ch in x_ch], device=DEVICE)
    g = torch.zeros(len(CHARS), device=DEVICE)
    g_tkn = [stoi[ch] for ch in guessed + list(word) if ch != MISS_TOKN]
    g[g_tkn] = 1.0
    logits = model(x.unsqueeze(0), g.unsqueeze(0))
    logits = logits.squeeze(0)

    preds = {itos[i]: logits[i].item() for i in range(len(logits))}
    return sorted(preds.items(), key=lambda item: item[1], reverse=True)

In [15]:
model = Transformer(emb=256, heads=8, mlp_emb=1024, decoder_layers=8, encoder_layers=4)
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.to("cuda")

sum(p.numel() for p in model.parameters())

11601760

In [16]:
class HangmanInit(Dataset):
    def __init__(self, words, max_length=MAX_WORD_LEN):
        self.words = words
        self.max_length = max_length
        
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, idx):
        word = self.words[idx]
        
        x_tkn = (START_TOKN + word + END_TOKN + PAD_TOKN * BLOCK_SIZE)[:BLOCK_SIZE]
        x = torch.tensor([stoi[c] for c in x_tkn], device=DEVICE)

        g = torch.zeros(len(CHARS), device=DEVICE)
        y = torch.zeros(len(CHARS), device=DEVICE)
        y[[stoi[c] for c in list(set(word))]] = 1.0

        return (x, g, y)

In [17]:
def get_loader(ds):
    loader = DataLoader(ds, batch_size=BATCH_SIZE, sampler=RandomSampler(ds, replacement=True, num_samples=BATCH_SIZE))
    return loader

In [18]:
class GameEnv:
    def __init__(self, words, g, y, max_tries=6):
        self.words = words
        self.g = g
        self.y = y
        x_mask = (self.words != stoi[START_TOKN]) & (self.words != stoi[END_TOKN]) & (self.words != stoi[PAD_TOKN])
        self.x = self.words.masked_fill(x_mask, stoi[MISS_TOKN])
        self.max_tries = max_tries
        self.fails = torch.zeros(self.y.size(0), dtype=torch.int, device=DEVICE)
        
    def get_state(self):
        return (self.x, self.g)
    
    def is_done(self):
        return len(self.g) == 0
        
    def update(self, guesses):
        B, V = self.y.shape
        T = self.x.size(1)

        correct_guesses = self.y[torch.arange(B, device=DEVICE), guesses].bool()
        new_guesses = F.one_hot(guesses, V).float()

        self.g = torch.maximum(self.g, new_guesses)
        self.y = torch.clamp(self.y - new_guesses * correct_guesses.unsqueeze(1).float(), min=0.0)

        expanded_guesses = guesses.unsqueeze(1).expand(-1, T)
        mask = self.words == expanded_guesses
        self.x = torch.where(mask, expanded_guesses, self.x)

        self.fails[~correct_guesses] += 1
        won = (self.y == 0).all(dim=1)
        lost = (self.fails >= self.max_tries)

        keep = ~(won | lost)
        self.words = self.words[keep]
        self.x = self.x[keep]
        self.g = self.g[keep]
        self.y = self.y[keep]
        self.fails = self.fails[keep]

        return won.sum().item()

In [19]:
@torch.no_grad()
def predict(model, x, g):
    model.eval()
    with torch.autocast(device_type=DEVICE, dtype=torch.bfloat16):
        logits = model(x, g)
    masked_logits = logits.masked_fill(g.bool(), -1e9)
    guesses = masked_logits.argmax(dim=1)
    return guesses

In [20]:
@torch.no_grad()
def validate(model, loader):
    model.eval()
    w_init, g_init, y_init = next(iter(loader))
    env = GameEnv(w_init, g_init, y_init)
    total_wins = 0

    while not env.is_done():
        x, g = env.get_state()
        guesses = predict(model, x, g)
        total_wins += env.update(guesses)

    return total_wins/w_init.size(0)

In [21]:
def get_win_rates(model, data, batches=10**3):
    wins = []
    loader = get_loader(HangmanInit(data))
    
    for _ in range(batches):
        wins.append(validate(model, loader))
        if (_+1) % max(1, (batches//10)) == 0 or (_+1) == batches:
            print(f"Step {_+1}: Win Rate={sum(wins)/len(wins)*100:.2f}%")
    
    num_games = BATCH_SIZE*len(wins)
    num_wins = sum(wins)*BATCH_SIZE
    p = num_wins/num_games
    var = p*(1 - p)/num_games
    std = np.sqrt(var)
    print()
    print(f"Won {int(num_wins)} out of {int(num_games)} games; E[p]={p:.4f}, Var(p)={var:.4f}, std(p)={std:.4f}")
    print(f"Win Rate between {(p - std)*100:.2f}% and {(p + std)*100:.2f}% with 65% probability")
    print(f"Win Rate between {(p - 2*std)*100:.2f}% and {(p + 2*std)*100:.2f}% with 95% probability")
    print(f"Win Rate between {(p - 3*std)*100:.2f}% and {(p + 3*std)*100:.2f}% with 99.7% probability")

In [22]:
print("Training dataset...")
get_win_rates(model, train_data)

Training dataset...
Step 100: Win Rate=71.19%
Step 200: Win Rate=71.80%
Step 300: Win Rate=71.18%
Step 400: Win Rate=71.21%
Step 500: Win Rate=71.24%
Step 600: Win Rate=71.18%
Step 700: Win Rate=71.00%
Step 800: Win Rate=71.05%
Step 900: Win Rate=71.11%
Step 1000: Win Rate=71.04%

Won 45464 out of 64000 games; E[p]=0.7104, Var(p)=0.0000, std(p)=0.0018
Win Rate between 70.86% and 71.22% with 65% probability
Win Rate between 70.68% and 71.40% with 95% probability
Win Rate between 70.50% and 71.58% with 99.7% probability


In [23]:
print("Validation dataset...")
get_win_rates(model, val_data)

Validation dataset...
Step 100: Win Rate=68.44%
Step 200: Win Rate=66.91%
Step 300: Win Rate=66.65%
Step 400: Win Rate=66.86%
Step 500: Win Rate=66.91%
Step 600: Win Rate=66.88%
Step 700: Win Rate=66.98%
Step 800: Win Rate=66.98%
Step 900: Win Rate=66.98%
Step 1000: Win Rate=66.92%

Won 42826 out of 64000 games; E[p]=0.6692, Var(p)=0.0000, std(p)=0.0019
Win Rate between 66.73% and 67.10% with 65% probability
Win Rate between 66.54% and 67.29% with 95% probability
Win Rate between 66.36% and 67.47% with 99.7% probability
