In [1]:
import collections

import numpy as np
from chainer.datasets import TupleDataset


class DatsetParseError(Exception):
    pass


def _parse_single(line):
    data = line.strip().split()[:48]
    rank = int(data[0])
    qid = int(data[1][4:])
    val = map(lambda x: float(x.split(':')[1]), data[2:])
    return qid, rank, val


def create_dataset(path, size=-1):
    """
    Create dataset from MQ2007 data.
    .. warning:: It will create dataset with label in range [0, 1, 2]
        It should be no problem for Permutation Probability Loss
        but do not plug in other loss function.
    """
    data = collections.defaultdict(lambda: [[], []])
    with open(path, mode='r') as fin:
        # Data has one json per line
        for i, line in enumerate(fin):
            q, r, v = _parse_single(line)
            if r not in {0, 1, 2}:
                raise DatasetParseError(
                    "L%d: Score must be 0, 1 or 2, but found %d" %
                    (i, r)
                )
            data[q][0].append(r)
            data[q][1].append(v)
    vectors = []
    scores = []
    for d in data.values():
        v = np.array(d[1], dtype=np.float32)
        s = np.array(d[0], dtype=np.float32)
        vectors.append(v)
        scores.append(s)
    s = max(map(len, scores))
    vectors_pad = np.zeros((len(vectors), s, v.shape[-1]), dtype=np.float32)
    scores_pad = np.zeros((len(scores), s), dtype=np.float32)
    length = np.empty((len(scores)), dtype=np.int32)
    for i, (s, v) in enumerate(zip(scores, vectors)):
        vectors_pad[i, :len(v), :] = v
        scores_pad[i, :len(s)] = s
        length[i] = len(v)

    if size > 0:
        # Sample data AFTER all data has been loaded. This is because
        # There might be bias in data ordering.
        ind = np.random.permutation(len(vectors))[:size]
        return TupleDataset(vectors_pad[ind], scores_pad[ind], length[ind])
    else:
        return TupleDataset(vectors_pad, scores_pad, length)

In [2]:
from __future__ import absolute_import, division, print_function

import chainer
import chainer.functions as F
import chainer.links as L
import chainer.initializers as I
import numpy as np


class ListNet(chainer.Chain):
    def __init__(self, input_size, n_units, dropout):
        super(ListNet, self).__init__(
            l1=L.Linear(input_size, n_units, initialW=I.GlorotUniform()),
            l2=L.Linear(n_units, n_units, initialW=I.GlorotUniform()),
            l3=L.Linear(n_units, 1, initialW=I.GlorotUniform(),
                        nobias=True)
        )
        self.add_persistent("_dropout", dropout)

    def __call__(self, x, train=True):
        s = list(x.shape)
        n_tokens = np.prod(s[:-1])
        x = F.reshape(x, (n_tokens, -1))
        if self._dropout > 0.:
            x = F.dropout(x, self._dropout, train=train)
        o_1 = F.relu(self.l1(x))

        if self._dropout > 0.:
            o_1 = F.dropout(o_1, self._dropout, train=train)
        o_2 = F.relu(self.l2(o_1))

        if self._dropout > 0.:
            o_2 = F.dropout(o_2, self._dropout, train=train)
        o_3 = self.l3(o_2)

        return F.reshape(o_3, s[:-1])

In [3]:
import logging

import chainer
import chainer.functions as F
import numpy as np


def mean_average_precision(probs, labels, length, th):
    """
    Calculate mean average precision.
    Label 1 and 2 are regarded "correct" as in the original evaluation
    script https://onedrive.live.com/?authkey=%21ACnoZZSZVfHPJd0&cid=8FEADC23D838BDA8&id=8FEADC23D838BDA8%21122&parId=8FEADC23D838BDA8%21107&o=OneUp
    Args:
        probs (numpy.ndarray): list of lists of probability predictions.
            [[0.1, 0.8, 0.3, 0.1, 0.6, 0.6, 0.2], [0., ...]...]
        labels (numpy.ndarray): list of lists of ground-truth labels.
            Each value must be 0, 1 or 2
        order (str): {'descending' or 'accending'}
    Return:
        float
    """
    num_queries = len(probs)
    out = 0.0

    for i in xrange(len(probs)):
        r = probs[i][:length[i]].argsort()
        r = r[::-1]
        candidates = labels[i, r]
        avg_prec = 0.
        precisions = []
        num_correct = 0.
        for i in xrange(min(th, len(candidates))):
            if candidates[i] >= 1:
                num_correct += 1
                precisions.append(num_correct / (i + 1))

        if precisions:
            avg_prec = sum(precisions) / len(precisions)

            out += avg_prec
    return out / float(num_queries)


def logsumexp(x, mask, zero_pad, axis):
    x_exp = F.where(mask, F.exp(x), zero_pad)
    return F.log(F.sum(x_exp, axis=axis))


def logsoftmax_no_mask(x, mask, zero_pad, axis):
    #x = x - F.broadcast_to(F.max(x, keepdims=True), x.shape)
    x_logsumexp = logsumexp(x, mask, zero_pad, axis)

    # log_p: (batch size, n)
    return x - F.broadcast_to(F.expand_dims(x_logsumexp, 1), x.shape)


def logsoftmax(x, mask, zero_pad, axis):
    return F.where(mask, logsoftmax_no_mask(x, mask, zero_pad, axis), zero_pad)


def softmax(x, mask, zero_pad, axis):
    x_explogsoftmax = F.exp(logsoftmax_no_mask(x, mask, zero_pad, axis))
    return F.where(mask, x_explogsoftmax, zero_pad)



def permutation_probability_loss(x, t, length):
    """Calculate permutation probability distributions (k=1) and measure the
    cross entropy over the two distributions.
    Args:
        x (Variable): Variable holding a 2d array whose element
            indicates unnormalized log probability: the first axis of the
            variable represents the number of samples, and the second axis
            represents the number data in a query.
        t (Variable): Variable holding a 2d float32 vector of ground truth
            scores. Must be in same size as x.
    Returns:
        Variable: A variable holding a scalar array of the loss.
    """
    length = length.reshape(-1, 1)
    mask = np.tile(np.arange(x.shape[1]).reshape(1, -1), (x.shape[0],  1)) < length
    mask = chainer.Variable(mask)
    padding = chainer.Variable(np.zeros(x.shape, dtype=x.dtype))

    # log_p: (batch size, n)
    log_p_x = logsoftmax(x, mask, padding, axis=1)
    # p_t: (batch size, n)
    log_p_t = logsoftmax(t, mask, padding, axis=1)

    # loss normalized over all instances
    loss = F.exp(log_p_t) * log_p_t - F.exp(log_p_t) * log_p_x

    return F.sum(loss) / float(x.shape[0])


def clip_data(x, l):
    return x[:, :max(l)]


def _run_batch(model, optimizer, batch, device, train):
    assert train == (optimizer is not None)
    model.cleargrads()

    x, t, l = chainer.dataset.concat_examples(batch, device=device)
    x = clip_data(x, l)
    t = clip_data(t, l)

    y = model(chainer.Variable(x), train=train)
    loss = permutation_probability_loss(y, chainer.Variable(t), l)
    acc = mean_average_precision(y.data, t, l, 100000)
    if optimizer is not None:
        loss.backward()
        optimizer.update()
    return float(loss.data), acc


def forward_pred(model, dataset, device=None):
    loss = 0.
    acc = 0.
    iterator = chainer.iterators.SerialIterator(dataset, batch_size=4,
                                                repeat=False, shuffle=False)
    for batch in iterator:
        l, a = _run_batch(model, None, batch, device, False)
        loss += l * len(batch)
        acc += a * len(batch)
    return loss / float(len(dataset)), acc / float(len(dataset))


def train(model, optimizer, train_itr, n_epoch, dev=None, device=None,
          tmp_dir='tmp.model', lr_decay=0.995):
    loss = 0.
    acc = 0.
    min_loss = float('inf')
    min_epoch = 0
    report_tmpl = "[{:>3d}] T/loss={:0.6f} T/acc={:0.6f} D/loss={:0.6f} D/acc={:0.6f} lr={:0.6f}"
    for batch in train_itr:
        if train_itr.is_new_epoch:
            # this is not executed at first epoch
            loss_dev, acc_dev = forward_pred(model, dev, device=device)
            loss = loss / len(train_itr.dataset)
            acc = acc / len(train_itr.dataset)
            logging.info(report_tmpl.format(
                train_itr.epoch - 1, loss, acc, loss_dev, acc_dev, optimizer.alpha))
            if loss_dev < min_loss:
                min_loss = loss_dev
                min_epoch = train_itr.epoch - 1
                chainer.serializers.save_npz(tmp_dir, model)

            loss = 0.
            acc = 0.
            optimizer.alpha *= lr_decay
        if train_itr.epoch == n_epoch:
            break
        l, a = _run_batch(model, optimizer, batch, device, True)
        loss += l * len(batch)
        acc += a * len(batch)

    logging.info('loading early stopped-model at epoch {}'.format(min_epoch))
    chainer.serializers.load_npz(tmp_dir, model)

In [None]:
import argparse
import logging

import chainer

#from listnet import dataset, training
#from listnet.model import ListNet

logging.basicConfig(level=logging.INFO)

def run(args):
    logging.info("Loading dataset")

    trains = create_dataset(args.train)
    logging.info("loaded {} sets for training".format(len(trains)))

    dev = create_dataset(args.dev)
    logging.info("loaded {} sets for dev".format(len(dev)))

    test = create_dataset(args.test)
    logging.info("loaded {} sets for test".format(len(test)))

    listnet = ListNet(trains[0][0].shape[1], 200, 0.0)
    optimizer = chainer.optimizers.Adam(alpha=0.0007)
    optimizer.setup(listnet)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))
    #optimizer.add_hook(chainer.optimizer.GradientClipping(5.))

    train_itr = chainer.iterators.SerialIterator(trains, batch_size=1)
    train(listnet, optimizer, train_itr, 1000, dev=dev,
                   device=None)
    loss, acc = forward_pred(listnet, test, device=None)
    logging.info("Test => loss={:0.6f} acc={:0.6f}".format(loss, acc))

if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--train', required=True, type=str,
                   help='SNLI train json file path')
    p.add_argument('--dev', required=True, type=str,
                   help='SNLI dev json file path')
    p.add_argument('--test', required=True, type=str,
                   help='SNLI test json file path')

    # optional
    p.add_argument('-g', '--gpu', type=int, default=None, help="GPU number")
    args = p.parse_args(["--train", "MQ2008/Fold1/train.txt",
                         "--dev", "MQ2008/Fold1/vali.txt",
                         "--test", "MQ2008/Fold1/test.txt"])

    run(args)

INFO:root:Loading dataset
INFO:root:loaded 471 sets for training
INFO:root:loaded 157 sets for dev
INFO:root:loaded 156 sets for test
INFO:root:[  0] T/loss=0.142821 T/acc=0.451238 D/loss=0.148315 D/acc=0.487140 lr=0.000700
INFO:root:[  1] T/loss=0.137200 T/acc=0.468169 D/loss=0.147032 D/acc=0.496971 lr=0.000696
INFO:root:[  2] T/loss=0.136525 T/acc=0.475992 D/loss=0.146266 D/acc=0.506027 lr=0.000693
INFO:root:[  3] T/loss=0.135640 T/acc=0.477695 D/loss=0.147637 D/acc=0.515359 lr=0.000690
INFO:root:[  4] T/loss=0.135373 T/acc=0.482647 D/loss=0.148527 D/acc=0.522586 lr=0.000686
INFO:root:[  5] T/loss=0.134315 T/acc=0.485358 D/loss=0.143353 D/acc=0.516742 lr=0.000683
INFO:root:[  6] T/loss=0.132754 T/acc=0.480707 D/loss=0.149556 D/acc=0.501629 lr=0.000679
INFO:root:[  7] T/loss=0.132707 T/acc=0.484308 D/loss=0.146763 D/acc=0.493562 lr=0.000676
INFO:root:[  8] T/loss=0.133133 T/acc=0.483089 D/loss=0.147581 D/acc=0.508903 lr=0.000672
INFO:root:[  9] T/loss=0.132194 T/acc=0.496119 D/loss=0.

INFO:root:[ 90] T/loss=0.090060 T/acc=0.558016 D/loss=0.173404 D/acc=0.484554 lr=0.000446
INFO:root:[ 91] T/loss=0.090318 T/acc=0.552522 D/loss=0.163917 D/acc=0.483026 lr=0.000444
INFO:root:[ 92] T/loss=0.089397 T/acc=0.554545 D/loss=0.164941 D/acc=0.494598 lr=0.000441
INFO:root:[ 93] T/loss=0.089295 T/acc=0.558467 D/loss=0.164723 D/acc=0.479170 lr=0.000439
INFO:root:[ 94] T/loss=0.089214 T/acc=0.558295 D/loss=0.167384 D/acc=0.484254 lr=0.000437
INFO:root:[ 95] T/loss=0.087805 T/acc=0.567494 D/loss=0.163780 D/acc=0.492810 lr=0.000435
INFO:root:[ 96] T/loss=0.088584 T/acc=0.561060 D/loss=0.178974 D/acc=0.493628 lr=0.000433
INFO:root:[ 97] T/loss=0.087695 T/acc=0.557601 D/loss=0.173326 D/acc=0.499136 lr=0.000430
INFO:root:[ 98] T/loss=0.087851 T/acc=0.560427 D/loss=0.167595 D/acc=0.489561 lr=0.000428
INFO:root:[ 99] T/loss=0.088234 T/acc=0.560517 D/loss=0.161404 D/acc=0.509424 lr=0.000426
INFO:root:[100] T/loss=0.086817 T/acc=0.565704 D/loss=0.163929 D/acc=0.482426 lr=0.000424
INFO:root:

INFO:root:[182] T/loss=0.070842 T/acc=0.585344 D/loss=0.186485 D/acc=0.474538 lr=0.000281
INFO:root:[183] T/loss=0.071069 T/acc=0.587732 D/loss=0.179803 D/acc=0.474761 lr=0.000280
INFO:root:[184] T/loss=0.071089 T/acc=0.587563 D/loss=0.179676 D/acc=0.471400 lr=0.000278
INFO:root:[185] T/loss=0.070681 T/acc=0.582908 D/loss=0.184562 D/acc=0.478597 lr=0.000277
INFO:root:[186] T/loss=0.070313 T/acc=0.586684 D/loss=0.185755 D/acc=0.476697 lr=0.000276
INFO:root:[187] T/loss=0.071008 T/acc=0.587157 D/loss=0.181622 D/acc=0.479383 lr=0.000274
INFO:root:[188] T/loss=0.070517 T/acc=0.587578 D/loss=0.183738 D/acc=0.481029 lr=0.000273
INFO:root:[189] T/loss=0.070484 T/acc=0.585657 D/loss=0.186160 D/acc=0.477504 lr=0.000271
INFO:root:[190] T/loss=0.070215 T/acc=0.589141 D/loss=0.180322 D/acc=0.484693 lr=0.000270
INFO:root:[191] T/loss=0.070575 T/acc=0.588895 D/loss=0.186704 D/acc=0.475733 lr=0.000269
INFO:root:[192] T/loss=0.069568 T/acc=0.585129 D/loss=0.177348 D/acc=0.475296 lr=0.000267
INFO:root: