In [1]:
import pickle
import argparse
import sys
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F

from torch.utils.data import random_split, DataLoader
from tensorboardX import SummaryWriter

from data import AgentVocab, ILDataset, get_encoded_metadata, generate_uniform_language_fixed_length
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 42
VOCAB_SIZE = 10
MAX_LENGTH = 5
BATCH_SIZE = 64
NUM_ITERATIONS = 1000
LOG_INTERVAL = 100

#### Seed and load data

In [2]:
seed_torch(seed=SEED)

vocab = AgentVocab(VOCAB_SIZE)
meta = get_encoded_metadata()

### Model

In [3]:
class LSTMModel(nn.Module):
    def __init__(self,
        vocab_size,
        max_length,
        hidden_size=256
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length+1
        self.hidden_size = hidden_size

        # transformation of the input to hidden
        self.linear_in = nn.Linear(14, hidden_size)
        # recurrent cell
        self.rnn = nn.LSTMCell(hidden_size, hidden_size)
        # from a hidden state to the vocab
        self.linear_out = nn.Linear(
            hidden_size, vocab_size
        )
            
    def forward(self, inputs, hidden=None):
        """
        Performs a forward pass
        """
        batch_size = inputs.shape[0]
        
        x = self.linear_in(inputs)
        outputs = []
        
        h = torch.zeros([batch_size, self.hidden_size], device=device)
        c = torch.zeros([batch_size, self.hidden_size], device=device)
        
        state = (h, c)
        
        for i in range(self.max_length):
            state = self.rnn(x, state)
            h, _ = state
            outputs.append(self.linear_out(h))
        
        outputs = F.softmax(torch.stack(outputs, dim=1), dim=2)
        return outputs

### Trainer

In [4]:
class ILTrainer(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

    def forward(self, batch, targets):
        batch_size, seq_length  = targets.shape[0], targets.shape[1]

        predictions = self.model(batch)
        loss = self.criterion(predictions.reshape(-1, predictions.shape[2]), targets.reshape(-1))
        
        # Calculate accuracy
        sequences = torch.argmax(predictions, dim=2)
        accuracy = (torch.sum(sequences == targets).type(torch.float) / (batch_size*seq_length))

        return loss, torch.mean(accuracy), sequences

In [6]:
language = generate_uniform_language_fixed_length(vocab, len(meta), MAX_LENGTH)

while g < NUM_GENERATIONS:
    dataset = ILDataset(meta, language)
    train_length = int(0.6 * len(meta))
    valid_length = int(0.1 * len(meta))
    test_length = len(meta) - train_length - valid_length

    train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_length, valid_length, test_length])    
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    model = LSTMModel(vocab.full_vocab_size, MAX_LENGTH)
    optimizer = torch.optim.Adam(model.parameters())
    trainer = ILTrainer(model)
    trainer.to(device)

    i = 0
    while i < NUM_ITERATIONS:
        for (batch, targets) in train_dataloader:
            loss, acc = train_one_batch(trainer, batch, targets, optimizer)
            if i % LOG_INTERVAL == 0:
                valid_loss_meter, valid_acc_meter, valid_messages = evaluate(
                    trainer, valid_dataloader)
                print(
                    "{}/{} Iterations: val loss: {}, val accuracy: {}".format(
                        i,
                        NUM_ITERATIONS,
                        valid_loss_meter.avg,
                        valid_acc_meter.avg))
            i += 1

    language = 

0/1000 Iterations: val loss: 2.5617374181747437, val accuracy: 0.08579282462596893
100/1000 Iterations: val loss: 2.4102375507354736, val accuracy: 0.2335069477558136
200/1000 Iterations: val loss: 2.4126192331314087, val accuracy: 0.24840857088565826
300/1000 Iterations: val loss: 2.4184833765029907, val accuracy: 0.2408854141831398
400/1000 Iterations: val loss: 2.422072410583496, val accuracy: 0.24551504105329514
500/1000 Iterations: val loss: 2.428613066673279, val accuracy: 0.24320023506879807
600/1000 Iterations: val loss: 2.4352664947509766, val accuracy: 0.2352430522441864
700/1000 Iterations: val loss: 2.4373056888580322, val accuracy: 0.23625578731298447
800/1000 Iterations: val loss: 2.440883159637451, val accuracy: 0.2339409664273262
900/1000 Iterations: val loss: 2.4406795501708984, val accuracy: 0.2352430522441864
