In [None]:
import torch
from code.dataloader import DataHandler
from string import ascii_lowercase, punctuation, digits, ascii_uppercas
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Union
import json
import _jsonnet
from code.analyser import Analyser as Scorer
from code.agents import SenderInput, ReceiverOutput, RnnSenderReinforce, RnnReceiverDeterministic
from code.game import ReinforceGame as Game
from code.loss import ReconstructionLoss
import re
import os
import string
import random
from collections import OrderedDict

In [None]:
class objectview(object):
    '''
    An object that makes a dictionary's keys attributes of the object, so they can
    be called by subscripting (mimics the functionality of argparse)
    '''
    def __init__(self, d):
        self.__dict__ = d

args = objectview(json.loads(_jsonnet.evaluate_file('learnability_config.jsonnet')))

In [None]:
data = DataHandler(args)

In [None]:
vocab_size = args.signal_chars #for comp and null languages
#vocab_size = args.signal_chars+2 #for tok language
embedding_size = args.embedding_size
hidden_size = args.hidden_size
cell_type = args.rnn_cell
signal_len = args.signal_len-1

lr = args.learning_rate
sender_entropy = args.sender_entropy
gram_fn = args.gram_fn

In [None]:
with open(f"dicts/{gram_fn}_dict.json") as infile:
    grammar = json.load(infile)

In [None]:
train = data.comp_train_set

In [None]:
meanings = train[:][1]
meanings = meanings.view(len(meanings), 5, 32)
fmeanings = meanings.argmax(dim=-1)
signals = train[:][0]
messages = []
for x in signals:
    messages.append(x.tolist())
smessages = [[j for j in x if j != 0] for x in messages]

Get indices of redundant and non-redundant messages

In [None]:
redmeanings = []
nonredmeanings = []
redmessages = []
nonredmessages = []

In [None]:
for n,x in enumerate(fmeanings):
    if torch.equal(x[0], x[3]) or torch.equal(x[1], x[4]):
        redmeanings.append(x)
        redmessages.append(smessages[n])
    else:
        nonredmeanings.append(x)
        nonredmessages.append(smessages[n])

Convert list of numbers to 'message string'

In [None]:
allnos = [x for l in messages for x in l]
charmapping = {n+1:a for n, a in enumerate(string.ascii_lowercase+string.ascii_uppercase)}

In [None]:
redcharmessages = []
for m in redmessages:
    newstr = ""
    for mm in m:
        c = charmapping[mm]
        newstr += c
    redcharmessages.append(newstr)
nonredcharmessages = []
for m in nonredmessages:
    newstr = ""
    for mm in m:
        c = charmapping[mm]
        newstr += c
    nonredcharmessages.append(newstr)

Compute Jaccard similarity

In [None]:
def jaccard_similarity(list1, list2):
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(set(list1)) + len(set(list2))) - intersection
    return float(intersection) / union

In [None]:
def get_pair_stats(vocab):
    pairs1 = {}
    pairs2 = {}
    pairs3 = {}
    for word, frequency in vocab.items():
        symbols = [char for char in word]
        # count occurrences of pairs
        for i in range(len(symbols)):   # unigrams
            pair = (symbols[i])
            current_frequency = pairs1.get(pair, 0)
            pairs1[pair] = current_frequency + frequency
        for i in range(len(symbols) - 1):   # bigrams
            pair = (symbols[i], symbols[i + 1])
            current_frequency = pairs2.get(pair, 0)
            pairs2[pair] = current_frequency + frequency
        for i in range(len(symbols) - 2):   # trigrams
            pair = (symbols[i], symbols[i + 1], symbols[i + 2])
            current_frequency = pairs3.get(pair, 0)
            pairs3[pair] = current_frequency + frequency

    pairs1_descending = OrderedDict(sorted(pairs1.items(), key=lambda kv: kv[1], reverse=True))
    pairs2_descending = OrderedDict(sorted(pairs2.items(), key=lambda kv: kv[1], reverse=True))
    pairs3_descending = OrderedDict(sorted(pairs3.items(), key=lambda kv: kv[1], reverse=True))

    pairs1 = dict((''.join(k), v) for k,v in pairs1_descending.items())
    pairs2 = dict((''.join(k), v) for k,v in pairs2_descending.items())
    pairs3 = dict((''.join(k), v) for k,v in pairs3_descending.items())
    
    return pairs1, pairs2, pairs3

In [None]:
nonredict = {}
redict = {}
for entry in nonredcharmessages:
    try:
        nonredict[entry] += 1
    except KeyError:
        nonredict[entry] = 0
        nonredict[entry] += 1

for entry in redcharmessages:
    try:
        redict[entry] += 1
    except KeyError:
        redict[entry] = 0
        redict[entry] += 1

In [None]:
nonredcharmessages1 = random.sample(nonredcharmessages, len(redcharmessages))
othernonred = [x for x in nonredmessages if x not in nonredcharmessages1]
nonredcharmessages2 = random.sample(nonredcharmessages, len(redcharmessages))

In [None]:
nonredict1 = {}
nonredict2 = {}
redict = {}
for entry in nonredcharmessages1:
    try:
        nonredict1[entry] += 1
    except KeyError:
        nonredict1[entry] = 0
        nonredict1[entry] += 1
        
for entry in nonredcharmessages2:
    try:
        nonredict2[entry] += 1
    except KeyError:
        nonredict2[entry] = 0
        nonredict2[entry] += 1

for entry in redcharmessages:
    try:
        redict[entry] += 1
    except KeyError:
        redict[entry] = 0
        redict[entry] += 1

In [None]:
rdict1, rdict2, rdict3 = get_pair_stats(redict)
odict1, odict2, odict3 = get_pair_stats(nonredict1)
sdict1, sdict2, sdict3 = get_pair_stats(nonredict2)

red_frequencies = {}
red_frequencies['unigram'] = rdict1
red_frequencies['bigram'] = rdict2
red_frequencies['trigram'] = rdict3

nonred_frequencies = {}
nonred_frequencies['unigram'] = odict1
nonred_frequencies['bigram'] = odict2
nonred_frequencies['trigram'] = odict3

samp_frequencies = {}
samp_frequencies['unigram'] = sdict1
samp_frequencies['bigram'] = sdict2
samp_frequencies['trigram'] = sdict3

In [None]:
red_frequencies['bigram']

In [None]:
# #JACCARDS
sorted_reds_unis = sorted(red_frequencies['unigram'].items(), key=lambda item: item[1], reverse=True)
sorted_other_unis = sorted(nonred_frequencies['unigram'].items(), key=lambda item: item[1], reverse=True)
sorted_samps_unis = sorted(samp_frequencies['unigram'].items(), key=lambda item: item[1], reverse=True)

sorted_reds_unis = sorted_reds_unis[:100]
sorted_reds_unis = [k[0] for k in sorted_reds_unis]
sorted_other_unis = sorted_other_unis[:100]
sorted_other_unis = [k[0] for k in sorted_other_unis]
sorted_samps_unis = sorted_samps_unis[:100]
sorted_samps_unis = [k[0] for k in sorted_samps_unis]
uni_and_jaccard = jaccard_similarity(sorted_samps_unis, sorted_reds_unis)
uni_nonred_jaccard = jaccard_similarity(sorted_samps_unis, sorted_other_unis)

sorted_reds_bis = sorted(red_frequencies['bigram'].items(), key=lambda item: item[1], reverse=True)
sorted_other_bis = sorted(nonred_frequencies['bigram'].items(), key=lambda item: item[1], reverse=True)
sorted_samps_bis = sorted(samp_frequencies['bigram'].items(), key=lambda item: item[1], reverse=True)

sorted_reds_bis = sorted_reds_bis[:100]
sorted_reds_bis = [k[0] for k in sorted_reds_bis]
sorted_other_bis = sorted_other_bis[:100]
sorted_other_bis = [k[0] for k in sorted_other_bis]
sorted_samps_bis = sorted_samps_bis[:100]
sorted_samps_bis = [k[0] for k in sorted_samps_bis]
bi_and_jaccard = jaccard_similarity(sorted_samps_bis, sorted_reds_bis)
bi_nonred_jaccard = jaccard_similarity(sorted_samps_bis, sorted_other_bis)

sorted_reds_tris = sorted(red_frequencies['trigram'].items(), key=lambda item: item[1], reverse=True)
sorted_other_tris = sorted(nonred_frequencies['trigram'].items(), key=lambda item: item[1], reverse=True)
sorted_samps_tris = sorted(samp_frequencies['trigram'].items(), key=lambda item: item[1], reverse=True)

sorted_reds_tris = sorted_reds_tris[:100]
sorted_reds_tris = [k[0] for k in sorted_reds_tris]
sorted_other_tris = sorted_other_tris[:100]
sorted_other_tris = [k[0] for k in sorted_other_tris]
sorted_samps_tris = sorted_samps_tris[:100]
sorted_samps_tris = [k[0] for k in sorted_samps_tris]
tri_and_jaccard = jaccard_similarity(sorted_other_tris, sorted_reds_tris)
tri_nonred_jaccard = jaccard_similarity(sorted_samps_tris, sorted_other_tris)

In [None]:
print(uni_and_jaccard, uni_nonred_jaccard)

In [None]:
print(bi_and_jaccard, bi_nonred_jaccard)

In [None]:
print(tri_and_jaccard, tri_nonred_jaccard)