In [None]:
import torch
from torch.nn import functional as F
import csv
import egg.core as core
from discourse_lang import DataHandler
from string import ascii_lowercase, punctuation, digits
import json
import _jsonnet
from egg.zoo.channel.archs import Receiver, Sender
from reconstructionloss import ReconstructionLoss
import re
import os

N.B. Languages are referred to in code as follows:

No Elision --> Comp

Pronoun --> Tok

Pro-drop --> Null

To change language used, change dataset, gram_fn and setting in learnability_config.jsonnet

Options are:
- No Elision: redlarge_comp
- Pronoun: redlarge_tok
- Pro-drop: redlarge_null

In [None]:
class objectview(object):
    '''
    An object that makes a dictionary's keys attributes of the object, so they can
    be called by subscripting (mimics the functionality of argparse)
    '''
    def __init__(self, d):
        self.__dict__ = d

args = objectview(json.loads(_jsonnet.evaluate_file('learnability_config.jsonnet')))

In [None]:
data = DataHandler(args)

In [None]:
vocab_size = args.signal_chars #for comp and null languages
#vocab_size = args.signal_chars+2 #for tok language
embedding_size = args.embedding_size
hidden_size = args.hidden_size
cell_type = args.rnn_cell
signal_len = args.signal_len-1

lr = args.learning_rate
sender_entropy = args.sender_entropy
gram_fn = args.gram_fn

In [None]:
with open(f"dicts/{gram_fn}_dict.json") as infile:
    grammar = json.load(infile)
    
initial_chars = ascii_lowercase + punctuation + digits
msg_chars = 'E'  # to mark EOS
msg_chars += initial_chars[:vocab_size-1]    

In [None]:
def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint[1])
    sender = model.sender
    receiver = model.receiver
    optimizer.load_state_dict(checkpoint[2])
    return model, sender, receiver, optimizer, checkpoint[0]

Training the model

In [None]:
def quick_train(model, optimizer, train_dataset, val_dataset, epochs, run):
    
    train_checkpoints = []
    val_checkpoints = []
    
    model.return_raw = True
    final_break = False
    #model.to(device)
    
    for epoch in range(epochs):
        
        if final_break:
            break
        
        mean_loss, n_batches = 0, 0
        for data, target in train_dataset:
                    
            #data, target = data.to(device), target.to(device)  
            optimizer.zero_grad()
                        
            output = model(data)
            
            oo = output[0].view(output[0].shape[0]*5, 32)
            
            target = target.view(target.shape[0], 5, 32)
            target = target.argmax(dim=-1).view(target.shape[0]*5)
            
            loss = F.cross_entropy(oo, target)
            loss.backward()
            optimizer.step()
            
            mean_loss += loss.mean().item()
            n_batches += 1
        
        train_checkpoints.append((epoch, mean_loss/n_batches))
        print(f'Train Epoch: {epoch}, mean loss: {mean_loss / n_batches}')


        #Validation
        
        if epoch % 1 == 0:
            
            vmean_loss, vn_batches = 0, 0
            model.eval()
            
            with torch.no_grad():
                for data, target in val_dataset:
            
                    #data, target = data.to(device), target.to(device)  

                    output = model(data)

                    oo = output[0].view(output[0].shape[0]*5, 32)

                    target = target.view(target.shape[0], 5, 32)
                    target = target.argmax(dim=-1).view(target.shape[0]*5)

                    loss = F.cross_entropy(oo, target)

                    vmean_loss += loss.mean().item()
                    vn_batches += 1
        
                val_checkpoints.append((epoch, vmean_loss/vn_batches))
                print(f'Val Epoch: {epoch}, val mean loss: {vmean_loss / vn_batches}')
                torch.save(receiver.state_dict(), f"receiver_state_dicts_v2/comp/{run}/{epoch}.tar")
             
            model.train()
                
    return train_checkpoints, val_checkpoints

In [None]:
sender = Sender(n_features=160, n_hidden=hidden_size)

sender = core.RnnSenderReinforce(
    sender,
    vocab_size,
    embedding_size,
    hidden_size,
    cell='gru',
    max_len=signal_len,
    num_layers=1,
    )

receiver = Receiver(n_features=160, n_hidden=hidden_size)
receiver = core.RnnReceiverDeterministic(
    receiver,
    vocab_size,
    embedding_size,
    hidden_size,
    cell='gru',
    num_layers=1,
    )

loss = ReconstructionLoss(5, 32)
game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=sender_entropy,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        )
optimizer = torch.optim.Adam(game.parameters(), lr=lr)

training_loader = data.train_comp_loader
val_loader = data.val_comp_loader
test_loader = data.test_comp_loader

optimizer = torch.optim.Adam(receiver.parameters(), lr=lr)


In [None]:
checkpoints = []
for n in range(10):
    sender = Sender(n_features=160, n_hidden=hidden_size)

    sender = core.RnnSenderReinforce(
        sender,
        vocab_size,
        embedding_size,
        hidden_size,
        cell='gru',
        max_len=signal_len,
        num_layers=1,
        )

    receiver = Receiver(n_features=160, n_hidden=hidden_size)
    receiver = core.RnnReceiverDeterministic(
        receiver,
        vocab_size,
        embedding_size,
        hidden_size,
        cell='gru',
        num_layers=1,
        )

    loss = ReconstructionLoss(5, 32)
    game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            loss,
            sender_entropy_coeff=sender_entropy,
            receiver_entropy_coeff=0.0,
            length_cost=0.0,
            )
    optimizer = torch.optim.Adam(game.parameters(), lr=lr)
    
    training_loader = data.train_comp_loader
    val_loader = data.val_comp_loader
    test_loader = data.test_comp_loader

    optimizer = torch.optim.Adam(receiver.parameters(), lr=lr)
    
    train_checkpoints, val_checkpoints = quick_train(receiver, optimizer, training_loader, val_loader, epochs=100, run=n)
    checkpoints.append({str(n): [train_checkpoints, val_checkpoints]})

with open("learnability_final_checkpoints/receiver_comp_checkpoints", 'w') as outf:
    json.dump(checkpoints, outf)

Testing the model

In [None]:
def test(test_loader, path): 
    # Load the model that we saved at the end of the training loop 
    model = receiver 
    model.load_state_dict(torch.load(path)) 
    model.eval()
    #running_accuracy = 0
    exact_accuracy = 0
    partial_accuracy = 0
    total = 0

    with torch.no_grad(): 
        for data, target in test_loader: 
            
            inputs, outputs = data, target
            
            outputs = outputs.to(torch.float32)
            outputs = outputs.view(outputs.shape[0], 5, 32)
            outputs = outputs.argmax(dim=-1)

            predicted_outputs = model(inputs)[0]
            predicted_outputs = predicted_outputs.view(predicted_outputs.shape[0], 5, 32)
            predicted = predicted_outputs.argmax(dim=-1)
            
            exact = (torch.sum((predicted == outputs).detach(), dim=1) == 5).sum().item()
            partial = (predicted == outputs).sum().item()

            total += outputs.size(0)

            exact_accuracy += exact
            partial_accuracy += partial
        
        #print(total)
        print(f"exact: {exact_accuracy/total}, partial: {partial_accuracy/(total*5)}")
        return ({exact_accuracy/total}, {partial_accuracy/(total*5)})
        
        #print('inputs is: %d %%' % (100 * running_accuracy / total))

Test Accuracy Calculation (compute earliest epoch for 100% accuracy)
- This may take a minute or two to process for each run (loading in and testing all the checkpoints)

In [None]:
directory = "learnability_state_dicts/comp/1/"
res = []
for filename in os.listdir(directory):
    f = os.path.join(directory, filename)
    if ".DS_Store" in f:
        pass
    else:
        #print(f)
        epoch = re.match(r'.*\/(\d{1,2}).tar', f).group(1)
        #print(f)
        res.append((int(epoch), [test(training_loader, f), test(val_loader, f), test(test_loader, f)]))
#         test(training_loader, f)
#         test(val_loader, f)
#         test(test_loader, f)
#         print()
#sorted(res, key = lambda x: x[0])

sres = sorted(res, key = lambda x: x[0])

test_accs = [(x[0], list(x[1][2][1])[0]) for x in sres]
for x in test_accs:
    if x[1] == 1.0:
        print()
        print("Earliest Epoch: ", x[0])
        break

Predictive Ambiguity Calculation

In [None]:
receiver = Receiver(n_features=160, n_hidden=hidden_size)
receiver = core.RnnReceiverDeterministic(
    receiver,
    vocab_size,
    embedding_size,
    hidden_size,
    cell='gru',
    num_layers=1,
    )

In [None]:
model = receiver
model.load_state_dict(torch.load("receiver_state_dicts_v2/comp/9/25.tar"))
model.eval() 

In [None]:
train_set = data.comp_train_set
val_set = data.comp_val_set
test_set = data.comp_test_set

In [None]:
#Training
all_msgs = []
red_msgs = []
#semired_msgs = []
semired_noun_msgs = []
semired_verb_msgs = []
fullred_msgs = []
other_msgs = []

for msg, dat in train_set:
    all_msgs.append(msg)
    #if len(msg[msg.nonzero().squeeze().detach()]) < 10:
    dat = dat.view(5, 32)
    if torch.equal(dat[0], dat[3]) or torch.equal(dat[1], dat[4]):
        red_msgs.append(msg)
        #if len(msg[msg.nonzero().squeeze().detach()]) == 8:
        if torch.equal(dat[0], dat[3]) and torch.equal(dat[1], dat[4]):
            fullred_msgs.append(msg)
        elif torch.equal(dat[0], dat[3]):
            semired_noun_msgs.append(msg)
            #semired_msgs.append(msg)
        elif torch.equal(dat[1], dat[4]):
            semired_verb_msgs.append(msg)
#             if torch.equal(dat[0], dat[3]):
#                 semired_noun_msgs.append(msg)
#             elif torch.equal(dat[1], dat[4]):
#                 semired_verb_msgs.append(msg)
        #elif len(msg[msg.nonzero().squeeze().detach()]) < 8:
        #    fullred_msgs.append(msg)
    else:
        other_msgs.append(msg)

In [None]:
all_signals = torch.stack(all_msgs)
red_signals = torch.stack(red_msgs)
#semired_signals = torch.stack(semired_msgs)
semired_noun_signals = torch.stack(semired_noun_msgs)
semired_verb_signals = torch.stack(semired_verb_msgs)
fullred_signals = torch.stack(fullred_msgs)
other_signals = torch.stack(other_msgs)

all_outputs = model(all_signals)[0]
red_outputs = model(red_signals)[0]
#semired_outputs = receiver(semired_signals)[0]
semired_noun_outputs = model(semired_noun_signals)[0]
semired_verb_outputs = model(semired_verb_signals)[0]
fullred_outputs = model(fullred_signals)[0]
other_outputs = model(other_signals)[0]

all_reconents = torch.distributions.Categorical(logits = all_outputs.view(len(all_outputs), 5, 32)).entropy()
red_reconents = torch.distributions.Categorical(logits = red_outputs.view(len(red_outputs), 5, 32)).entropy()
#semired_reconents = torch.distributions.Categorical(logits = semired_outputs.view(len(semired_outputs), 5, 32)).entropy()
semired_noun_reconents = torch.distributions.Categorical(logits = semired_noun_outputs.view(len(semired_noun_outputs), 5, 32)).entropy()
semired_verb_reconents = torch.distributions.Categorical(logits = semired_verb_outputs.view(len(semired_verb_outputs), 5, 32)).entropy()
fullred_reconents = torch.distributions.Categorical(logits = fullred_outputs.view(len(fullred_outputs), 5, 32)).entropy()
other_reconents = torch.distributions.Categorical(logits = other_outputs.view(len(other_outputs), 5, 32)).entropy()

mean_all_entrops = []
mean_red_entrops = []
#mean_semired_entrops = []
mean_semired_noun_entrops = []
mean_semired_verb_entrops = []
mean_fullred_entrops = []
mean_other_entrops = []

Write values to a csv

In [None]:
with open("receivercompreconents_v2.csv", 'w') as inf:
    writer = csv.writer(inf)
    
    header = ["train.reconent1", "train.reconent2", "train.reconent3", "train.reconent4", "train.reconent5", "train.semirednoun_reconent1", "train.semirednoun_reconent2", "train.semirednoun_reconent3", "train.semirednoun_reconent4", "train.semirednoun_reconent5", "train.semiredverb_reconent1", "train.semiredverb_reconent2", "train.semiredverb_reconent3", "train.semiredverb_reconent4", "train.semiredverb_reconent5", "train.red_reconent1", "train.red_reconent2", "train.red_reconent3", "train.red_reconent4", "train.red_reconent5", "train.allred_reconent1", "train.allred_reconent2", "train.allred_reconent3", "train.allred_reconent4", "train.allred_reconent5", "train.other_reconent1", "train.other_reconent2", "train.other_reconent3", "train.other_reconent4", "train.other_reconent5"]
    
    #header = ["test.reconent1", "test.reconent2", "test.reconent3", "test.reconent4", "test.reconent5", "test.semirednoun1_reconent", "test.semirednoun2_reconent", "test.semirednoun3_reconent", "test.semirednoun4_reconent", "test.semirednoun5_reconent", "test.semiredverb1_reconent", "test.semiredverb2_reconent", "test.semiredverb3_reconent", "test.semiredverb4_reconent", "test.semiredverb5_reconent", "test.red_reconent1", "test.red_reconent2", "test.red_reconent3", "test.red_reconent4", "test.red_reconent5", "test.allred_reconent1", "test.allred_reconent2", "test.allred_reconent3", "test.allred_reconent4", "test.allred_reconent5", "test.other_reconent1", "test.other_reconent2", "test.other_reconent3", "test.other_reconent4", "test.other_reconent5"]
    
    writer.writerow(header)
    
    csv_data = []

    for i in range(len(all_reconents[0])):
        mean_all_entrops.append(all_reconents[:,i].mean().item())
        
    csv_data.append(mean_all_entrops[0])
    csv_data.append(mean_all_entrops[1])
    csv_data.append(mean_all_entrops[2])
    csv_data.append(mean_all_entrops[3])
    csv_data.append(mean_all_entrops[4])
    
    for i in range(len(semired_noun_reconents[0])):
        mean_semired_noun_entrops.append(semired_noun_reconents[:,i].mean().item())
    
    csv_data.append(mean_semired_noun_entrops[0])
    csv_data.append(mean_semired_noun_entrops[1])
    csv_data.append(mean_semired_noun_entrops[2])
    csv_data.append(mean_semired_noun_entrops[3])
    csv_data.append(mean_semired_noun_entrops[4])
    
    for i in range(len(semired_verb_reconents[0])):
        mean_semired_verb_entrops.append(semired_verb_reconents[:,i].mean().item())
    
    csv_data.append(mean_semired_verb_entrops[0])
    csv_data.append(mean_semired_verb_entrops[1])
    csv_data.append(mean_semired_verb_entrops[2])
    csv_data.append(mean_semired_verb_entrops[3])
    csv_data.append(mean_semired_verb_entrops[4])
    
    for i in range(len(fullred_reconents[0])):
        mean_fullred_entrops.append(fullred_reconents[:,i].mean().item())
        
    csv_data.append(mean_fullred_entrops[0])
    csv_data.append(mean_fullred_entrops[1])
    csv_data.append(mean_fullred_entrops[2])
    csv_data.append(mean_fullred_entrops[3])
    csv_data.append(mean_fullred_entrops[4])
    
    for i in range(len(red_reconents[0])):
        mean_red_entrops.append(red_reconents[:,i].mean().item())
    
    csv_data.append(mean_red_entrops[0])
    csv_data.append(mean_red_entrops[1])
    csv_data.append(mean_red_entrops[2])
    csv_data.append(mean_red_entrops[3])
    csv_data.append(mean_red_entrops[4])
    
    for i in range(len(other_reconents[0])):
        mean_other_entrops.append(other_reconents[:,i].mean().item())
    
    csv_data.append(mean_other_entrops[0])
    csv_data.append(mean_other_entrops[1])
    csv_data.append(mean_other_entrops[2])
    csv_data.append(mean_other_entrops[3])
    csv_data.append(mean_other_entrops[4])
    
    writer.writerow(csv_data)