In [1]:
%load_ext autoreload
%autoreload 2

from cls2_data.data import Dataset
import evals
from util import Logger

import editdistance
import numpy as np
import sexpdata
import torch
from torch import nn, optim
from torch.optim import lr_scheduler as opt_sched
from torch.autograd import Variable

In [2]:
N_EMBED = 128
N_HIDDEN = 256
N_BATCH = 128

In [3]:
def unwrap(var):
    return var.data.numpy()

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self._conv_part = nn.Sequential(
            nn.Conv2d(3, 6, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self._fc_part = nn.Sequential(
            nn.Linear(16*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self._pred_part = nn.Sequential(
            nn.Linear(64, 64),
            nn.Tanh()
        )
        self._loss = nn.BCEWithLogitsLoss()
    
    def forward(self, batch):
        n_batch, n_ex, c, w, h = batch.feats_in.shape
        conv_in = self._conv_part(batch.feats_in.view(n_batch * n_ex, c, w, h))
        fc_in = self._fc_part(conv_in.view(n_batch * n_ex, 16*5*5))
        #fc_in += 0 * Variable(torch.randn(fc_in.shape))
        predictor = self._pred_part(fc_in.view(n_batch, n_ex, 64).sum(dim=1))
        #predictor += 0.0 * Variable(torch.randn(predictor.shape))
        
        conv_out = self._conv_part(batch.feats_out)
        rep_out = self._fc_part(conv_out.view(n_batch, 16*5*5))
        #print(predictor[0])
        #print(rep_out[0])
        
        score = (predictor * rep_out).sum(dim=1)
        #print(score)
        labels = (score > 0).float()
        loss = self._loss(score, batch.label_out)
        
        return loss, (labels == batch.label_out).float().mean(), labels, predictor

In [11]:
def eval_isom_tree(reps, lfs):
    return evals.isom(reps, lfs, evals.cos_dist, evals.tree_dist)

def eval_isom_ext(reps, exts):
    return evals.isom(reps, exts, evals.cos_dist, evals.l1_dist)

def eval_hom(model, reps1, reps2):
    #sim_rep = lambda x, y: ((x-y)**2).sum()
    return evals.hom(reps1, reps2, None)

def info(reps):
    buckets = np.zeros((64, 10))
    for rep in reps:
        for i in range(len(rep)):
            bucket = 5 + int(rep[i] * 10)
            bucket = max(bucket, 0)
            bucket = min(bucket, 9)
            buckets[i, bucket] += 1
    buckets += 1e-7
    probs = buckets / buckets.sum(axis=1, keepdims=True)
    logprobs = np.log(probs)
    entropies = -(probs * logprobs).sum(axis=1)
    return entropies.mean()

EPOCH = 'epoch'
TRN_LOSS = 'trn loss'
TRN_ACC = 'trn acc'
VAL_ACC = 'val acc'
CVAL_ACC = 'cval acc'
INFO_TX = 'I(T;X)'
ISOM_TREE = 'isom (r-t)'
ISOM_EXT = 'isom (r-e)'
ISOM_CHK = 'isom (t-e)'
HOM = 'hom'
CHOM = 'c_hom'
LOG_KEYS = [EPOCH, TRN_LOSS, TRN_ACC, VAL_ACC, CVAL_ACC, HOM,   CHOM,  INFO_TX]
LOG_FMTS = ['d',   '.3f',    '.3f',   '.3f',   '.3f',    '.3f', '.3f', '.3f']

def validate(dataset, model, logger):
    val_batch = dataset.get_val_batch()
    _, val_acc, _, val_reps = model(val_batch)
    val_acc, = unwrap(val_acc)
    logger.update(VAL_ACC, val_acc)
    
    cval_batch = dataset.get_cval_batch()
    _, cval_acc, _, cval_reps = model(cval_batch)
    cval_acc, = unwrap(cval_acc)
    logger.update(CVAL_ACC, cval_acc)
    
    prim_batch = dataset.get_prim_batch()
    _, _, _, prim_reps = model(prim_batch)
    
    comp = evals.comp_eval(
        prim_reps.data.numpy(), prim_batch.lf, 
        val_reps.data.numpy(), val_batch.lf, 
        lambda x, y: (x+y), 
        evals.cos_dist)
    logger.update(HOM, np.mean(comp))
    
    ccomp = evals.comp_eval(
        prim_reps.data.numpy(), prim_batch.lf,
        cval_reps.data.numpy(), cval_batch.lf,
        lambda x, y: (x+y),
        evals.cos_dist)
    logger.update(CHOM, np.mean(ccomp))
    
    logger.update(INFO_TX, info(val_reps))
    
    return val_acc

def train(dataset, model):
    opt = optim.Adam(model.parameters(), lr=1e-3)
    sched = opt_sched.ReduceLROnPlateau(opt, factor=0.5, verbose=True, mode='max')
    logger = Logger(LOG_KEYS, LOG_FMTS, width=10)
    logger.begin()
    
    val_acc = validate(dataset, model, logger)
    logger.print()
    
    for i in range(100):
        trn_loss = 0
        trn_acc = 0
        for j in range(100):
            batch = dataset.get_train_batch(N_BATCH)
            loss, acc, _, _ = model(batch)
            opt.zero_grad()
            loss.backward()
            opt.step()
            trn_loss += unwrap(loss)[0]
            trn_acc += unwrap(acc)[0]
        trn_loss /= 100
        trn_acc /= 100
        
        logger.update(EPOCH, i)
        logger.update(TRN_LOSS, trn_loss)
        logger.update(TRN_ACC, trn_acc)
        val_acc = validate(dataset, model, logger)
        sched.step(val_acc)
        logger.print()

In [None]:
#dataset = Dataset()
model = Model(dataset)
train(dataset, model)

|      epoch |   trn loss |    trn acc |    val acc |   cval acc |        hom |      c_hom |     I(T;X) |
|            |            |            |      0.480 |      0.533 |      0.029 |      0.023 |      0.263 |
|          0 |      0.663 |      0.592 |      0.672 |      0.709 |      0.355 |      0.276 |      1.398 |
|          1 |      0.644 |      0.639 |      0.682 |      0.713 |      0.357 |      0.293 |      1.415 |
|          2 |      0.602 |      0.681 |      0.726 |      0.711 |      0.449 |      0.418 |      1.681 |
|          3 |      0.590 |      0.689 |      0.704 |      0.669 |      0.410 |      0.406 |      1.666 |
|          4 |      0.586 |      0.691 |      0.718 |      0.699 |      0.419 |      0.382 |      1.780 |
|          5 |      0.569 |      0.705 |      0.704 |      0.722 |      0.433 |      0.372 |      1.775 |
|          6 |      0.561 |      0.709 |      0.720 |      0.720 |      0.402 |      0.358 |      1.854 |
|          7 |      0.560 |      0.716 |      