In [1]:
from google.colab import drive 
drive.mount("/content/drive")
project_dir = "/content/drive/MyDrive/6.806 6.864 Final Project"

Mounted at /content/drive


In [2]:
%%bash
pip install tokenizers
pip install sacrebleu

Collecting tokenizers
  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
Installing collected packages: tokenizers
Successfully installed tokenizers-0.10.2
Collecting sacrebleu
  Downloading https://files.pythonhosted.org/packages/7e/57/0c7ca4e31a126189dab99c19951910bd081dea5bbd25f24b77107750eae7/sacrebleu-1.5.1-py3-none-any.whl (54kB)
Collecting portalocker==2.0.0
  Downloading https://files.pythonhosted.org/packages/89/a6/3814b7107e0788040870e8825eebf214d72166adf656ba7d4bf14759a06a/portalocker-2.0.0-py2.py3-none-any.whl
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-2.0.0 sacrebleu-1.5.1


In [3]:
import math
from tqdm import tqdm

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils import data
from torch import cuda
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from tokenizers import Tokenizer
from tokenizers.trainers import BpeTrainer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace

import os
import json

import sacrebleu

device = "cuda" if cuda.is_available() else "cpu"

train_input_path = os.path.join(project_dir, "hearthstone", "train_hs.in")
train_target_path = os.path.join(project_dir, "hearthstone", "train_hs.out")
test_input_path = os.path.join(project_dir, "hearthstone", "test_hs.in")
test_target_path = os.path.join(project_dir, "hearthstone", "test_hs.out")

# Parsing
In this first section, we parse the input and outputs texts. This consists of tokenizing the texts and converting tokens into IDs

In [None]:
PAD_ID = 0
PAD_TOKEN = "PAD_INDEX"
SOS_ID = 2
SOS_TOKEN = "SOS_INDEX"
EOS_ID = 3
EOS_TOKEN = "EOS_INDEX"

In [4]:
def read_file(path):
    """Reads file and returns a list of each line"""
    with open(path) as f:
        return f.readlines()

def tokenize_input_corpus(inputs, mode, token2id=None):
    """Tokenize corpus (list of texts) and returns IDs of each text"""
    line_tokens = [line.split() for line in inputs]
    max_length = max([len(tokens) for tokens in line_tokens])
    if not token2id:
        special_tokens = [
            "PAD_INDEX", "UNK", "SOS_INDEX", "EOS_INDEX", "NAME_END", "ATK_END", "DEF_END",
            "COST_END", "DUR_END", "TYPE_END", "PLAYER_CLS_END", "RACE_END", "RARITY_END"
        ]
        token2id = { token: idx for (idx, token) in enumerate(special_tokens) }
    tokenized_corpus = []
    line_lens = []
    for tokens in line_tokens:
        tokens = line.split()
        word_ids, token2id, ids_len = process_line(tokens, token2id, mode, 
                                                     max_length)
        tokenized_corpus.append(word_ids)
        line_lens.append(ids_len)
    return tokenized_corpus, vocab_dict, line_lens
        
def tokens_to_ids(tokens, token2id, mode, max_length):
    """Convert list of tokens to IDs"""
    token_ids = []
    for token in tokens:
        if token in token2id:
            token_ids.append(token2id[token])
        elif mode == "train":
            token_ids.append(len(token2id))
            token2id[token] = len(token2id)
        else:
            token_ids.append(token2id["UNK"])
    ids_len = len(token_ids) + 2 
    token_ids = [token2id["SOS_INDEX"]] + token_ids + \
                   [token2id["EOS_INDEX"]] + \
                   [token2id["PAD_INDEX"]]*(max_length - ids_len)
    return token_ids, token2id, ids_len

def batch_tokens_to_ids(line_tokens, token2id, mode, max_length=None):
    """Convert tokens to IDs for a batch and returns long tensor for whole batch of IDs"""
    if max_length is None:
        max_length = max([len(line) for line in line_tokens])
    line_ids = []
    line_lens = []
    for tokens in line_tokens:
        ids, token2id, line_len = tokens_to_ids(tokens, token2id, mode, max_length=max_length)
        line_ids.append(ids)
        line_lens.append(line_len)
    return torch.LongTensor(line_ids).to(device), token2id, torch.LongTensor(line_lens).to(device)

In [5]:
def save_mapping(mapping, path):
    with open(path, 'w') as f:
        json.dump(mapping, f)
        
def load_mapping(path):
    with open(path) as f:
        return json.load(f)

## Byte Pair Encoding Tokenizer
We experiment with tokenizing using Huggingface's byte pair encoding tokenizer

In [6]:
def get_trained_tokenizer(train_paths):
    """Use huggingface tokenizer and train on corpus"""
    tokenizer = Tokenizer(BPE())
    tokenizer.pre_tokenizer = Whitespace()
    trainer = BpeTrainer()
    tokenizer.train(files=train_paths)
    return tokenizer

In [None]:
# Train tokenizer for inputs and targets
input_special_tokens = [
    "PAD_INDEX", "UNK", "SOS_INDEX", "EOS_INDEX", "NAME_END", "ATK_END", "DEF_END",
    "COST_END", "DUR_END", "TYPE_END", "PLAYER_CLS_END", "RACE_END", "RARITY_END"
]
train_raw_inputs = read_file(train_input_path)
input_tokenizer = get_trained_tokenizer([train_input_path])

target_special_tokens = ["PAD_INDEX", "UNK", "SOS_INDEX", "EOS_INDEX"]
train_raw_targets = read_file(train_target_path)
target_tokenizer = get_trained_tokenizer([train_target_path])

In [None]:
# Get tokens for inputs and targets
input_line_tokens = [e.tokens for e in input_tokenizer.encode_batch(train_raw_inputs)]
target_line_tokens = [e.tokens for e in target_tokenizer.encode_batch(train_raw_targets)]

# Create mapping of tokens to IDs for inputs and targets
input_tokens = set([t for line in input_line_tokens for t in line]) - set(input_special_tokens)
all_input_tokens = input_special_tokens + sorted(list(input_tokens))
input_token2id = { token: id for (id, token) in enumerate(all_input_tokens) }
input_id2token = { id: token for (token, id) in input_token2id.items() }

target_tokens = set([t for line in target_line_tokens for t in line]) - set(target_special_tokens)
all_target_tokens = target_special_tokens + sorted(list(target_tokens))
target_token2id = { token: id for (id, token) in enumerate(all_target_tokens) }
target_id2token = { id: token for (token, id) in target_token2id.items() }

# Get longest length of tokens (+2 for start and end tokens)
input_max_seq_len = max([len(line) for line in input_line_tokens]) + 2
target_max_seq_len = max([len(line) for line in target_line_tokens]) + 2

In [None]:
# Set to True if you want to save these mappings
save_mappings = False
if save_mappings:
    save_mapping(input_token2id, os.path.join(project_dir, "input_token2id.json"))
    save_mapping(target_token2id, os.path.join(project_dir, "target_token2id.json"))

# Set to True if you want to load previously used mappings
load_mappings = False
if load_mappings:
    input_token2id = load_mapping(os.path.join(project_dir, "input_token2id.json"))
    target_token2id = load_mapping(os.path.join(project_dir, "target_token2id.json"))
    input_id2token = { id: token for (token, id) in input_token2id.items() }
    target_id2token = { id: token for (token, id) in target_token2id.items() }

# Organizing Data
We will organize the input and output tokens above into a dataset object for easier use during training. This dataset will convert tokens into padded ID sequences.

In [7]:
class SimpleHearthstoneDataset(data.Dataset):
    """Simple dataset with input and output tokens as IDs"""
    def __init__(self, input_line_tokens, target_line_tokens, input_token2id, target_token2id, mode, input_max_seq_len=None, target_max_seq_len=None):
        self.input_line_ids, _, self.input_line_lens = batch_tokens_to_ids(input_line_tokens, input_token2id, mode, max_length=input_max_seq_len)
        self.target_line_ids, _, self.target_line_lens = batch_tokens_to_ids(target_line_tokens, target_token2id, mode, max_length=target_max_seq_len)
        
    def __len__(self):
        return len(self.input_line_ids)
    
    def __getitem__(self, idx):
        return self.input_line_ids[idx], self.target_line_ids[idx], self.input_line_lens[idx], self.target_line_lens[idx]

In [None]:
# Split validation and training data
validation_ratio = 0.1
train_size = int((1 - validation_ratio) * len(input_line_tokens))
train_input_line_tokens = input_line_tokens[: train_size]
train_target_line_tokens = target_line_tokens[: train_size]
validation_input_line_tokens = input_line_tokens[train_size:]
validation_target_line_tokens = target_line_tokens[train_size:]

# Create datasets for validation and training
simple_train_dataset = SimpleHearthstoneDataset(train_input_line_tokens, train_target_line_tokens, input_token2id, target_token2id, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)
simple_validation_dataset = SimpleHearthstoneDataset(validation_input_line_tokens, validation_target_line_tokens, input_token2id, target_token2id, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)

# Simple Seq2Seq Model
We start by experimenting with a simple seq2seq model w/attention to get benchmark performance
1. Tokenize each word without separating into fields
2. Encode input sequences using bi-RNN
3. Use last hidden layer and outputs of encoder in decoder to generate output tokens

In [8]:
class SimpleHearthstoneEncoder(nn.Module):
    """Simple encoder for hearthstone tokens"""
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers=3, dropout=0.1):
        super(SimpleHearthstoneEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.GRU(
            input_size=embedding_size,
            hidden_size=hidden_size,
            num_layers=3,
            dropout=dropout,
            batch_first=True,
            bidirectional=True
        )

    def forward(self, inputs, lengths, max_seq_length=None):
        """
        :param inputs: 3d tensor of shape (batch_size, max_seq_length, embed_size)
        :param lengths: 1d tensor of shape (batch_size,)

        :return: (outputs, finals) where outputs is 3d tensor of shape (batch_size, max_seq_length, hidden_size)
                and finals is 3d tensor of shape (num_layers, batch_size, 2*hidden_size)
        """
        if max_seq_length is None:
            max_seq_length = inputs.size(1)
            
        embedded_inputs = self.embedding(inputs)
        packed = pack_padded_sequence(embedded_inputs, lengths.cpu(), batch_first=True, enforce_sorted=False)
        outputs, hidden = self.rnn(packed)
        outputs, _ = pad_packed_sequence(outputs, batch_first=True, total_length=max_seq_length)

        forward_hidden = hidden[::2]
        backward_hidden = hidden[1::2]
        hidden = torch.cat([forward_hidden, backward_hidden], dim=2)

        return outputs, hidden

In [9]:
# Decoder with attention architecture based on
# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
# http://www.davidsbatista.net/blog/2020/01/25/Attention-seq2seq/
# https://lena-voita.github.io/nlp_course/seq2seq_and_attention.html

class SimpleHearthstoneDecoder(nn.Module):
    """Simple decoder with attention for hearthstone tokens to python code"""
    def __init__(self, vocab_size, embedding_size, hidden_size, enc_hidden_size, enc_max_seq_length, rnn_num_layers=3, dropout=0.1):
        # TODO: add dropout?
        super(SimpleHearthstoneDecoder, self).__init__()
        self.bridge = nn.Linear(enc_hidden_size, hidden_size)
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        
        self.dropout = nn.Dropout(p=dropout)
        self.attention = nn.Linear(embedding_size + hidden_size, enc_max_seq_length)
        self.combine_attention = nn.Linear(enc_hidden_size + embedding_size, hidden_size)
        
        self.rnn = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True, num_layers=rnn_num_layers)

    def forward_step(self, prev_embed, hidden, encoder_outputs):
        """
        :param prev_embed: 3d tensor of shape (batch_size, 1, embed_size) containing word embeddings
                from previous time step
        :param hidden: 3d tensor of shape (num_layers, batch_size, hidden_size) representing current decoder hidden state
        :param encoder_outputs: 3d tensor of shape (batch_size, max_seq_length, enc_hidden_size) representing output layers of encoder
                for all time steps

        :return: [pre_output, hidden] of current time step
        """
        # Use previous embedding and last hidden layer to compute attention weights
        concat_prev_embed = torch.cat((prev_embed.squeeze(1), hidden[-1]), dim=-1)
        attention_raw = self.attention(concat_prev_embed)
        attention_weights = F.softmax(attention_raw)
        # Apply attention weights to encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)

        # Combine context vector with previous embedding
        prev_embed_with_context = torch.cat((prev_embed, context), dim=-1).squeeze(1)
        prev_embed_with_context = self.combine_attention(prev_embed_with_context)
        prev_embed_with_context = F.relu(prev_embed_with_context).unsqueeze(1)

        return self.rnn(prev_embed_with_context, hidden)

    def forward(self, inputs, encoder_outputs, encoder_hidden, hidden=None, max_output_len=None):
        """
        :param inputs: 3d tensor of shape (batch_size, max_seq_length) with target sentences
        :param encoder_outputs: 3d tensor of shape (batch_size, max_seq_length, enc_hidden_size) with output layers from encoder
        :param encoder_hidden: 3d tensor of shape (num_enc_layers, batch_size, hidden_size) with final encoder hidden state
        :param hidden: 3d tensor of shape (1, batch_size, hidden_size) with hidden state from previous time step
        :param max_output_len: int maximum length of output sequence
        """
        # Initialize values if not given
        if max_output_len is None:
            max_output_len = inputs.size(1)
        if hidden is None:
            hidden = self.init_hidden(encoder_hidden)
            
        embedded = self.embedding(inputs)
        dropped_embedded = self.dropout(embedded)

        # Generate output and hidden for each word
        pre_output_vectors = []
        for i in range(max_output_len):
            prev_embed = dropped_embedded[:, i].unsqueeze(1)
            pre_output, hidden = self.forward_step(prev_embed, hidden, encoder_outputs)
            pre_output_vectors.append(pre_output)

        outputs = torch.cat(pre_output_vectors, dim=1)
        return outputs, hidden

    def init_hidden(self, encoder_hidden):
        """
        :param encoder_hidden: 3d tensor of shape (num_enc_layers, batch_size, hidden_size) with final encoder hidden state
        """
        return torch.tanh(self.bridge(encoder_hidden))

In [10]:
class SimpleHearthstoneEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, generator):
        """
        Inputs:
          - `encoder`: an `Encoder` object.
          - `decoder`: a `Decoder` object.
          - `generator`: a `Generator` object. Essentially a linear mapping. See
              the next code cell.
        """
        super(SimpleHearthstoneEncoderDecoder, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator

    def forward(self, src_ids, trg_ids, src_lengths):
        """Take in and process masked source and target sequences.

        Inputs:
          `src_ids`: a 2d-tensor of shape (batch_size, max_seq_length) representing
            a batch of source sentences of word ids.
          `trg_ids`: a 2d-tensor of shape (batch_size, max_seq_length) representing
            a batch of target sentences of word ids.
          `src_lengths`: a 1d-tensor of shape (batch_size,) representing the
            sequence length of `src_ids`.

        Returns the decoder outputs, see the above cell.
        """
        encoder_outputs, encoder_hidden = self.encode(src_ids, src_lengths)
        return self.decode(trg_ids[:, :-1], encoder_outputs, encoder_hidden)

    def encode(self, src_ids, src_lengths):
        return self.encoder(src_ids, src_lengths)

    def decode(self, trg_ids, encoder_outputs, encoder_hidden, decoder_hidden=None):
        return self.decoder(trg_ids, encoder_outputs, encoder_hidden, hidden=decoder_hidden)

In [11]:
class Generator(nn.Module):
    """Define standard linear + softmax generation step."""
    def __init__(self, hidden_size, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

# Training

In [12]:
class SimpleLossCompute:
    """A simple loss compute and train function."""

    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)),
                                y.contiguous().view(-1))
        loss = loss / norm

        if self.opt is not None:    # training mode
            loss.backward()            
            self.opt.step()
            self.opt.zero_grad()

        return loss.data.item() * norm

In [13]:
def run_epoch(data_loader, model, loss_compute):
    """Standard Training and Logging Function"""
    total_tokens = 0
    total_loss = 0

    for i, (src_ids_BxT, trg_ids_BxL, src_lengths_B, trg_lengths_B) in enumerate(tqdm(data_loader, position=0, leave=True)):
        # We define some notations here to help you understand the loaded tensor
        # shapes:
        #     `B`: batch size
        #     `T`: max sequence length of source sentences
        #     `L`: max sequence length of target sentences; due to our preprocessing
        #        in the beginning, `L` == `T` == 50
        # An example of `src_ids_BxT` (when B = 2):
        #     [[2, 4, 6, 7, ..., 4, 3, 0, 0, 0],
        #    [2, 8, 6, 5, ..., 9, 5, 4, 3, 0]]
        # The corresponding `src_lengths_B` would be [47, 49].

        src_ids_BxT = src_ids_BxT.to(device)
        src_lengths_B = src_lengths_B.to(device)
        trg_ids_BxL = trg_ids_BxL.to(device)

        del trg_lengths_B     # unused

        output, _ = model(src_ids_BxT, trg_ids_BxL, src_lengths_B)

        loss = loss_compute(x=output, y=trg_ids_BxL[:, 1:],
                            norm=src_ids_BxT.size(0))
        total_loss += loss
        total_tokens += (trg_ids_BxL[:, 1:] != PAD_ID).data.sum().item()

    print(f"Total loss: {math.exp(total_loss / float(total_tokens))}")

    return math.exp(total_loss / float(total_tokens))

def train(model, train_data_loader, val_data_loader, num_epochs, learning_rate):
    # Set `ignore_index` as PAD_INDEX so that pad tokens won't be included when
    # computing the loss.
    criterion = nn.NLLLoss(reduction="sum", ignore_index=PAD_ID)
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Keep track of dev ppl for each epoch.
    dev_ppls = []

    for epoch in range(num_epochs):
        print("Epoch", epoch)

        model.train()
        train_ppl = run_epoch(data_loader=train_data_loader, model=model,
                                loss_compute=SimpleLossCompute(model.generator,
                                                             criterion, optim))

        model.eval()
        with torch.no_grad():        
            dev_ppl = run_epoch(data_loader=val_data_loader, model=model,
                                loss_compute=SimpleLossCompute(model.generator,
                                                             criterion, None))
            print("Validation perplexity: %f" % dev_ppl)
            dev_ppls.append(dev_ppl)
        
    return dev_ppls

In [None]:
# Set params
input_embedding_size = 256
input_hidden_size = 256
target_embedding_size = 256
target_hidden_size = 256
batch_size = 8

# Create data objects
simple_train_dataloader = data.DataLoader(simple_train_dataset, batch_size=batch_size, shuffle=True)
simple_validation_dataloader = data.DataLoader(simple_validation_dataset, batch_size=batch_size, shuffle=True)
input_vocab_size = len(input_token2id)
target_vocab_size = len(target_token2id)

# Create models
simple_encoder = SimpleHearthstoneEncoder(input_vocab_size, input_embedding_size, input_hidden_size).to(device)
simple_decoder = SimpleHearthstoneDecoder(target_vocab_size, target_embedding_size, target_hidden_size, 2 * input_hidden_size, input_max_seq_len).to(device)
simple_generator = Generator(target_hidden_size, target_vocab_size)
simple_encoder_decoder = SimpleHearthstoneEncoderDecoder(simple_encoder, simple_decoder, simple_generator).to(device)

In [None]:
# Train model
epochs = 20
lr = 1e-3

train(simple_encoder_decoder, simple_train_dataloader, simple_validation_dataloader, epochs, lr)

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.81it/s]

Total loss: 96.70452770506758


100%|██████████| 7/7 [00:01<00:00,  5.76it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 66.02021267971398
Validation perplexity: 66.020213
Epoch 1


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.81it/s]

Total loss: 48.98931493740771


100%|██████████| 7/7 [00:01<00:00,  5.62it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 51.30249561738432
Validation perplexity: 51.302496
Epoch 2


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.95it/s]

Total loss: 38.906787834542925


100%|██████████| 7/7 [00:01<00:00,  5.58it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 45.68239112190183
Validation perplexity: 45.682391
Epoch 3


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.89it/s]

Total loss: 33.67026832979503


100%|██████████| 7/7 [00:01<00:00,  5.55it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 41.78727966139444
Validation perplexity: 41.787280
Epoch 4


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.84it/s]

Total loss: 29.324395184191822


100%|██████████| 7/7 [00:01<00:00,  5.60it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 36.482113245105026
Validation perplexity: 36.482113
Epoch 5


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.71it/s]

Total loss: 24.781192485138135


100%|██████████| 7/7 [00:01<00:00,  5.65it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 30.42686051854323
Validation perplexity: 30.426861
Epoch 6


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.70it/s]

Total loss: 19.975010359931307


100%|██████████| 7/7 [00:01<00:00,  5.65it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 24.09215548001891
Validation perplexity: 24.092155
Epoch 7


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.42it/s]

Total loss: 14.178256448330295


100%|██████████| 7/7 [00:01<00:00,  5.65it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 15.96575728577981
Validation perplexity: 15.965757
Epoch 8


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.27it/s]

Total loss: 10.097060952148745


100%|██████████| 7/7 [00:01<00:00,  5.50it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 12.537885190707557
Validation perplexity: 12.537885
Epoch 9


100%|██████████| 60/60 [00:32<00:00,  1.85it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.89it/s]

Total loss: 8.041673375853048


100%|██████████| 7/7 [00:01<00:00,  5.67it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 10.167672190626027
Validation perplexity: 10.167672
Epoch 10


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.95it/s]

Total loss: 6.3707693740538565


100%|██████████| 7/7 [00:01<00:00,  5.76it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 8.53414204206735
Validation perplexity: 8.534142
Epoch 11


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.84it/s]

Total loss: 5.346270869216869


100%|██████████| 7/7 [00:01<00:00,  5.64it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 7.436633749346687
Validation perplexity: 7.436634
Epoch 12


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.57it/s]

Total loss: 4.636551228584552


100%|██████████| 7/7 [00:01<00:00,  5.39it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 6.738077019986307
Validation perplexity: 6.738077
Epoch 13


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.64it/s]

Total loss: 4.139951898358672


100%|██████████| 7/7 [00:01<00:00,  5.72it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 6.109578465705943
Validation perplexity: 6.109578
Epoch 14


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.54it/s]

Total loss: 3.747990331742934


100%|██████████| 7/7 [00:01<00:00,  5.59it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.562359923912533
Validation perplexity: 5.562360
Epoch 15


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.93it/s]

Total loss: 3.400493724109707


100%|██████████| 7/7 [00:01<00:00,  5.70it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.191183346838957
Validation perplexity: 5.191183
Epoch 16


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.83it/s]

Total loss: 3.160959319564019


100%|██████████| 7/7 [00:01<00:00,  5.64it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.915829559045294
Validation perplexity: 4.915830
Epoch 17


100%|██████████| 60/60 [00:32<00:00,  1.87it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.83it/s]

Total loss: 2.937889689410582


100%|██████████| 7/7 [00:01<00:00,  5.61it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.741284766207096
Validation perplexity: 4.741285
Epoch 18


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.83it/s]

Total loss: 2.7627043984451154


100%|██████████| 7/7 [00:01<00:00,  5.74it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.41464438651577
Validation perplexity: 4.414644
Epoch 19


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.60it/s]

Total loss: 2.5979389893193146


100%|██████████| 7/7 [00:01<00:00,  5.63it/s]

Total loss: 4.25753033179605
Validation perplexity: 4.257530





[66.02021267971398,
 51.30249561738432,
 45.68239112190183,
 41.78727966139444,
 36.482113245105026,
 30.42686051854323,
 24.09215548001891,
 15.96575728577981,
 12.537885190707557,
 10.167672190626027,
 8.53414204206735,
 7.436633749346687,
 6.738077019986307,
 6.109578465705943,
 5.562359923912533,
 5.191183346838957,
 4.915829559045294,
 4.741284766207096,
 4.41464438651577,
 4.25753033179605]

In [None]:
# Set to True if you want to save this model
# MAKE SURE TO SAVE YOUR MAPPINGS AS WELL WITH THE CELL EARLIER IN THE NOTEBOOK
save_model = False
if save_model:
    torch.save(simple_encoder_decoder.state_dict(), os.path.join(project_dir, "simple_encoder_decoder.pt"))

# Set to True if you want to load a previously trained model
# MAKE SURE TO LOAD MODEL'S CORRESPONDING MAPPINGS WITH THE CELL EARLIER IN THE NOTEBOOK
load_model = False
if load_model:
    simple_encoder_decoder.load_state_dict(torch.load(os.path.join(project_dir, "simple_encoder_decoder.pt"), map_location=device))
    simple_encoder_decoder.eval()

# Decoding
Here we will use the trained model to decode input texts into generated code

In [14]:
def tokens_to_text(ids, mapping):
    # TODO: handle newlines and spaces
    return "".join([mapping[id] for id in ids])

In [None]:
# Read and tokenize test inputs and targets
test_raw_inputs = read_file(test_input_path)
test_raw_targets = read_file(test_target_path)
test_input_line_tokens = [e.tokens for e in input_tokenizer.encode_batch(test_raw_inputs)]
test_target_line_tokens = [e.tokens for e in target_tokenizer.encode_batch(test_raw_targets)]

# Truncate line tokens so it matches training data
trunc_test_input_line_tokens = []
for line_tokens in test_input_line_tokens:
    if len(line_tokens) > input_max_seq_len - 2:
        trunc_test_input_line_tokens.append(line_tokens[:input_max_seq_len - 2])
    else:
        trunc_test_input_line_tokens.append(line_tokens)
        
trunc_test_target_line_tokens = []
for line_tokens in test_target_line_tokens:
    if len(line_tokens) > target_max_seq_len - 2:
        trunc_test_target_line_tokens.append(line_tokens[:target_max_seq_len - 2])
    else:
        trunc_test_target_line_tokens.append(line_tokens)

simple_test_dataset = SimpleHearthstoneDataset(trunc_test_input_line_tokens, trunc_test_target_line_tokens, input_token2id, target_token2id, "test", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)

## Greedy Decoding
We first try decoding using a greedy approach, taking the most likely token at each step and expanding that sequence

In [None]:
def greedy_decoding(model, src_ids, src_lengths, max_len):
    """Greedily decode a sentence for EncoderDecoder. Make sure to chop off the 
         EOS token!"""

    with torch.no_grad():
        encoder_outputs, encoder_hidden = model.encode(src_ids, src_lengths)
        prev_y = torch.ones(1, 1).fill_(SOS_ID).type_as(src_ids)
    
    output = []
    hidden = None

    for i in range(max_len):
        with torch.no_grad():
            outputs, hidden = model.decode(prev_y, encoder_outputs, encoder_hidden, hidden)
            prob = model.generator(outputs[:, -1])
        d, next_word = torch.max(prob, dim=1)
        next_word = next_word.data.item()
        output.append(next_word)
        prev_y = torch.ones(1, 1).type_as(src_ids).fill_(next_word)

    output = np.array(output)

    # Cut off everything starting from </s>.
    first_eos = np.where(output == EOS_ID)[0]
    if len(first_eos) > 0:
        output = output[:first_eos[0]]

    return output

In [None]:
def spot_check_greedy(model, dataset, idx=None, n=1):
    """Compare a (random) generated and target sequence using greedy search"""
    for i in range(n):
        if idx is None:
            idx = np.random.randint(0, len(dataset))
        inp_ids, trg_ids, inp_lens, trg_lens = dataset[idx: idx+1]
        greedy_decoded = greedy_decoding(model, inp_ids, inp_lens, target_max_seq_len)
        stripped_trg_ids = trg_ids[0][trg_ids[0] != PAD_ID][1:-1].tolist()
        stripped_inp_ids = inp_ids[0][inp_ids[0] != PAD_ID][1:-1].tolist()
        print("===============================")
        print(f"Input: {tokens_to_text(stripped_inp_ids, input_id2token)}")
        print(f"Expected:\n\n\t{tokens_to_text(stripped_trg_ids, target_id2token)}\n\n-got-\n\n\t{tokens_to_text(greedy_decoded, target_id2token)}")
        print("===============================")

In [None]:
# Spot check for training set
spot_check_greedy(simple_encoder_decoder, simple_train_dataset)

Input: SacrificialPactNAME_END-1ATK_END-1DEF_END0COST_END-1DUR_ENDSpellTYPE_ENDWarlockPLAYER_CLS_ENDNILRACE_ENDCommonRARITY_ENDDestroyaDemon.Restore#5Healthtoyourhero.
Expected:

	classSacrificialPact(SpellCard):§def__init__(self):§super().__init__("SacrificialPact",0,CHARACTER_CLASS.WARLOCK,CARD_RARITY.COMMON,target_func=hearthbreaker.targeting.find_spell_target,filter_func=lambdacharacter:character.card.minion_type==MINION_TYPE.DEMON)§§defuse(self,player,game):§super().use(player,game)§self.target.die(self)§player.hero.heal(player.effective_heal_power(5),self)§

-got-

	classArcaneMissiles(SpellCard):§def__init__(self):§super().__init__("HolySpirit",2,CHARACTER_CLASS.SHAMAN,CARD_RARITY.COMMON,target_func=hearthbreaker.targeting.find_minion_spell_target)§§defuse(self,player,game):§super().use(player,game)§self.target.damage(player.effective_spell_damage(2),self)§


  attention_weights = F.softmax(attention_raw)


In [None]:
# Spot check for testing set
spot_check_greedy(simple_encoder_decoder, simple_test_dataset)



Input: MadUNKBomberNAME_END5ATK_END4DEF_END5COST_END-1DUR_ENDMinionTYPE_ENDNeutralPLAYER_CLS_ENDNILRACE_ENDRareRARITY_END<b>Battlecry:</b>Deal6damagerandomlysplitbetweenallothercharacters.
Expected:

	classMadUNKBomber(MinionCard):§def__init__(self):§super().__init__("MadUNKBomber",5,CHARACTER_CLASS.ALL,CARD_RARITY.RARE,battlecry=Battlecry(Damage(1),CharacterSelector(players=BothPlayer(),picker=RandomPicker(6))))§§defcreate_minion(self,player):§returnMinion(5,4)§

-got-

	classDruidOfTheClaw(MinionCard):§def__init__(self):§super().__init__("Ancientof",3,CHARACTER_CLASS.ALL,CARD_RARITY.COMMON)§§defcreate_minion(self,player):§returnMinion(2,2,effects=[Effect(TurnEnded(),ActionTag(Give(ChangeAttack(1)),SelfSelector()))])§


## Beam Search
We also try decoding using beam search, expanding the top k most likely sequences at each step

In [None]:
def beam_search_decoding(model, src_ids, src_lengths, max_len, k=25):
    """Keep expanding top k most likely sequences"""
    with torch.no_grad():
        encoder_outputs, encoder_hidden = model.encode(src_ids, src_lengths)
    
    # Keep track of top outputs stores as (log prob, output ID seq, hidden)
    top_outputs = [(0, [SOS_ID], None)]

    for i in range(max_len):
        new_top_outputs = []
        for log_prob, output, hidden in top_outputs:
            # Get last token of candidate output sequence and use as input to decoder
            prev_y = torch.ones(1, 1).type_as(src_ids).fill_(output[-1])
            probs = None
            h = None
            with torch.no_grad():
                o, h = model.decode(prev_y, encoder_outputs, encoder_hidden, hidden)
                probs = model.generator(o[:, -1])
            # Get top k log probs and ids
            topk_log_probs, topk_ids = torch.topk(probs,k, dim=1)
            for token_log_prob, token_id in zip(topk_log_probs[0], topk_ids[0]):
                new_top_outputs.append((log_prob + token_log_prob.data.item(), output + [token_id.data.item()], h))
        # Get top k most likely output sequences up to this point
        new_top_outputs = sorted(new_top_outputs, key=lambda d: d[0], reverse=True)
        top_outputs = new_top_outputs[:k]
    
    # Get the most likely output sequence of all top outputs
    output = np.array(max(top_outputs, key=lambda d: d[0])[1])

    # Cut off everything starting from </s>.
    first_eos = np.where(output == EOS_ID)[0]
    if len(first_eos) > 0:
        output = output[:first_eos[0]]

    return output[1:]

In [None]:
def spot_check_beam(model, dataset, idx=None, k=10):
    """Compare a (random) generated and target sequence using beam search"""
    if idx is None:
        idx = np.random.randint(0, len(dataset))
    inp_ids, trg_ids, inp_lens, trg_lens = dataset[idx: idx+1]
    beam_decoded = beam_search_decoding(model, inp_ids, inp_lens, target_max_seq_len)
    stripped_trg_ids = trg_ids[0][trg_ids[0] != PAD_ID][1:-1].tolist()
    stripped_trg_ids = trg_ids[0][trg_ids[0] != PAD_ID][1:-1].tolist()
    stripped_inp_ids = inp_ids[0][inp_ids[0] != PAD_ID][1:-1].tolist()
    print("===============================")
    print(f"Input: {tokens_to_text(stripped_inp_ids, input_id2token)}")
    print(f"Expected:\n\n\t{tokens_to_text(stripped_trg_ids, target_id2token)}\n\n-got-\n\n\t{tokens_to_text(beam_decoded, target_id2token)}")
    print("===============================")

In [None]:
# Spot check for training set
spot_check_beam(simple_encoder_decoder, simple_train_dataset)



Input: CaptainGreenskinNAME_END5ATK_END4DEF_END5COST_END-1DUR_ENDMinionTYPE_ENDNeutralPLAYER_CLS_ENDPirateRACE_ENDLegendaryRARITY_END<b>Battlecry:</b>Giveyourweapon+1/+1.
Expected:

	classCaptainGreenskin(MinionCard):§def__init__(self):§super().__init__("CaptainGreenskin",5,CHARACTER_CLASS.ALL,CARD_RARITY.LEGENDARY,minion_type=MINION_TYPE.PIRATE,battlecry=Battlecry([IncreaseWeaponAttack(1),IncreaseDurability()],WeaponSelector()))§§defcreate_minion(self,player):§returnMinion(5,4)§

-got-

	classDruidOfTheClaw(MinionCard):§def__init__(self):§super().__init__("AncientTotem",3,CHARACTER_CLASS.ALL,CARD_RARITY.COMMON,minion_type=MINION_TYPE.MECH)§§defcreate_minion(self,player):§returnMinion(2,3,effects=[Effect(TurnStarted(),ActionTag(Give(ChangeAttack(1)),PlayerSelector()))])§


In [None]:
# Spot check for testing set
spot_check_beam(simple_encoder_decoder, simple_test_dataset)



Input: UNKUNKUNKVUNKUNKNAME_END-1ATK_END-1DEF_END5COST_END-1DUR_ENDSpellTYPE_ENDPaladinPLAYER_CLS_ENDNILRACE_ENDCommonRARITY_ENDDraw2cards.Costs(1)lessforeachminionthatdiedthisturn.
Expected:

	classUNKUNKUNKUNKVUNKUNK(SpellCard):§def__init__(self):§super().__init__("UNKUNKUNKUNKVUNKUNK",5,CHARACTER_CLASS.PALADIN,CARD_RARITY.COMMON,buffs=[Buff(ManaChange(Count(DeadMinionSelector(players=BothPlayer())),-1))])§§defuse(self,player,game):§super().use(player,game)§forUNKinrange(0,2):§player.draw()§

-got-

	classLightOfTheNaaru(SpellCard):§def__init__(self):§super().__init__("Arcaneofs",1,CHARACTER_CLASS.SHAMAN,CARD_RARITY.COMMON,target_func=hearthbreaker.targeting.find_minion_spell_target)§§defuse(self,player,game):§super().use(player,game)§self.target.damage(player.effective_spell_damage(2),self)§


# Evaluation
We evaluate our models using accuracy (exact match) and BLEU scores

In [None]:
def evaluate_accuracy(model, test_dataset, decoder, pad_token):
    """
    :param model: model to evaluate
    :param test_dataset: test dataset to evaluate that yields (input, target_tokens, input_length (or empty value), target_length (or empty value))
            target_tokens should have sequence that starts and ends with SOS and EOS tokens respectively and may be padded with pad_token
    :param decoder: decoder to evaluate with; returns a list of predicted tokens with SOS, EOS, and PAD tokens removed
    :param pad_token: padding token used in target_tokens
    """
    matches = []
    for i in tqdm(range(len(test_dataset)), position=0, leave=True):
        inp, trg_tokens, inp_len, trg_len = test_dataset[i: i+1]
        trunc_trg_tokens = trg_tokens[0][trg_tokens[0] != pad_token][1:-1].tolist()

        pred_tokens = decoder(model, inp, inp_len, trg_len)
        
        matches.append(1 if pred_tokens == trg_tokens else 0)
    return matches

In [None]:
def evaluate_bleu(model, test_dataset, decoder, token2str, pad_token):
    """
    :param model: model to evaluate
    :param test_dataset: test dataset to evaluate that yields (input, target_tokens, input_length (or empty value), target_length (or empty value))
            target_tokens should have sequence that starts and ends with SOS and EOS tokens respectively and may be padded with pad_token
    :param decoder: decoder to evaluate with; returns a list of predicted tokens with SOS, EOS, and PAD tokens removed
    :param token2str: mapping from token to string
    :param pad_token: padding token used in target_tokens
    """
    bleu_scores = []
    for i in tqdm(range(len(test_dataset)), position=0, leave=True):
        inp, trg_tokens, inp_len, trg_len = test_dataset[i: i+1]
        trunc_trg_tokens = trg_tokens[0][trg_tokens[0] != pad_token][1:-1].tolist()

        pred_tokens = decoder(model, inp, inp_len, trg_len)
        pred_text = " ".join([token2str[t] for t in pred_tokens])
        trg_text = " ".join([token2str[t] for t in trunc_trg_tokens])

        bleu_scores.append(sacrebleu.raw_corpus_bleu([pred_text], [[trg_text]], 0.01).score)

    return bleu_scores

## Simple Model

In [None]:
# Beam Search
simple_beam_matches = evaluate_accuracy(simple_encoder_decoder, simple_test_dataset, beam_search_decoding, PAD_ID)
simple_beam_bleus = evaluate_bleu(simple_encoder_decoder, simple_test_dataset, beam_search_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [03:08<00:00,  2.85s/it]
100%|██████████| 66/66 [03:07<00:00,  2.84s/it]


In [None]:
# Greedy Search
simple_greedy_matches = evaluate_accuracy(simple_encoder_decoder, simple_test_dataset, greedy_decoding, PAD_ID)
simple_greedy_bleus = evaluate_bleu(simple_encoder_decoder, simple_test_dataset, greedy_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [00:04<00:00, 15.17it/s]
100%|██████████| 66/66 [00:04<00:00, 14.99it/s]


In [None]:
print("Metrics for Simple UNK Encoder/Decoder")
print(f"Beam Accuracy: {sum(simple_beam_matches) / len(simple_beam_matches)}")
print(f"Beam BLEU: {sum(simple_beam_bleus) / len(simple_beam_bleus)}\n")
print(f"Greedy Accuracy: {sum(simple_greedy_matches) / len(simple_greedy_matches)}")
print(f"Greedy BLEU: {sum(simple_greedy_bleus) / len(simple_greedy_bleus)}")

Metrics for Simple UNK Encoder/Decoder
Beam Accuracy: 0.0
Beam BLEU: 42.81463596876857

Greedy Accuracy: 0.0
Greedy BLEU: 42.900648880536515


# Enforced UNK Tokens
Our simple model may be achieving 0 accuracy because it cannot handle UNK tokens due to the training data not having UNK tokens. Here we will try to train the simple model again, this time with the training data tokenized probabilistically to be replaced with UNK with probability p.

In [None]:
def tokens_to_ids_enforced_unk(tokens, token2id, mode, max_length, unk_prob=0.5):
    """Convert list of tokens to IDs"""
    token_ids = []
    for token in tokens:
        if token in token2id:
            if mode == "train":
                if np.random.random() >= unk_prob:
                    token_ids.append(token2id[token])
                else:    
                    token_ids.append(token2id["UNK"])
            else:
                token_ids.append(token2id[token])
        elif mode == "train":
            token_ids.append(len(token2id))
            token2id[token] = len(token2id)
        else:
            token_ids.append(token2id["UNK"])
    ids_len = len(token_ids) + 2 
    token_ids = [token2id["SOS_INDEX"]] + token_ids + \
                   [token2id["EOS_INDEX"]] + \
                   [token2id["PAD_INDEX"]]*(max_length - ids_len)
    return token_ids, token2id, ids_len

def batch_tokens_to_ids_enforced_unk(line_tokens, token2id, mode, max_length=None, unk_prob=0.5):
    """Convert tokens to IDs for a batch and returns long tensor for whole batch of IDs"""
    if max_length is None:
        max_length = max([len(line) for line in line_tokens])
    line_ids = []
    line_lens = []
    for tokens in line_tokens:
        ids, token2id, line_len = tokens_to_ids_enforced_unk(tokens, token2id, mode, max_length=max_length, unk_prob=unk_prob)
        line_ids.append(ids)
        line_lens.append(line_len)
    return torch.LongTensor(line_ids).to(device), token2id, torch.LongTensor(line_lens).to(device)

In [None]:
class EnforcedUNKHearthstoneDataset(data.Dataset):
    """Simple dataset with forced UNK tokens"""
    def __init__(self, input_line_tokens, target_line_tokens, input_token2id, target_token2id, mode, input_max_seq_len=None, target_max_seq_len=None, unk_prob=0.5):
        self.input_line_ids, _, self.input_line_lens = batch_tokens_to_ids_enforced_unk(input_line_tokens, input_token2id, mode, max_length=input_max_seq_len, unk_prob=unk_prob)
        self.target_line_ids, _, self.target_line_lens = batch_tokens_to_ids_enforced_unk(target_line_tokens, target_token2id, mode, max_length=target_max_seq_len, unk_prob=unk_prob)
        
    def __len__(self):
        return len(self.input_line_ids)
    
    def __getitem__(self, idx):
        return self.input_line_ids[idx], self.target_line_ids[idx], self.input_line_lens[idx], self.target_line_lens[idx]

In [None]:
unk_prob = 0.1

# Split validation and training data
validation_ratio = 0.1
train_size = int((1 - validation_ratio) * len(input_line_tokens))
train_input_line_tokens = input_line_tokens[: train_size]
train_target_line_tokens = target_line_tokens[: train_size]
validation_input_line_tokens = input_line_tokens[train_size:]
validation_target_line_tokens = target_line_tokens[train_size:]

# Create datasets for validation and training
enforced_unk_train_dataset = EnforcedUNKHearthstoneDataset(train_input_line_tokens, train_target_line_tokens, input_token2id, target_token2id, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len, unk_prob=unk_prob)
enforced_unk_validation_dataset = EnforcedUNKHearthstoneDataset(validation_input_line_tokens, validation_target_line_tokens, input_token2id, target_token2id, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len, unk_prob=unk_prob)

In [None]:
# Set params
input_embedding_size = 256
input_hidden_size = 256
target_embedding_size = 256
target_hidden_size = 256
batch_size = 8

# Create data objects
enforced_unk_train_dataloader = data.DataLoader(enforced_unk_train_dataset, batch_size=batch_size, shuffle=True)
enforced_unk_validation_dataloader = data.DataLoader(enforced_unk_validation_dataset, batch_size=batch_size, shuffle=True)
input_vocab_size = len(input_token2id)
target_vocab_size = len(target_token2id)

# Create models
enforced_unk_encoder = SimpleHearthstoneEncoder(input_vocab_size, input_embedding_size, input_hidden_size).to(device)
enforced_unk_decoder = SimpleHearthstoneDecoder(target_vocab_size, target_embedding_size, target_hidden_size, 2 * input_hidden_size, input_max_seq_len).to(device)
enforced_unk_generator = Generator(target_hidden_size, target_vocab_size)
enforced_unk_encoder_decoder = SimpleHearthstoneEncoderDecoder(enforced_unk_encoder, enforced_unk_decoder, enforced_unk_generator).to(device)

In [None]:
# Train model
epochs = 20
lr = 1e-3

train(enforced_unk_encoder_decoder, enforced_unk_train_dataloader, enforced_unk_validation_dataloader, epochs, lr)

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.87it/s]

Total loss: 92.77032767430049


100%|██████████| 7/7 [00:01<00:00,  5.80it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 64.59443659079196
Validation perplexity: 64.594437
Epoch 1


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.64it/s]

Total loss: 48.891635205487574


100%|██████████| 7/7 [00:01<00:00,  5.75it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 51.11495101568877
Validation perplexity: 51.114951
Epoch 2


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.52it/s]

Total loss: 39.58172748103038


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 45.10766823932047
Validation perplexity: 45.107668
Epoch 3


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.00it/s]

Total loss: 34.66082653726583


100%|██████████| 7/7 [00:01<00:00,  5.66it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 40.62151746055084
Validation perplexity: 40.621517
Epoch 4


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.08it/s]

Total loss: 31.303155629711213


100%|██████████| 7/7 [00:01<00:00,  5.73it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 38.234697182824284
Validation perplexity: 38.234697
Epoch 5


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.96it/s]

Total loss: 29.403740340650828


100%|██████████| 7/7 [00:01<00:00,  5.81it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 37.95302318738754
Validation perplexity: 37.953023
Epoch 6


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.93it/s]

Total loss: 27.342854793617203


100%|██████████| 7/7 [00:01<00:00,  5.77it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 35.35450890742274
Validation perplexity: 35.354509
Epoch 7


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.85it/s]

Total loss: 25.692043881454772


100%|██████████| 7/7 [00:01<00:00,  5.77it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 35.300711434534136
Validation perplexity: 35.300711
Epoch 8


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.70it/s]

Total loss: 24.736581774388515


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 34.07126083265607
Validation perplexity: 34.071261
Epoch 9


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.95it/s]

Total loss: 23.38962841131841


100%|██████████| 7/7 [00:01<00:00,  5.86it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 31.867095242302625
Validation perplexity: 31.867095
Epoch 10


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.71it/s]

Total loss: 22.935115550322706


100%|██████████| 7/7 [00:01<00:00,  5.75it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 31.88045793660857
Validation perplexity: 31.880458
Epoch 11


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.51it/s]

Total loss: 23.694296748085513


100%|██████████| 7/7 [00:01<00:00,  5.73it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 32.179020954045555
Validation perplexity: 32.179021
Epoch 12


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.10it/s]

Total loss: 21.569570093892345


100%|██████████| 7/7 [00:01<00:00,  5.78it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 28.929138871489776
Validation perplexity: 28.929139
Epoch 13


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.88it/s]

Total loss: 20.089082940110465


100%|██████████| 7/7 [00:01<00:00,  5.81it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 29.224560952621612
Validation perplexity: 29.224561
Epoch 14


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.06it/s]

Total loss: 18.31020873646927


100%|██████████| 7/7 [00:01<00:00,  5.90it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 26.314070551501736
Validation perplexity: 26.314071
Epoch 15


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.94it/s]

Total loss: 18.510124516125533


100%|██████████| 7/7 [00:01<00:00,  5.76it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 26.767510583979487
Validation perplexity: 26.767511
Epoch 16


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.74it/s]

Total loss: 17.466616161335683


100%|██████████| 7/7 [00:01<00:00,  5.34it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 25.515796022626485
Validation perplexity: 25.515796
Epoch 17


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.90it/s]

Total loss: 16.647960925848313


100%|██████████| 7/7 [00:01<00:00,  5.80it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 25.311507377062533
Validation perplexity: 25.311507
Epoch 18


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.91it/s]

Total loss: 17.43338158722167


100%|██████████| 7/7 [00:01<00:00,  5.78it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 25.215000358228714
Validation perplexity: 25.215000
Epoch 19


100%|██████████| 60/60 [00:31<00:00,  1.92it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.75it/s]

Total loss: 16.805323656696036


100%|██████████| 7/7 [00:01<00:00,  5.72it/s]

Total loss: 25.660071331196782
Validation perplexity: 25.660071





[64.59443659079196,
 51.11495101568877,
 45.10766823932047,
 40.62151746055084,
 38.234697182824284,
 37.95302318738754,
 35.35450890742274,
 35.300711434534136,
 34.07126083265607,
 31.867095242302625,
 31.88045793660857,
 32.179020954045555,
 28.929138871489776,
 29.224560952621612,
 26.314070551501736,
 26.767510583979487,
 25.515796022626485,
 25.311507377062533,
 25.215000358228714,
 25.660071331196782]

In [None]:
# Set to True if you want to save this model
# MAKE SURE TO SAVE YOUR MAPPINGS AS WELL WITH THE CELL EARLIER IN THE NOTEBOOK
save_model = False
if save_model:
    torch.save(enforced_unk_encoder_decoder.state_dict(), os.path.join(project_dir, "enforced_unk_encoder_decoder.pt"))

# Set to True if you want to load a previously trained model
# MAKE SURE TO LOAD MODEL'S CORRESPONDING MAPPINGS WITH THE CELL EARLIER IN THE NOTEBOOK
load_model = False
if load_model:
    enforced_unk_encoder_decoder.load_state_dict(torch.load(os.path.join(project_dir, "enforced_unk_encoder_decoder.pt"), map_location=device))
    enforced_unk_encoder_decoder.eval()

## Evaluate Model

In [None]:
# Read and tokenize test inputs and targets
test_raw_inputs = read_file(test_input_path)
test_raw_targets = read_file(test_target_path)
test_input_line_tokens = [e.tokens for e in input_tokenizer.encode_batch(test_raw_inputs)]
test_target_line_tokens = [e.tokens for e in target_tokenizer.encode_batch(test_raw_targets)]

# Truncate line tokens so it matches training data
trunc_test_input_line_tokens = []
for line_tokens in test_input_line_tokens:
    if len(line_tokens) > input_max_seq_len - 2:
        trunc_test_input_line_tokens.append(line_tokens[:input_max_seq_len - 2])
    else:
        trunc_test_input_line_tokens.append(line_tokens)
        
trunc_test_target_line_tokens = []
for line_tokens in test_target_line_tokens:
    if len(line_tokens) > target_max_seq_len - 2:
        trunc_test_target_line_tokens.append(line_tokens[:target_max_seq_len - 2])
    else:
        trunc_test_target_line_tokens.append(line_tokens)

simple_test_dataset = SimpleHearthstoneDataset(trunc_test_input_line_tokens, trunc_test_target_line_tokens, input_token2id, target_token2id, "test", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)

In [None]:
spot_check_greedy(enforced_unk_encoder_decoder, simple_test_dataset)
# spot_check_beam(enforced_unk_encoder_decoder, simple_test_dataset)



Input: SummonUNKPortalNAME_END0ATK_END4DEF_END4COST_END-1DUR_ENDMinionTYPE_ENDWarlockPLAYER_CLS_ENDNILRACE_ENDCommonRARITY_ENDYourminionscost(2)less,bUNKUNKUNKlessUNK(1).
Expected:

	classSummonUNKPortal(MinionCard):§def__init__(self):§super().__init__("SummonUNKPortal",4,CHARACTER_CLASS.WARLOCK,CARD_RARITY.COMMON)§§defcreate_minion(self,player):§returnMinion(0,4,auras=[Aura(ManaChange(-2,1,mUNKUNK=1),CardSelector(condition=IsMinion()))])§

-got-

	classUNK(MinionCard):§def__init__(self):§super().__init__("UNKUNKUNK",",,,CHARACTER_CLASS.ALLCARD_RARITY.COMMON,defdef((self((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((((


In [None]:
# Beam Search
enforced_unk_beam_matches = evaluate_accuracy(enforced_unk_encoder_decoder, simple_test_dataset, beam_search_decoding, PAD_ID)
enforced_unk_beam_bleus = evaluate_bleu(enforced_unk_encoder_decoder, simple_test_dataset, beam_search_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [03:04<00:00,  2.79s/it]
100%|██████████| 66/66 [03:03<00:00,  2.79s/it]


In [None]:
# Greedy Search
enforced_unk_greedy_matches = evaluate_accuracy(enforced_unk_encoder_decoder, simple_test_dataset, greedy_decoding, PAD_ID)
enforced_unk_greedy_bleus = evaluate_bleu(enforced_unk_encoder_decoder, simple_test_dataset, greedy_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [00:04<00:00, 15.12it/s]
100%|██████████| 66/66 [00:04<00:00, 14.98it/s]


In [None]:
print("Metrics for Enforced UNK Encoder/Decoder")
print(f"Beam Accuracy: {sum(enforced_unk_beam_matches) / len(enforced_unk_beam_matches)}")
print(f"Beam BLEU: {sum(enforced_unk_beam_bleus) / len(enforced_unk_beam_bleus)}\n")
print(f"Greedy Accuracy: {sum(enforced_unk_greedy_matches) / len(enforced_unk_greedy_matches)}")
print(f"Greedy BLEU: {sum(enforced_unk_greedy_bleus) / len(enforced_unk_greedy_bleus)}")

Metrics for Enforced UNK Encoder/Decoder
Beam Accuracy: 0.0
Beam BLEU: 28.52484777391203

Greedy Accuracy: 0.0
Greedy BLEU: 27.989055259631616


## Alternative Approach
Instead of replacing tokens in a sequence with random UNK, we replace entire tokens with UNK throughout all sequences (from input to output)

In [None]:
unk_prob = 0.1

# Choose dropped tokens
all_tokens = set(input_token2id.keys()).union(set(target_token2id.keys()))
dropped_tokens = set()
for token in all_tokens:
    if np.random.random() < unk_prob:
        dropped_tokens.add(token)

# Drop tokens from input and target mappings
dropped_input_token2id = {
    token: (id if token not in dropped_tokens else input_token2id["UNK"])
    for (token, id) in input_token2id.items()
}

dropped_target_token2id = {
    token: (id if token not in dropped_tokens else target_token2id["UNK"])
    for (token, id) in target_token2id.items()
}

In [None]:
# Split validation and training data
validation_ratio = 0.1
train_size = int((1 - validation_ratio) * len(input_line_tokens))
train_input_line_tokens = input_line_tokens[: train_size]
train_target_line_tokens = target_line_tokens[: train_size]
validation_input_line_tokens = input_line_tokens[train_size:]
validation_target_line_tokens = target_line_tokens[train_size:]

# Create datasets for validation and training
enforced_unk_train_dataset_v2 = SimpleHearthstoneDataset(train_input_line_tokens, train_target_line_tokens, dropped_input_token2id, dropped_target_token2id, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)
enforced_unk_validation_dataset_v2 = SimpleHearthstoneDataset(validation_input_line_tokens, validation_target_line_tokens, dropped_input_token2id, dropped_target_token2id, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)

In [None]:
# Set params
input_embedding_size = 256
input_hidden_size = 256
target_embedding_size = 256
target_hidden_size = 256
batch_size = 8

# Create data objects
enforced_unk_train_dataloader_v2 = data.DataLoader(enforced_unk_train_dataset_v2, batch_size=batch_size, shuffle=True)
enforced_unk_validation_dataloader_v2 = data.DataLoader(enforced_unk_validation_dataset_v2, batch_size=batch_size, shuffle=True)
input_vocab_size = len(input_token2id)
target_vocab_size = len(target_token2id)

# Create models
enforced_unk_encoder_v2 = SimpleHearthstoneEncoder(input_vocab_size, input_embedding_size, input_hidden_size).to(device)
enforced_unk_decoder_v2 = SimpleHearthstoneDecoder(target_vocab_size, target_embedding_size, target_hidden_size, 2 * input_hidden_size, input_max_seq_len).to(device)
enforced_unk_generator_v2 = Generator(target_hidden_size, target_vocab_size)
enforced_unk_encoder_decoder_v2 = SimpleHearthstoneEncoderDecoder(enforced_unk_encoder_v2, enforced_unk_decoder_v2, enforced_unk_generator_v2).to(device)

In [None]:
# Train model
epochs = 20
lr = 1e-3

train(enforced_unk_encoder_decoder_v2, enforced_unk_train_dataloader_v2, enforced_unk_validation_dataloader_v2, epochs, lr)

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.00it/s]

Total loss: 86.99063554375284


100%|██████████| 7/7 [00:01<00:00,  5.77it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 55.93571853527773
Validation perplexity: 55.935719
Epoch 1


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.78it/s]

Total loss: 43.19354521168168


100%|██████████| 7/7 [00:01<00:00,  5.74it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 42.52581145639964
Validation perplexity: 42.525811
Epoch 2


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.59it/s]

Total loss: 33.871814307446826


100%|██████████| 7/7 [00:01<00:00,  5.54it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 37.51456801933417
Validation perplexity: 37.514568
Epoch 3


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.73it/s]

Total loss: 28.973027322808345


100%|██████████| 7/7 [00:01<00:00,  5.69it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 33.54435115696399
Validation perplexity: 33.544351
Epoch 4


100%|██████████| 60/60 [00:32<00:00,  1.85it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.71it/s]

Total loss: 25.5705341044714


100%|██████████| 7/7 [00:01<00:00,  5.55it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 30.02757694986674
Validation perplexity: 30.027577
Epoch 5


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.90it/s]

Total loss: 22.495538624085615


100%|██████████| 7/7 [00:01<00:00,  5.62it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 26.165340942865896
Validation perplexity: 26.165341
Epoch 6


100%|██████████| 60/60 [00:32<00:00,  1.85it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.76it/s]

Total loss: 17.797076533503574


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 20.305135698322616
Validation perplexity: 20.305136
Epoch 7


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.68it/s]

Total loss: 13.774245278179588


100%|██████████| 7/7 [00:01<00:00,  5.62it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 15.337506187230248
Validation perplexity: 15.337506
Epoch 8


100%|██████████| 60/60 [00:32<00:00,  1.85it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.87it/s]

Total loss: 10.406968874387482


100%|██████████| 7/7 [00:01<00:00,  5.61it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 11.057919975613332
Validation perplexity: 11.057920
Epoch 9


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.71it/s]

Total loss: 6.795201434413


100%|██████████| 7/7 [00:01<00:00,  5.62it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 7.6366299014346435
Validation perplexity: 7.636630
Epoch 10


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.82it/s]

Total loss: 4.969993699617301


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.974474289644345
Validation perplexity: 5.974474
Epoch 11


100%|██████████| 60/60 [00:32<00:00,  1.85it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.84it/s]

Total loss: 4.07069735218965


100%|██████████| 7/7 [00:01<00:00,  5.59it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.290117000226588
Validation perplexity: 5.290117
Epoch 12


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.88it/s]

Total loss: 3.5759226856719555


100%|██████████| 7/7 [00:01<00:00,  5.58it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.854024420297158
Validation perplexity: 4.854024
Epoch 13


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.01it/s]

Total loss: 3.2191803802486065


100%|██████████| 7/7 [00:01<00:00,  5.70it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.384808802926134
Validation perplexity: 4.384809
Epoch 14


100%|██████████| 60/60 [00:32<00:00,  1.85it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.69it/s]

Total loss: 2.9456561541236947


100%|██████████| 7/7 [00:01<00:00,  5.67it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.153415953208622
Validation perplexity: 4.153416
Epoch 15


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.49it/s]

Total loss: 2.746035385736981


100%|██████████| 7/7 [00:01<00:00,  5.58it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.0123031717407995
Validation perplexity: 4.012303
Epoch 16


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.96it/s]

Total loss: 2.595272552695728


100%|██████████| 7/7 [00:01<00:00,  5.69it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 3.811992156954705
Validation perplexity: 3.811992
Epoch 17


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  6.00it/s]

Total loss: 2.4350947595766734


100%|██████████| 7/7 [00:01<00:00,  5.64it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 3.7009473205194627
Validation perplexity: 3.700947
Epoch 18


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.88it/s]

Total loss: 2.3192863135169977


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 3.668884895560412
Validation perplexity: 3.668885
Epoch 19


100%|██████████| 60/60 [00:32<00:00,  1.86it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.89it/s]

Total loss: 2.2279127879187373


100%|██████████| 7/7 [00:01<00:00,  5.79it/s]

Total loss: 3.6172720972582355
Validation perplexity: 3.617272





[55.93571853527773,
 42.52581145639964,
 37.51456801933417,
 33.54435115696399,
 30.02757694986674,
 26.165340942865896,
 20.305135698322616,
 15.337506187230248,
 11.057919975613332,
 7.6366299014346435,
 5.974474289644345,
 5.290117000226588,
 4.854024420297158,
 4.384808802926134,
 4.153415953208622,
 4.0123031717407995,
 3.811992156954705,
 3.7009473205194627,
 3.668884895560412,
 3.6172720972582355]

In [None]:
# Set to True if you want to save this model
# MAKE SURE TO SAVE YOUR MAPPINGS AS WELL WITH THE CELL EARLIER IN THE NOTEBOOK
save_model = False
if save_model:
    torch.save(enforced_unk_encoder_decoder_v2.state_dict(), os.path.join(project_dir, "enforced_unk_encoder_decoder_v2.pt"))

# Set to True if you want to load a previously trained model
# MAKE SURE TO LOAD MODEL'S CORRESPONDING MAPPINGS WITH THE CELL EARLIER IN THE NOTEBOOK
load_model = False
if load_model:
    enforced_unk_encoder_decoder_v2.load_state_dict(torch.load(os.path.join(project_dir, "enforced_unk_encoder_decoder_v2.pt"), map_location=device))
    enforced_unk_encoder_decoder_v2.eval()

In [None]:
spot_check_greedy(enforced_unk_encoder_decoder_v2, simple_test_dataset)
# spot_check_beam(enforced_unk_encoder_decoder, simple_test_dataset)



Input: UNKUNKUNKNAME_END-1ATK_END-1DEF_END1COST_END-1DUR_ENDSpellTYPE_ENDRoguePLAYER_CLS_ENDNILRACE_ENDCommonRARITY_ENDGiveyourminions<b>Stealth</b>untilyournextturn.
Expected:

	classUNKUNKUNK(SpellCard):§def__init__(self):§super().__init__("UNKUNKUNK",1,CHARACTER_CLASS.ROGUE,CARD_RARITY.COMMON)§§defuse(self,player,game):§super().use(player,game)§forminioninplayer.minions:§ifnotminion.stealth:§minion.add_buff(BuffUntil(Stealth(),TurnStarted()))§

-got-

	UNKUNK(MinionCard):§def__init__(self):§super().__init__("UNKUNK",UNK,CHARACTER_CLASS.ALL,CARD_RARITY.COMMON)§§defcreate_minion(self,player):§returnMinion(UNK,UNK,UNKUNKTrue)§


In [None]:
# Beam Search
enforced_unk_v2_beam_matches = evaluate_accuracy(enforced_unk_encoder_decoder_v2, simple_test_dataset, beam_search_decoding, PAD_ID)
enforced_unk_v2_beam_bleus = evaluate_bleu(enforced_unk_encoder_decoder_v2, simple_test_dataset, beam_search_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [03:04<00:00,  2.79s/it]
100%|██████████| 66/66 [03:03<00:00,  2.78s/it]


In [None]:
# Greedy Search
enforced_unk_v2_greedy_matches = evaluate_accuracy(enforced_unk_encoder_decoder_v2, simple_test_dataset, greedy_decoding, PAD_ID)
enforced_unk_v2_greedy_bleus = evaluate_bleu(enforced_unk_encoder_decoder_v2, simple_test_dataset, greedy_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [00:04<00:00, 15.17it/s]
100%|██████████| 66/66 [00:04<00:00, 15.02it/s]


In [None]:
print("Metrics for Enforced UNK Encoder/Decoder v2")
print(f"Beam Accuracy: {sum(enforced_unk_beam_matches) / len(enforced_unk_beam_matches)}")
print(f"Beam BLEU: {sum(enforced_unk_beam_bleus) / len(enforced_unk_beam_bleus)}\n")
print(f"Greedy Accuracy: {sum(enforced_unk_greedy_matches) / len(enforced_unk_greedy_matches)}")
print(f"Greedy BLEU: {sum(enforced_unk_greedy_bleus) / len(enforced_unk_greedy_bleus)}")

Metrics for Enforced UNK Encoder/Decoder v2
Beam Accuracy: 0.0
Beam BLEU: 28.52484777391203

Greedy Accuracy: 0.0
Greedy BLEU: 27.989055259631616


## Yet Another Approach
Here we will try yet another approach, replacing the least frequent tokens with UNK instead of choosing tokens at random.

In [None]:
unk_ratio = 0.25
token_counts = {}
all_line_tokens = input_line_tokens + target_line_tokens

# Get all the token counts between input and target
for line in all_line_tokens:
    for token in line:
        token_counts[token] = token_counts.get(token, 0) + 1

# Drop the least used tokens
tokens_sorted_by_count = [token for (token, count) in sorted(token_counts.items(), key=lambda i: i[1])]
dropped_tokens = set(tokens_sorted_by_count[:int(unk_ratio * len(tokens_sorted_by_count))])

# Drop tokens from input and target mappings
dropped_input_token2id_v3 = {
    token: (id if token not in dropped_tokens else input_token2id["UNK"])
    for (token, id) in input_token2id.items()
}

dropped_target_token2id_v3 = {
    token: (id if token not in dropped_tokens else target_token2id["UNK"])
    for (token, id) in target_token2id.items()
}

In [None]:
# Split validation and training data
validation_ratio = 0.1
train_size = int((1 - validation_ratio) * len(input_line_tokens))
train_input_line_tokens = input_line_tokens[: train_size]
train_target_line_tokens = target_line_tokens[: train_size]
validation_input_line_tokens = input_line_tokens[train_size:]
validation_target_line_tokens = target_line_tokens[train_size:]

# Create datasets for validation and training
enforced_unk_train_dataset_v3 = SimpleHearthstoneDataset(train_input_line_tokens, train_target_line_tokens, dropped_input_token2id_v3, dropped_target_token2id_v3, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)
enforced_unk_validation_dataset_v3 = SimpleHearthstoneDataset(validation_input_line_tokens, validation_target_line_tokens, dropped_input_token2id_v3, dropped_target_token2id_v3, "train", input_max_seq_len=input_max_seq_len, target_max_seq_len=target_max_seq_len)

In [None]:
# Set params
input_embedding_size = 256
input_hidden_size = 256
target_embedding_size = 256
target_hidden_size = 256
batch_size = 8

# Create data objects
enforced_unk_train_dataloader_v3 = data.DataLoader(enforced_unk_train_dataset_v3, batch_size=batch_size, shuffle=True)
enforced_unk_validation_dataloader_v3 = data.DataLoader(enforced_unk_validation_dataset_v3, batch_size=batch_size, shuffle=True)
input_vocab_size = len(input_token2id)
target_vocab_size = len(target_token2id)

# Create models
enforced_unk_encoder_v3 = SimpleHearthstoneEncoder(input_vocab_size, input_embedding_size, input_hidden_size).to(device)
enforced_unk_decoder_v3 = SimpleHearthstoneDecoder(target_vocab_size, target_embedding_size, target_hidden_size, 2 * input_hidden_size, input_max_seq_len).to(device)
enforced_unk_generator_v3 = Generator(target_hidden_size, target_vocab_size)
enforced_unk_encoder_decoder_v3 = SimpleHearthstoneEncoderDecoder(enforced_unk_encoder_v3, enforced_unk_decoder_v3, enforced_unk_generator_v3.to(device))

In [None]:
# Train model
epochs = 20
lr = 1e-3

train(enforced_unk_encoder_decoder_v3, enforced_unk_train_dataloader_v3, enforced_unk_validation_dataloader_v3, epochs, lr)

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.83it/s]

Total loss: 92.90155431578977


100%|██████████| 7/7 [00:01<00:00,  5.81it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 70.86083590341445
Validation perplexity: 70.860836
Epoch 1


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.63it/s]

Total loss: 47.77642286643535


100%|██████████| 7/7 [00:01<00:00,  5.66it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 54.9881297865004
Validation perplexity: 54.988130
Epoch 2


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.91it/s]

Total loss: 36.58463477760536


100%|██████████| 7/7 [00:01<00:00,  5.74it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 46.88114481707129
Validation perplexity: 46.881145
Epoch 3


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.93it/s]

Total loss: 30.313693006299705


100%|██████████| 7/7 [00:01<00:00,  5.70it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 39.78882473213219
Validation perplexity: 39.788825
Epoch 4


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.94it/s]

Total loss: 23.486893803743


100%|██████████| 7/7 [00:01<00:00,  5.79it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 30.59235461348468
Validation perplexity: 30.592355
Epoch 5


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.94it/s]

Total loss: 17.381082393832063


100%|██████████| 7/7 [00:01<00:00,  5.74it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 22.398059777594625
Validation perplexity: 22.398060
Epoch 6


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.60it/s]

Total loss: 12.217228434710618


100%|██████████| 7/7 [00:01<00:00,  5.62it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 15.917244522912503
Validation perplexity: 15.917245
Epoch 7


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.91it/s]

Total loss: 8.697692718890112


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 12.321492994599591
Validation perplexity: 12.321493
Epoch 8


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:00,  6.09it/s]

Total loss: 6.640472151253793


100%|██████████| 7/7 [00:01<00:00,  5.93it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 9.42590698210788
Validation perplexity: 9.425907
Epoch 9


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.97it/s]

Total loss: 5.311904868396282


100%|██████████| 7/7 [00:01<00:00,  5.75it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 8.10437417005206
Validation perplexity: 8.104374
Epoch 10


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.75it/s]

Total loss: 4.495995793279119


100%|██████████| 7/7 [00:01<00:00,  5.65it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 7.056398593629289
Validation perplexity: 7.056399
Epoch 11


100%|██████████| 60/60 [00:31<00:00,  1.88it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.90it/s]

Total loss: 4.020671981885799


100%|██████████| 7/7 [00:01<00:00,  5.73it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 6.462063131907395
Validation perplexity: 6.462063
Epoch 12


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.74it/s]

Total loss: 3.6450487368526563


100%|██████████| 7/7 [00:01<00:00,  5.84it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.997051400354593
Validation perplexity: 5.997051
Epoch 13


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.84it/s]

Total loss: 3.3660196887087332


100%|██████████| 7/7 [00:01<00:00,  5.77it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.830583713406175
Validation perplexity: 5.830584
Epoch 14


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.17it/s]

Total loss: 3.134064816442468


100%|██████████| 7/7 [00:01<00:00,  5.54it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.398652191580625
Validation perplexity: 5.398652
Epoch 15


100%|██████████| 60/60 [00:31<00:00,  1.91it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.92it/s]

Total loss: 2.9368403540203487


100%|██████████| 7/7 [00:01<00:00,  5.72it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 5.109109934093419
Validation perplexity: 5.109110
Epoch 16


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.43it/s]

Total loss: 2.7602158116509057


100%|██████████| 7/7 [00:01<00:00,  5.64it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.879119976635583
Validation perplexity: 4.879120
Epoch 17


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.57it/s]

Total loss: 2.600889293316594


100%|██████████| 7/7 [00:01<00:00,  5.65it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.640098611702336
Validation perplexity: 4.640099
Epoch 18


100%|██████████| 60/60 [00:31<00:00,  1.90it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.63it/s]

Total loss: 2.471539292243573


100%|██████████| 7/7 [00:01<00:00,  5.78it/s]
  0%|          | 0/60 [00:00<?, ?it/s]

Total loss: 4.52993085945324
Validation perplexity: 4.529931
Epoch 19


100%|██████████| 60/60 [00:31<00:00,  1.89it/s]
 14%|█▍        | 1/7 [00:00<00:01,  5.76it/s]

Total loss: 2.3470666426455615


100%|██████████| 7/7 [00:01<00:00,  5.68it/s]

Total loss: 4.324327860800476
Validation perplexity: 4.324328





[70.86083590341445,
 54.9881297865004,
 46.88114481707129,
 39.78882473213219,
 30.59235461348468,
 22.398059777594625,
 15.917244522912503,
 12.321492994599591,
 9.42590698210788,
 8.10437417005206,
 7.056398593629289,
 6.462063131907395,
 5.997051400354593,
 5.830583713406175,
 5.398652191580625,
 5.109109934093419,
 4.879119976635583,
 4.640098611702336,
 4.52993085945324,
 4.324327860800476]

In [None]:
# Set to True if you want to save this model
save_model = False
if save_model:
    torch.save(enforced_unk_encoder_decoder_v3.state_dict(), os.path.join(project_dir, "enforced_unk_encoder_decoder_v3.pt"))

# Set to True if you want to load a previously trained model
load_model = False
if load_model:
    enforced_unk_encoder_decoder_v3.load_state_dict(torch.load(os.path.join(project_dir, "enforced_unk_encoder_decoder_v3.pt"), map_location=device))
    enforced_unk_encoder_decoder_v3.eval()

In [None]:
# spot_check_greedy(enforced_unk_encoder_decoder_v3, simple_test_dataset)
spot_check_beam(enforced_unk_encoder_decoder_v3, simple_test_dataset)



Input: UNKUNKUNKNAME_END4ATK_END7DEF_END6COST_END-1DUR_ENDMinionTYPE_ENDNeutralPLAYER_CLS_ENDNILRACE_ENDCommonRARITY_END<b>SpellDamage+1</b>
Expected:

	classUNKUNKUNK(MinionCard):§def__init__(self):§super().__init__("UNKUNKUNK",6,CHARACTER_CLASS.ALL,CARD_RARITY.COMMON)§§defcreate_minion(self,player):§returnMinion(4,7,spell_damage=1)§

-got-

	classUNK(MinionCard):§def__init__(self):§super().__init__("Druidofthe",3,CHARACTER_CLASS.ALL,CARD_RARITY.COMMON,minion_type=MINION_TYPE.MECH)§§defcreate_minion(self,player):§returnMinion(2,3,effects=[Effect(TurnEnded(),ActionTag(Give(ChangeAttack(1)),SelfSelector()))])§


In [None]:
# Beam Search
enforced_unk_v3_beam_matches = evaluate_accuracy(enforced_unk_encoder_decoder_v3, simple_test_dataset, beam_search_decoding, PAD_ID)
enforced_unk_v3_beam_bleus = evaluate_bleu(enforced_unk_encoder_decoder_v3, simple_test_dataset, beam_search_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [03:06<00:00,  2.83s/it]
100%|██████████| 66/66 [03:12<00:00,  2.92s/it]


In [None]:
# Greedy Search
enforced_unk_v3_greedy_matches = evaluate_accuracy(enforced_unk_encoder_decoder_v3, simple_test_dataset, greedy_decoding, PAD_ID)
enforced_unk_v3_greedy_bleus = evaluate_bleu(enforced_unk_encoder_decoder_v3, simple_test_dataset, greedy_decoding, target_id2token, PAD_ID)

100%|██████████| 66/66 [00:04<00:00, 14.64it/s]
100%|██████████| 66/66 [00:04<00:00, 14.22it/s]


In [None]:
print("Metrics for Enforced UNK Encoder/Decoder v3")
print(f"Beam Accuracy: {sum(enforced_unk_v3_beam_matches) / len(enforced_unk_v3_beam_matches)}")
print(f"Beam BLEU: {sum(enforced_unk_v3_beam_bleus) / len(enforced_unk_v3_beam_bleus)}\n")
print(f"Greedy Accuracy: {sum(enforced_unk_v3_greedy_matches) / len(enforced_unk_v3_greedy_matches)}")
print(f"Greedy BLEU: {sum(enforced_unk_v3_greedy_bleus) / len(enforced_unk_v3_greedy_bleus)}")

Metrics for Enforced UNK Encoder/Decoder v3
Beam Accuracy: 0.0
Beam BLEU: 43.46861853141471

Greedy Accuracy: 0.0
Greedy BLEU: 43.06489261164009


# C2W Representations
We experiment with an alternative approach to word representations, using [C2W](https://arxiv.org/pdf/1508.02096.pdf)

In [15]:
chars = "!\"#$%&'’()*+,-./:;<=>?@[\\]^_`{|}~§ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789\n "
char2id = { c: id for (id, c) in enumerate(chars) }
char2id["PAD"] = len(char2id)
# char2id["SOS"] = len(char2id)
# char2id["EOS"] = len(char2id)
id2char = { id: c for (c, id) in char2id.items() }

In [16]:
def read_file(path):
    """Reads file and returns a list of each line"""
    with open(path) as f:
        return f.readlines()

def get_card_specs_from_lines(lines):
    return [card_spec_from_line(line) for line in lines]

def tokenize_by_space(lines):
    tokenized = []
    for line in lines:
        tokenized.append(line.split())
    return tokenized

def tokenize_by_char(token_lists, char2id, max_token_len=None, max_line_len=None):
    """
    :param token_lists: list of lists where each line contains all the tokens in one line
    :param char2id: mapping from character to ID to encode with
    :return: 3 tensors
            - (num_lines, max_line_length, max_token_length) containing all the encoded words and lines
            - (num_lines, max_line_length) containing the lengths of each token
            - (num_lines,) containing all the lengths of each line
    """
    # Get max token and line lengths
    if max_token_len is None:
        max_token_len = max([len(token) for tokens in token_lists for token in tokens])
    if max_line_len is None:
        max_line_len = max([len(tokens) for tokens in token_lists]) + 2
    # Create padding token
    line_pad = [char2id["PAD"]] * max_token_len
#     sos = [char2id["SOS"]] * max_token_len
#     eos = [char2id["EOS"]] * max_token_len
    total_token_lens = []
    total_line_lens = []
    total_encoded = []
    for tokens in token_lists:
        token_lens = []
        encoded = []
        for token in tokens:
            token_lens.append(len(token))
            # Pad each encoded token to reach max size
            encoded.append([char2id[c] for c in token] + [char2id["PAD"]] * (max_token_len - len(token)))
        # Pad token lengths entry to reach max line length
        total_token_lens.append([1] + token_lens + [1] + [1] * (max_line_len - (len(tokens) + 2)))
        total_line_lens.append(len(tokens) + 2) # 2 for the start and end paddings
        # Pad encoded line to reach max line length
#         total_encoded.append([sos] + encoded + [eos] + [line_pad] * (max_line_len - (len(tokens) + 2)))
        total_encoded.append([line_pad] + encoded + [line_pad] + [line_pad] * (max_line_len - (len(tokens) + 2)))
    return torch.LongTensor(total_encoded).to(device), torch.IntTensor(total_token_lens).to(device), torch.IntTensor(total_line_lens).to(device)

In [17]:
class C2W(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, num_chars):
        super(C2W, self).__init__()

        self.embedding = nn.Embedding(num_chars, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, 1, 
                            batch_first=True, bidirectional=True)
        self.output = nn.Linear(2 * hidden_size, output_size)
        self.output_size = output_size
        
    def forward(self, seq, token_lens):
        """
        :param seq: tensor (line_len, max_token_len) containing sequence
        :param token_lens: tensor (line_len,) containing token lengths of each word
        :returns: tensor (1, line_len, output_size)
        """
        embedded = self.embedding(seq) # (line_len, max_token_len, embedding_size)
        packed_embedded = pack_padded_sequence(embedded, token_lens, batch_first=True, enforce_sorted=False)
        _, (hidden, _) = self.lstm(packed_embedded) # (2 * num_layers, line_len, hidden_size)
        forward_hidden = hidden[::2]
        backward_hidden = hidden[1::2]
        combined = torch.cat([forward_hidden, backward_hidden], dim=2) # (num_layers, line_len, 2 * hidden_size)
        
        return self.output(combined[-1:]) # (1, line_len, output_size)

## C2W Encoder
We will experiment using an ecoder that uses C2W to encode words. This will differ in that the input will be tokenized by word instead of using BPE and the decoder will generate characters.

In [18]:
class HearthstoneDatasetC2W(data.Dataset):
    def __init__(self, raw_input_lines, raw_target_lines, char2id, input_max_token_len, input_max_seq_len, target_max_token_len, target_max_seq_len):
        assert len(raw_input_lines) == len(raw_target_lines)
        self.raw_input_lines = raw_input_lines
        self.raw_target_lines = raw_target_lines

        self.input_words, self.input_token_lengths, self.input_line_lengths = tokenize_by_char(tokenize_by_space(raw_input_lines), char2id, max_token_len=input_max_token_len, max_line_len=input_max_seq_len)
        self.target_words, self.target_token_lengths, self.target_line_lengths = tokenize_by_char([list(line) for line in raw_target_lines], char2id, max_token_len=target_max_token_len, max_line_len=target_max_seq_len)

    def __len__(self):
        return len(self.input_words)

    def __getitem__(self, idx):
        input_line_length = self.input_line_lengths[idx]
        inputs = self.input_words[idx]
        input_lengths = self.input_token_lengths[idx]

        target_line_length = self.target_line_lengths[idx]
        targets = self.target_words[idx]
        target_lengths = self.target_token_lengths[idx]

        return inputs.to(device), targets.to(device).squeeze(1), input_lengths.to("cpu"), target_lengths.to("cpu")

In [19]:
class C2WHearthstoneEncoder(nn.Module):
    """Encoder for hearthstone tokens using C2W"""
    def __init__(self, c2w, hidden_size, num_layers=3, dropout=0.1):
        super(C2WHearthstoneEncoder, self).__init__()
        self.c2w = c2w
        self.rnn = nn.GRU(
            input_size=c2w.output_size,
            hidden_size=hidden_size,
            num_layers=3,
            dropout=dropout,
            batch_first=True,
            bidirectional=True
        )

    def forward(self, inputs, token_lens):
        """
        :param inputs: 3d tensor of shape (1, line_len, max_token_len) with the word embeddings for a single line (batch size is 1)
        :param token_lens: tensor (1, line_len) containing token lengths of each word

        :return: (outputs, hidden) where outputs is 3d tensor of shape (1, max_seq_length, hidden_size)
                and hidden is 3d tensor of shape (num_layers, 1, 2*hidden_size)
        """
        inputs = inputs.squeeze(0)
        token_lens = token_lens.squeeze(0).to("cpu")
        embedded_inputs = self.c2w(inputs, token_lens)

        outputs, hidden = self.rnn(embedded_inputs)

        forward_hidden = hidden[::2]
        backward_hidden = hidden[1::2]
        hidden = torch.cat([forward_hidden, backward_hidden], dim=2)

        return outputs, hidden

In [None]:
raw_input_lines = read_file(train_input_path)
raw_target_lines = read_file(train_target_path)
input_max_token_len = max([len(word) for line in raw_input_lines for word in line.split()])
input_max_seq_len = max([len(line.split()) for line in raw_input_lines]) + 2
target_max_seq_len = max([len(list(line)) for line in raw_target_lines]) + 2

validation_ratio = 0.1
train_size = int((1 - validation_ratio) * len(raw_input_lines))
val_size = int(validation_ratio * len(raw_input_lines))
c2w_training_dataset = HearthstoneDatasetC2W(raw_input_lines[:train_size], raw_target_lines[:train_size], char2id, input_max_token_len, input_max_seq_len, 1, target_max_seq_len)
c2w_validation_dataset = HearthstoneDatasetC2W(raw_input_lines[train_size:], raw_target_lines[train_size:], char2id, input_max_token_len, input_max_seq_len, 1, target_max_seq_len)

In [None]:
c2w_char_embed_size = 100
c2w_hidden_size = 300
c2w_output_size = 300
input_hidden_size = 300
num_chars = len(char2id)

target_embedding_size = 100
target_hidden_size = 300

# Create data objects
c2w_training_loader = data.DataLoader(c2w_training_dataset, batch_size=1, shuffle=True)
c2w_validation_loader = data.DataLoader(c2w_validation_dataset, batch_size=1, shuffle=True)

# Create models
c2w = C2W(c2w_char_embed_size, c2w_hidden_size, c2w_output_size, num_chars).to(device)
c2w_encoder = C2WHearthstoneEncoder(c2w, input_hidden_size).to(device)
c2w_decoder = SimpleHearthstoneDecoder(num_chars, target_embedding_size, target_hidden_size, 2 * input_hidden_size, input_max_seq_len).to(device)
c2w_generator = Generator(target_hidden_size, num_chars)
c2w_encoder_decoder = SimpleHearthstoneEncoderDecoder(c2w_encoder, c2w_decoder, c2w_generator).to(device)

In [None]:
def run_epoch_c2w(data_loader, model, loss_compute):
    """Standard Training and Logging Function"""
    total_tokens = 0
    total_loss = 0

    for i, (src_ids_BxT, trg_ids_BxL, src_lengths_B, trg_lengths_B) in enumerate(tqdm(data_loader, position=0, leave=True)):
        # We define some notations here to help you understand the loaded tensor
        # shapes:
        #     `B`: batch size
        #     `T`: max sequence length of source sentences
        #     `L`: max sequence length of target sentences; due to our preprocessing
        #        in the beginning, `L` == `T` == 50
        # An example of `src_ids_BxT` (when B = 2):
        #     [[2, 4, 6, 7, ..., 4, 3, 0, 0, 0],
        #    [2, 8, 6, 5, ..., 9, 5, 4, 3, 0]]
        # The corresponding `src_lengths_B` would be [47, 49].

        src_ids_BxT = src_ids_BxT.to(device)
        src_lengths_B = src_lengths_B.to(device)
        trg_ids_BxL = trg_ids_BxL.to(device)

        del trg_lengths_B     # unused

        output, _ = model(src_ids_BxT, trg_ids_BxL, src_lengths_B)

        loss = loss_compute(x=output, y=trg_ids_BxL[:, 1:],
                            norm=src_ids_BxT.size(0))
        
        if i % 20 == 0:
            print(f"Iteration {i} Loss: {loss}")
        total_loss += loss
        total_tokens += (trg_ids_BxL[:, 1:] != char2id["PAD"]).data.sum().item() + 1

    print(f"Total loss: {math.exp(total_loss / float(total_tokens))}")

    return math.exp(total_loss / float(total_tokens))

def train_c2w(model, train_data_loader, val_data_loader, num_epochs, learning_rate):
    # Set `ignore_index` as PAD_INDEX so that pad tokens won't be included when
    # computing the loss.
    criterion = nn.NLLLoss(reduction="sum", ignore_index=char2id["PAD"])
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Keep track of dev ppl for each epoch.
    dev_ppls = []

    for epoch in range(num_epochs):
        print("Epoch", epoch)

        model.train()
        train_ppl = run_epoch_c2w(data_loader=train_data_loader, model=model,
                                loss_compute=SimpleLossCompute(model.generator,
                                                             criterion, optim))

        model.eval()
        with torch.no_grad():        
            dev_ppl = run_epoch_c2w(data_loader=val_data_loader, model=model,
                                loss_compute=SimpleLossCompute(model.generator,
                                                             criterion, None))
            print("Validation perplexity: %f" % dev_ppl)
            dev_ppls.append(dev_ppl)
        
    return dev_ppls

In [None]:
# Train model
epochs = 1
lr = 1e-3

train_c2w(c2w_encoder_decoder, c2w_training_loader, c2w_validation_loader, epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  attention_weights = F.softmax(attention_raw)
  0%|          | 1/479 [00:03<27:20,  3.43s/it]

Iteration 0 Loss: 1004.1509399414062


  4%|▍         | 21/479 [01:10<26:02,  3.41s/it]

Iteration 20 Loss: 1014.3773193359375


  9%|▊         | 41/479 [02:18<24:19,  3.33s/it]

Iteration 40 Loss: 746.728759765625


 13%|█▎        | 61/479 [03:25<23:36,  3.39s/it]

Iteration 60 Loss: 861.34228515625


 17%|█▋        | 81/479 [04:32<22:24,  3.38s/it]

Iteration 80 Loss: 826.9150390625


 21%|██        | 101/479 [05:40<21:04,  3.34s/it]

Iteration 100 Loss: 478.9886779785156


 25%|██▌       | 121/479 [06:40<15:25,  2.58s/it]

Iteration 120 Loss: 489.9083557128906


 29%|██▉       | 141/479 [07:38<15:41,  2.78s/it]

Iteration 140 Loss: 1254.7677001953125


 34%|███▎      | 161/479 [08:34<15:48,  2.98s/it]

Iteration 160 Loss: 521.7062377929688


 38%|███▊      | 181/479 [09:27<12:07,  2.44s/it]

Iteration 180 Loss: 245.38674926757812


 42%|████▏     | 201/479 [10:21<13:20,  2.88s/it]

Iteration 200 Loss: 577.2071533203125


 46%|████▌     | 221/479 [11:21<13:21,  3.11s/it]

Iteration 220 Loss: 725.4498291015625


 50%|█████     | 241/479 [12:24<12:06,  3.05s/it]

Iteration 240 Loss: 382.6413879394531


 54%|█████▍    | 261/479 [13:26<11:34,  3.18s/it]

Iteration 260 Loss: 205.62701416015625


 59%|█████▊    | 281/479 [14:29<10:21,  3.14s/it]

Iteration 280 Loss: 525.160400390625


 63%|██████▎   | 301/479 [15:33<09:15,  3.12s/it]

Iteration 300 Loss: 701.9708251953125


 67%|██████▋   | 321/479 [16:36<08:04,  3.07s/it]

Iteration 320 Loss: 1247.2027587890625


 71%|███████   | 341/479 [17:32<05:55,  2.57s/it]

Iteration 340 Loss: 182.8584747314453


 75%|███████▌  | 361/479 [18:21<05:33,  2.83s/it]

Iteration 360 Loss: 201.04940795898438


 80%|███████▉  | 381/479 [19:20<04:46,  2.93s/it]

Iteration 380 Loss: 137.69503784179688


 84%|████████▎ | 401/479 [20:20<03:41,  2.84s/it]

Iteration 400 Loss: 873.7568359375


 88%|████████▊ | 421/479 [21:06<02:18,  2.38s/it]

Iteration 420 Loss: 292.6247863769531


 92%|█████████▏| 441/479 [21:57<01:46,  2.81s/it]

Iteration 440 Loss: 134.40345764160156


 96%|█████████▌| 461/479 [22:48<00:42,  2.38s/it]

Iteration 460 Loss: 109.41586303710938


100%|██████████| 479/479 [23:34<00:00,  2.95s/it]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 4.369875180479339


  2%|▏         | 1/54 [00:00<00:31,  1.68it/s]

Iteration 0 Loss: 171.69740295410156


 39%|███▉      | 21/54 [00:13<00:22,  1.48it/s]

Iteration 20 Loss: 585.6018676757812


 76%|███████▌  | 41/54 [00:25<00:07,  1.70it/s]

Iteration 40 Loss: 210.47496032714844


100%|██████████| 54/54 [00:32<00:00,  1.65it/s]

Total loss: 2.0583144041299906
Validation perplexity: 2.058314





[2.0583144041299906]

In [None]:
# Train for another epoch
train_c2w(c2w_encoder_decoder, c2w_training_loader, c2w_validation_loader, epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  attention_weights = F.softmax(attention_raw)
  0%|          | 1/479 [00:03<25:07,  3.15s/it]

Iteration 0 Loss: 104.9936752319336


  4%|▍         | 21/479 [01:05<23:38,  3.10s/it]

Iteration 20 Loss: 234.01702880859375


  9%|▊         | 41/479 [02:07<23:16,  3.19s/it]

Iteration 40 Loss: 143.1583709716797


 13%|█▎        | 61/479 [03:10<22:10,  3.18s/it]

Iteration 60 Loss: 279.5332336425781


 17%|█▋        | 81/479 [04:12<20:19,  3.07s/it]

Iteration 80 Loss: 139.3435821533203


 21%|██        | 101/479 [05:15<19:42,  3.13s/it]

Iteration 100 Loss: 232.7386474609375


 25%|██▌       | 121/479 [06:18<18:48,  3.15s/it]

Iteration 120 Loss: 255.1192626953125


 29%|██▉       | 141/479 [07:18<15:51,  2.82s/it]

Iteration 140 Loss: 114.7650375366211


 34%|███▎      | 161/479 [08:16<15:04,  2.84s/it]

Iteration 160 Loss: 89.224609375


 38%|███▊      | 181/479 [09:15<15:27,  3.11s/it]

Iteration 180 Loss: 183.22389221191406


 42%|████▏     | 201/479 [10:18<14:30,  3.13s/it]

Iteration 200 Loss: 173.48146057128906


 46%|████▌     | 221/479 [11:15<12:23,  2.88s/it]

Iteration 220 Loss: 94.85865020751953


 50%|█████     | 241/479 [12:17<12:18,  3.10s/it]

Iteration 240 Loss: 170.0771942138672


 54%|█████▍    | 261/479 [13:14<09:24,  2.59s/it]

Iteration 260 Loss: 142.26699829101562


 59%|█████▊    | 281/479 [14:11<10:12,  3.09s/it]

Iteration 280 Loss: 57.66069793701172


 63%|██████▎   | 301/479 [15:13<09:30,  3.21s/it]

Iteration 300 Loss: 230.29859924316406


 67%|██████▋   | 321/479 [16:11<08:13,  3.12s/it]

Iteration 320 Loss: 190.509521484375


 71%|███████   | 341/479 [17:14<07:17,  3.17s/it]

Iteration 340 Loss: 72.6450424194336


 75%|███████▌  | 361/479 [18:19<06:30,  3.31s/it]

Iteration 360 Loss: 184.3108367919922


 80%|███████▉  | 381/479 [19:22<05:06,  3.13s/it]

Iteration 380 Loss: 163.19735717773438


 84%|████████▎ | 401/479 [20:24<04:03,  3.12s/it]

Iteration 400 Loss: 151.927734375


 88%|████████▊ | 421/479 [21:27<03:02,  3.14s/it]

Iteration 420 Loss: 206.09600830078125


 92%|█████████▏| 441/479 [22:29<02:02,  3.22s/it]

Iteration 440 Loss: 159.31797790527344


 96%|█████████▌| 461/479 [23:33<00:58,  3.24s/it]

Iteration 460 Loss: 68.4783935546875


100%|██████████| 479/479 [24:31<00:00,  3.07s/it]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.7333565065245211


  2%|▏         | 1/54 [00:01<00:53,  1.01s/it]

Iteration 0 Loss: 400.7561950683594


 39%|███▉      | 21/54 [00:18<00:28,  1.18it/s]

Iteration 20 Loss: 176.44232177734375


 76%|███████▌  | 41/54 [00:30<00:09,  1.44it/s]

Iteration 40 Loss: 416.2169189453125


100%|██████████| 54/54 [00:39<00:00,  1.35it/s]

Total loss: 1.7130133513195525
Validation perplexity: 1.713013





[1.7130133513195525]

In [None]:
# Train for another epoch
train_c2w(c2w_encoder_decoder, c2w_training_loader, c2w_validation_loader, epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  attention_weights = F.softmax(attention_raw)
  0%|          | 1/479 [00:03<25:49,  3.24s/it]

Iteration 0 Loss: 118.39209747314453


  4%|▍         | 21/479 [00:59<23:37,  3.09s/it]

Iteration 20 Loss: 120.62686157226562


  9%|▊         | 41/479 [02:04<23:20,  3.20s/it]

Iteration 40 Loss: 127.58795166015625


 13%|█▎        | 61/479 [03:09<22:13,  3.19s/it]

Iteration 60 Loss: 90.86248779296875


 17%|█▋        | 81/479 [04:12<20:50,  3.14s/it]

Iteration 80 Loss: 197.1774444580078


 21%|██        | 101/479 [05:16<21:00,  3.33s/it]

Iteration 100 Loss: 118.23661041259766


 25%|██▌       | 121/479 [06:19<19:29,  3.27s/it]

Iteration 120 Loss: 173.5605010986328


 29%|██▉       | 141/479 [07:25<18:35,  3.30s/it]

Iteration 140 Loss: 89.86589813232422


 34%|███▎      | 161/479 [08:32<17:19,  3.27s/it]

Iteration 160 Loss: 126.35882568359375


 38%|███▊      | 181/479 [09:34<15:37,  3.15s/it]

Iteration 180 Loss: 1107.270263671875


 42%|████▏     | 201/479 [10:37<14:29,  3.13s/it]

Iteration 200 Loss: 83.2724609375


 46%|████▌     | 221/479 [11:40<13:35,  3.16s/it]

Iteration 220 Loss: 151.109130859375


 50%|█████     | 241/479 [12:43<12:35,  3.18s/it]

Iteration 240 Loss: 224.6641845703125


 54%|█████▍    | 261/479 [13:36<07:16,  2.00s/it]

Iteration 260 Loss: 147.69778442382812


 59%|█████▊    | 281/479 [14:37<10:25,  3.16s/it]

Iteration 280 Loss: 69.943115234375


 63%|██████▎   | 301/479 [15:37<08:00,  2.70s/it]

Iteration 300 Loss: 92.1741943359375


 67%|██████▋   | 321/479 [16:36<08:27,  3.21s/it]

Iteration 320 Loss: 113.05504608154297


 71%|███████   | 341/479 [17:39<07:15,  3.16s/it]

Iteration 340 Loss: 108.65818786621094


 75%|███████▌  | 361/479 [18:37<05:06,  2.60s/it]

Iteration 360 Loss: 134.15457153320312


 80%|███████▉  | 381/479 [19:27<04:30,  2.76s/it]

Iteration 380 Loss: 120.38597106933594


 84%|████████▎ | 401/479 [20:12<03:11,  2.46s/it]

Iteration 400 Loss: 143.81491088867188


 88%|████████▊ | 421/479 [21:04<02:16,  2.35s/it]

Iteration 420 Loss: 117.31340789794922


 92%|█████████▏| 441/479 [21:53<01:35,  2.52s/it]

Iteration 440 Loss: 95.00065612792969


 96%|█████████▌| 461/479 [22:44<00:53,  2.97s/it]

Iteration 460 Loss: 137.13197326660156


100%|██████████| 479/479 [23:24<00:00,  2.93s/it]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.5379131258143661


  2%|▏         | 1/54 [00:00<00:29,  1.82it/s]

Iteration 0 Loss: 147.22579956054688


 39%|███▉      | 21/54 [00:11<00:18,  1.80it/s]

Iteration 20 Loss: 212.70162963867188


 76%|███████▌  | 41/54 [00:22<00:07,  1.78it/s]

Iteration 40 Loss: 51.27867126464844


100%|██████████| 54/54 [00:30<00:00,  1.79it/s]

Total loss: 1.613265676804163
Validation perplexity: 1.613266





[1.613265676804163]

In [None]:
# Train for another epoch
train_c2w(c2w_encoder_decoder, c2w_training_loader, c2w_validation_loader, epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  attention_weights = F.softmax(attention_raw)
  0%|          | 1/479 [00:02<23:52,  3.00s/it]

Iteration 0 Loss: 176.2230224609375


  4%|▍         | 21/479 [00:46<14:07,  1.85s/it]

Iteration 20 Loss: 407.892333984375


  9%|▊         | 41/479 [01:22<12:52,  1.76s/it]

Iteration 40 Loss: 131.88021850585938


 13%|█▎        | 61/479 [02:18<21:26,  3.08s/it]

Iteration 60 Loss: 157.58567810058594


 17%|█▋        | 81/479 [03:06<12:19,  1.86s/it]

Iteration 80 Loss: 337.029541015625


 21%|██        | 101/479 [03:54<18:41,  2.97s/it]

Iteration 100 Loss: 142.92381286621094


 25%|██▌       | 121/479 [04:55<18:06,  3.04s/it]

Iteration 120 Loss: 75.4714584350586


 29%|██▉       | 141/479 [05:56<17:24,  3.09s/it]

Iteration 140 Loss: 354.0461120605469


 34%|███▎      | 161/479 [06:57<16:04,  3.03s/it]

Iteration 160 Loss: 109.08141326904297


 38%|███▊      | 181/479 [07:58<15:23,  3.10s/it]

Iteration 180 Loss: 131.9510955810547


 42%|████▏     | 201/479 [08:59<13:53,  3.00s/it]

Iteration 200 Loss: 170.85128784179688


 46%|████▌     | 221/479 [09:59<13:16,  3.09s/it]

Iteration 220 Loss: 145.7288360595703


 50%|█████     | 241/479 [11:00<11:58,  3.02s/it]

Iteration 240 Loss: 470.9762878417969


 54%|█████▍    | 261/479 [12:02<11:21,  3.12s/it]

Iteration 260 Loss: 99.40811157226562


 59%|█████▊    | 281/479 [13:03<09:55,  3.01s/it]

Iteration 280 Loss: 91.61802673339844


 63%|██████▎   | 301/479 [14:04<09:11,  3.10s/it]

Iteration 300 Loss: 69.38383483886719


 67%|██████▋   | 321/479 [15:02<08:07,  3.08s/it]

Iteration 320 Loss: 122.82331085205078


 71%|███████   | 341/479 [15:59<06:55,  3.01s/it]

Iteration 340 Loss: 111.33098602294922


 75%|███████▌  | 361/479 [17:00<05:56,  3.02s/it]

Iteration 360 Loss: 509.963623046875


 80%|███████▉  | 381/479 [18:02<05:09,  3.16s/it]

Iteration 380 Loss: 278.86407470703125


 84%|████████▎ | 401/479 [19:00<02:59,  2.31s/it]

Iteration 400 Loss: 278.9239196777344


 88%|████████▊ | 421/479 [20:05<03:07,  3.24s/it]

Iteration 420 Loss: 142.5436553955078


 92%|█████████▏| 441/479 [21:07<02:00,  3.17s/it]

Iteration 440 Loss: 198.1768798828125


 96%|█████████▌| 461/479 [22:10<00:55,  3.08s/it]

Iteration 460 Loss: 167.74264526367188


100%|██████████| 479/479 [23:07<00:00,  2.90s/it]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.7783313433274683


  2%|▏         | 1/54 [00:00<00:30,  1.72it/s]

Iteration 0 Loss: 238.40606689453125


 39%|███▉      | 21/54 [00:19<00:35,  1.06s/it]

Iteration 20 Loss: 373.0147399902344


 76%|███████▌  | 41/54 [00:41<00:14,  1.10s/it]

Iteration 40 Loss: 174.46994018554688


100%|██████████| 54/54 [00:51<00:00,  1.04it/s]

Total loss: 1.8271525224257794
Validation perplexity: 1.827153





[1.8271525224257794]

In [None]:
# Train for another epoch
train_c2w(c2w_encoder_decoder, c2w_training_loader, c2w_validation_loader, epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  attention_weights = F.softmax(attention_raw)
  0%|          | 1/479 [00:03<25:01,  3.14s/it]

Iteration 0 Loss: 264.27606201171875


  4%|▍         | 21/479 [01:08<24:12,  3.17s/it]

Iteration 20 Loss: 178.02565002441406


  9%|▊         | 41/479 [02:14<24:44,  3.39s/it]

Iteration 40 Loss: 82.88799285888672


 13%|█▎        | 61/479 [03:19<22:00,  3.16s/it]

Iteration 60 Loss: 163.76466369628906


 17%|█▋        | 81/479 [04:25<21:36,  3.26s/it]

Iteration 80 Loss: 279.50128173828125


 21%|██        | 101/479 [05:29<19:53,  3.16s/it]

Iteration 100 Loss: 268.5705871582031


 25%|██▌       | 121/479 [06:34<19:32,  3.27s/it]

Iteration 120 Loss: 165.120361328125


 29%|██▉       | 141/479 [07:37<17:15,  3.06s/it]

Iteration 140 Loss: 175.19920349121094


 34%|███▎      | 161/479 [08:40<16:44,  3.16s/it]

Iteration 160 Loss: 480.13031005859375


 38%|███▊      | 181/479 [09:42<15:37,  3.15s/it]

Iteration 180 Loss: 129.5928497314453


 42%|████▏     | 201/479 [10:45<14:41,  3.17s/it]

Iteration 200 Loss: 248.69854736328125


 46%|████▌     | 221/479 [11:49<13:48,  3.21s/it]

Iteration 220 Loss: 171.96432495117188


 50%|█████     | 241/479 [12:51<12:29,  3.15s/it]

Iteration 240 Loss: 125.17884063720703


 54%|█████▍    | 261/479 [13:53<11:37,  3.20s/it]

Iteration 260 Loss: 171.51480102539062


 59%|█████▊    | 281/479 [14:56<10:16,  3.12s/it]

Iteration 280 Loss: 82.65476989746094


 63%|██████▎   | 301/479 [15:59<09:29,  3.20s/it]

Iteration 300 Loss: 99.76980590820312


 67%|██████▋   | 321/479 [17:05<09:03,  3.44s/it]

Iteration 320 Loss: 162.74586486816406


 71%|███████   | 341/479 [18:07<07:28,  3.25s/it]

Iteration 340 Loss: 167.47909545898438


 75%|███████▌  | 361/479 [19:12<06:17,  3.20s/it]

Iteration 360 Loss: 190.69244384765625


 80%|███████▉  | 381/479 [20:15<04:48,  2.94s/it]

Iteration 380 Loss: 234.4292449951172


 84%|████████▎ | 401/479 [21:19<04:01,  3.10s/it]

Iteration 400 Loss: 114.69639587402344


 88%|████████▊ | 421/479 [22:23<03:05,  3.19s/it]

Iteration 420 Loss: 233.8186492919922


 92%|█████████▏| 441/479 [23:25<01:57,  3.09s/it]

Iteration 440 Loss: 162.4336395263672


 96%|█████████▌| 461/479 [24:29<00:57,  3.21s/it]

Iteration 460 Loss: 92.80748748779297


100%|██████████| 479/479 [25:26<00:00,  3.19s/it]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.7841438718447378


  2%|▏         | 1/54 [00:00<00:30,  1.72it/s]

Iteration 0 Loss: 174.72935485839844


 39%|███▉      | 21/54 [00:15<00:25,  1.28it/s]

Iteration 20 Loss: 277.00164794921875


 76%|███████▌  | 41/54 [00:29<00:10,  1.29it/s]

Iteration 40 Loss: 280.65606689453125


100%|██████████| 54/54 [00:38<00:00,  1.39it/s]

Total loss: 1.8372679519248165
Validation perplexity: 1.837268





[1.8372679519248165]

In [None]:
# Set to True if you want to save this model
save_model = False
if save_model:
    torch.save(c2w_encoder_decoder.state_dict(), os.path.join(project_dir, "c2w_encoder_decoder_drop_3e.pt"))

# Set to True if you want to load a previously trained model
load_model = True
if load_model:
    c2w_encoder_decoder.load_state_dict(torch.load(os.path.join(project_dir, "c2w_encoder_decoder_3e.pt"), map_location=device))
    c2w_encoder_decoder.eval()

## C2W Decoding

In [None]:
# Read and tokenize test inputs and targets
test_raw_inputs = read_file(test_input_path)
test_raw_targets = read_file(test_target_path)

# Truncate line tokens so it matches training data
trunc_test_inputs = []
for line in test_raw_inputs:
    tokens = line.split()
    if len(tokens) > input_max_seq_len - 2:
        trunc_test_inputs.append(" ".join(tokens[:input_max_seq_len - 2]))
    else:
        trunc_test_inputs.append(line)
        
trunc_test_raw_targets = []
for line in test_raw_targets:
    if len(line) > target_max_seq_len - 2:
        trunc_test_raw_targets.append(line[:target_max_seq_len - 2])
    else:
        trunc_test_raw_targets.append(line)

c2w_test_dataset = HearthstoneDatasetC2W(trunc_test_inputs, trunc_test_raw_targets, char2id, input_max_token_len, input_max_seq_len, 1, target_max_seq_len)

In [None]:
def greedy_decoding_c2w(model, src_ids, src_lengths, max_len):
    """Greedily decode a sentence for EncoderDecoder. Make sure to chop off the 
         EOS token!"""

    with torch.no_grad():
        encoder_outputs, encoder_hidden = model.encode(src_ids, src_lengths)
#         prev_y = torch.ones(1, 1).fill_(char2id["SOS"]).type_as(src_ids)
        prev_y = torch.ones(1, 1).fill_(char2id["PAD"]).type_as(src_ids)
    
    output = []
    hidden = None

    for i in range(max_len):
        with torch.no_grad():
            outputs, hidden = model.decode(prev_y, encoder_outputs, encoder_hidden, hidden)
            prob = model.generator(outputs[:, -1])
        d, next_word = torch.max(prob, dim=1)
        next_word = next_word.data.item()
        output.append(next_word)
        prev_y = torch.ones(1, 1).type_as(src_ids).fill_(next_word)

    output = np.array(output)

    # Cut off everything starting from </s>.
#     first_pad = np.where(output[1:] == char2id["EOS"])[0]
    first_pad = np.where(output[1:] == char2id["PAD"])[0]
    if len(first_pad) > 0:
        output = output[:first_pad[0]]

    return output

In [None]:
def spot_check_greedy_c2w(model, dataset, idx=None, n=1):
    """Compare a (random) generated and target sequence using greedy search"""
    for i in range(n):
        if idx is None:
            idx = np.random.randint(0, len(dataset))
        inp_ids, trg_ids, inp_lens, trg_lens = dataset[idx: idx+1]
        greedy_decoded = greedy_decoding_c2w(model, inp_ids, inp_lens, target_max_seq_len)
        inp_ids = inp_ids[0][1:]
        trg_ids = trg_ids[0][1:]
        stripped_trg_ids = trg_ids[trg_ids != char2id["PAD"]].tolist()[:-1]
        stripped_inp_ids = inp_ids[inp_ids != char2id["PAD"]].tolist()
#         stripped_inp_ids = inp_ids[inp_ids != char2id["PAD"]]
#         stripped_inp_ids = stripped_inp_ids[stripped_inp_ids != char2id["EOS"]].tolist()
        print("===============================")
        print(f"Input: {tokens_to_text(stripped_inp_ids, id2char)}")
        print(f"Expected:\n\n\t{tokens_to_text(stripped_trg_ids, id2char)}\n\n-got-\n\n\t{tokens_to_text(greedy_decoded, id2char)}")
        print("===============================")

In [None]:
spot_check_greedy_c2w(c2w_encoder_decoder, c2w_test_dataset)

  attention_weights = F.softmax(attention_raw)


Input: ConcealNAME_END-1ATK_END-1DEF_END1COST_END-1DUR_ENDSpellTYPE_ENDRoguePLAYER_CLS_ENDNILRACE_ENDCommonRARITY_ENDGiveyourminions<b>Stealth</b>untilyournextturn.
Expected:

	class Conceal(SpellCard):§    def __init__(self):§        super().__init__("Conceal", 1, CHARACTER_CLASS.ROGUE, CARD_RARITY.COMMON)§§    def use(self, player, game):§        super().use(player, game)§        for minion in player.minions:§            if not minion.stealth:§                minion.add_buff(BuffUntil(Stealth(), TurnStarted()))§

-got-

	class Sindind(MinionCard):§    def __init__(self):§        super().__init__("Sring Cring", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.MECH)§§    def create_minion(self, player):§        return Minion(2, 5)§


In [None]:
def beam_search_decoding_c2w(model, src_ids, src_lengths, max_len, k=25):
    """Keep expanding top k most likely sequences"""
    with torch.no_grad():
        encoder_outputs, encoder_hidden = model.encode(src_ids, src_lengths)
    
    # Keep track of top outputs stores as (log prob, output ID seq, hidden)
    top_outputs = [(0, [char2id["SOS"]], None)]

    for i in range(max_len):
        new_top_outputs = []
        for log_prob, output, hidden in top_outputs:
            # Get last token of candidate output sequence and use as input to decoder
            prev_y = torch.ones(1, 1).type_as(src_ids).fill_(output[-1])
            probs = None
            h = None
            with torch.no_grad():
                o, h = model.decode(prev_y, encoder_outputs, encoder_hidden, hidden)
                probs = model.generator(o[:, -1])
            # Get top k log probs and ids
            topk_log_probs, topk_ids = torch.topk(probs,k, dim=1)
            for token_log_prob, token_id in zip(topk_log_probs[0], topk_ids[0]):
                new_top_outputs.append((log_prob + token_log_prob.data.item(), output + [token_id.data.item()], h))
        # Get top k most likely output sequences up to this point
        new_top_outputs = sorted(new_top_outputs, key=lambda d: d[0], reverse=True)
        top_outputs = new_top_outputs[:k]
    
    # Get the most likely output sequence of all top outputs
    output = np.array(max(top_outputs, key=lambda d: d[0])[1])

    # Cut off everything starting from </s>.
    first_pad = np.where(output[1:] == char2id["EOS"])[0]
    if len(first_pad) > 0:
        output = output[:first_pad[0]]

    return output[1:]

In [None]:
def spot_check_beam_c2w(model, dataset, idx=None, n=1):
    """Compare a (random) generated and target sequence using greedy search"""
    for i in range(n):
        if idx is None:
            idx = np.random.randint(0, len(dataset))
        inp_ids, trg_ids, inp_lens, trg_lens = dataset[idx: idx+1]
        greedy_decoded = beam_search_decoding_c2w(model, inp_ids, inp_lens, target_max_seq_len)
        inp_ids = inp_ids[0][1:]
        trg_ids = trg_ids[0][1:]
        stripped_trg_ids = trg_ids[trg_ids != char2id["PAD"]].tolist()[:-1]
        stripped_inp_ids = inp_ids[inp_ids != char2id["PAD"]]
        stripped_inp_ids = stripped_inp_ids[stripped_inp_ids != char2id["EOS"]].tolist()
        print("===============================")
        print(f"Input: {tokens_to_text(stripped_inp_ids, id2char)}")
        print(f"Expected:\n\n\t{tokens_to_text(stripped_trg_ids, id2char)}\n\n-got-\n\n\t{tokens_to_text(greedy_decoded, id2char)}")
        print("===============================")

In [None]:
spot_check_beam_c2w(c2w_encoder_decoder, c2w_test_dataset)

  attention_weights = F.softmax(attention_raw)


Input: ManaWraithNAME_END2ATK_END2DEF_END2COST_END-1DUR_ENDMinionTYPE_ENDNeutralPLAYER_CLS_ENDNILRACE_ENDRareRARITY_ENDALLminionscost(1)more.
Expected:

	class ManaWraith(MinionCard):§    def __init__(self):§        super().__init__("Mana Wraith", 2, CHARACTER_CLASS.ALL, CARD_RARITY.RARE)§§    def create_minion(self, player):§        return Minion(2, 2, auras=[Aura(ManaChange(1), CardSelector(BothPlayer(), IsMinion()))])§


-got-

	class SpellCard(MinionCard):§    def __init__(self):§        super().__init__("Ancient", 1, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.BEAST)§§    def create_minion(self, player):§        return Minion(1, 1, effects=[Effect(TurnEnded(player=BothPlayer()), ActionTag(Give([Buff(ChangeAttack(2)), SelfSelector()))])§


## Evaluate

In [None]:
def evaluate_accuracy_c2w(model, test_dataset, decoder, pad_token):
    """
    :param model: model to evaluate
    :param test_dataset: test dataset to evaluate that yields (input, target_tokens, input_length (or empty value), target_length (or empty value))
            target_tokens should have sequence that starts and ends with SOS and EOS tokens respectively and may be padded with pad_token
    :param decoder: decoder to evaluate with; returns a list of predicted tokens with SOS, EOS, and PAD tokens removed
    :param pad_token: padding token used in target_tokens
    """
    matches = []
    for i in tqdm(range(len(test_dataset)), position=0, leave=True):
        inp, trg_tokens, inp_len, trg_len = test_dataset[i: i+1]
        trunc_trg_tokens = trg_tokens[0][trg_tokens[0] != pad_token].tolist()

        pred_tokens = decoder(model, inp, inp_len, target_max_seq_len)
        
        matches.append(1 if pred_tokens == trg_tokens else 0)
    return matches

In [None]:
def evaluate_bleu_c2w(model, test_dataset, decoder, token2str, pad_token):
    """
    :param model: model to evaluate
    :param test_dataset: test dataset to evaluate that yields (input, target_tokens, input_length (or empty value), target_length (or empty value))
            target_tokens should have sequence that starts and ends with SOS and EOS tokens respectively and may be padded with pad_token
    :param decoder: decoder to evaluate with; returns a list of predicted tokens with SOS, EOS, and PAD tokens removed
    :param token2str: mapping from token to string
    :param pad_token: padding token used in target_tokens
    """
    bleu_scores = []
    for i in tqdm(range(len(test_dataset)), position=0, leave=True):
        inp, trg_tokens, inp_len, trg_len = test_dataset[i: i+1]
        trunc_trg_tokens = trg_tokens[0][trg_tokens[0] != pad_token].tolist()

        pred_tokens = decoder(model, inp, inp_len, target_max_seq_len)
        pred_text = "".join([id2char[t] for t in pred_tokens])
        trg_text = "".join([id2char[t] for t in trunc_trg_tokens])

        bleu_scores.append(sacrebleu.raw_corpus_bleu([pred_text], [[trg_text]], 0.01).score)

    return bleu_scores

In [None]:
# Beam Search
# c2w_beam_matches = evaluate_accuracy_c2w(c2w_encoder_decoder, c2w_test_dataset, beam_search_decoding_c2w, char2id["PAD"])
# c2w_beam_bleus = evaluate_bleu_c2w(c2w_encoder_decoder, c2w_test_dataset, beam_search_decoding_c2w, id2char, char2id["PAD"])

In [None]:
# Greedy Search
c2w_greedy_matches = evaluate_accuracy_c2w(c2w_encoder_decoder, c2w_test_dataset, greedy_decoding_c2w, char2id["PAD"])
c2w_greedy_bleus = evaluate_bleu_c2w(c2w_encoder_decoder, c2w_test_dataset, greedy_decoding_c2w, id2char, char2id["PAD"])

  attention_weights = F.softmax(attention_raw)
100%|██████████| 66/66 [01:03<00:00,  1.05it/s]
100%|██████████| 66/66 [01:03<00:00,  1.04it/s]


In [None]:
print("Metrics for C2W Encoder / Decoder")
# print(f"Beam Accuracy: {sum(c2w_beam_matches) / len(c2w_beam_matches)}")
# print(f"Beam BLEU: {sum(c2w_beam_bleus) / len(c2w_beam_bleus)}\n")
print(f"Greedy Accuracy: {sum(c2w_greedy_matches) / len(c2w_greedy_matches)}")
print(f"Greedy BLEU: {sum(c2w_greedy_bleus) / len(c2w_greedy_bleus)}")

Metrics for C2W Encoder / Decoder
Greedy Accuracy: 0.0
Greedy BLEU: 13.058270175547005


# Separate Fields
We previously treated our input sequence as one long string. We experiment with separating each field in the card spec to see if we can apply different attention weights to each.

In [20]:
class HearthstoneCardSpec(object):
    def __init__(self, name, attack, defense, cost, durability, card_type, player_cls, race, rarity, description):
        self.name = name
        self.attack = attack
        self.defense = defense
        self.cost = cost
        self.durability = durability
        self.card_type = card_type
        self.player_cls = player_cls
        self.race = race
        self.rarity = rarity
        self.description = description

class HearthstoneCardSpecTokenized(object):
    def __init__(self, card_spec, char2id):
        self.tokenized_name, self.name_token_lengths, self.name_line_length = tokenize_by_char([card_spec.name.split()], char2id)
        self.tokenized_name = self.tokenized_name.squeeze(0)[1:-1]
        self.name_token_lengths = self.name_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.name_line_length = self.name_line_length.item() - 2

        self.tokenized_attack, self.attack_token_lengths, self.attack_line_length = tokenize_by_char([card_spec.attack.split()], char2id)
        self.tokenized_attack = self.tokenized_attack.squeeze(0)[1:-1]
        self.attack_token_lengths = self.attack_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.attack_line_length = self.attack_line_length.item() - 2

        self.tokenized_defense, self.defense_token_lengths, self.defense_line_length = tokenize_by_char([card_spec.defense.split()], char2id)
        self.tokenized_defense = self.tokenized_defense.squeeze(0)[1:-1]
        self.defense_token_lengths = self.defense_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.defense_line_length = self.defense_line_length.item() - 2

        self.tokenized_cost, self.cost_token_lengths, self.cost_line_length = tokenize_by_char([card_spec.cost.split()], char2id)
        self.tokenized_cost = self.tokenized_cost.squeeze(0)[1:-1]
        self.cost_token_lengths = self.cost_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.cost_line_length = self.cost_line_length.item() - 2

        self.tokenized_durability, self.durability_token_lengths, self.durability_line_length = tokenize_by_char([card_spec.durability.split()], char2id)
        self.tokenized_durability = self.tokenized_durability.squeeze(0)[1:-1]
        self.durability_token_lengths = self.durability_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.durability_line_length = self.durability_line_length.item() - 2

        self.tokenized_card_type, self.card_type_token_lengths, self.card_type_line_length = tokenize_by_char([card_spec.card_type.split()], char2id)
        self.tokenized_card_type = self.tokenized_card_type.squeeze(0)[1:-1]
        self.card_type_token_lengths = self.card_type_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.card_type_line_length = self.card_type_line_length.item() - 2

        self.tokenized_player_cls, self.player_cls_token_lengths, self.player_cls_line_length = tokenize_by_char([card_spec.player_cls.split()], char2id)
        self.tokenized_player_cls = self.tokenized_player_cls.squeeze(0)[1:-1]
        self.player_cls_token_lengths = self.player_cls_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.player_cls_line_length = self.player_cls_line_length.item() - 2

        self.tokenized_race, self.race_token_lengths, self.race_line_length = tokenize_by_char([card_spec.race.split()], char2id)
        self.tokenized_race = self.tokenized_race.squeeze(0)[1:-1]
        self.race_token_lengths = self.race_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.race_line_length = self.race_line_length.item() - 2

        self.tokenized_rarity, self.rarity_token_lengths, self.rarity_line_length = tokenize_by_char([card_spec.rarity.split()], char2id)
        self.tokenized_rarity = self.tokenized_rarity.squeeze(0)[1:-1]
        self.rarity_token_lengths = self.rarity_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.rarity_line_length = self.rarity_line_length.item() - 2

        self.tokenized_description, self.description_token_lengths, self.description_line_length = tokenize_by_char([card_spec.description.split()], char2id)
        self.tokenized_description = self.tokenized_description.squeeze(0)[1:-1]
        self.description_token_lengths = self.description_token_lengths.squeeze(0).to("cpu")[1:-1]
        self.description_line_length = self.description_line_length.item() - 2

def card_spec_from_line(line):
    name, rest = line.split("NAME_END")
    attack, rest = rest.split("ATK_END")
    defense, rest = rest.split("DEF_END")
    cost, rest = rest.split("COST_END")
    durability, rest = rest.split("DUR_END")
    card_type, rest = rest.split("TYPE_END")
    player_cls, rest = rest.split("PLAYER_CLS_END")
    race, rest = rest.split("RACE_END")
    rarity, description = rest.split("RARITY_END")
    return HearthstoneCardSpec(name.strip(), attack.strip(), defense.strip(), cost.strip(), durability.strip(), card_type.strip(), player_cls.strip(), race.strip(), rarity.strip(), description.strip())

In [21]:
class HearthstoneDatasetByField(data.Dataset):
    def __init__(self, raw_input_lines, raw_target_lines, char2id):
        assert len(raw_input_lines) == len(raw_target_lines)
        self.raw_input_lines = raw_input_lines
        self.raw_target_lines = raw_target_lines

        self.hearthstone_card_specs = [card_spec_from_line(line) for line in raw_input_lines]
        self.hearthstone_card_specs_encoded = [HearthstoneCardSpecTokenized(card_spec, char2id) for card_spec in self.hearthstone_card_specs]
        self.tokenized_targets_ids = [torch.LongTensor([char2id["PAD"]] + [char2id[c] for c in line] + [char2id["PAD"]]).to(device) for line in raw_target_lines]

    def __len__(self):
        return len(self.raw_input_lines)

    def __getitem__(self, idx):
        return self.hearthstone_card_specs_encoded[idx], self.tokenized_targets_ids[idx]

In [22]:
raw_input_lines = read_file(train_input_path)
raw_target_lines = read_file(train_target_path)
max_target_seq_len = max([len(line) for line in raw_target_lines])
validation_ratio = 0.1
train_size = int((1 - validation_ratio) * len(raw_input_lines))
val_size = int(validation_ratio * len(raw_input_lines))
c2c_training_dataset = HearthstoneDatasetByField(raw_input_lines[:train_size], raw_target_lines[:train_size], char2id)
c2c_validation_dataset = HearthstoneDatasetByField(raw_input_lines[train_size:], raw_target_lines[train_size:], char2id)

# Card2Code Model
We build a version of the Card2Code to compare performance

## Encoder / Attention (section 4)
1. Separate each field (e.g. name, description, health...)
2. Tokenize each field and get representations using C2W (representations are learned on character level)
3. Feed fields with multiple words through a Bi-LSTM to get context aware representation (not sure if this becomes the field representation or it's concatenated to the tokens)
4. Apply linear layers on all token representations to map to same dimension
5. Compute scalar attention coefficients $a_{ki}$ for token $x_{ki}$ by solving $a_{ki} = softmax(v(f(x_{ki}), h_{t-1}))$
	- $f$ is the mapping function of the linear layers in 4
	- $v$ is a function that concatenates $f(x_{ki})$ and $h_{t-1}$ then feeds it through linear -> tanh -> linear layers
	- $h_{t-1}$ is the previous state in the RNN
6. Compute overall input vector representation as summation of tokens and attention coefficients $z_t = \sum_{k, i} a_{ki} * f(x_{ki})$
7. Compute new RNN hidden state $h_t = g(y_{t-1}, h_{t-1}, z_t)$
    - Calculate in decoder
	- $g$ uses an LSTM
	- $y_{t-1}$ is the previous context encoded at character level


In [23]:
class HearthstoneEncoder(nn.Module):
    """
    Fields: name, atk, def, cost, dur, type, player cls, race, rarity, description
        - name and description are text fields
    """
    def __init__(self, c2w, text_field_hidden_size, output_size, text_field_num_layers=1):
        super(HearthstoneEncoder, self).__init__()
        self.c2w = c2w
        self.text_field_encoder = nn.LSTM(
            input_size=c2w.output_size,
            hidden_size=text_field_hidden_size,
            num_layers=text_field_num_layers,
            batch_first=True,
            bidirectional=True
        )
        # Field projection layers
        self.singular_field_encoding = nn.Linear(c2w.output_size, output_size)
        self.text_field_encoding = nn.Linear(2 * text_field_hidden_size, output_size)
        self.output_size = output_size

    def encode_singular_field(self, field, token_lengths):
        embedding = self.c2w(field, token_lengths) # (num_layers=1, num_tokens_in_field, embedding_size)
        return self.singular_field_encoding(embedding.squeeze(0)) # (num_tokens_in_field, output_size)

    def encode_text_field(self, field, token_lengths):
        embedding = self.c2w(field, token_lengths) # (num_layers=1, num_tokens_in_field, embedding_size)
        output, (_, _) = self.text_field_encoder(embedding) # (1, num_tokens_in_field, 2 * hidden_size)
        return self.text_field_encoding(output.squeeze(0)) # (num_tokens_in_field, output_size)

    def forward(self, card_spec_tokenized):
        """
        :param card_spec_tokenized: HearthstoneCardSpecTokenized object containing the tokenized specs for 1 card
        :return: list of length <# fields> containing tensors (num_tokens_in_field, output_size) where each tensor contains
                encodings of all the tokens in a single field; fields are ordered as [name, atk, def, cost, dur, type, cls, race, rarity, desc]
        """
        # Encode each field - only name and description are text fields
        # All will be of size (num_tokens_in_field, output_size)
        encoded_name = self.encode_text_field(card_spec_tokenized.tokenized_name, card_spec_tokenized.name_token_lengths)
        encoded_attack = self.encode_singular_field(card_spec_tokenized.tokenized_attack, card_spec_tokenized.attack_token_lengths)
        encoded_defense = self.encode_singular_field(card_spec_tokenized.tokenized_defense, card_spec_tokenized.defense_token_lengths)
        encoded_cost = self.encode_singular_field(card_spec_tokenized.tokenized_cost, card_spec_tokenized.cost_token_lengths)
        encoded_durability = self.encode_singular_field(card_spec_tokenized.tokenized_durability, card_spec_tokenized.durability_token_lengths)
        encoded_card_type = self.encode_singular_field(card_spec_tokenized.tokenized_card_type, card_spec_tokenized.card_type_token_lengths)
        encoded_player_cls = self.encode_singular_field(card_spec_tokenized.tokenized_player_cls, card_spec_tokenized.player_cls_token_lengths)
        encoded_race = self.encode_singular_field(card_spec_tokenized.tokenized_race, card_spec_tokenized.race_token_lengths)
        encoded_rarity = self.encode_singular_field(card_spec_tokenized.tokenized_rarity, card_spec_tokenized.rarity_token_lengths)
        encoded_description = self.encode_text_field(card_spec_tokenized.tokenized_description, card_spec_tokenized.description_token_lengths)
        all_encoded_fields = [encoded_name, encoded_attack, encoded_defense, encoded_cost, encoded_durability, encoded_card_type, encoded_player_cls, encoded_race, encoded_rarity, encoded_description]
        
        return all_encoded_fields

In [24]:
class HearthstoneAttention(nn.Module):
    def __init__(self, encoding_output_size, decoder_hidden_size):
        super(HearthstoneAttention, self).__init__()
        self.attention_weights = nn.Linear(encoding_output_size + decoder_hidden_size, 1)
        self.decoder_hidden_size = decoder_hidden_size

    def forward(self, all_encoded_fields, prev_decoder_hidden_state=None):
        """
        :param all_encoded_fields: list of length <# fields> containing tensors (num_tokens_in_field, encoding_output_size) where each tensor contains
                encodings of all the tokens in a single field
        :param prev_decoder_hidden: tensor (1, decoder_hidden_size) representing final hidden state of decoder from previous timestep
        :return: vector of size (encoding_output_size,) representing attention as a linear combination of all the token representations
        """
        if prev_decoder_hidden_state is None:
            prev_decoder_hidden_state = torch.zeros((1, self.decoder_hidden_size)).to(device)
        encoded_input_tokens = torch.cat(all_encoded_fields, dim=0) # (total_tokens, encoding_output_size)
        concat_encoded_tokens = torch.cat([encoded_input_tokens, prev_decoder_hidden_state.squeeze(0).repeat(encoded_input_tokens.size(0), 1).to(device)], dim=1) # (total_tokens, projection_size + decoder_hidden_size)

        # Compute attention weights by applying nonlinearity (tanh) and using linear layer to map to 1 dimension
        raw_attention_weights = self.attention_weights(torch.tanh(concat_encoded_tokens)) # (total_tokens, 1)
        attention_weights = F.softmax(raw_attention_weights, dim=0) # (total_tokens, 1)
        
        attention_vector = torch.mul(attention_weights, encoded_input_tokens).sum(dim=0) # (encoding_output_size,)

        return attention_vector

In [25]:
# Params taken from setup section
hearthstone_c2w = C2W(100, 300, 300, len(chars) + 1).to(device) # char embd size, hidden size, output size, vocab size
hearthstone_encoder = HearthstoneEncoder(hearthstone_c2w, 300, 300).to(device) # c2w, text field hidden size, output size
hearthstone_attention = HearthstoneAttention(300, 300).to(device)

hearthstone_c2w.eval()
hearthstone_encoder.eval()
hearthstone_attention.eval()

HearthstoneAttention(
  (attention_weights): Linear(in_features=600, out_features=1, bias=True)
)

In [26]:
# Sample of how to use models to compute encodings and attention
sample_spec_input, sample_target = c2c_training_dataset[0]
all_encoded_fields = hearthstone_encoder(sample_spec_input)
attention_vector = hearthstone_attention(all_encoded_fields)

## Card2Code Encoder With Simple Decoder
This differs from card2code by not using pointer network architectures.

In [27]:
class C2CDecoder(nn.Module):
    def __init__(self, hidden_size, attention, enc_output_size, vocab_size):
        super(C2CDecoder, self).__init__()
        # [name, atk, def, cost, dur, type, cls, race, rarity, desc]
        self.attention = attention
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.dropout = nn.Dropout(p=0.1)
        self.combine_attention = nn.Linear(enc_output_size + hidden_size, hidden_size)
        
        self.rnn_num_layers = 1
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True, num_layers=self.rnn_num_layers, dropout=0.1)
        
        
    def forward_step(self, prev_embed, hidden, all_encoded_fields):
        """
        :param prev_embed: 3d tensor of shape (batch_size, 1, embed_size) containing word embeddings
                from previous time step
        :param hidden: 3d tensor of shape (num_layers, batch_size, hidden_size) representing current decoder hidden state

        :return: [pre_output, hidden] of current time step
        """
        attn = self.attention(all_encoded_fields, hidden[-1])
        concat_attn = torch.cat((prev_embed.squeeze(1), attn.unsqueeze(0)), dim=-1)
        combined = self.combine_attention(concat_attn).unsqueeze(1)
        
        return self.rnn(combined, hidden)

    def forward(self, inputs, all_encoded_fields, hidden=None, max_output_len=None):
        # Initialize values if not given
        if max_output_len is None:
            max_output_len = inputs.size(1)
        if hidden is None:
            hidden = torch.zeros((self.rnn_num_layers, inputs.size(0), self.hidden_size)).to(device)
            
        embedded = self.embedding(inputs)
        dropped_embedded = self.dropout(embedded)

        # Generate output and hidden for each word
        pre_output_vectors = []
        for i in range(max_output_len):
            prev_embed = dropped_embedded[:, i].unsqueeze(1)
            pre_output, hidden = self.forward_step(prev_embed, hidden, all_encoded_fields)
            pre_output_vectors.append(pre_output)

        outputs = torch.cat(pre_output_vectors, dim=1)
        return outputs, hidden

In [28]:
class C2CEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, generator):
        """
        Inputs:
          - `encoder`: an `Encoder` object.
          - `decoder`: a `Decoder` object.
          - `generator`: a `Generator` object. Essentially a linear mapping. See
              the next code cell.
        """
        super(C2CEncoderDecoder, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator

    def forward(self, card_spec_tokenized, target):
        """Take in and process masked source and target sequences.
        Returns the decoder outputs, see the above cell.
        """
        all_encoded_fields = self.encode(card_spec_tokenized)
        return self.decode(target[:,:-1], all_encoded_fields)

    def encode(self, card_spec_tokenized):
        return self.encoder(card_spec_tokenized)

    def decode(self, target, all_encoded_fields, decoder_hidden=None):
        return self.decoder(target, all_encoded_fields, hidden=decoder_hidden)

In [29]:
def run_epoch_c2c(dataset, model, loss_compute):
    """Standard Training and Logging Function"""
    total_tokens = 0
    total_loss = 0

    for i in tqdm(range(len(dataset)), position=0, leave=True):
        card_spec, target = dataset[i]
        target = target.unsqueeze(0)
        output, _ = model(card_spec, target)
        
        loss = loss_compute(x=output, y=target[:, 1:],
                            norm=target.size(0))
        total_loss += loss
        total_tokens += target.size(1) - 1
        
        if i % 100 == 0:
            print(f"Iteration {i} Loss: {loss}")

    print(f"Total loss: {math.exp(total_loss / float(total_tokens))}")

    return math.exp(total_loss / float(total_tokens))

def train_c2c(model, train_dataset, val_dataset, num_epochs, learning_rate):
    # Set `ignore_index` as PAD_INDEX so that pad tokens won't be included when
    # computing the loss.
    criterion = nn.NLLLoss(reduction="sum")
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Keep track of dev ppl for each epoch.
    dev_ppls = []

    for epoch in range(num_epochs):
        print("Epoch", epoch)

        model.train()
        train_ppl = run_epoch_c2c(dataset=train_dataset, model=model,
                                loss_compute=SimpleLossCompute(model.generator,
                                                             criterion, optim))

        model.eval()
        with torch.no_grad():        
            dev_ppl = run_epoch_c2c(dataset=val_dataset, model=model,
                                loss_compute=SimpleLossCompute(model.generator,
                                                             criterion, None))
            print("Validation perplexity: %f" % dev_ppl)
            dev_ppls.append(dev_ppl)
        
    return dev_ppls

In [32]:
# Set params
c2w_char_embed_size = 100
c2w_hidden_size = 300
c2w_output_size = 300
encoder_text_hidden_size = 300
encoder_output_size = 300
decoder_output_size = 300
decoder_hidden_size = 300
batch_size = 1

# Params taken from setup section
c2c_c2w = C2W(c2w_char_embed_size, c2w_hidden_size, c2w_output_size, len(char2id)).to(device) # char embd size, hidden size, output size, vocab size
c2c_encoder = HearthstoneEncoder(c2c_c2w, encoder_text_hidden_size, encoder_output_size).to(device) # c2w, text field hidden size, output size
c2c_attention = HearthstoneAttention(encoder_output_size, decoder_output_size).to(device)
c2c_decoder = C2CDecoder(decoder_hidden_size, c2c_attention, encoder_output_size, len(char2id)).to(device)
c2c_generator = Generator(decoder_hidden_size, len(char2id)).to(device)
c2c_encoder_decoder = C2CEncoderDecoder(c2c_encoder, c2c_decoder, c2c_generator).to(device)

  "num_layers={}".format(dropout, num_layers))


In [33]:
num_epochs = 1
lr = 1e-3

train_c2c(c2c_encoder_decoder, c2c_training_dataset, c2c_validation_dataset, num_epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/479 [00:00<02:36,  3.06it/s]

Iteration 0 Loss: 1289.3021240234375


 21%|██        | 101/479 [00:38<02:16,  2.76it/s]

Iteration 100 Loss: 290.7304382324219


 42%|████▏     | 201/479 [01:21<02:12,  2.10it/s]

Iteration 200 Loss: 536.466064453125


 63%|██████▎   | 301/479 [02:02<01:31,  1.95it/s]

Iteration 300 Loss: 507.6236572265625


 84%|████████▎ | 401/479 [02:44<00:27,  2.83it/s]

Iteration 400 Loss: 123.83995056152344


100%|██████████| 479/479 [03:19<00:00,  2.41it/s]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 2.8841519765791883


  4%|▎         | 2/54 [00:00<00:12,  4.13it/s]

Iteration 0 Loss: 744.932861328125


100%|██████████| 54/54 [00:08<00:00,  6.53it/s]

Total loss: 2.012009187257607
Validation perplexity: 2.012009





[2.012009187257607]

In [48]:
# Train for 1 more epoch
train_c2c(c2c_encoder_decoder, c2c_training_dataset, c2c_validation_dataset, num_epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/479 [00:00<02:39,  3.00it/s]

Iteration 0 Loss: 194.34280395507812


 21%|██        | 101/479 [00:38<02:16,  2.76it/s]

Iteration 100 Loss: 126.01023864746094


 42%|████▏     | 201/479 [01:22<02:15,  2.06it/s]

Iteration 200 Loss: 334.14080810546875


 63%|██████▎   | 301/479 [02:04<01:34,  1.89it/s]

Iteration 300 Loss: 326.5388488769531


 84%|████████▎ | 401/479 [02:47<00:27,  2.83it/s]

Iteration 400 Loss: 101.48735809326172


100%|██████████| 479/479 [03:21<00:00,  2.38it/s]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.735957484606414


  4%|▎         | 2/54 [00:00<00:12,  4.03it/s]

Iteration 0 Loss: 590.7501220703125


100%|██████████| 54/54 [00:08<00:00,  6.54it/s]

Total loss: 1.7810753438023887
Validation perplexity: 1.781075





[1.7810753438023887]

In [59]:
# Train for 1 more epoch
train_c2c(c2c_encoder_decoder, c2c_training_dataset, c2c_validation_dataset, num_epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/479 [00:00<02:44,  2.90it/s]

Iteration 0 Loss: 152.2542724609375


 21%|██        | 101/479 [00:38<02:17,  2.75it/s]

Iteration 100 Loss: 96.17437744140625


 42%|████▏     | 201/479 [01:22<02:16,  2.03it/s]

Iteration 200 Loss: 274.4189453125


 63%|██████▎   | 301/479 [02:03<01:32,  1.93it/s]

Iteration 300 Loss: 255.27357482910156


 84%|████████▎ | 401/479 [02:46<00:27,  2.80it/s]

Iteration 400 Loss: 94.45476531982422


100%|██████████| 479/479 [03:20<00:00,  2.38it/s]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.5817935244476857


  4%|▎         | 2/54 [00:00<00:12,  4.06it/s]

Iteration 0 Loss: 513.6979370117188


100%|██████████| 54/54 [00:08<00:00,  6.56it/s]

Total loss: 1.695475679824379
Validation perplexity: 1.695476





[1.695475679824379]

In [70]:
# Train for 1 more epoch
train_c2c(c2c_encoder_decoder, c2c_training_dataset, c2c_validation_dataset, num_epochs, lr)

  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/479 [00:00<02:43,  2.92it/s]

Iteration 0 Loss: 125.07328796386719


 21%|██        | 101/479 [00:39<02:21,  2.67it/s]

Iteration 100 Loss: 78.52368927001953


 42%|████▏     | 201/479 [01:23<02:15,  2.05it/s]

Iteration 200 Loss: 227.59451293945312


 63%|██████▎   | 301/479 [02:04<01:33,  1.91it/s]

Iteration 300 Loss: 227.42515563964844


 84%|████████▎ | 401/479 [02:47<00:28,  2.74it/s]

Iteration 400 Loss: 90.56700897216797


100%|██████████| 479/479 [03:21<00:00,  2.37it/s]
  0%|          | 0/54 [00:00<?, ?it/s]

Total loss: 1.5010941351463838


  4%|▎         | 2/54 [00:00<00:13,  3.99it/s]

Iteration 0 Loss: 432.57843017578125


100%|██████████| 54/54 [00:08<00:00,  6.61it/s]

Total loss: 1.6648790582181232
Validation perplexity: 1.664879





[1.6648790582181232]

In [77]:
# Set to True if you want to save this model
save_model = False
if save_model:
    torch.save(c2c_encoder_decoder.state_dict(), os.path.join(project_dir, "5-17 Fork", "c2c_encoder_decoder_4e.pt"))

# Set to True if you want to load a previously trained model
load_model = False
if load_model:
    c2c_encoder_decoder.load_state_dict(torch.load(os.path.join(project_dir, "5-17 Fork", "c2c_encoder_decoder_2e.pt"), map_location=device))
    c2c_encoder_decoder.eval()

## Evaluate

In [35]:
# Read and tokenize test inputs and targets
test_raw_inputs = read_file(test_input_path)
test_raw_targets = read_file(test_target_path)

trunc_test_raw_targets = []
for line in test_raw_targets:
    if len(line) > max_target_seq_len - 2:
        trunc_test_raw_targets.append(line[:max_target_seq_len - 2])
    else:
        trunc_test_raw_targets.append(line)

c2c_test_dataset = HearthstoneDatasetByField(test_raw_inputs, trunc_test_raw_targets, char2id)

In [36]:
def greedy_decoding_c2c(model, card_spec_tokenized, max_len):
    """Greedily decode a sentence for EncoderDecoder. Make sure to chop off the 
         EOS token!"""

    with torch.no_grad():
        all_encoded_fields = model.encode(card_spec_tokenized)
        prev_y = torch.ones(1, 1).fill_(char2id["PAD"]).long().to(device)
    
    output = []
    hidden = None

    for i in range(max_len):
        with torch.no_grad():
            outputs, hidden = model.decode(prev_y, all_encoded_fields, hidden)
            prob = model.generator(outputs[:, -1])
        d, next_word = torch.max(prob, dim=1)
        next_word = next_word.data.item()
        output.append(next_word)
        prev_y = torch.ones(1, 1).fill_(next_word).long().to(device)

    output = np.array(output)

    # Cut off everything starting from </s>.
    first_pad = np.where(output == char2id["PAD"])[0]
    if len(first_pad) > 0:
        output = output[:first_pad[0]]

    
    return output

In [50]:
def spot_check_greedy_c2c(model, dataset, idx=None, n=1):
    """Compare a (random) generated and target sequence using greedy search"""
    for i in range(n):
        if idx is None:
            idx = np.random.randint(0, len(dataset))
        card_spec_tokenized, targets = dataset[idx: idx+1]
        greedy_decoded = greedy_decoding_c2c(model, card_spec_tokenized[0], max_target_seq_len)
        targets = targets[0][1:-1]
        stripped_trg = targets[targets != char2id["PAD"]].tolist()
        print("===============================")
        print(f"Expected:\n\n\t{tokens_to_text(stripped_trg, id2char)}\n\n-got-\n\n\t{tokens_to_text(greedy_decoded, id2char)}")
        print("===============================")

In [80]:
spot_check_greedy_c2c(c2c_encoder_decoder, c2c_test_dataset)

Expected:

	class DarkscaleHealer(MinionCard):§    def __init__(self):§        super().__init__("Darkscale Healer", 5, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, battlecry=Battlecry(Heal(2), CharacterSelector()))§§    def create_minion(self, player):§        return Minion(4, 5)§


-got-

	class Stormper(MinionCard):§    def __init__(self):§        super().__init__("Starger", 3, CHARACTER_CLASS.ALL, CARD_RARITY.COMMON, minion_type=MINION_TYPE.BEAST)§§    def create_minion(self, player):§        return Minion(3, 5, deathrattle=Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Deathrattle(Dea

In [None]:
def beam_search_decoding_c2c(model, card_spec_tokenized, max_len, k=25):
    """Keep expanding top k most likely sequences"""
    with torch.no_grad():
        all_encoded_fields = model.encode(card_spec_tokenized)
    
    # Keep track of top outputs stores as (log prob, output ID seq, hidden)
    top_outputs = [(0, [char2id["SOS"]], None)]

    for i in range(max_len):
        new_top_outputs = []
        for log_prob, output, hidden in top_outputs:
            # Get last token of candidate output sequence and use as input to decoder
            prev_y = torch.ones(1, 1).long().fill_(output[-1]).to(device)
            probs = None
            h = None
            with torch.no_grad():
                o, h = model.decode(prev_y, all_encoded_fields, hidden)
                probs = model.generator(o[:, -1])
            # Get top k log probs and ids
            topk_log_probs, topk_ids = torch.topk(probs,k, dim=1)
            for token_log_prob, token_id in zip(topk_log_probs[0], topk_ids[0]):
                new_top_outputs.append((log_prob + token_log_prob.data.item(), output + [token_id.data.item()], h))
        # Get top k most likely output sequences up to this point
        new_top_outputs = sorted(new_top_outputs, key=lambda d: d[0], reverse=True)
        top_outputs = new_top_outputs[:k]
    
    # Get the most likely output sequence of all top outputs
    output = np.array(max(top_outputs, key=lambda d: d[0])[1])

    # Cut off everything starting from </s>.
    first_pad = np.where(output[1:] == char2id["EOS"])[0]
    if len(first_pad) > 0:
        output = output[:first_pad[0]]

    return output[1:]

In [None]:
def spot_check_beam_c2c(model, dataset, idx=None, n=1):
    """Compare a (random) generated and target sequence using greedy search"""
    for i in range(n):
        if idx is None:
            idx = np.random.randint(0, len(dataset))
        card_spec_tokenized, targets = dataset[idx: idx+1]
        greedy_decoded = beam_search_decoding_c2c(model, card_spec_tokenized[0], max_target_seq_len)
        targets = targets[0][1:-1]
        stripped_trg = targets[targets != char2id["PAD"]].tolist()
        print("===============================")
        print(f"Expected:\n\n\t{tokens_to_text(stripped_trg, id2char)}\n\n-got-\n\n\t{tokens_to_text(greedy_decoded, id2char)}")
        print("===============================")

In [56]:
spot_check_beam_c2c(c2c_encoder_decoder, c2c_test_dataset)

NameError: ignored

In [66]:
def evaluate_accuracy_c2c(model, test_dataset, decoder):
    """
    :param model: model to evaluate
    :param test_dataset: test dataset to evaluate that yields (input, target_tokens, input_length (or empty value), target_length (or empty value))
            target_tokens should have sequence that starts and ends with SOS and EOS tokens respectively and may be padded with pad_token
    :param decoder: decoder to evaluate with; returns a list of predicted tokens with SOS, EOS, and PAD tokens removed
    :param pad_token: padding token used in target_tokens
    """
    matches = []
    for i in tqdm(range(len(test_dataset)), position=0, leave=True):
        card_spec_tokenized, targets = test_dataset[i: i+1]
        
        targets = targets[0][1:-1]
        stripped_trg = targets[targets != char2id["PAD"]].tolist()

        pred_tokens = decoder(model, card_spec_tokenized[0], max_target_seq_len).tolist()
        
        matches.append(1 if pred_tokens == stripped_trg else 0)
    return matches

In [67]:
def evaluate_bleu_c2c(model, test_dataset, decoder, token2str):
    """
    :param model: model to evaluate
    :param test_dataset: test dataset to evaluate that yields (input, target_tokens, input_length (or empty value), target_length (or empty value))
            target_tokens should have sequence that starts and ends with SOS and EOS tokens respectively and may be padded with pad_token
    :param decoder: decoder to evaluate with; returns a list of predicted tokens with SOS, EOS, and PAD tokens removed
    :param token2str: mapping from token to string
    :param pad_token: padding token used in target_tokens
    """
    bleu_scores = []
    for i in tqdm(range(len(test_dataset)), position=0, leave=True):
        card_spec_tokenized, targets = test_dataset[i: i+1]
        targets = targets[0][1:-1]
        stripped_trg = targets[targets != char2id["PAD"]].tolist()

        pred_tokens = decoder(model, card_spec_tokenized[0], max_target_seq_len)
        pred_text = "".join([id2char[t] for t in pred_tokens])
        trg_text = "".join([id2char[t] for t in stripped_trg])

        bleu_scores.append(sacrebleu.raw_corpus_bleu([pred_text], [[trg_text]], 0.01).score)

    return bleu_scores

In [81]:
# Greedy Search
c2c_greedy_matches = evaluate_accuracy_c2c(c2c_encoder_decoder, c2c_test_dataset, greedy_decoding_c2c)
c2c_greedy_bleus = evaluate_bleu_c2c(c2c_encoder_decoder, c2c_test_dataset, greedy_decoding_c2c, id2char)

100%|██████████| 66/66 [01:19<00:00,  1.21s/it]
100%|██████████| 66/66 [01:19<00:00,  1.21s/it]


In [83]:
print("Metrics for C2W Encoder / Decoder")
# print(f"Beam Accuracy: {sum(c2w_beam_matches) / len(c2w_beam_matches)}")
# print(f"Beam BLEU: {sum(c2w_beam_bleus) / len(c2w_beam_bleus)}\n")
print(f"Greedy Accuracy: {sum(c2c_greedy_matches) / len(c2c_greedy_matches)}")
print(f"Greedy BLEU: {sum(c2c_greedy_bleus) / len(c2c_greedy_bleus)}")

Metrics for C2W Encoder / Decoder
Greedy Accuracy: 0.0
Greedy BLEU: 12.389094297166572


## Decode (section 5)
1. Select a predictor $r_t$ using probabilities from $softmax(h_{t-1}, z_t)$ to generate a sequence $s_t$
	- 1 predictor for each field that copies from the field
	- 1 predictor that generates characters
	- $|x| + 1$ predictors for $|x|$ fields
	- $h_{t-1}$ and $z_t$ are from encode
2. If generate char selected, generate the char using probabilities from $softmax(h_t)$
3. If generate field is selected for a singular field (one word), copy all the characters from that word with probability 1
4. If generate field is selected for a text field (multiple words), copy word from text based on pointer network probability
	- Probability of word c_i is $p(c_i) = softmax(v(h(c_i), q))$
	- $h$ is a representation of word $c_i$
        - $f(x_{ki})$ from attention calculation
	- $v$ is a function that concatenates $h(c_i)$ and $q$ then feeds it through linear -> tanh -> linear layers
	- $q$ is a concatenation of $h_{t-1}$ and $z_t$ from encode
5. At each time step, generate states $S = (r_t, s_t)$ based on scores $V(S) = \log P(s_t|y_1...y_{t-1}, x, r_t) + \log P(r_t | y_1...y_{t-1}, x) + V(prev(S))$
	- Expand top $n$ states
	- All states producing same output up to that point are merged by summing their probabilities

In [None]:
class HearthstoneDecoder(nn.Module):
    def __init__(self, embedding_size, hidden_size, enc_output_size, 
                 attention, predictor, generator, num_of_fields=10,
                 vocab_size=98, num_layers=1):
        super(HearthstoneDecoder, self).__init__()
        self.trg_embed = nn.Embedding(vocab_size, embedding_size) # embed_vec for each char
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.LSTM(enc_output_size + embedding_size, hidden_size, num_layers,
                           batch_first=True)  # z_t, h_(t-1) -> h_t
        self.attention = attention  # HearthstoneAttention
        self.predictor = predictor # 
        self.generator = generator
        self.num_of_fields = num_of_fields
        self.hidden2pred = nn.Linear(hidden_size, num_of_fields + 1) # map hidden vectors to num. of predictors size vec
    def forward_step(self, prev_embed, hidden, all_enc_fields, mode="train"):

        attention_vec = self.attention(all_enc_fields, hidden[0]).view(1,1,-1) # (1,1,enc_output_size)
        prev_embed_attention = torch.cat([attention_vec, prev_embed], dim=2) 
        print(prev_embed_attention.size())
        _, hidden = self.rnn(prev_embed_attention, hidden)
        pred_vec = self.hidden2pred(hidden[0])
        pred_probs = F.log_softmax(pred_vec)
        
        # assuming predictors return the length of the output embed vector seq

        return pred_probs, attention_vec, hidden

    def forward(self, trg_seq, all_enc_fields, input_card, hidden=None, max_output_len=None, 
                mode="train"):

        if max_output_len is None:
            max_output_len = trg_seq.size(-1)
        
        if hidden is None:
            hidden = (torch.zeros((1,1,self.hidden_size)).to(device), 
                      torch.zeros((1,1,self.hidden_size)).to(device))
        
        trg_emb = self.trg_embed(trg_seq)

        predictor_selection = []
        attentions = []
        hiddens = [hidden[0]]
        outputs = []
        if mode == "train":    
            i = 0
            while i < max_output_len:
                prev_embed = trg_emb[:,i].unsqueeze(1)
                pred_probs, attention_vec, hidden = self.forward_step(prev_embed, hidden, all_enc_fields)
                pred_id = torch.argmax(pred_probs).item()
                #predictor = self.predictors(pred_id)
                if pred_id == self.num_of_fields:
                    output, seq_len = self.generator(hidden[0]), 1
                else:
                    output = self.predictor(pred_id, all_enc_fields, input_card) #, attention_vec, prev_hidden)
                    seq_len = len(output) 
                i = i + seq_len
                if type(output) is list:
                    output = F.one_hot(torch.tensor([output]), num_classes=97).float()
                outputs.append(output)
        else:
            for i in range(max_output_len):
                prev_embed = trg_emb[:,i].unsqueeze(1)
                pred_probs, attention_vec, hidden = self.forward_step(prev_embed, hidden, all_enc_fields)    
                predictor_selection.append(pred_probs)
                attentions.append(attentions)
                hiddens.append(hidden)
        
        return predictor_selection, attentions, hiddens, torch.cat(outputs, dim=0).to(device)


In [None]:
class HearthstoneEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        """
        Inputs:
          - `encoder`: an `Encoder` object.
          - `decoder`: a `Decoder` object.
          - `generator`: a `Generator` object. Essentially a linear mapping. See
              the next code cell.
        """
        super(HearthstoneEncoderDecoder, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, card_spec_tokenized, input_card, trg_ids):
        """Take in and process masked source and target sequences.

        Inputs:
          `src_ids`: a 2d-tensor of shape (batch_size, max_seq_length) representing
            a batch of source sentences of word ids.
          `trg_ids`: a 2d-tensor of shape (batch_size, max_seq_length) representing
            a batch of target sentences of word ids.
          `src_lengths`: a 1d-tensor of shape (batch_size,) representing the
            sequence length of `src_ids`.

        Returns the decoder outputs, see the above cell.
        """
        all_encoded_fields = self.encode(card_spec_tokenized)
        _ , _ , _ , outputs = self.decode(trg_ids, all_encoded_fields, input_card)
        return outputs

    def encode(self, card_spec_tokenized):
        return self.encoder(card_spec_tokenized)

    def decode(self, trg_ids, all_encoded_fields, input_card):
        return self.decoder(trg_ids, all_encoded_fields, input_card)

In [None]:
class HearthstonePredictor(nn.Module):
    def __init__(self, encoder_output_size, dataset):
        super(HearthstonePredictor, self).__init__()
        self.max_seq_len = max([len(card_spec.description.split(" ")) for card_spec in training_dataset.hearthstone_card_specs])
        self.proj = nn.Linear(encoder_output_size, self.max_seq_len)
        self.input_size = encoder_output_size
        self.softmax = nn.Softmax(dim = 0)
    
    def forward(self, idx, all_encoded_fields, input_card):
        input_seq = all_encoded_fields[idx]
        word_seq = None
        if idx == 0:
            word_seq = input_card.name.split(" ")
        elif idx == 1:
            word_seq = input_card.attack.split(" ")
        elif idx == 2:
            word_seq = input_card.defense.split(" ")
        elif idx == 3:
            word_seq = input_card.cost.split(" ")
        elif idx == 4:
            word_seq = input_card.durability.split(" ")
        elif idx == 5:
            word_seq = input_card.card_type.split(" ")
        elif idx == 6:
            word_seq = input_card.player_cls.split(" ")
        elif idx == 7:
            word_seq = input_card.race.split(" ")
        elif idx == 8:
            word_seq = input_card.rarity.split(" ")
        elif idx == 9:
            word_seq = input_card.description.split(" ")

        vocab_prob = self.softmax(self.proj(input_seq))
        word_id = torch.argmax(vocab_prob).item()
        id = word_id % len(word_seq)
        copy_seq = [char2id[char] for char in word_seq[id]]
        return copy_seq

class HearthstoneGenerator(nn.Module):
    def __init__(self, enc_output_size, vocab_size=98):
        super(HearthstoneGenerator, self).__init__()
        self.proj = nn.Linear(enc_output_size, vocab_size, bias=False)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

In [None]:
class HearthLossCompute:
    def __init__(self, criterion, opt=None):
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y, norm):
        
        loss = self.criterion(x.contiguous().view(-1, x.size(-1))[:y.size(-1),:],
                                y.contiguous().view(-1))
        loss = loss / norm

        if self.opt is not None:    # training mode
            loss.backward()            
            self.opt.step()
            self.opt.zero_grad()

        return loss.data.item() * norm

In [None]:
def run_epoch(size, data, model, loss_compute):
    """Standard Training and Logging Function"""
    total_tokens = 0
    total_loss = 0

    for i in tqdm(range(size)):
        card_spec_tok, input_card, trg_raw = data[i]
        trg_ids = [[char2id[ch] for ch in trg_raw[:-1]]]
        trg_ids = torch.tensor(trg_ids).to(device)
        outputs = model(card_spec_tok, input_card, trg_ids)
        loss = loss_compute(x=outputs, y=trg_ids,
                            norm=trg_ids.size(0))
        total_loss += loss
        total_tokens += (trg_ids != 97).data.sum().item()

    print(f"Total loss: {math.exp(total_loss / float(total_tokens))}")

    return math.exp(total_loss / float(total_tokens))

def train(model, train_size, val_size, train_data, val_data, num_epochs, learning_rate=0.1):
    # Set `ignore_index` as PAD_INDEX so that pad tokens won't be included when
    # computing the loss.
    criterion = nn.NLLLoss(reduction="sum", ignore_index=97)
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Keep track of dev ppl for each epoch.
    dev_ppls = []

    for epoch in range(num_epochs):
        print("Epoch", epoch)

        model.train()
        train_ppl = run_epoch(train_size, train_data, model=model,
                                loss_compute=HearthLossCompute(criterion, optim))

        model.eval()
        with torch.no_grad():        
            dev_ppl = run_epoch(val_size, val_data, model=model,
                                loss_compute=HearthLossCompute(criterion, None))
            print("Validation perplexity: %f" % dev_ppl)
            dev_ppls.append(dev_ppl)
        
    return dev_ppls

In [None]:
hearth_pred = HearthstonePredictor(300, training_dataset)
hearth_gen = HearthstoneGenerator(300)
hearth_decoder = HearthstoneDecoder(300, 300, 300, hearthstone_attention, 
                                    hearth_pred, hearth_gen)
hearth_encoder_decoder = HearthstoneEncoderDecoder(hearthstone_encoder, hearth_decoder)
training_loader = data.DataLoader(training_dataset, batch_size=1, shuffle=True)
validation_loader = data.DataLoader(validation_dataset, batch_size=1, shuffle=True)

In [None]:
epochs = 20
lr = 1e-3
hearth_encoder_decoder = hearth_encoder_decoder.to(device)
train(hearth_encoder_decoder, train_size, val_size, training_dataset, validation_dataset, epochs, lr)



  0%|          | 0/479 [00:00<?, ?it/s]

Epoch 0
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Size([1, 1, 600])
torch.Si




RuntimeError: ignored