In [39]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [40]:
%%capture
!pip install wandb --upgrade

In [41]:
!cp drive/MyDrive/Uni/Masterarbeit/data/imputation/* .

BRITS implementation in colab

In [42]:
import os
import time

import numpy as np
import pandas as pd
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
import math

from torch.nn.parameter import Parameter

In [43]:
SEQ_LEN = 28
RNN_HID_SIZE = 64
INPUT_SIZE = 9

Other than in the original BRITS, we removed the classification loss, as we do not have a classificaition task. Further we did a hard split into train and test set as it is good practice. The code is based on, but updated, cleaned and converted to Python3:
https://github.com/NIPS-BRITS/BRITS

In [44]:
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu') 
    return device
device = get_device()
print(device)

cpu


In [45]:
class FeatureRegression(nn.Module):
    def __init__(self, input_size):
        super(FeatureRegression, self).__init__()
        self.build(input_size)

    def build(self, input_size):
        self.W = Parameter(torch.Tensor(input_size, input_size)).to(device)
        self.b = Parameter(torch.Tensor(input_size)).to(device)

        m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)
        m = m.to(device)
        self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, x):
        z_h = F.linear(x, self.W * self.m, self.b)
        return z_h

class TemporalDecay(nn.Module):
    def __init__(self, input_size, output_size, diag = False):
        super(TemporalDecay, self).__init__()
        self.diag = diag

        self.build(input_size, output_size)

    def build(self, input_size, output_size):
        self.W = Parameter(torch.Tensor(output_size, input_size)).to(device)
        self.b = Parameter(torch.Tensor(output_size)).to(device)

        if self.diag == True:
            assert(input_size == output_size)
            m = torch.eye(input_size, input_size).to(device)
            self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, d):
        if self.diag == True:
            gamma = F.relu(F.linear(d, self.W * self.m, self.b))
        else:
            gamma = F.relu(F.linear(d, self.W, self.b))
        gamma = torch.exp(-gamma)
        return gamma

In [46]:
class RITS(nn.Module):
    def __init__(self):
        super(RITS, self).__init__()
        self.build()

    def build(self):
        self.rnn_cell = nn.LSTMCell(INPUT_SIZE * 2, RNN_HID_SIZE).to(device)

        self.temp_decay_h = TemporalDecay(input_size = INPUT_SIZE, output_size = RNN_HID_SIZE, diag = False)
        self.temp_decay_x = TemporalDecay(input_size = INPUT_SIZE, output_size = INPUT_SIZE, diag = True)

        self.hist_reg = nn.Linear(RNN_HID_SIZE, INPUT_SIZE).to(device)
        self.feat_reg = FeatureRegression(INPUT_SIZE).to(device)

        self.weight_combine = nn.Linear(INPUT_SIZE * 2, INPUT_SIZE).to(device)

        self.dropout = nn.Dropout(p = 0.25).to(device)
        self.out = nn.Linear(RNN_HID_SIZE, 1).to(device)

    def forward(self, data, direct):
        values = data[direct]['values'].to(device)
        masks = data[direct]['masks'].to(device)
        deltas = data[direct]['deltas'].to(device)

        evals = data[direct]['evals'].to(device)
        eval_masks = data[direct]['eval_masks'].to(device)

        h = torch.zeros((values.size()[0], RNN_HID_SIZE)).to(device)
        c = torch.zeros((values.size()[0], RNN_HID_SIZE)).to(device)
          

        x_loss = 0.0
        y_loss = 0.0

        imputations = []

        for t in range(SEQ_LEN):
            x = values[:, t, :]
            m = masks[:, t, :]
            d = deltas[:, t, :]

            gamma_h = self.temp_decay_h(d)
            gamma_x = self.temp_decay_x(d)

            h = h * gamma_h

            x_h = self.hist_reg(h)
            x_loss += torch.sum(torch.abs(x - x_h) * m) / (torch.sum(m) + 1e-5)

            x_c =  m * x +  (1 - m) * x_h

            z_h = self.feat_reg(x_c)
            x_loss += torch.sum(torch.abs(x - z_h) * m) / (torch.sum(m) + 1e-5)

            alpha = self.weight_combine(torch.cat([gamma_x, m], dim = 1))

            c_h = alpha * z_h + (1 - alpha) * x_h
            x_loss += torch.sum(torch.abs(x - c_h) * m) / (torch.sum(m) + 1e-5)

            c_c = m * x + (1 - m) * c_h

            inputs = torch.cat([c_c, m], dim = 1)

            h, c = self.rnn_cell(inputs, (h, c))

            imputations.append(c_c.unsqueeze(dim = 1))

        imputations = torch.cat(imputations, dim = 1)

        return {'loss': x_loss / SEQ_LEN,
                'imputations': imputations,
                'evals': evals, 
                'eval_masks': eval_masks}

    def run_on_batch(self, data, optimizer):
        ret = self(data, direct = 'forward')

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

In [47]:
class BRITS(nn.Module):
    def __init__(self):
        super(BRITS, self).__init__()
        self.build()

    def build(self):
        self.rits_f = RITS()
        self.rits_b = RITS()

    def forward(self, data):
        ret_f = self.rits_f(data, 'forward')
        ret_b = self.reverse(self.rits_b(data, 'backward'))

        ret = self.merge_ret(ret_f, ret_b)

        return ret

    def merge_ret(self, ret_f, ret_b):
        loss_f = ret_f['loss']
        loss_b = ret_b['loss']
        loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations'])

        loss = loss_f + loss_b + loss_c

        imputations = (ret_f['imputations'] + ret_b['imputations']) / 2

        ret_f['loss'] = loss
        ret_f['imputations'] = imputations

        return ret_f

    def get_consistency_loss(self, pred_f, pred_b):
        #loss old:
        #loss = torch.pow(pred_f - pred_b, 2.0).mean()
        #return loss
        loss = torch.abs(pred_f - pred_b).mean() * 1e-1
        return loss

    def reverse(self, ret):
        def reverse_tensor(tensor_):
            if tensor_.dim() <= 1:
                return tensor_
            indices = range(tensor_.size()[1])[::-1]
            indices = torch.tensor(indices, requires_grad=False).long()#.requires_grad_(requires_grad=False) 

            if torch.cuda.is_available():
                indices = indices.cuda()

            return tensor_.index_select(1, indices)

        for key in ret:
            ret[key] = reverse_tensor(ret[key])

        return ret

    def run_on_batch(self, data, optimizer):
        ret = self(data)

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

In [48]:
def load_normalizations(file):
  content = open(file,'rb')
  ret = pickle.load(content)
  content.close()
  return ret

The Data is strucutred as follows:
Each Entry in the Dataset is one Timeseries with n steps.
It has a `forward` and a `backward` direction. For the both RITS Networks
Each has following entries:

*   `values`: data after elimination of values
*   `masks`: indicating if data is missing
*   `deltas`: timedeltas since last recorded data
*   `evals`: ground truth
*   `eval_masks`: 1 if is ground truth and missing in values 0 otherwise

In [49]:
data_set_dict_entries=['values','masks','deltas','evals','eval_masks']
class MySet2(Dataset):
    def __init__(self,content_path):
        super(MySet2, self).__init__()
        content = open(content_path,'rb')
        recs = pickle.load(content)
        content.close()
        self.forward = self.to_tensor_dict([x['forward'] for x in recs])
        self.backward = self.to_tensor_dict([x['backward'] for x in recs])

    def __len__(self):
        return len(self.forward[data_set_dict_entries[0]])
    
    def to_tensor_dict(self,recs):
      return_dict = {}
      for dict_key in data_set_dict_entries:
        tens = torch.FloatTensor([[x[dict_key][0:INPUT_SIZE]for x in r]for r in recs])
        return_dict[dict_key] = tens 
      return return_dict
    
    def __getitem__(self, idx):
      forward = {}
      backward = {}
      for dict_key in data_set_dict_entries:
        forward[dict_key] = self.forward[dict_key][idx]
        backward[dict_key] = self.backward[dict_key][idx]
      return {'forward':forward,'backward':backward}

In [50]:
def collate_fn2(recs):
  batch_size = len(recs)
  forward = {}
  backward = {}
  for dict_key in data_set_dict_entries:
      forward[dict_key] = torch.empty(batch_size,SEQ_LEN,INPUT_SIZE)
      backward[dict_key] = torch.empty(batch_size,SEQ_LEN,INPUT_SIZE)
  for idx,x in enumerate(recs):
    for dict_key in data_set_dict_entries:
      forward[dict_key][idx] = x['forward'][dict_key]
      backward[dict_key][idx] = x['backward'][dict_key] 
  return {'forward': forward, 'backward': backward}

In [51]:
indexes= {'meal':0,'calories':1,'carbs':2,'fat':3,'protein':4}

In [52]:
epochs = 100
batch_size = 500
use_norm = True
path_train = './brits_train.pickle'
path_test = './brits_test.pickle'
if not use_norm:
  path_train = './brits_train_nonnorm.pickle'
  path_test = './brits_test_nonnorm.pickle'
normalizations = load_normalizations('brits_normalization.pickle')

model = BRITS()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
train_set = MySet2(path_train)
train_iter = DataLoader(dataset = train_set,batch_size = batch_size,shuffle = True,pin_memory = True, collate_fn = collate_fn2)

In [53]:
t0 = time.time()
for epoch in range(epochs):
  model.train()

  run_loss = 0.0

  for idx, data in enumerate(train_iter):
    ret = model.run_on_batch(data, optimizer)

    run_loss += ret['loss'].item()

    print('\r Progress epoch {}, {:.2f}%, average loss {}'.format(epoch, (idx + 1) * 100.0 / len(train_iter), run_loss / (idx + 1.0)))

print(f"{device} {time.time()-t0}")

 Progress epoch 0, 100.00%, average loss 46.651329040527344
 Progress epoch 1, 100.00%, average loss 46.3891716003418
 Progress epoch 2, 100.00%, average loss 46.13164520263672
 Progress epoch 3, 100.00%, average loss 45.8769416809082
 Progress epoch 4, 100.00%, average loss 45.632198333740234
 Progress epoch 5, 100.00%, average loss 45.39121627807617
 Progress epoch 6, 100.00%, average loss 45.15153884887695
 Progress epoch 7, 100.00%, average loss 44.91617202758789
 Progress epoch 8, 100.00%, average loss 44.68472671508789
 Progress epoch 9, 100.00%, average loss 44.455814361572266
 Progress epoch 10, 100.00%, average loss 44.22711944580078
 Progress epoch 11, 100.00%, average loss 43.999755859375
 Progress epoch 12, 100.00%, average loss 43.773197174072266
 Progress epoch 13, 100.00%, average loss 43.549922943115234
 Progress epoch 14, 100.00%, average loss 43.32828903198242
 Progress epoch 15, 100.00%, average loss 43.10932540893555
 Progress epoch 16, 100.00%, average loss 42.8933

Testing against MEAN

In [94]:
test_set = MySet2(path_test)
test_iter = DataLoader(dataset = test_set,batch_size = batch_size,shuffle = False,pin_memory = True, collate_fn = collate_fn2)
model.eval()
for idx, data in enumerate(test_iter):
  ret = model.run_on_batch(data, None)
  break

In [95]:
if len(test_set) == len(train_set):
  raise Exception("Length of test and train set are equal. Is there a Mistake?")

In [96]:
def revert_norm(value,name):
  if not use_norm:
    #only if we actually use normalized data
    return value
  return value * normalizations['std'][name] + normalizations['mean'][name]

In [99]:
def print_impu_real(ret,index_name):
  index = indexes[index_name]
  filter = (ret['eval_masks'][:,:,index]==1)#get all rows that are missing to compare to ground truth
  impu = revert_norm(ret['imputations'][:,:,index][filter],index_name)
  real = revert_norm(ret['evals'][:,:,index][filter],index_name)
  px = pd.DataFrame()
  px.insert(0,"Real",real.numpy())
  px.insert(0,"Imputation",impu.detach().numpy())
  return px

def calc_error_on_row(ret, index_name):
  index = indexes[index_name]
  filter = (ret['eval_masks'][:,:,index]==1)#get all rows that are missing to compare to ground truth
  impu = revert_norm(ret['imputations'][:,:,index][filter],index_name)
  real = revert_norm(ret['evals'][:,:,index][filter],index_name)
  return torch.abs(impu-real).mean()

def calculate_personal_mean_error(data,index_name):
  index = indexes[index_name]
  forward = data['forward']
  error = torch.tensor(())
  for x in range(forward['eval_masks'].size(dim=0)):
    filter = (forward['eval_masks'][x,:,index]==1)#get all rows that are missing
    z =  revert_norm(forward['values'][x,:,index][filter==False],index_name).mean()#calculate the mean of all non missing
    real =  revert_norm(forward['evals'][x,:,index][filter],index_name)
    error = torch.cat((error,torch.abs(real-z)), dim=0)
  return error.mean()

def calc_mean_error(data,index_name):
  index = indexes[index_name]
  forward = data['forward']
  filter = (forward['eval_masks'][:,:,index]==1)#get all rows that are missing
  z =  revert_norm(forward['values'][:,:,index][filter==False],index_name).mean()#calculate the mean of all non missing
  real =  revert_norm(forward['evals'][:,:,index][filter],index_name)
  #mean_erro = torch.abs(real-z).mean()
  mean_erro = torch.abs(real-z).mean()
  return mean_erro

index_name = 'calories'
mean_erro = calc_mean_error(data,index_name)
personal_mean_erro = calculate_personal_mean_error(data,index_name)
rnn_erro = calc_error_on_row(ret,index_name)

print(f"RNN   {rnn_erro}")
print(f"MEAN  {mean_erro}")
print(f"MEANP {personal_mean_erro}")

RNN   221.11209106445312
MEAN  285.0473937988281
MEANP 296.0644836425781


In [None]:
print_impu_real(ret,index_name)

In [None]:
normalizations

In [None]:
#use_norm = True
def flatten(t):
    return [item for sublist in t for item in sublist]
  

in_file =  open(path_test,'rb')
recs = pickle.load(in_file)
in_file.close()
#recs = [json.loads(x) for x in content]
recs = [x['forward'] for x in recs]
#recs = [[(y['values'][indexes['calories']] - normalizations['mean']['calories'] )/ normalizations['std']['calories'] for y in x]for x in recs]
#recs = flatten(recs)
#recs = [revert_norm(x,'calories') for x in recs]
recs = [[revert_norm(y['values'][indexes['calories']],'calories')for y in x]for x in recs]
recs = flatten(recs)
recs = np.array(recs)
recs.mean()