In [1]:
%load_ext autoreload
%autoreload 2

from data import Dataset
import evals

import editdistance
import numpy as np
import sexpdata
import torch
from torch import nn, optim
from torch.optim import lr_scheduler as opt_sched
from torch.autograd import Variable
import zss

In [2]:
N_EMBED = 64
N_HIDDEN = 256
N_BATCH = 64
N_MTRAIN = 16

In [3]:
def unwrap(var):
    return var.data.numpy()

In [9]:
class Model(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self._emb_feats = nn.Linear(dataset.n_features, N_EMBED)
        self._emb_label = nn.Linear(2, N_EMBED)
        self._make_pred_1 = nn.Sequential(
            nn.Linear(N_EMBED, N_HIDDEN),
            nn.ReLU())
        self._make_pred_2 = nn.Sequential(
            nn.Linear(N_HIDDEN, N_EMBED))
        #self._pred_rnn = nn.GRU(input_size=N_EMBED, hidden_size=N_HIDDEN, num_layers=1, batch_first=True)
        #self._make_pred = nn.Linear(N_HIDDEN, N_EMBED)
        self._loss = nn.BCEWithLogitsLoss()
    
    def forward(self, batch):
        emb_mtrain_feats = self._emb_feats(batch.mtrain_feats)
        emb_mtrain_labels = self._emb_label(batch.mtrain_labels)
        emb_mtrain = emb_mtrain_feats * emb_mtrain_labels
        
        emb_pred_feats = self._emb_feats(batch.mpred_feats)
        
        predictor = self._make_pred_1(emb_mtrain.mean(dim=1))
        predictor += 1 * Variable(torch.randn(*predictor.shape))
        predictor = self._make_pred_2(predictor)
        exp_predictor = predictor.unsqueeze(1).expand_as(emb_pred_feats)
        #_, pred_hidden = self._pred_rnn(emb_mtrain)
        #predictor = self._make_pred(pred_hidden.squeeze(0))
        
        pred_logits = (emb_pred_feats * exp_predictor).sum(dim=2)
        pred_labels = (pred_logits > 0).long()
        loss = self._loss(pred_logits, batch.mpred_labels)
        return loss, (pred_labels == batch.mpred_labels.long()).float().mean(), pred_labels, predictor

In [18]:
_tree_distance_cache = {}
def _tree_distance(x, y):
    key = (sexpdata.dumps(x), sexpdata.dumps(y))
    if key not in _tree_distance_cache:
        dist = zss.simple_distance(
            x,
            y,
            lambda x: x[1:] if isinstance(x, list) else [],
            lambda x: x[0].value() if isinstance(x, list) else x.value(),
            lambda x, y: 0 if x == y else 1)
        _tree_distance_cache[key] = dist
    return _tree_distance_cache[key]

def eval_metric(model, reps, lfs):
    lfs = [sexpdata.loads(l) for l in lfs]
    sim_rep = lambda x, y: ((x-y)**2).sum()
    sim_lf = _tree_distance
    return evals.metric(reps, lfs, sim_rep, sim_lf)

def eval_comp(model, reps1, reps2):
    #sim_rep = lambda x, y: ((x-y)**2).sum()
    return evals.comp(reps1, reps2, None)

class Logger(object):
    EPOCH = 'epoch'
    TRN_LOSS = 'trn loss'
    TRN_ACC = 'trn acc'
    VAL_ACC = 'val acc'
    METRIC_LABEL = 'metric (lab)'
    METRIC_INTENT = 'metric (int)'
    COMP_INTENT = 'comp (int)'
    KEYS = [EPOCH, TRN_LOSS, TRN_ACC, VAL_ACC, METRIC_LABEL, METRIC_INTENT, COMP_INTENT]
    
    def __init__(self):
        self.data = {}
    
    def begin(self):
        print('| ' + ' | '.join('%12s' % k for k in self.KEYS) + ' |')
        
    def update(self, key, value):
        assert key not in self.data
        self.data[key] = value
        
    def print(self):
        print('| ' + ' | '.join('%12.3f' % self.data[k] for k in self.KEYS) + ' |')
        self.data.clear()

def validate(dataset, model, logger):
    val_batch = dataset.get_val_batch(N_MTRAIN)
    _, val_acc, val_preds, val_reps = model(val_batch)
    val_acc = unwrap(val_acc)[0]
    true_lfs = [
        dataset.name(val_batch.mids[m], val_batch.indices[m][1], unwrap(val_preds[m, :]))
        for m in range(val_preds.shape[0])]
    
    logger.update(logger.VAL_ACC, val_acc)
    logger.update(logger.METRIC_LABEL, eval_metric(model, unwrap(val_reps), val_batch.lfs))
    logger.update(logger.METRIC_INTENT, eval_metric(model, unwrap(val_reps), true_lfs))
    
    full_batch = dataset.get_full_batch(N_MTRAIN)
    _, _, full_preds, full_reps = model(full_batch)
    full_lfs = [
        dataset.name(full_batch.mids[m], full_batch.indices[m][1], unwrap(full_preds[m, :]))
        for m in range(full_preds.shape[0])]
    
    pairs = []
    for i, lf1 in enumerate(full_lfs):
        for j, lf2 in enumerate(full_lfs):
            if lf1 == '(not %s)' % lf2:
                pairs.append((i, j))
                break
    reps1 = np.asarray([unwrap(full_reps[i, :]) for i, _ in pairs])
    reps2 = np.asarray([unwrap(full_reps[j, :]) for _, j in pairs])
    logger.update(logger.COMP_INTENT, eval_comp(model, reps1, reps2))
    
    return val_acc

def train(dataset, model):
    opt = optim.Adam(model.parameters(), lr=1e-2)
    sched = opt_sched.ReduceLROnPlateau(opt, factor=0.5, verbose=True, mode='max')
    logger = Logger()
    logger.begin()
    
    for i in range(100):
        trn_loss = 0
        trn_acc = 0
        for j in range(100):
            batch = dataset.get_train_batch(N_BATCH, N_MTRAIN)
            loss, acc, _, _ = model(batch)
            opt.zero_grad()
            loss.backward()
            opt.step()
            trn_loss += unwrap(loss)[0]
            trn_acc += unwrap(acc)[0]
        trn_loss /= 100
        trn_acc /= 100
        
        logger.update(logger.EPOCH, i)
        logger.update(logger.TRN_LOSS, trn_loss)
        logger.update(logger.TRN_ACC, trn_acc)
        val_acc = validate(dataset, model, logger)
        sched.step(val_acc)
        logger.print()

In [19]:
dataset = Dataset()
model = Model(dataset)
train(dataset, model)

|        epoch |     trn loss |      trn acc |      val acc | metric (lab) | metric (int) |   comp (int) |
|        0.000 |        0.790 |        0.576 |        0.659 |        0.089 |        0.117 |       83.876 |
|        1.000 |        0.633 |        0.655 |        0.660 |        0.092 |        0.110 |      124.672 |
|        2.000 |        0.625 |        0.663 |        0.661 |        0.098 |        0.104 |      118.366 |
|        3.000 |        0.622 |        0.665 |        0.682 |        0.143 |        0.162 |      179.449 |
|        4.000 |        0.625 |        0.664 |        0.674 |        0.133 |        0.116 |       92.614 |
|        5.000 |        0.623 |        0.669 |        0.691 |        0.109 |        0.126 |       20.128 |
|        6.000 |        0.598 |        0.696 |        0.694 |        0.114 |        0.128 |      115.648 |
|        7.000 |        0.567 |        0.723 |        0.748 |        0.097 |        0.211 |       64.521 |
|        8.000 |        0.529 |      

KeyboardInterrupt: 