In [1]:
import torch
from torch import nn, optim
import torch.utils.data as torch_data
import torch.nn.functional as F
import itertools as it
import numpy as np
import random
from itertools import combinations, product
from mutex import EncDec, Vocab, batch_seqs, Mutex
from absl import app, flags
import sys

In [2]:
DEVICE = torch.device("cuda:0")

In [3]:
input_symbols_list   = set(['dax', 'lug', 'wif', 'zup', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer'])
output_symbols_list  = set(['RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK'])
def encode(data,vocab):
    encoded = []
    for (inp,out) in data:
        encoded.append(( [vocab.sos()]  + vocab.encode(inp) + [vocab.eos()], [vocab.sos()] + vocab.encode(out) + [vocab.eos()]))
    return encoded

In [4]:
f1  = lambda w:  w + w + w
f2  = lambda w1, w2: w1 + w2 + w1
f3  = lambda w1, w2: w2 + w1
get = lambda vals, hsh: [hsh[val] for val in vals]

def unary(f, w, colormap, fmap):
    return (w + [fmap[f]], get(f(w),colormap))

def binary(f, w1, w2, colormap, fmap):
    return (w1 + [fmap[f]] + w2, get(f(w1,w2),colormap))

def binary_mapped(f, p1, p2, fmap):
    w1, o1 = p1
    w2, o2 = p2
    return (w1 + [fmap[f]] + w2, f(o1,o2))

def samplef(fs,words,colormap,fmap):
    f = random.choice(fs)
    if f == f1:
        w = random.choice(tuple(words))
        return unary(f, [w], colormap, fmap)
    else:
        w1, w2 = random.choice(list(product(words,words)))
        return binary(f, [w1], [w2], colormap, fmap)

In [5]:
def generate_fig2_exp(input_symbols, output_symbols):
        words     = set(random.sample(input_symbols,4))
        colors    = set(random.sample(output_symbols,4))
        colormap  = dict(zip(words, colors))
        fnames    = random.sample(input_symbols - words,3)
        fmap      = dict(zip([f1,f2,f3], fnames))
        print("color map: ", colormap)
        print("function names: ", fnames)
        
        trn,tst  = [],[]
        #Primitives
        for (i,w) in enumerate(words):
            trn.append(([w],get([w],colormap)))
            
            
        combs = set(combinations(words, r=2))
        #Function 1 : x f1 -> X X X 
        trnwords = set(random.sample(words,2))
        tstwords = set(random.sample(words-trnwords,2)) 
        for w in trnwords:
            trn.append(unary(f1,[w],colormap, fmap)) 
        for w in tstwords:
            tst.append(unary(f1,[w],colormap, fmap)) 
        
        #Function 2 : x f2 y-> X Y X 
        trnpairs = set(random.sample(combs,2))
        tstpairs = set(random.sample(combs-trnpairs,2)) 
        for (w1,w2) in trnpairs:
            trn.append(binary(f2,[w1],[w2],colormap, fmap))
        for (w1,w2) in tstpairs:
            tst.append(binary(f2,[w1],[w2],colormap, fmap))
                       
        #Function 3 : x f3 y-> Y X
        trnpairs = set(random.sample(combs,2))
        tstpairs = set(random.sample(combs-trnpairs,2)) 
        for (w1,w2) in trnpairs:
            trn.append(binary(f3, [w1], [w2], colormap, fmap))
        for (w1,w2) in tstpairs:
            tst.append(binary(f3, [w1], [w2], colormap, fmap))
            
        #Study Compositions
        for i in range(2):
            w1 = random.choice(tuple(words))
            p1 = ([w1], [colormap[w1]])
            center  = random.choice((f2,f3))
            p2 = samplef((f1,),words,colormap,fmap)
            trn.append(binary_mapped(center,p1,p2,fmap))
        
        # Order of The Operations: fother (*) always before fcenter (+)
        if random.random() > 0.5:
            fcenter,fother = f2,f3
        else:
            fcenter,fother = f3,f2

        for i in range(2): 
            w1 = random.choice(tuple(words))
            p1 = ([w1], [colormap[w1]])
            center = fcenter
            p2 = samplef((fother,),words,colormap,fmap)
            if i == 1: p2, p1 = p1,p2
            trn.append(binary_mapped(center,p1,p2,fmap))
        
        #Test Compositions
        for i in range(2):
            w1 = random.choice(tuple(words))
            p1 = ([w1], [colormap[w1]])
            center  = random.choice((f2,f3))
            p2 = samplef((f1,),words,colormap,fmap)
            tst.append(binary_mapped(center,p1,p2,fmap))
        
        w1 = random.choice(tuple(words))
        p1 = ([w1], [colormap[w1]])
        center = fcenter
        p2 = samplef((fother,),words,colormap,fmap)
        if random.random() > 0.5: p2, p1 = p1,p2
        tst.append(binary_mapped(center,p1,p2,fmap))
                
        for i in range(2):
            p1 = samplef((f1,),words,colormap,fmap)
            center  = fcenter
            p2 = samplef((fother,),words,colormap,fmap)
            if random.random() > 0.5: p2, p1 = p1,p2
            tst.append(binary_mapped(center,p1,p2,fmap))
            
        return trn,tst

In [6]:
study, test = generate_fig2_exp(input_symbols_list, output_symbols_list)

color map:  {'kiki': 'GREEN', 'wif': 'PURPLE', 'fep': 'BLUE', 'blicket': 'PINK'}
function names:  ['dax', 'gazzer', 'zup']


In [7]:
study

[(['kiki'], ['GREEN']),
 (['wif'], ['PURPLE']),
 (['fep'], ['BLUE']),
 (['blicket'], ['PINK']),
 (['wif', 'dax'], ['PURPLE', 'PURPLE', 'PURPLE']),
 (['blicket', 'dax'], ['PINK', 'PINK', 'PINK']),
 (['kiki', 'gazzer', 'wif'], ['GREEN', 'PURPLE', 'GREEN']),
 (['kiki', 'gazzer', 'fep'], ['GREEN', 'BLUE', 'GREEN']),
 (['wif', 'zup', 'blicket'], ['PINK', 'PURPLE']),
 (['kiki', 'zup', 'blicket'], ['PINK', 'GREEN']),
 (['fep', 'zup', 'kiki', 'dax'], ['GREEN', 'GREEN', 'GREEN', 'BLUE']),
 (['kiki', 'gazzer', 'blicket', 'dax'],
  ['GREEN', 'PINK', 'PINK', 'PINK', 'GREEN']),
 (['fep', 'gazzer', 'kiki', 'zup', 'wif'],
  ['BLUE', 'PURPLE', 'GREEN', 'BLUE']),
 (['blicket', 'zup', 'blicket', 'gazzer', 'wif'],
  ['PINK', 'PINK', 'PURPLE', 'PINK', 'PINK'])]

In [8]:
test

[(['kiki', 'dax'], ['GREEN', 'GREEN', 'GREEN']),
 (['fep', 'dax'], ['BLUE', 'BLUE', 'BLUE']),
 (['wif', 'gazzer', 'blicket'], ['PURPLE', 'PINK', 'PURPLE']),
 (['wif', 'gazzer', 'fep'], ['PURPLE', 'BLUE', 'PURPLE']),
 (['fep', 'zup', 'blicket'], ['PINK', 'BLUE']),
 (['wif', 'zup', 'fep'], ['BLUE', 'PURPLE']),
 (['wif', 'gazzer', 'kiki', 'dax'],
  ['PURPLE', 'GREEN', 'GREEN', 'GREEN', 'PURPLE']),
 (['fep', 'gazzer', 'wif', 'dax'],
  ['BLUE', 'PURPLE', 'PURPLE', 'PURPLE', 'BLUE']),
 (['kiki', 'gazzer', 'kiki', 'zup', 'fep'],
  ['GREEN', 'BLUE', 'GREEN', 'GREEN']),
 (['kiki', 'zup', 'fep', 'gazzer', 'wif', 'dax'],
  ['BLUE', 'GREEN', 'PURPLE', 'PURPLE', 'PURPLE', 'BLUE', 'GREEN']),
 (['fep', 'dax', 'gazzer', 'blicket', 'zup', 'kiki'],
  ['BLUE', 'BLUE', 'BLUE', 'GREEN', 'PINK', 'BLUE', 'BLUE', 'BLUE'])]

In [9]:
FLAGS = flags.FLAGS
flags_dict = FLAGS._flags()
keys_list = [keys for keys in flags_dict]
for keys in keys_list: delattr(FLAGS,keys)
flags.DEFINE_integer("dim", 200, "trasnformer dimension")
flags.DEFINE_integer("n_layers", 1, "number of rnn layers")
flags.DEFINE_integer("n_batch", 1, "batch size")
flags.DEFINE_integer("n_epochs",100, "number of training epochs")
flags.DEFINE_float("lr", 0.001, "learning rate")
flags.DEFINE_float("dropout", 0.0, "dropout")
flags.DEFINE_string("save_model", "model.m", "model save location")
flags.DEFINE_integer("seed", 0, "random seed")

In [10]:
FLAGS(['mutex.ipynb'])

['mutex.ipynb']

In [11]:
def collate(batch):
    inp, out = zip(*batch)
    inp = batch_seqs(inp).to(DEVICE)
    out = batch_seqs(out).to(DEVICE)
    return inp, out

def pretrain(model, train_dataset, val_dataset):
    opt = optim.Adam(model.parameters(), lr=FLAGS.lr)
    train_loader = torch_data.DataLoader(
        train_dataset, batch_size=FLAGS.n_batch, shuffle=True, 
        collate_fn=collate
    )
    best_loss  = np.inf
    for i_epoch in range(FLAGS.n_epochs):
        model.train()
        train_loss = 0
        train_batches = 0
        for inp, _ in train_loader:
            x = inp[:-1,:]
            pred, *extras = model(None, x.shape[0], x)
            output = pred.view(-1, len(model.vocab))
            loss = model.nllreduce(output,inp[1:, :].view(-1))
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            opt.step()
            train_loss    += loss.item() * inp.shape[1]
            train_batches += inp.shape[1]

        if (i_epoch + 1) % 3 != 0:
            continue

        curr_loss = train_loss / train_batches
        best_loss = min(best_loss, curr_loss)
        print(curr_loss)
        torch.save(model.state_dict(), FLAGS.save_model)

    print("best_loss", best_loss)

    
    
def train(model, train_dataset, val_dataset):
    opt = optim.Adam(model.parameters(), lr=FLAGS.lr)
    train_loader = torch_data.DataLoader(
        train_dataset, batch_size=FLAGS.n_batch, shuffle=True, 
        collate_fn=collate
    )
    best_f1  = -np.inf
    best_acc = -np.inf
    for i_epoch in range(FLAGS.n_epochs):
        model.train()
        train_loss = 0
        train_batches = 0
        for inp, out in train_loader:
            nll = model(inp, out)
            loss = nll.mean()
            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            opt.step()
            train_loss += loss.item()
            train_batches += 1

        if (i_epoch + 1) % 3 != 0:
            continue

        print(train_loss / train_batches)
        acc, f1 = validate(model, val_dataset)
        print(f"epoch_{i_epoch}_acc", acc)
        print(f"epoch_{i_epoch}_f1", f1)
        best_f1 = max(best_f1, f1)
        best_acc = max(best_acc, acc)
        print()
        torch.save(model.state_dict(), FLAGS.save_model)

    print("final_acc", acc)
    print("final_f1", f1)
    print("best_acc", best_acc)
    print("best_f1", best_f1)

def eval_format(vocab, seq):
    if vocab.eos() in seq:
        seq = seq[:seq.index(vocab.eos())+1]
    seq = seq[1:-1]
    return vocab.decode(seq)

def validate(model, val_dataset, vis=False, tag=[]):
    model.eval()
    first = True
    val_loader = torch_data.DataLoader(
        val_dataset, batch_size=FLAGS.n_batch, shuffle=True, 
        collate_fn=collate
    )
    total = 0
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    with torch.no_grad():
        for inp, out in val_loader:
            pred, _ = model.sample(inp, temp=1.0, max_len=40, greedy=True)
            for i, seq in enumerate(pred):
                ref = out[:, i].detach().cpu().numpy().tolist()
                ref = eval_format(model.vocab, ref)
                pred_here = eval_format(model.vocab, pred[i])
                correct_here = pred_here == ref
                correct += correct_here
                tp_here = len([p for p in pred_here if p in ref])
                tp += tp_here
                fp_here = len([p for p in pred_here if p not in ref])
                fp += fp_here
                fn_here = len([p for p in ref if p not in pred_here])
                fn += fn_here
                path = "/" + "/".join(tag)
                if vis:
                    print(f"@{path}/{total}", correct_here, tp_here, fp_here, fn_here)
                    inp_lst = inp[:, i].detach().cpu().numpy().tolist()
                    print(eval_format(model.vocab, inp_lst))
                    print("gold", ref)
                    print("pred", pred_here)
                    print(pred_here == ref)
                    print()
                total += 1

    acc = correct / total
    prec = tp / (tp + fp)
    rec = tp / (tp + fn)
    if prec == 0 or rec == 0:
        f1 = 0
    else:
        f1 = 2 * prec * rec / (prec + rec)
    return acc, f1

In [12]:
def main(model=None):
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)
    vocab = Vocab()
    for sym in input_symbols_list.union(output_symbols_list): 
        vocab.add(sym)
        
    study, test = generate_fig2_exp(input_symbols_list, output_symbols_list)
    
    train_items, test_items = encode(study,vocab), encode(test,vocab)
    
    outlist = list(output_symbols_list)
    def py(batch, max_len):
        ys = []
        for i in range(batch):
            length = random.randrange(1,max_len-1)
            symbols = random.choices(outlist, k=length)
            encoded = [vocab.sos()]  +  vocab.encode(symbols) + [vocab.eos()]
            ys.append(encoded)
        return batch_seqs(ys).to(DEVICE)
        
    if model is None:
        model = Mutex(vocab, 
                      FLAGS.dim, 
                      FLAGS.dim, py, 
                      copy=False, 
                      n_layers=FLAGS.n_layers, 
                      self_att=False, 
                      dropout=FLAGS.dropout,
                      lamda=0.1,
                      Nsample=50).to(DEVICE)
        
    pretrain(model.px, train_items + test_items, test_items)

    print("px samples: ")
    print(model.sample_px(5))
    print("py samples: ")
    print(model.sample_py(5))
    train(model, train_items, test_items)
    return model, study, test

In [13]:
import importlib
import mutex
importlib.reload(mutex)
from mutex import EncDec, Vocab, batch_seqs, Mutex

In [14]:
model, study, test = main()

color map:  {'dax': 'GREEN', 'gazzer': 'PURPLE', 'fep': 'RED', 'blicket': 'PINK'}
function names:  ['tufa', 'zup', 'kiki']
1.2795107913017274
1.2489318227767945
1.1783106231689453
1.1710542941093445
1.0849985432624818
1.0763975143432618
1.0682159817218781
1.025890244245529
0.9886056041717529
1.0331876802444457
1.005784273147583
1.0127952337265014
1.0041201949119567
0.9804985404014588
0.9978327453136444
0.9939659976959229
0.9818767368793487
1.0010983395576476
0.9893605232238769
0.9752936637401581
0.9634255373477936
0.9820947217941284
0.9933621227741242
0.9714501488208771
0.9614493381977082
0.9651017999649047
0.9574231529235839
0.9480963158607483
0.9615914595127105
0.9597129058837891
0.9542734599113465
0.9421617722511292
0.9527459383010864
best_loss 0.9421617722511292
px samples: 
[['<s>', 'gazzer', 'kiki', 'blicket', '</s>'], ['<s>', 'gazzer', 'kiki', 'blicket', '</s>'], ['<s>', 'gazzer', 'tufa', '</s>'], ['<s>', 'gazzer', 'tufa', '</s>'], ['<s>', 'gazzer', 'zup', 'dax', 'tufa', '</s>']

In [15]:
model.temp = 1.0
model.sample_qxy(model.py(5,5))

[['<s>', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW'],
 ['<s>', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW'],
 ['<s>', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW'],
 ['<s>', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW'],
 ['<s>', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW', 'YELLOW']]

In [16]:
model.temp = 1.0
model.sample_px(50)

[['<s>', 'gazzer', 'kiki', 'blicket', '</s>'],
 ['<s>', 'gazzer', '</s>'],
 ['<s>', 'dax', 'zup', 'fep', '</s>'],
 ['<s>', 'gazzer', 'kiki', 'blicket', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'kiki', 'blicket', '</s>'],
 ['<s>', 'blicket', '</s>'],
 ['<s>', 'gazzer', 'zup', 'dax', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'zup', 'dax', 'tufa', '</s>'],
 ['<s>', 'fep', 'zup', 'gazzer', 'kiki', 'dax', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'kiki', 'blicket', 'zup', 'fep', '</s>'],
 ['<s>', 'gazzer', 'kiki', 'blicket', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'kiki', 'blicket', '</s>'],
 ['<s>', 'dax', '</s>'],
 ['<s>', 'fep', 'kiki', 'blicket', 'zup', 'dax', '</s>'],
 ['<s>', 'fep', 'zup', 'blicket', '</s>'],
 ['<s>', 'blicket', '</s>'],
 ['<s>', 'fep', 'kiki', 'blicket', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'zup', 'dax', 'tufa', '</s>'],
 ['<s>', 'blicket', 'tufa', 'kiki', 'gazzer', 'zup', 'gazzer', '</s>'],
 ['<s>', 'dax', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'tufa', '</s>'],
 ['<s>', 'gazzer', 'zup', 'da

In [17]:
validate(model, encode(test,model.vocab), vis=True)

@//0 False 2 0 0
['gazzer', 'zup', 'fep']
gold ['PURPLE', 'RED', 'PURPLE']
pred ['RED', 'PURPLE']
False

@//1 False 1 1 0
['dax', 'tufa']
gold ['GREEN', 'GREEN', 'GREEN']
pred ['GREEN', 'RED']
False

@//2 False 3 1 1
['gazzer', 'kiki', 'blicket', 'zup', 'fep']
gold ['PINK', 'RED', 'PINK', 'PURPLE']
pred ['PINK', 'GREEN', 'PINK', 'RED']
False

@//3 False 3 0 0
['fep', 'zup', 'blicket']
gold ['RED', 'PINK', 'RED']
pred ['PINK', 'RED', 'PINK']
False

@//4 False 4 0 0
['blicket', 'tufa', 'kiki', 'gazzer', 'zup', 'gazzer']
gold ['PURPLE', 'PURPLE', 'PURPLE', 'PINK', 'PINK', 'PINK']
pred ['PINK', 'PINK', 'PINK', 'PURPLE']
False

@//5 False 3 0 0
['dax', 'kiki', 'fep']
gold ['RED', 'GREEN']
pred ['GREEN', 'RED', 'GREEN']
False

@//6 False 1 3 2
['gazzer', 'zup', 'dax', 'tufa']
gold ['PURPLE', 'GREEN', 'GREEN', 'GREEN', 'PURPLE']
pred ['PINK', 'GREEN', 'PINK', 'RED']
False

@//7 True 2 0 0
['gazzer', 'kiki', 'blicket']
gold ['PINK', 'PURPLE']
pred ['PINK', 'PURPLE']
True

@//8 False 1 1 0
['ga

(0.18181818181818182, 0.8125)

In [18]:
validate(model, encode(study,model.vocab), vis=True)

@//0 True 5 0 0
['blicket', 'zup', 'blicket', 'tufa']
gold ['PINK', 'PINK', 'PINK', 'PINK', 'PINK']
pred ['PINK', 'PINK', 'PINK', 'PINK', 'PINK']
True

@//1 True 2 0 0
['gazzer', 'kiki', 'fep']
gold ['RED', 'PURPLE']
pred ['RED', 'PURPLE']
True

@//2 True 3 0 0
['blicket', 'tufa']
gold ['PINK', 'PINK', 'PINK']
pred ['PINK', 'PINK', 'PINK']
True

@//3 True 4 0 0
['gazzer', 'kiki', 'blicket', 'tufa']
gold ['PINK', 'PINK', 'PINK', 'PURPLE']
pred ['PINK', 'PINK', 'PINK', 'PURPLE']
True

@//4 True 1 0 0
['blicket']
gold ['PINK']
pred ['PINK']
True

@//5 True 4 0 0
['fep', 'kiki', 'blicket', 'zup', 'dax']
gold ['PINK', 'GREEN', 'PINK', 'RED']
pred ['PINK', 'GREEN', 'PINK', 'RED']
True

@//6 True 3 0 0
['dax', 'zup', 'fep']
gold ['GREEN', 'RED', 'GREEN']
pred ['GREEN', 'RED', 'GREEN']
True

@//7 True 4 0 0
['blicket', 'zup', 'fep', 'kiki', 'fep']
gold ['RED', 'PINK', 'RED', 'PINK']
pred ['RED', 'PINK', 'RED', 'PINK']
True

@//8 True 3 0 0
['fep', 'tufa']
gold ['RED', 'RED', 'RED']
pred ['RED'

(1.0, 1.0)