In [1]:
#from google.colab import drive
#drive.mount('/content/drive')

In [2]:
%%capture
!pip install wandb --upgrade

In [3]:
!cp drive/MyDrive/Uni/Masterarbeit/data/imputation/* .

BRITS implementation in colab

In [4]:
import os
import time

import numpy as np
import pandas as pd
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter

from tqdm import tqdm
import math
import wandb

In [5]:
SEQ_LEN = 28
RNN_HID_SIZE = 64
INPUT_SIZE = 9

In [6]:
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu') 
    return device
device = get_device()
print(device)

cuda:0


#Network Classes

Other than in the original BRITS, we removed the classification loss, as we do not have a classificaition task. Further we did a hard split into train and test set as it is good practice. The code is based on, but updated, cleaned and converted to Python3:
https://github.com/NIPS-BRITS/BRITS

In [7]:
class FeatureRegression(nn.Module):
    def __init__(self, input_size):
        super(FeatureRegression, self).__init__()
        self.build(input_size)

    def build(self, input_size):
        self.W = Parameter(torch.Tensor(input_size, input_size)).to(device)
        self.b = Parameter(torch.Tensor(input_size)).to(device)

        m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size)
        m = m.to(device)
        self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, x):
        z_h = F.linear(x, self.W * self.m, self.b)
        return z_h

class TemporalDecay(nn.Module):
    def __init__(self, input_size, output_size, diag = False):
        super(TemporalDecay, self).__init__()
        self.diag = diag

        self.build(input_size, output_size)

    def build(self, input_size, output_size):
        self.W = Parameter(torch.Tensor(output_size, input_size)).to(device)
        self.b = Parameter(torch.Tensor(output_size)).to(device)

        if self.diag == True:
            assert(input_size == output_size)
            m = torch.eye(input_size, input_size).to(device)
            self.register_buffer('m', m)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.W.size(0))
        self.W.data.uniform_(-stdv, stdv)
        if self.b is not None:
            self.b.data.uniform_(-stdv, stdv)

    def forward(self, d):
        if self.diag == True:
            gamma = F.relu(F.linear(d, self.W * self.m, self.b))
        else:
            gamma = F.relu(F.linear(d, self.W, self.b))
        gamma = torch.exp(-gamma)
        return gamma

In [8]:
class RITS(nn.Module):
    def __init__(self):
        super(RITS, self).__init__()
        self.build()

    def build(self):
        self.rnn_cell = nn.LSTMCell(INPUT_SIZE * 2, RNN_HID_SIZE).to(device)

        self.temp_decay_h = TemporalDecay(input_size = INPUT_SIZE, output_size = RNN_HID_SIZE, diag = False)
        self.temp_decay_x = TemporalDecay(input_size = INPUT_SIZE, output_size = INPUT_SIZE, diag = True)

        self.hist_reg = nn.Linear(RNN_HID_SIZE, INPUT_SIZE).to(device)
        self.feat_reg = FeatureRegression(INPUT_SIZE).to(device)

        self.weight_combine = nn.Linear(INPUT_SIZE * 2, INPUT_SIZE).to(device)

        self.dropout = nn.Dropout(p = 0.25).to(device)
        self.out = nn.Linear(RNN_HID_SIZE, 1).to(device)

    def forward(self, data, direct):
        values = data[direct]['values'].to(device)
        masks = data[direct]['masks'].to(device)
        deltas = data[direct]['deltas'].to(device)

        evals = data[direct]['evals'].to(device)
        eval_masks = data[direct]['eval_masks'].to(device)

        h = torch.zeros((values.size()[0], RNN_HID_SIZE)).to(device)
        c = torch.zeros((values.size()[0], RNN_HID_SIZE)).to(device)
          

        x_loss = 0.0
        y_loss = 0.0

        imputations = []

        for t in range(SEQ_LEN):
            x = values[:, t, :]
            m = masks[:, t, :]
            d = deltas[:, t, :]

            gamma_h = self.temp_decay_h(d)
            gamma_x = self.temp_decay_x(d)

            h = h * gamma_h

            x_h = self.hist_reg(h)
            x_loss += torch.sum(torch.abs(x - x_h) * m) / (torch.sum(m) + 1e-5)

            x_c =  m * x +  (1 - m) * x_h

            z_h = self.feat_reg(x_c)
            x_loss += torch.sum(torch.abs(x - z_h) * m) / (torch.sum(m) + 1e-5)

            alpha = self.weight_combine(torch.cat([gamma_x, m], dim = 1))

            c_h = alpha * z_h + (1 - alpha) * x_h
            x_loss += torch.sum(torch.abs(x - c_h) * m) / (torch.sum(m) + 1e-5)

            c_c = m * x + (1 - m) * c_h

            inputs = torch.cat([c_c, m], dim = 1)

            h, c = self.rnn_cell(inputs, (h, c))

            imputations.append(c_c.unsqueeze(dim = 1))

        imputations = torch.cat(imputations, dim = 1)

        return {'loss': x_loss / SEQ_LEN,
                'imputations': imputations,
                'evals': evals, 
                'eval_masks': eval_masks}

    def run_on_batch(self, data, optimizer):
        ret = self(data, direct = 'forward')

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

In [9]:
class BRITS(nn.Module):
    def __init__(self):
        super(BRITS, self).__init__()
        self.build()

    def build(self):
        self.rits_f = RITS()
        self.rits_b = RITS()

    def forward(self, data):
        ret_f = self.rits_f(data, 'forward')
        ret_b = self.reverse(self.rits_b(data, 'backward'))

        ret = self.merge_ret(ret_f, ret_b)

        return ret

    def merge_ret(self, ret_f, ret_b):
        loss_f = ret_f['loss']
        loss_b = ret_b['loss']
        loss_c = self.get_consistency_loss(ret_f['imputations'], ret_b['imputations'])

        loss = loss_f + loss_b + loss_c

        imputations = (ret_f['imputations'] + ret_b['imputations']) / 2

        ret_f['loss'] = loss
        ret_f['imputations'] = imputations

        return ret_f

    def get_consistency_loss(self, pred_f, pred_b):
        #loss old:
        #loss = torch.pow(pred_f - pred_b, 2.0).mean()
        #return loss
        loss = torch.abs(pred_f - pred_b).mean() * 1e-1
        return loss

    def reverse(self, ret):
        def reverse_tensor(tensor_):
            if tensor_.dim() <= 1:
                return tensor_
            indices = range(tensor_.size()[1])[::-1]
            indices = torch.tensor(indices, requires_grad=False).long()#.requires_grad_(requires_grad=False) 

            if torch.cuda.is_available():
                indices = indices.cuda()

            return tensor_.index_select(1, indices)

        for key in ret:
            ret[key] = reverse_tensor(ret[key])

        return ret

    def run_on_batch(self, data, optimizer):
        ret = self(data)

        if optimizer is not None:
            optimizer.zero_grad()
            ret['loss'].backward()
            optimizer.step()

        return ret

# Dataset

The Data is strucutred as follows:
Each Entry in the Dataset is one Timeseries with n steps.
It has a `forward` and a `backward` direction. For the both RITS Networks
Each has following entries:

*   `values`: data after elimination of values
*   `masks`: indicating if data is missing
*   `deltas`: timedeltas since last recorded data
*   `evals`: ground truth
*   `eval_masks`: 1 if is ground truth and missing in values 0 otherwise

In [10]:
data_set_dict_entries=['values','masks','deltas','evals','eval_masks']

class MySet2(Dataset):
    def __init__(self,content_path):
        super(MySet2, self).__init__()
        content = open(content_path,'rb')
        recs = pickle.load(content)
        content.close()
        self.forward = self.to_tensor_dict([x['forward'] for x in recs])
        self.backward = self.to_tensor_dict([x['backward'] for x in recs])

    def __len__(self):
        return len(self.forward[data_set_dict_entries[0]])
    
    def to_tensor_dict(self,recs):
      return_dict = {}
      for dict_key in data_set_dict_entries:
        tens = torch.FloatTensor([[x[dict_key][0:INPUT_SIZE]for x in r]for r in recs])
        return_dict[dict_key] = tens 
      return return_dict
    
    def __getitem__(self, idx):
      forward = {}
      backward = {}
      for dict_key in data_set_dict_entries:
        forward[dict_key] = self.forward[dict_key][idx]
        backward[dict_key] = self.backward[dict_key][idx]
      return {'forward':forward,'backward':backward}

def collate_fn2(recs):
  batch_size = len(recs)
  forward = {}
  backward = {}
  for dict_key in data_set_dict_entries:
      forward[dict_key] = torch.empty(batch_size,SEQ_LEN,INPUT_SIZE)
      backward[dict_key] = torch.empty(batch_size,SEQ_LEN,INPUT_SIZE)
  for idx,x in enumerate(recs):
    for dict_key in data_set_dict_entries:
      forward[dict_key][idx] = x['forward'][dict_key]
      backward[dict_key][idx] = x['backward'][dict_key] 
  return {'forward': forward, 'backward': backward}

# Error Functions

In [11]:
indexes= {'calories':1,'carbs':2,'fat':3,'protein':4}
def load_normalizations(file):
  content = open(file,'rb')
  ret = pickle.load(content)
  content.close()
  return ret
normalizations = load_normalizations('brits_normalization.pickle')

def revert_norm(value,name):
  if not use_norm:
    #only if we actually use normalized data
    return value
  return value * normalizations['std'][name] + normalizations['mean'][name]

def get_missing_and_index(index_name):
  index = indexes[index_name]
  filter = (ret['eval_masks'][:,:,index]==1) #get all rows that are missing to compare to ground truth
  return index,filter

def print_impu_real(ret,index_name):
  index, filter = get_missing_and_index(index_name)
  impu = revert_norm(ret['imputations'][:,:,index][filter],index_name)
  real = revert_norm(ret['evals'][:,:,index][filter],index_name)
  px = pd.DataFrame()
  px.insert(0,"Real",real.numpy())
  px.insert(0,"Imputation",impu.detach().numpy())
  return px

def get_abs_error_val(ret, index_name):
  """Calculates the absolute error of the brits imputation for each missing value and returns the mean error"""
  index, filter = get_missing_and_index(index_name)
  impu = revert_norm(ret['imputations'][:,:,index][filter],index_name)
  real = revert_norm(ret['evals'][:,:,index][filter],index_name)
  return torch.abs(impu-real).mean().item()

def get_abs_personal_mean_impu_error(data,index_name):
  """Calculates the absolute error of the personal mean imputation for each missing value and returns the mean error"""
  index = indexes[index_name]
  forward = data['forward']
  error = torch.tensor(())
  for x in range(forward['eval_masks'].size(dim=0)):
    filter = (forward['eval_masks'][x,:,index]==1)#get all rows that are missing
    z =  revert_norm(forward['values'][x,:,index][filter==False],index_name).mean()#calculate the mean of all non missing
    real =  revert_norm(forward['evals'][x,:,index][filter],index_name)
    error = torch.cat((error,torch.abs(real-z)), dim=0)
  return error.mean().item()

def get_abs_mean_impu_error(data,index_name):
  """Calculates the absolute error of the mean imputation for each missing value and returns the mean error"""
  index, filter = get_missing_and_index(index_name)
  forward = data['forward']
  z =  revert_norm(forward['values'][:,:,index][filter==False],index_name).mean()#calculate the mean of all non missing
  real =  revert_norm(forward['evals'][:,:,index][filter],index_name)
  mean_erro = torch.abs(real-z).mean().item()
  return mean_erro

def get_rel_error_val(ret,index_name):
  """Calculates the relative error of the brits imputation for each missing value and returns the mean error
  It uses the normalized values to remove problems with 0 division
  """
  index, filter = get_missing_and_index(index_name)
  impu = ret['imputations'][:,:,index][filter]
  real = ret['evals'][:,:,index][filter]
  return torch.abs((impu-real)/(torch.abs(real)+1)).mean().item()

def get_errors_impu(ret):
  return_dict = {}
  for index_name in indexes.keys():
    #rel_err = get_rel_error_val(ret,index_name)
    abs_err = get_abs_error_val(ret, index_name)
    return_dict[index_name] = abs_err
  return return_dict

def get_errors_bench(data):
  ret = {}
  for index_name in indexes.keys():
    abs_err = get_abs_mean_impu_error(data,index_name)
    ret[index_name] = abs_err
  return ret

#Training

In [14]:
use_norm = True
path_train = './brits_train.pickle'
path_test = './brits_test.pickle'
if not use_norm:
  path_train = './brits_train_nonnorm.pickle'
  path_test = './brits_test_nonnorm.pickle'
test_set = MySet2(path_test)
train_set = MySet2(path_train)

if len(test_set) == len(train_set):
  raise Exception("Length of test and train set are equal. Is there a Mistake?")

In [15]:
config = dict(
    lr = 1e-3,
    batch_size= 500,
    use_norm = use_norm,
    num_workers=1,
    dataset="My Fitnesspal Small",
    epochs=4000,
    test_set_size= len(test_set),
    train_set_size= len(train_set),
    early_stop_after=2
)

In [16]:
wandb.login()
run = wandb.init(project="brits", entity="gege-hoho", config=config)

[34m[1mwandb[0m: Currently logged in as: [33mgege-hoho[0m (use `wandb login --relogin` to force relogin)


In [17]:
model = BRITS()
optimizer = optim.Adam(model.parameters(), lr = config['lr'] )

train_iter = DataLoader(dataset = train_set,batch_size = config['batch_size'],
                        shuffle = True,pin_memory = True, 
                        collate_fn = collate_fn2, 
                        num_workers=config["num_workers"])

test_iter = DataLoader(dataset = test_set,batch_size = config['batch_size'],shuffle = False,pin_memory = True, collate_fn = collate_fn2)

In [19]:
t0 = time.time()
last_errors = {}
for epoch in tqdm(range(config['epochs'])):
  model.train()
  train_loss = 0.0
  for idx, data in enumerate(train_iter):
    ret = model.run_on_batch(data, optimizer)
    train_loss += ret['loss'].item()
  train_loss = train_loss/(idx + 1.0)
  wandb.log({"train-loss": train_loss})

  model.eval()
  test_loss = 0.0
  for idx,data in enumerate(test_iter):
    ret = model.run_on_batch(data, None)
    test_loss += ret['loss'].item()
  test_loss = test_loss/(idx+ 1.0)
  test_errors = get_errors_impu(ret)
  wandb.log({"test-loss": test_loss})
  wandb.log({"test-errors": test_errors})

  #early stopping
  count = sum([1 for k,v in last_errors.items() if v < test_errors[k]])
  if count >= 2:
    #more than 2 values went worse so we stop
    print("\n Early Stopping")
    break
print("\n")
print(f"{device} {time.time()-t0}")

100%|██████████| 4000/4000 [25:16<00:00,  2.64it/s]




cuda:0 1516.259934425354


# Testing against MEAN

In [151]:
model.eval()
for idx, data in enumerate(train_iter):
  ret = model.run_on_batch(data, optimizer)
  break
index_name = 'calories'
mean_erro = get_abs_mean_impu_error(data,index_name)
personal_mean_erro = get_abs_personal_mean_impu_error(data,index_name)
rnn_erro = get_abs_error_val(ret,index_name)

print(f"RNN   {rnn_erro}")
print(f"MEAN  {mean_erro}")
print(f"MEANP {personal_mean_erro}")

RNN   393.9133605957031
MEAN  248.78421020507812
MEANP 257.0036926269531


In [123]:
"""
use_norm = False
path_train = './brits_train.pickle'
path_train2 = './brits_train_nonnorm.pickle'
train_set = MySet2(path_train)
train_set2 = MySet2(path_train2)
train_iter = DataLoader(dataset = train_set,batch_size = 5000,
                        shuffle = False,pin_memory = True, 
                        collate_fn = collate_fn2, 
                        num_workers=config["num_workers"])
train_iter2 = DataLoader(dataset = train_set2,batch_size = 5000,
                        shuffle = False,pin_memory = True, 
                        collate_fn = collate_fn2, 
                        num_workers=config["num_workers"])

for d in train_iter:
  break
for d2 in train_iter2:
  break
d2 = d2['forward']
d = d['forward']
d2 = d2['values'][:,:,1].view(8288)
d = d['values'][:,:,1].view(8288)
dnorm = revert_norm(d,'calories').numpy()
print(dnorm.mean())
print(d2.mean().item())
px = pd.DataFrame()
px["d"] = d.numpy()
px["dnorm"] = dnorm
px["d2"] = d2.numpy()
px["d-d2"] = px["dnorm"]-px["d2"]
px[(px["d-d2"] >=0.1)]
#px
"""

-3.861616
456.36041259765625


Unnamed: 0,d,dnorm,d2,d-d2
