In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
import torch.nn.functional as F

# Load a smaller subset if memory is an issue
with open("words_250000_train.txt", "r") as f:
    word_list = [w.strip() for w in f if len(w.strip()) > 2 and len(w.strip()) <= 12]
    word_list = random.sample(word_list, 20000)  # reduce to 20k words if needed

# Build vocab
all_chars = sorted(set("".join(word_list)))
char_to_idx = {ch: i + 1 for i, ch in enumerate(all_chars)}
char_to_idx["_"] = len(char_to_idx) + 1  # MASK
idx_to_char = {i: ch for ch, i in char_to_idx.items()}
PAD_IDX = 0
vocab_size = len(char_to_idx) + 1

# Dataset
class HangmanDataset(Dataset):
    def __init__(self, words, max_len=16):
        self.data = []
        self.max_len = max_len
        for word in words:
            idx = random.randint(0, len(word) - 1)
            target = word[idx]
            masked = list(word)
            masked[idx] = "_"
            self.data.append(("".join(masked), target, idx))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        masked, target, pos = self.data[idx]
        x = [char_to_idx.get(c, 0) for c in masked]
        x += [PAD_IDX] * (self.max_len - len(x))
        return torch.tensor(x), torch.tensor(char_to_idx[target]), torch.tensor(pos)

# Attention Mechanism
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim * 2, 1)

    def forward(self, lstm_output):
        scores = self.attention(lstm_output)  # [batch, seq_len, 1]
        attention_weights = F.softmax(scores, dim=1)  # Normalize to get the attention weights
        context = torch.sum(lstm_output * attention_weights, dim=1)  # Weighted sum
        return context

# Model with Attention Mechanism
class BiLSTMHangman(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=512, layers=3, dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=layers,
                            batch_first=True, bidirectional=True, dropout=dropout)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        emb = self.embedding(x)  # [batch, seq_len, embed_dim]
        lstm_out, _ = self.lstm(emb)  # [batch, seq_len, hidden_dim * 2]
        context = self.attention(lstm_out)  # [batch, hidden_dim * 2]
        output = self.fc(self.dropout(context))  # [batch, vocab_size]
        return output

# Training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiLSTMHangman(vocab_size).to(device)
dataset = HangmanDataset(word_list)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)
criterion = nn.CrossEntropyLoss()

# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

for epoch in range(10):
    model.train()
    total_loss = 0
    for x, y, pos in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        x, y, pos = x.to(device), y.to(device), pos.to(device)
        optimizer.zero_grad()
        out = model(x)  # [batch, vocab_size]
        loss = criterion(out, y)
        loss.backward()

        # Gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        
        optimizer.step()
        total_loss += loss.item()

    scheduler.step()  # Adjust learning rate based on scheduler
    print(f"Epoch {epoch+1} loss: {total_loss / len(dataloader):.4f}")

# Save model + mappings
torch.save(model.state_dict(), "bilstm_hangman_attention.pth")
torch.save({"char_to_idx": char_to_idx, "idx_to_char": idx_to_char}, "bilstm_vocab.pth")


Epoch 1: 100%|██████████| 313/313 [02:02<00:00,  2.56it/s]


Epoch 1 loss: 2.6429


Epoch 2: 100%|██████████| 313/313 [02:36<00:00,  2.01it/s]


Epoch 2 loss: 2.2915


Epoch 3: 100%|██████████| 313/313 [02:32<00:00,  2.05it/s]


Epoch 3 loss: 2.0968


Epoch 4: 100%|██████████| 313/313 [02:11<00:00,  2.38it/s]


Epoch 4 loss: 1.9986


Epoch 5:   4%|▍         | 13/313 [00:07<02:47,  1.79it/s]


KeyboardInterrupt: 