In [1]:
%load_ext autoreload
%autoreload 2

from rep_data.data import Dataset
import evals
from util import Logger

import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler as opt_sched

In [2]:
N_BATCH = 256
N_EMBED = 512

In [3]:
def unwrap(var):
    return var.data.numpy()

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self._emb = nn.Embedding(dataset.n_vocab, N_EMBED)
        self._pred = nn.Linear(N_EMBED, dataset.n_vocab)
        self._loss = nn.CrossEntropyLoss()
    
    def forward(self, batch):
        emb = self._emb(batch.ctx)
        hid = emb.sum(dim=1)
        logits = self._pred(hid)
        loss = self._loss(logits, batch.tgt)
        return loss
    
    def represent(self, indices):
        return self._pred.weight[indices, :]

In [10]:
EPOCH = 'epoch'
TRN_LOSS = 'trn loss'
ISOM = 'isom'
HOM = 'hom'
LOG_KEYS = [EPOCH, TRN_LOSS, ISOM,  HOM]
LOG_FMTS = ['d',   '.3f',    '.3f', '.3f']
        
def validate(dataset, model, logger):
    comp_batch = dataset.get_comp_batch()
    reps_uni1 = unwrap(model.represent(comp_batch.uni1))
    reps_uni2 = unwrap(model.represent(comp_batch.uni2))
    reps_uni = np.concatenate((reps_uni1, reps_uni2))
    reps_bi = unwrap(model.represent(comp_batch.bi))
    exprs_uni = comp_batch.uni1 + comp_batch.uni2
    exprs_bi = list(zip(comp_batch.uni1, comp_batch.uni2))
    comp = evals.comp_eval(reps_uni, exprs_uni, reps_bi, exprs_bi, lambda x, y: x + y, evals.cos_dist)
    cstr = ['%.3f' % n for n in sorted(comp)]
    logger.update(HOM, np.mean(comp))

def train(dataset, model):
    opt = optim.Adam(model.parameters(), lr=1e-3)
    sched = opt_sched.ReduceLROnPlateau(opt, factor=0.5, verbose=True, mode='max')
    logger = Logger(LOG_KEYS, LOG_FMTS)
    logger.begin()
    validate(dataset, model, logger)
    logger.print()
    
    for i in range(10):
        trn_loss = 0
        for j in range(200):
            batch = dataset.get_batch(N_BATCH)
            loss = model(batch)
            opt.zero_grad()
            loss.backward()
            opt.step()
            trn_loss += unwrap(loss)[0]
        trn_loss /= 100
        
        logger.update(EPOCH, i)
        logger.update(TRN_LOSS, trn_loss)
        validate(dataset, model, logger)
        #sched.step(val_acc)
        logger.print()

In [12]:
for ctx in [1, 3, 5, 7]:
    print('CTX %d' % ctx)
    dataset = Dataset(ctx)
    model = Model(dataset)
    train(dataset, model)

CTX 1
|        epoch |     trn loss |         isom |          hom |
|              |              |              |        1.004 |
|            0 |       14.575 |              |        0.908 |
|            1 |       13.307 |              |        0.884 |
|            2 |       12.796 |              |        0.869 |
|            3 |       12.534 |              |        0.851 |
|            4 |       12.291 |              |        0.840 |
|            5 |       12.049 |              |        0.829 |
|            6 |       11.891 |              |        0.821 |
|            7 |       11.742 |              |        0.812 |
|            8 |       11.656 |              |        0.807 |
|            9 |       11.521 |              |        0.803 |
CTX 3
|        epoch |     trn loss |         isom |          hom |
|              |              |              |        1.003 |
|            0 |       16.715 |              |        0.915 |
|            1 |       14.676 |              |        0.86