In [None]:
import pickle
import argparse
import sys

import numpy as np
import scipy.spatial
import scipy.stats

import torch
import torch.nn as nn
from torch.nn import functional as F

from tensorboardX import SummaryWriter

from model import ILTrainer, LSTMModel
from utils import *
from data import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 42
VOCAB_SIZE = 10
MAX_LENGTH = 5
BATCH_SIZE = 64
NUM_ITERATIONS = 1000
LOG_INTERVAL = 100
NUM_GENERATIONS = 20

#### Seed and load data

In [None]:
seed_torch(seed=SEED)

vocab = AgentVocab(VOCAB_SIZE)
meta = get_encoded_metadata()

## Run Iterated Learning

In [None]:
language = generate_uniform_language_fixed_length(vocab, len(meta), MAX_LENGTH)

for g in range(NUM_GENERATIONS):
    dataset = ILDataset(meta, language)
    
    train_dataloader, valid_dataloader, test_dataloader = split_dataset_into_dataloaders(dataset, batch_size=BATCH_SIZE)
    
    model = LSTMModel(vocab.full_vocab_size, MAX_LENGTH)
    optimizer = torch.optim.Adam(model.parameters())
    trainer = ILTrainer(model)
    trainer.to(device)

    i = 0
    while i < NUM_ITERATIONS:
        for (batch, targets) in train_dataloader:
            loss, acc = train_one_batch(trainer, batch, targets, optimizer)
            if i % LOG_INTERVAL == 0:
                valid_loss_meter, valid_acc_meter, sequences = evaluate(
                    trainer, valid_dataloader)
                print(
                    "{}/{} Iterations: val loss: {}, val accuracy: {}".format(
                        i,
                        NUM_ITERATIONS,
                        valid_loss_meter.avg,
                        valid_acc_meter.avg))
            i += 1

    language = infer_new_language(trainer, dataset, batch_size=BATCH_SIZE)