In [1]:
%load_ext autoreload
%autoreload 2

from cls_data.data import Dataset
import evals
from util import Logger

import editdistance
import numpy as np
import sexpdata
import torch
from torch import nn, optim
from torch.optim import lr_scheduler as opt_sched
from torch.autograd import Variable

In [2]:
N_EMBED = 128
N_HIDDEN = 256
N_BATCH = 128
N_MTRAIN = 150

In [3]:
def unwrap(var):
    return var.data.numpy()

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self._emb_feats = nn.Linear(dataset.n_features, N_EMBED)
        self._emb_label = nn.Linear(2, N_EMBED)
        self._make_pred_1 = nn.Sequential(
            nn.Linear(N_EMBED, N_HIDDEN),
            nn.ReLU())
        self._make_pred_2 = nn.Sequential(
            #nn.Linear(N_HIDDEN, N_HIDDEN),
            #nn.ReLU(),
            nn.Linear(N_HIDDEN, N_EMBED))
        #self._pred_rnn = nn.GRU(input_size=N_EMBED, hidden_size=N_HIDDEN, num_layers=1, batch_first=True)
        #self._make_pred = nn.Linear(N_HIDDEN, N_EMBED)
        self._pred = nn.Linear(N_EMBED, 1)
        self._loss = nn.BCEWithLogitsLoss()
    
    def forward(self, batch):
        emb_mtrain_feats = self._emb_feats(batch.mtrain_feats)
        emb_mtrain_labels = self._emb_label(batch.mtrain_labels)
        emb_mtrain = emb_mtrain_feats * emb_mtrain_labels
        
        emb_pred_feats = self._emb_feats(batch.mpred_feats)
        
        rep = self._make_pred_1(emb_mtrain.mean(dim=1))
        noised_rep = rep + 1 * Variable(torch.randn(*rep.shape))
        predictor = self._make_pred_2(noised_rep)
        exp_predictor = predictor.unsqueeze(1).expand_as(emb_pred_feats)
        #_, pred_hidden = self._pred_rnn(emb_mtrain)
        #predictor = self._make_pred(pred_hidden.squeeze(0))
        
        #pred_logits = (emb_pred_feats * exp_predictor).sum(dim=2)
        pred_logits = self._pred(emb_pred_feats * exp_predictor).squeeze(2)
        pred_labels = (pred_logits > 0).long()
        loss = self._loss(pred_logits, batch.mpred_labels)
        return loss, (pred_labels == batch.mpred_labels.long()).float().mean(), pred_labels, predictor

In [5]:
def eval_isom_tree(reps, lfs):
    return evals.isom(reps, lfs, evals.cos_dist, evals.tree_dist)

def eval_isom_ext(reps, exts):
    return evals.isom(reps, exts, evals.cos_dist, evals.l1_dist)

def eval_hom(model, reps1, reps2):
    #sim_rep = lambda x, y: ((x-y)**2).sum()
    return evals.hom(reps1, reps2, None)

EPOCH = 'epoch'
TRN_LOSS = 'trn loss'
TRN_ACC = 'trn acc'
VAL_ACC = 'val acc'
ISOM_TREE = 'isom (r-t)'
ISOM_EXT = 'isom (r-e)'
ISOM_CHK = 'isom (t-e)'
HOM = 'hom'
HOM2 = 'hom2'
#METRIC_LABEL = 'metric (lab)'
#METRIC_INTENT = 'metric (int)'
#COMP_INTENT = 'comp (int)'
LOG_KEYS = [EPOCH, TRN_LOSS, TRN_ACC, VAL_ACC, ISOM_TREE, ISOM_EXT, ISOM_CHK, HOM,   HOM2]
LOG_FMTS = ['d',   '.3f',    '.3f',   '.3f',   '.3f',     '.3f',    '.3f',    '.3f', '.3f']

def validate(dataset, model, logger):
    val_batch = dataset.get_val_batch(N_MTRAIN)
    _, val_acc, val_preds, val_reps = model(val_batch)
    val_acc = unwrap(val_acc)[0]
    logger.update(VAL_ACC, val_acc)
    
    named = [
        dataset.name(val_batch.mids[m], val_batch.indices[m][1], unwrap(val_preds[m, :]))
        for m in range(val_preds.shape[0])]
    true_lfs, full_exts = zip(*named)
    
    logger.update(ISOM_TREE, eval_isom_tree(unwrap(val_reps), true_lfs))
    logger.update(ISOM_EXT, eval_isom_ext(unwrap(val_reps), full_exts))
    logger.update(ISOM_CHK, evals.isom(true_lfs, full_exts, evals.tree_dist, evals.l1_dist))
    
    return val_acc
    
def validate_hom(dataset, model, logger):
    hom_batch, groups = dataset.get_hom_batch(N_MTRAIN)
    _, _, hom_preds, hom_reps = model(hom_batch)
    hom_reps = unwrap(hom_reps)
    
    named = [
        dataset.name(hom_batch.mids[m], hom_batch.indices[m][1], unwrap(hom_preds[m, :]))
        for m in range(hom_preds.shape[0])]
    true_lfs, _ = zip(*named)
    good_ids = [i for i in range(len(hom_batch.lfs)) if hom_batch.lfs[i] == true_lfs[i]]
    all_ids = list(range(len(hom_batch.lfs)))
    
    for ids, key in ((good_ids, HOM), (all_ids, HOM2)):
        good_groups = [g for g in groups if all(gg in ids for gg in g)]
        parents = [hom_reps[p] for p, _, _ in good_groups]
        children = [(hom_reps[c1] + hom_reps[c2])/2 for _, c1, c2 in good_groups]
        logger.update(key, evals.hom(parents, children, evals.cos_dist))

def train(dataset, model):
    opt = optim.Adam(model.parameters(), lr=1e-2)
    sched = opt_sched.ReduceLROnPlateau(opt, factor=0.5, verbose=True, mode='max')
    logger = Logger(LOG_KEYS, LOG_FMTS, width=10)
    logger.begin()
    
    val_acc = validate(dataset, model, logger)
    validate_hom(dataset, model, logger)
    logger.print()
    
    for i in range(100):
        trn_loss = 0
        trn_acc = 0
        for j in range(50):
            batch = dataset.get_train_batch(N_BATCH, N_MTRAIN)
            loss, acc, _, _ = model(batch)
            opt.zero_grad()
            loss.backward()
            opt.step()
            trn_loss += unwrap(loss)[0]
            trn_acc += unwrap(acc)[0]
        trn_loss /= 50
        trn_acc /= 50
        
        logger.update(EPOCH, i)
        logger.update(TRN_LOSS, trn_loss)
        logger.update(TRN_ACC, trn_acc)
        val_acc = validate(dataset, model, logger)
        validate_hom(dataset, model, logger)
        sched.step(val_acc)
        logger.print()

In [6]:
dataset = Dataset()
model = Model(dataset)
train(dataset, model)

|      epoch |   trn loss |    trn acc |    val acc | isom (r-t) | isom (r-e) | isom (t-e) |        hom |       hom2 |
|            |            |            |      0.504 |      0.006 |      0.092 |      0.005 |      0.983 |      0.986 |


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


|          0 |      0.535 |      0.685 |      0.698 |      0.149 |      0.910 |      0.133 |        nan |      0.378 |
|          1 |      0.508 |      0.700 |      0.701 |      0.141 |      0.929 |      0.117 |        nan |      0.506 |
|          2 |      0.482 |      0.719 |      0.717 |      0.049 |      0.862 |      0.018 |      0.391 |      0.362 |
|          3 |      0.444 |      0.747 |      0.764 |      0.068 |      0.842 |     -0.028 |      0.464 |      0.451 |
|          4 |      0.410 |      0.777 |      0.779 |      0.038 |      0.837 |     -0.016 |      0.357 |      0.337 |
|          5 |      0.380 |      0.809 |      0.815 |     -0.010 |      0.881 |     -0.046 |      0.278 |      0.322 |
|          6 |      0.352 |      0.831 |      0.836 |     -0.025 |      0.886 |     -0.032 |      0.266 |      0.280 |
|          7 |      0.332 |      0.844 |      0.850 |     -0.011 |      0.889 |     -0.056 |      0.246 |      0.265 |
|          8 |      0.301 |      0.865 |      0.

KeyboardInterrupt: 