In [7]:
%load_ext autoreload
%autoreload 2

from data import Dataset
import evals

import editdistance
import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler as opt_sched

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
N_BATCH = 256
N_EMBED = 512

In [9]:
def unwrap(var):
    return var.data.numpy()

In [10]:
class Model(nn.Module):
    def __init__(self, dataset):
        super().__init__()
        self._emb = nn.Embedding(dataset.n_vocab, N_EMBED)
        self._pred = nn.Linear(N_EMBED, dataset.n_vocab)
        self._loss = nn.CrossEntropyLoss()
    
    def forward(self, batch):
        emb = self._emb(batch.ctx)
        hid = emb.sum(dim=1)
        logits = self._pred(hid)
        loss = self._loss(logits, batch.tgt)
        return loss
    
    def represent(self, indices):
        return self._pred.weight[indices, :]

In [11]:
def l2dist(x, y):
    return ((x-y)**2).sum()

def cos_dist(x, y):
    return 1 - ((x/np.linalg.norm(x) * y/np.linalg.norm(y))).sum()

def eval_metric(reps, words):
    sim_rep = cos_dist
    #sim_lf = editdistance.eval
    sim_word = lambda x, y: 2 - (len(set(x.split('__')) & set(y.split('__'))))
    return evals.metric(words, reps, sim_word, sim_rep)

def eval_hom(preds, trues):
    return evals.hom(preds, trues, cos_dist)

class Logger(object):
    EPOCH = 'epoch'
    TRN_LOSS = 'trn loss'
    METRIC = 'metric'
    HOM = 'hom'
    KEYS = [EPOCH, TRN_LOSS, METRIC, HOM]
    
    def __init__(self):
        self.data = {}
    
    def begin(self):
        print('| ' + ' | '.join('%12s' % k for k in self.KEYS) + ' |')
        
    def update(self, key, value):
        assert key not in self.data
        self.data[key] = value
        
    def print(self):
        print('| ' + ' | '.join('%12.3f' % self.data[k] for k in self.KEYS) + ' |')
        self.data.clear()
        
def validate(dataset, model, logger):
    comp_batch = dataset.get_comp_batch()
    reps_bi = unwrap(model.represent(comp_batch.bi))
    reps_pred = unwrap(model.represent(comp_batch.uni1) + model.represent(comp_batch.uni2))
    words_bi = [dataset.unencode(i) for i in comp_batch.bi]
    logger.update(Logger.METRIC, eval_metric(reps_bi, words_bi))
    logger.update(Logger.HOM, eval_hom(reps_pred, reps_bi))

def train(dataset, model):
    opt = optim.Adam(model.parameters(), lr=1e-3)
    sched = opt_sched.ReduceLROnPlateau(opt, factor=0.5, verbose=True, mode='max')
    logger = Logger()
    logger.begin()
    
    for i in range(100):
        trn_loss = 0
        for j in range(10):
            batch = dataset.get_batch(N_BATCH)
            loss = model(batch)
            opt.zero_grad()
            loss.backward()
            opt.step()
            trn_loss += unwrap(loss)[0]
        trn_loss /= 100
        
        logger.update(logger.EPOCH, i)
        logger.update(logger.TRN_LOSS, trn_loss)
        validate(dataset, model, logger)
        #sched.step(val_acc)
        logger.print()

In [12]:
dataset = Dataset()
model = Model(dataset)
train(dataset, model)

|        epoch |     trn loss |       metric |          hom |
|        0.000 |        1.285 |        0.080 |        0.994 |
|        1.000 |        1.196 |        0.083 |        0.994 |
|        2.000 |        1.075 |        0.088 |        0.993 |
|        3.000 |        0.973 |        0.088 |        0.991 |
|        4.000 |        0.923 |        0.084 |        0.988 |
|        5.000 |        0.903 |        0.084 |        0.986 |
|        6.000 |        0.878 |        0.087 |        0.982 |
|        7.000 |        0.851 |        0.091 |        0.979 |
|        8.000 |        0.855 |        0.092 |        0.975 |
|        9.000 |        0.869 |        0.091 |        0.970 |
|       10.000 |        0.836 |        0.093 |        0.966 |
|       11.000 |        0.845 |        0.095 |        0.962 |
|       12.000 |        0.836 |        0.097 |        0.958 |
|       13.000 |        0.828 |        0.100 |        0.954 |
|       14.000 |        0.825 |        0.102 |        0.950 |
|       

 99.000 |        0.614 |        0.129 |        0.834