In [137]:
%load_ext autoreload
%autoreload 2

import evals
from util import Logger

from collections import Counter
import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler as opt_sched
from torch.autograd import Variable

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [138]:
N_BATCH = 256
N_EMBED = 512

In [139]:
def unwrap(var):
    return var.data.numpy()

In [140]:
class QATask(object):
    N_FIELDS = 2
    N_VALUES = 4
    
    @classmethod
    def sample(cls):
        n_fields = cls.N_FIELDS
        n_values = cls.N_VALUES
        
        tar_pairs = []
        dis_pairs = []
        common_field = np.random.randint(n_fields)
        common_value = np.random.randint(n_values)
        if np.random.random() < 0.75:
            tar_pairs.append((common_field, common_value))
            dis_pairs.append((common_field, common_value))
        
        assert n_fields == 2
        other_field = 1 - common_field
        while True:
            ov1, ov2 = np.random.randint(n_values, size=2)
            if ov1 != ov2:
                break
        tar_pairs.append((other_field, ov1))
        dis_pairs.append((other_field, ov2))
        
        return QATask.create_with(tar_pairs, dis_pairs)
                      
    @classmethod
    def create_with(cls, tar_pairs, dis_pairs):
        n_fields = cls.N_FIELDS
        n_values = cls.N_VALUES
        
        tar_feats = np.zeros((n_fields, n_values))
        dis_feats = np.zeros((n_fields, n_values))
        
        for pair in tar_pairs:
            tar_feats[pair] = 1
        for pair in dis_pairs:
            dis_feats[pair] = 1
            
        assert n_fields == 2
        pretty = [''.join('%d:%d' % pair for pair in sorted(pairs)) for pairs in [tar_pairs, dis_pairs]]
        return QATask(tar_feats.ravel(), dis_feats.ravel(), pretty)
    
    def __init__(self, tar_feats, dis_feats, pretty):
        self.tar_feats = tar_feats
        self.dis_feats = dis_feats
        self.pretty = pretty

In [141]:
N_HIDDEN = 256

class Decoder(nn.Module):
    def __init__(self, vocab, start_sym, stop_sym):
        hid = N_HIDDEN
        super().__init__()
        self._vocab = vocab
        self._start_id = vocab[start_sym]
        self._stop_id = vocab[stop_sym]

        #self._embed = nn.Linear(len(vocab), hid)
        #self._rnn = nn.GRU(input_size=hid, hidden_size=hid, num_layers=1)
        self._rnn = nn.GRU(input_size=len(vocab), hidden_size=hid, num_layers=1)
        self._predict = nn.Linear(hid, len(vocab))
        self._softmax = nn.Softmax(dim=1)

    def forward(self, state, inp):
        #emb = self._embed(inp)
        emb = inp
        rep, enc = self._rnn(emb, state)
        logits = self._predict(rep)
        return enc, logits

    def decode(self, init_state, max_len, sample=False):
        n_stack, n_batch, _ = init_state.shape
        out = [[self._start_id] for _ in range(n_batch)]
        tok_inp = [self._start_id for _ in range(n_batch)]
        done = [False for _ in range(n_batch)]
        state = init_state
        for _ in range(max_len):
            hot_inp = np.zeros((1, n_batch, len(self._vocab)))
            for i, t in enumerate(tok_inp):
                hot_inp[0, i, t] = 1
            hot_inp = Variable(torch.FloatTensor(hot_inp))
            if init_state.is_cuda:
                hot_inp = hot_inp.cuda()
            new_state, label_logits = self(state, hot_inp)
            label_logits = label_logits.squeeze(0)
            label_probs = unwrap(self._softmax(label_logits))
            new_tok_inp = []
            for i, row in enumerate(label_probs):
                if sample:
                    tok = np.random.choice(row.size, p=row)
                else:
                    tok = row.argmax()
                new_tok_inp.append(tok)
                if not done[i]:
                    out[i].append(tok)
                done[i] = done[i] or tok == self._stop_id
            state = new_state
            tok_inp = new_tok_inp
            if all(done):
                break
        return out

class SModel(nn.Module):
    def __init__(self, n_vocab):
        super().__init__()
        self._rep = nn.Linear(QATask.N_FIELDS * QATask.N_VALUES, N_HIDDEN)
        self.vocab = {'<s>': 0, '</s>': 1}
        self.vocab.update({chr(ord('a') + i): i+2 for i in range(n_vocab)})
        self._decoder = Decoder(self.vocab, '<s>', '</s>')
        self._nll = nn.CrossEntropyLoss(reduce=False)
        self._softmax = nn.Softmax(dim=1)
        self._log_softmax = nn.LogSoftmax(dim=1)
        
    def unencode(self, t):
        for k, v in self.vocab.items():
            if v == t:
                return k
        
    def sample(self, obs, max_len, max=False):
        rep = self._rep(obs).unsqueeze(0)
        dec = self._decoder.decode(rep, max_len, sample=not max)
        return dec
    
    def forward(self, obs, msg, msg_tgts, mask):
        rep = self._rep(obs).unsqueeze(0)
        _, logits = self._decoder(rep, msg)
        time, batch, vocab = logits.shape
        logits = logits.view(time * batch, vocab)
        msg_tgts = msg_tgts.view(time * batch)
        nll = self._nll(logits, msg_tgts).view(time, batch) * mask
        ent = -(self._softmax(logits) * self._log_softmax(logits)).sum(dim=1).view(time, batch) * mask
        
        return nll.sum(dim=0), ent.sum(dim=0)
        
class LModel(nn.Module):
    def __init__(self, s_model):
        super().__init__()
        n_vocab = len(s_model.vocab)
        self._rep = nn.Linear(QATask.N_FIELDS * QATask.N_VALUES, N_HIDDEN)
        #self._rnn = nn.GRU(input_size=n_vocab, hidden_size=N_HIDDEN, num_layers=1)
        self._proj = nn.Linear(n_vocab, N_HIDDEN)
        self._nll = nn.CrossEntropyLoss(reduce=False)
        self._softmax = nn.Softmax(dim=1)
        
    def forward(self, tar_feats, dis_feats, msg_feats):
        rep_tar = self._rep(tar_feats)
        rep_dis = self._rep(dis_feats)
        #_, pred = self._rnn(msg_feats)
        pred = self._proj(msg_feats.sum(dim=0))
        pred = pred.squeeze(0)
        
        scores_tar = (pred * rep_tar).sum(1)
        scores_dis = (pred * rep_dis).sum(1)
        scores = torch.stack((scores_tar, scores_dis), dim=1)
        
        indices = torch.LongTensor(scores.shape[0]).zero_()
        nll = self._nll(scores, Variable(indices))
        
        probs = unwrap(self._softmax(scores))
        preds = np.asarray([np.random.choice(2, p=row) for row in probs])
        return nll, (preds == 0)

In [150]:
N_BATCH = 64
MAX_LEN = 8
VOCAB_SIZE = 16
COMM_COST = .01
START_COST = .05
ENT_BONUS = .01

def str_dist(x, y):
    c1 = Counter(x)
    c2 = Counter(y)
    s1 = sum(c1.values())
    s2 = sum(c2.values())
    c1 = Counter({k: v / s1 for k, v in c1.items()})
    c2 = Counter({k: v / s2 for k, v in c2.items()})
    return sum(abs(a) for a in (c1-c2).values())

def validate(s_model, l_model):
    def get_msg(task):
        msg, = s_model.sample(Variable(torch.FloatTensor([task.tar_feats])), MAX_LEN, max=True)
        return ''.join(s_model.unencode(t) for t in msg[1:-1])
    
    prim = []
    comp = []
    
    line = ['']
    for v2 in range(QATask.N_VALUES):
        task = QATask.create_with([(1, v2)], [])
        msg = get_msg(task)
        prim.append(('1:%d' % v2, msg))
        line.append('%s:   ' % v2 + msg)
    print(''.join('%-20s' % l for l in line))
        
    for v1 in range(QATask.N_VALUES):
        line = []
        task = QATask.create_with([(0, v1)], [])
        msg = get_msg(task)
        prim.append(('0:%d' % v1, msg))
        line.append('%s: ' % v1 + msg)
        for v2 in range(QATask.N_VALUES):
            task = QATask.create_with([(0, v1), (1, v2)], [])
            msg = get_msg(task)
            comp.append((('0:%d' % v1, '1:%d' % v2), msg))
            line.append('%s,%s: ' % (v1, v2) + msg)
        print(''.join('%-20s' % l for l in line))
        
    e_prim, r_prim = zip(*prim)
    e_comp, r_comp = zip(*comp)
    return np.mean(evals.comp_eval(r_prim, e_prim, r_comp, e_comp, lambda x, y: x + y, str_dist))

for restart in range(10):
    s_model = SModel(VOCAB_SIZE)
    l_model = LModel(s_model)

    s_opt = optim.Adam(s_model.parameters(), lr=1e-3)
    l_opt = optim.Adam(l_model.parameters(), lr=1e-3)

    for i_epoch in range(10):
        loss = 0
        acc = 0
        #print('L' if i_epoch % 2 == 0 else 'S')
        for i in range(50):
            tasks = [QATask.sample() for _ in range(N_BATCH)]
            tar_feats = Variable(torch.FloatTensor([t.tar_feats for t in tasks]))
            dis_feats = Variable(torch.FloatTensor([t.dis_feats for t in tasks]))

            msgs = s_model.sample(tar_feats, MAX_LEN)

            msg_feats = np.zeros((MAX_LEN, N_BATCH, len(s_model.vocab)))
            msg_targets = np.zeros((MAX_LEN, N_BATCH))
            msg_mask = np.zeros((MAX_LEN, N_BATCH))
            for i_msg, msg in enumerate(msgs):
                for t in range(len(msg)-1):
                    msg_feats[t, i_msg, msg[t]] = 1
                    msg_targets[t, i_msg] = msg[t+1]
                    msg_mask[t, i_msg] = 1
            msg_feats = Variable(torch.FloatTensor(msg_feats))
            msg_targets = Variable(torch.LongTensor(msg_targets))
            msg_mask = Variable(torch.FloatTensor(msg_mask))

            l_nll, l_acc = l_model(tar_feats, dis_feats, msg_feats)
            l_loss = l_nll.mean()

            l_opt.zero_grad()
            l_loss.backward()
            nn.utils.clip_grad_norm(l_model.parameters(), 1) 
            l_opt.step()
            loss += unwrap(l_nll).mean()

            dl_nll = l_nll.detach()
            reward = -(
                dl_nll - dl_nll.mean() 
                + COMM_COST * Variable(torch.FloatTensor([len(msg)-1 for msg in msgs]))
                + START_COST * Variable(torch.FloatTensor([sum(1 for t in msg[1:] if t == 0) for msg in msgs]))
                )
            s_nll, s_ent = s_model(tar_feats, msg_feats, msg_targets, msg_mask)
            s_surr = s_nll * reward - ENT_BONUS * s_ent
            s_surr_loss = s_surr.mean()
            s_opt.zero_grad()
            s_surr_loss.backward()
            nn.utils.clip_grad_norm(s_model.parameters(), 1) 
            s_opt.step()
            loss += unwrap(s_surr).mean()

            acc += l_acc.mean()
        comp = validate(s_model, l_model)
        print('%0.3f %0.3f %0.3f' % (loss / 50, acc / 50, comp))
        print()

                    0:                  1:                  2:                  3:                  
0: kkkk             0,0:                0,1: kkkkkkk        0,2: kkkkkkk        0,3:                
1:                  1,0:                1,1:                1,2:                1,3:                
2:                  2,0:                2,1:                2,2:                2,3:                
3:                  3,0:                3,1:                3,2:                3,3:                
-0.764 0.515 0.000

                    0:                  1:   ipppppp        2:   kkkkkkk        3:   ppppppp        
0: kkkpppp          0,0: eeepfff        0,1: iiiiiii        0,2: kkkkkkk        0,3: kpppppp        
1:                  1,0:                1,1:                1,2:                1,3:                
2:                  2,0:                2,1: g              2,2: kkkkkkk        2,3:                
3: ppppppp          3,0: ppppppp        3,1: ppppppp        3,2: kkkkkp

KeyboardInterrupt: 