In [None]:
%%capture
!pip install wandb --upgrade
!pip install ujson

BRITS implementation in colab

In [None]:
import os
import time

import ujson as json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
import math

from torch.nn.parameter import Parameter

In [None]:
SEQ_LEN = 28
RNN_HID_SIZE = 64
INPUT_SIZE = 9

Other than in the original BRITS, we removed the classification loss, as we do not have a classificaition task. Further we did a hard split into train and test set as it is good practice. The code is based on, but updated, cleaned and converted to Python3:
https://github.com/NIPS-BRITS/BRITS

In [None]:
class FeatureRegression(nn.Module):
    def __init__(self, input_size):
        super(FeatureRegression, self).__init__()
        self.build(input_size)

    def build(self, input_size):
        self.W = Parameter(torch.Tensor(input_size, input_size))
        self.b = Parameter(torch.Tensor(input_size))

        m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)
        self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, x):
        z_h = F.linear(x, self.W * self.m, self.b)
        return z_h

class TemporalDecay(nn.Module):
    def __init__(self, input_size, output_size, diag = False):
        super(TemporalDecay, self).__init__()
        self.diag = diag

        self.build(input_size, output_size)

    def build(self, input_size, output_size):
        self.W = Parameter(torch.Tensor(output_size, input_size))
        self.b = Parameter(torch.Tensor(output_size))

        if self.diag == True:
            assert(input_size == output_size)
            m = torch.eye(input_size, input_size)
            self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, d):
        if self.diag == True:
            gamma = F.relu(F.linear(d, self.W * self.m, self.b))
        else:
            gamma = F.relu(F.linear(d, self.W, self.b))
        gamma = torch.exp(-gamma)
        return gamma

In [None]:
class RITS(nn.Module):
    def __init__(self):
        super(RITS, self).__init__()
        self.build()

    def build(self):
        self.rnn_cell = nn.LSTMCell(INPUT_SIZE * 2, RNN_HID_SIZE)

        self.temp_decay_h = TemporalDecay(input_size = INPUT_SIZE, output_size = RNN_HID_SIZE, diag = False)
        self.temp_decay_x = TemporalDecay(input_size = INPUT_SIZE, output_size = INPUT_SIZE, diag = True)

        self.hist_reg = nn.Linear(RNN_HID_SIZE, INPUT_SIZE)
        self.feat_reg = FeatureRegression(INPUT_SIZE)

        self.weight_combine = nn.Linear(INPUT_SIZE * 2, INPUT_SIZE)

        self.dropout = nn.Dropout(p = 0.25)
        self.out = nn.Linear(RNN_HID_SIZE, 1)

    def forward(self, data, direct):
        values = data[direct]['values']
        masks = data[direct]['masks']
        deltas = data[direct]['deltas']

        evals = data[direct]['evals']
        eval_masks = data[direct]['eval_masks']

        h = torch.zeros((values.size()[0], RNN_HID_SIZE))
        c = torch.zeros((values.size()[0], RNN_HID_SIZE))

        if torch.cuda.is_available():
            h, c = h.cuda(), c.cuda()

        x_loss = 0.0
        y_loss = 0.0

        imputations = []

        for t in range(SEQ_LEN):
            x = values[:, t, :]
            m = masks[:, t, :]
            d = deltas[:, t, :]

            gamma_h = self.temp_decay_h(d)
            gamma_x = self.temp_decay_x(d)

            h = h * gamma_h

            x_h = self.hist_reg(h)
            x_loss += torch.sum(torch.abs(x - x_h) * m) / (torch.sum(m) + 1e-5)

            x_c =  m * x +  (1 - m) * x_h

            z_h = self.feat_reg(x_c)
            x_loss += torch.sum(torch.abs(x - z_h) * m) / (torch.sum(m) + 1e-5)

            alpha = self.weight_combine(torch.cat([gamma_x, m], dim = 1))

            c_h = alpha * z_h + (1 - alpha) * x_h
            x_loss += torch.sum(torch.abs(x - c_h) * m) / (torch.sum(m) + 1e-5)

            c_c = m * x + (1 - m) * c_h

            inputs = torch.cat([c_c, m], dim = 1)

            h, c = self.rnn_cell(inputs, (h, c))

            imputations.append(c_c.unsqueeze(dim = 1))

        imputations = torch.cat(imputations, dim = 1)

        return {'loss': x_loss / SEQ_LEN,
                'imputations': imputations,
                'evals': evals, 
                'eval_masks': eval_masks}

    def run_on_batch(self, data, optimizer):
        ret = self(data, direct = 'forward')

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

In [None]:
class BRITS(nn.Module):
    def __init__(self):
        super(BRITS, self).__init__()
        self.build()

    def build(self):
        self.rits_f = RITS()
        self.rits_b = RITS()

    def forward(self, data):
        ret_f = self.rits_f(data, 'forward')
        ret_b = self.reverse(self.rits_b(data, 'backward'))

        ret = self.merge_ret(ret_f, ret_b)

        return ret

    def merge_ret(self, ret_f, ret_b):
        loss_f = ret_f['loss']
        loss_b = ret_b['loss']
        loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations'])

        loss = loss_f + loss_b + loss_c

        imputations = (ret_f['imputations'] + ret_b['imputations']) / 2

        ret_f['loss'] = loss
        ret_f['imputations'] = imputations

        return ret_f

    def get_consistency_loss(self, pred_f, pred_b):
        #loss old:
        #loss = torch.pow(pred_f - pred_b, 2.0).mean()
        #return loss
        loss = torch.abs(pred_f - pred_b).mean() * 1e-1
        return loss

    def reverse(self, ret):
        def reverse_tensor(tensor_):
            if tensor_.dim() <= 1:
                return tensor_
            indices = range(tensor_.size()[1])[::-1]
            indices = torch.tensor(indices, requires_grad=False).long()#.requires_grad_(requires_grad=False) 

            if torch.cuda.is_available():
                indices = indices.cuda()

            return tensor_.index_select(1, indices)

        for key in ret:
            ret[key] = reverse_tensor(ret[key])

        return ret

    def run_on_batch(self, data, optimizer):
        ret = self(data)

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

In [None]:
def load_normalizations(file):
  content = open(file).readlines()
  return json.loads(content[0])

The Data is strucutred as follows:
Each Entry in the Dataset is one Timeseries with n steps.
It has a `forward` and a `backward` direction. For the both RITS Networks
Each has following entries:

*   `values`: data after elimination of values
*   `masks`: indicating if data is missing
*   `deltas`: timedeltas since last recorded data
*   `evals`: ground truth
*   `eval_masks`: 1 if is ground truth and missing in values 0 otherwise

In [None]:
data_set_dict_entries=['values','masks','deltas','evals','eval_masks']
class MySet2(Dataset):
    def __init__(self,content_path):
        super(MySet2, self).__init__()
        content = open(content_path).readlines()
        recs = [json.loads(x) for x in content]
        self.forward = self.to_tensor_dict([x['forward'] for x in recs])
        self.backward = self.to_tensor_dict([x['backward'] for x in recs])

    def __len__(self):
        return len(self.forward[data_set_dict_entries[0]])
    
    def to_tensor_dict(self,recs):
      return_dict = {}
      for dict_key in data_set_dict_entries:
        return_dict[dict_key] = torch.FloatTensor([[x[dict_key][0:INPUT_SIZE]for x in r]for r in recs])
      return return_dict
    
    def __getitem__(self, idx):
      forward = {}
      backward = {}
      for dict_key in data_set_dict_entries:
        forward[dict_key] = self.forward[dict_key][idx]
        backward[dict_key] = self.backward[dict_key][idx]
      return {'forward':forward,'backward':backward}

In [None]:
def collate_fn2(recs):
  batch_size = len(recs)
  forward = {}
  backward = {}
  for dict_key in data_set_dict_entries:
    forward[dict_key] = torch.empty(batch_size,SEQ_LEN,INPUT_SIZE)
    backward[dict_key] = torch.empty(batch_size,SEQ_LEN,INPUT_SIZE)
  for idx,x in enumerate(recs):
    for dict_key in data_set_dict_entries:
      forward[dict_key][idx] = x['forward'][dict_key]
      backward[dict_key][idx] = x['backward'][dict_key] 
  return {'forward': forward, 'backward': backward}

In [None]:
indexes= {'meal':0,'calories':1,'carbs':2,'fat':3,'protein':4}

In [None]:
epochs = 50
batch_size = 64
use_norm = True
path_train_json = './brits_json.json'
path_test_json = './brits_test.json'
if not use_norm:
  path_train_json = './brits_json_nonnorm.json'
  path_test_json = './brits_test_nonnorm.json'
normalizations = load_normalizations('brits_normalization.json')

model = BRITS()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
train_set = MySet2(path_train_json)
train_iter = DataLoader(dataset = train_set,batch_size = batch_size,shuffle = True,pin_memory = True, collate_fn = collate_fn2)

In [None]:
for epoch in range(epochs):
  model.train()

  run_loss = 0.0

  for idx, data in enumerate(train_iter):
    ret = model.run_on_batch(data, optimizer)

    run_loss += ret['loss'].item()

    print('\r Progress epoch {}, {:.2f}%, average loss {}'.format(epoch, (idx + 1) * 100.0 / len(train_iter), run_loss / (idx + 1.0))),

Testing against MEAN

In [None]:
test_set = MySet2(path_test_json)
test_iter = DataLoader(dataset = test_set,batch_size = batch_size,shuffle = False,pin_memory = True, collate_fn = collate_fn2)
model.eval()
for idx, data in enumerate(test_iter):
  ret = model.run_on_batch(data, None)
  break

In [None]:
if len(test_set) == len(train_set):
  raise Exception("Length of test and train set are equal. Is there a Mistake?")

In [None]:

def revert_norm(value,name):
  if not use_norm:
    #only if we actually use normalized data
    return value
  return value * normalizations['std'][name] + normalizations['mean'][name]

In [None]:

def calc_error_on_row(ret, index_name):
  index = indexes[index_name]
  filter = (ret['eval_masks'][:,:,index]==1)#get all rows that are missing to compare to ground truth
  impu = revert_norm(ret['imputations'][:,:,index][filter],index_name)
  real = revert_norm(ret['evals'][:,:,index][filter],index_name)
  return torch.abs(impu-real).mean()

def calculate_personal_mean_error(data,index_name):
  index = indexes[index_name]
  forward = data['forward']
  error = torch.tensor(())
  for x in range(forward['eval_masks'].size(dim=0)):
    filter = (forward['eval_masks'][x,:,index]==1)#get all rows that are missing
    z =  revert_norm(forward['values'][x,:,index][filter==False],index_name).mean()#calculate the mean of all non missing
    real =  revert_norm(forward['evals'][x,:,index][filter],index_name)
    error = torch.cat((error,torch.abs(real-z)), dim=0)
  return error.mean()

def calc_mean_error(data,index_name):
  index = indexes[index_name]
  forward = data['forward']
  filter = (forward['eval_masks'][:,:,index]==1)#get all rows that are missing
  z =  revert_norm(forward['values'][:,:,index][filter==False],index_name).mean()#calculate the mean of all non missing
  real =  revert_norm(forward['evals'][:,:,index][filter],index_name)
  #mean_erro = torch.abs(real-z).mean()
  mean_erro = torch.abs(real-z).mean()
  return mean_erro

index_name = 'calories'
mean_erro = calc_mean_error(data,index_name)
personal_mean_erro = calculate_personal_mean_error(data,index_name)
rnn_erro = calc_error_on_row(ret,index_name)

print(f"RNN   {rnn_erro}")
print(f"MEAN  {mean_erro}")
print(f"MEANP {personal_mean_erro}")

RNN   341.09503173828125
MEAN  295.6275939941406
MEANP 304.252685546875


In [None]:
normalizations

{'mean': {'calories': 569.9514961389962,
  'carbs': 57.65986969111969,
  'cholest': 82.33252895752896,
  'fat': 23.61534749034749,
  'fiber': 6.033301158301159,
  'protein': 29.91686776061776,
  'sodium': 737.3712596525097,
  'sugar': 20.981901544401545},
 'std': {'calories': 19.033275991779742,
  'carbs': 6.547162660719812,
  'cholest': 11.447796506294267,
  'fat': 4.473921516546562,
  'fiber': 2.448608392557053,
  'protein': 4.818375431259857,
  'sodium': 31.317557844872915,
  'sugar': 4.7081826053650495}}