In [3]:
import torch
from torch.nn.parameter import Parameter
import torch.nn as nn

import import_ipynb
import math
import utils
import argparse
import data_loader

import rits
from sklearn import metrics

from ipdb import set_trace


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.build()

    def build(self):
        self.rits_f = rits.Model()
        self.rits_b = rits.Model()

    def forward(self, data, seq_len):
        self.seq_len=seq_len
        ret_f = self.rits_f(data, 'forward',seq_len)
        ret_b = self.reverse(self.rits_b(data, 'backward',seq_len))

        ret = self.merge_ret(ret_f, ret_b)

        return ret

    def merge_ret(self, ret_f, ret_b):
        loss_f = ret_f['loss']
        loss_b = ret_b['loss']
        loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations'])

        loss = loss_f + loss_b + loss_c

        imputations = (ret_f['imputations'] + ret_b['imputations']) / 2

        ret_f['loss'] = loss
        ret_f['imputations'] = imputations

        return ret_f

    def get_consistency_loss(self, pred_f, pred_b):
        loss = torch.sqrt(torch.pow(pred_f - pred_b, 2.0).mean())
        return loss

    def reverse(self, ret):
        def reverse_tensor(tensor_):
            if tensor_.dim() <= 1:
                return tensor_
            indices = range(tensor_.size()[1])[::-1]
            indices = torch.tensor(indices, requires_grad=False)

            if torch.cuda.is_available():
                indices = indices.cuda()

            return tensor_.index_select(1, indices)

        for key in ret:
            ret[key] = reverse_tensor(ret[key])

        return ret

    def run_on_batch(self, data, optimizer, seq_len):
        ret = self(data, seq_len=seq_len)

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret