In [11]:
import pickle
import argparse
import sys

import numpy as np
import scipy.spatial
import scipy.stats

import torch
import torch.nn as nn
from torch.nn import functional as F

from tensorboardX import SummaryWriter

from data import AgentVocab, ILDataset, get_encoded_metadata, generate_uniform_language_fixed_length
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 42
VOCAB_SIZE = 10
MAX_LENGTH = 5
BATCH_SIZE = 64
NUM_ITERATIONS = 1000
LOG_INTERVAL = 100
NUM_GENERATIONS = 20

#### Seed and load data

In [2]:
seed_torch(seed=SEED)

vocab = AgentVocab(VOCAB_SIZE)
meta = get_encoded_metadata()

### Model

In [3]:
class LSTMModel(nn.Module):
    def __init__(self,
        vocab_size,
        max_length,
        hidden_size=256
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length+1
        self.hidden_size = hidden_size

        # transformation of the input to hidden
        self.linear_in = nn.Linear(14, hidden_size)
        # recurrent cell
        self.rnn = nn.LSTMCell(hidden_size, hidden_size)
        # from a hidden state to the vocab
        self.linear_out = nn.Linear(
            hidden_size, vocab_size
        )
            
    def forward(self, inputs, hidden=None):
        """
        Performs a forward pass
        """
        batch_size = inputs.shape[0]
        
        x = self.linear_in(inputs)
        outputs = []
        
        h = torch.zeros([batch_size, self.hidden_size], device=device)
        c = torch.zeros([batch_size, self.hidden_size], device=device)
        
        state = (h, c)
        
        for i in range(self.max_length):
            state = self.rnn(x, state)
            h, _ = state
            outputs.append(self.linear_out(h))
        
        outputs = F.softmax(torch.stack(outputs, dim=1), dim=2)
        return outputs

### Trainer

In [4]:
class ILTrainer(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

    def forward(self, batch, targets):
        batch_size, seq_length  = targets.shape[0], targets.shape[1]

        predictions = self.model(batch)
        loss = self.criterion(predictions.reshape(-1, predictions.shape[2]), targets.reshape(-1))
        
        # Calculate accuracy
        sequences = torch.argmax(predictions, dim=2)
        accuracy = (torch.sum(sequences == targets).type(torch.float) / (batch_size*seq_length))

        return loss, torch.mean(accuracy), sequences

### Infer new langguage

In [5]:
def infer_new_language(model, full_dataset):
    dataloader = DataLoader(full_dataset, batch_size=BATCH_SIZE)
    loss, acc, sequences = evaluate(model, dataloader)
    return sequences

#### Metric

## Run Iterated Learning

In [6]:
def compositionality_metrics(compositional_representation, messages, samples=5000):
    """
    Approximates Topological Similarity and TRE score from https://arxiv.org/abs/1902.07181
    Args:
        compositional_representation (np.array): one-hot encoded compositional, size N*C
        messages (np.array): one-hot encoded messages, size N*M
        samples (int, optional): default 5000 - number of pairs to sample
    Returns:
        topological_similarity (float): correlation between similarity of pairs in representation/messages
        tre_score (float): L1 distance for similarity pairs between representation/messages
                           Note: This is only an approximation of the orinal score,
                                 it correlates heavily with the original TRE, but does
                                 not have the same magnitude.
                                 The full implementation is not yet implemented.
    """
    assert compositional_representation.shape[0] == messages.shape[0]

    sim_representation = np.zeros(samples)
    sim_messages = np.zeros(samples)

    for i in range(samples):
        rnd = np.random.choice(len(messages), 2, replace=False)
        s1, s2 = rnd[0], rnd[1]

        sim_representation[i] = scipy.spatial.distance.cosine(
            compositional_representation[s1], compositional_representation[s2]
        )

        sim_messages[i] = scipy.spatial.distance.cosine(messages[s1], messages[s2])

    topological_similarity = scipy.stats.pearsonr(sim_messages, sim_representation)[0]
    tre_score = np.linalg.norm(sim_representation - sim_messages, ord=1)

    return (topological_similarity, tre_score)






language = generate_uniform_language_fixed_length(vocab, len(meta), MAX_LENGTH)

for g in range(NUM_GENERATIONS):
    dataset = ILDataset(meta, language)
    
    train_dataloader, valid_dataloader, test_dataloader = split_dataset_into_dataloaders(dataset, batch_size=BATCH_SIZE)
    
    model = LSTMModel(vocab.full_vocab_size, MAX_LENGTH)
    optimizer = torch.optim.Adam(model.parameters())
    trainer = ILTrainer(model)
    trainer.to(device)

    i = 0
    while i < NUM_ITERATIONS:
        for (batch, targets) in train_dataloader:
            loss, acc = train_one_batch(trainer, batch, targets, optimizer)
            if i % LOG_INTERVAL == 0:
                valid_loss_meter, valid_acc_meter, sequences = evaluate(
                    trainer, valid_dataloader)
                print(
                    "{}/{} Iterations: val loss: {}, val accuracy: {}".format(
                        i,
                        NUM_ITERATIONS,
                        valid_loss_meter.avg,
                        valid_acc_meter.avg))
            i += 1

    language = infer_new_language(trainer, dataset)

0/1000 Iterations: val loss: 2.561823010444641, val accuracy: 0.17129629850387573
100/1000 Iterations: val loss: 2.4114818572998047, val accuracy: 0.2361111119389534
200/1000 Iterations: val loss: 2.414752244949341, val accuracy: 0.24247685074806213
300/1000 Iterations: val loss: 2.42464280128479, val accuracy: 0.2282986119389534
400/1000 Iterations: val loss: 2.4286916255950928, val accuracy: 0.23423032462596893
500/1000 Iterations: val loss: 2.4319578409194946, val accuracy: 0.24377893656492233
600/1000 Iterations: val loss: 2.433031439781189, val accuracy: 0.24638310074806213
700/1000 Iterations: val loss: 2.4328125715255737, val accuracy: 0.24609375
800/1000 Iterations: val loss: 2.433192253112793, val accuracy: 0.24276620894670486
900/1000 Iterations: val loss: 2.4349664449691772, val accuracy: 0.2404513880610466
0/1000 Iterations: val loss: 2.5610945224761963, val accuracy: 0.17737267911434174
100/1000 Iterations: val loss: 2.2065787315368652, val accuracy: 0.5138888955116272
200

In [41]:

    



get_topographical_similarity()

0.2853646369238426