# Import libraries

In [None]:
import os
import random
import numpy as np
import pandas as pd
from math import sqrt
import math
import os
import random

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim

In [None]:
!pip install autoimpute
from autoimpute.imputations import MultipleImputer

Collecting autoimpute
  Downloading autoimpute-0.12.2-py3-none-any.whl (100 kB)
[?25l[K     |███▎                            | 10 kB 21.4 MB/s eta 0:00:01[K     |██████▌                         | 20 kB 24.2 MB/s eta 0:00:01[K     |█████████▊                      | 30 kB 18.5 MB/s eta 0:00:01[K     |█████████████                   | 40 kB 15.7 MB/s eta 0:00:01[K     |████████████████▎               | 51 kB 5.6 MB/s eta 0:00:01[K     |███████████████████▌            | 61 kB 6.0 MB/s eta 0:00:01[K     |██████████████████████▊         | 71 kB 5.5 MB/s eta 0:00:01[K     |██████████████████████████      | 81 kB 6.2 MB/s eta 0:00:01[K     |█████████████████████████████▎  | 92 kB 6.1 MB/s eta 0:00:01[K     |████████████████████████████████| 100 kB 4.3 MB/s 
Installing collected packages: autoimpute
Successfully installed autoimpute-0.12.2


# Missingness method

In [None]:
def missing_method(test_data, train_data, num_embeddings) :
    
    # test data
    test_data = test_data.copy()
    test_rows, test_cols = test_data.shape
    test_cols -= num_embeddings

    # train data
    train_data = train_data.copy()
    train_rows, train_cols = train_data.shape
    train_cols -= num_embeddings

    # missingness threshold
    t = 0.2

    # uniform random vector, missing values where v<=t
    # test data corruption
    # embedding columns do not have any missing values
    v = np.random.uniform(size=(test_rows, test_cols))
    embeddings_mask = np.zeros((test_rows, num_embeddings), dtype=bool)
    mask = (v<=t)
    mask = np.c_[mask, embeddings_mask]
    test_data[mask] = np.NAN

    # train data corruption - this is used for training MultipleImputer for imputing mean/median values in the dataset
    v_train = np.random.uniform(size=(train_rows, train_cols))
    embeddings_mask = np.zeros((train_rows, num_embeddings), dtype=bool)
    mask_train = (v_train <= t)
    mask_train = np.c_[mask_train, embeddings_mask]
    train_data[mask_train] = np.NAN
        
    return test_data, train_data, mask

# Imputation Methods

## Autoencoder

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, dim):
        super(Autoencoder, self).__init__()
        self.dim = dim
        
        self.drop_out = nn.Dropout(p=0.2)
        
        # encoder architecture
        self.encoder = nn.Sequential(
            nn.Linear(dim, int(dim*0.7)),
            nn.Tanh(),
            nn.Linear(int(dim*0.7), int(dim*0.5)),
            nn.Tanh(),
            nn.Linear(int(dim*0.5), int(dim*0.2))
        )
        
        # decoder architecture
        self.decoder = nn.Sequential(
            nn.Linear(int(dim*0.2), int(dim*0.5)),
            nn.Tanh(),
            nn.Linear(int(dim*0.5), int(dim*0.7)),
            nn.Tanh(),
            nn.Linear(int(dim*0.7), dim)
        )
        
    def forward(self, x):
        x = x.view(-1, self.dim)

        # adding dropout to introduce input corruption during training
        x_missed = self.drop_out(x)
        
        z = self.encoder(x_missed)
        out = self.decoder(z)
        
        out = out.view(-1, self.dim)
        
        return out

In [None]:
def training(num_epochs, model, train_loader, criterion, optimizer):
  for epoch in range(num_epochs):
      loss = 0
      for i, batch_features in enumerate(train_loader):
          # load it to the active device
          batch_features = batch_features.to(device)
        
          # reset the gradients back to zero
          optimizer.zero_grad()
        
          # compute reconstructions
          outputs = model(batch_features)
        
          # compute training reconstruction loss
          train_loss = criterion(outputs, batch_features)
        
          # compute accumulated gradients
          train_loss.backward()
        
          # perform parameter update based on current gradients
          optimizer.step()
        
          # add the mini-batch training loss to epoch loss
          loss += train_loss.item()
    
      # compute the epoch training loss
      loss = loss / len(train_loader)
    
      #  display the epoch training loss
      print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, num_epochs, loss))

## Mean and Median Imputation

In [None]:
# takes the test dataset and the dataset containing the average/median values by category
# imputes the values found in the dataset by catogory in place of the NaN values in the test dataset
# returns pandas DataFrame

def impute_traditional_by_category(X_test, df, indices):
  for i in indices:
    temp1 = X_test.loc[[i]]
    temp2 = df[df['fdc_id'] == i]
    del temp2['fdc_id']
    for col in temp1.columns:
      if math.isnan(temp1[col].values[0]):
        temp1[col] = temp2[col].values[0]
    X_test.loc[[i]] = temp1

  return X_test

In [None]:
# uses a multiple imputer to fill the NaN values in the test dataset, calculates mean/median by column
# returns pandas DataFrame

def impute_traditional_all_foods(train, test, method):
  imputer = MultipleImputer(1, strategy=method, return_list=True)
  imputer.fit(train)
  data = imputer.transform(test)
  return data[0][1]

# Evaluation

In [None]:
def rmse_error(test_data, imputed_data, num_cols, mask):
  rmse_sum = 0

  for i in range(num_cols):
    if mask[:,i].sum() > 0 :
      y_actual = test_data[:,i][mask[:,i]]
      y_predicted = imputed_data[:,i][mask[:,i]]

      rmse = sqrt(mean_squared_error(y_actual, y_predicted))
      rmse_sum += rmse

  return rmse_sum

# Data Preparation

In [None]:
num_epochs = 200
test_size = 0.2
use_cuda = False
batch_size  = 1

data_mean = 'datasets/median datasets/traditionalMethods/food_mean.csv'
data_median = 'datasets/median datasets/traditionalMethods/food_median.csv'

In [None]:
df_mean = pd.read_csv(data_mean)
del df_mean['food_category_id']
del df_mean['Unnamed: 0']
df_mean.describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,1323,1325,1329,1330,1331,1333,1334,1335,1404,1405,1406,1409,1411,1414,2003,2004,2005,2006,2007,2008,2009,2010,2012,2013,2014,2015,2016,2018,2019,2020,2021,2022,2023,2024,2025,2026,2028,2029,2032,2033
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,0.01792,0.368833,0.158792,0.0,0.039055,0.0,0.33962,0.030308,0.013052,0.015636,0.004294,0.000468,8e-05,0.094002,0.103418,0.0,0.804514,0.0,0.0,0.00031,0.050196,0.000746,0.003615,9e-05,2.803481,0.068301,2.102743,7.074479,17.721367,10.375581,1.904639,0.049905,0.361082,0.011032,0.959839,0.022555,5.589921,0.118821,0.313166,...,0.000225,0.0,0.0,0.0,0.0,5.103489e-07,4e-06,4e-06,0.00234,1.8e-05,8.2e-05,1e-06,4.6e-05,0.0,0.0,0.0,0.0,7e-06,0.0,6e-06,0.000176,0.0,0.000495,0.0,1e-05,2.954754e-07,0.0,0.0,7.670335e-07,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.042914,0.0,0.000924,0.003405
std,254401.8,0.015051,0.79271,0.17947,0.0,0.134327,0.0,1.024358,0.476577,0.018567,0.025224,0.006188,0.001836,0.000181,0.213566,0.234958,0.0,1.075676,0.0,0.0,0.000422,0.122506,0.001276,0.006272,0.000168,5.195553,0.173789,5.715402,16.491529,48.088008,52.579387,6.057867,0.109987,1.148451,0.031719,6.358265,0.062911,17.668482,0.207363,0.425251,...,0.000606,0.0,0.0,0.0,0.0,1.280017e-06,8e-06,6e-06,0.011316,5.3e-05,0.000164,2e-06,0.000123,0.0,0.0,0.0,0.0,3.7e-05,0.0,8e-06,0.000451,0.0,0.00215,0.0,4.3e-05,1.656785e-06,0.0,0.0,2.886749e-06,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.087419,0.0,0.001332,0.006002
min,319877.0,0.0,0.0,0.007897,0.0,0.000121,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.004,0.00016,0.0,0.0,0.0,0.0,0.0,8e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,325648.0,0.004464,0.025211,0.020099,0.0,0.014267,0.0,0.0,0.001502,0.000269,0.000128,0.0,0.0,0.0,0.0,0.0,0.0,0.348238,0.0,0.0,0.0,0.000141,0.0,0.0,0.0,0.348943,0.006395,0.167418,0.568905,1.40498,0.853045,0.0,0.003781,0.0,0.000101,0.0,0.000101,0.0,0.015248,0.0,...,1.5e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000363,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,331410.0,0.0218,0.124265,0.16848,0.0,0.019052,0.0,0.0,0.001502,0.000269,0.000281,0.0,0.0,0.0,0.0,0.0,0.0,0.52187,0.0,0.0,0.000128,0.000141,0.0,0.0,0.0,0.780612,0.006786,0.17738,2.350983,2.052367,3.488889,0.0,0.014632,0.0,0.000815,0.0,0.001328,0.0,0.096217,0.0,...,5.5e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000363,4e-06,9e-06,0.0,1.3e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.2e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.011478,0.0,0.0,0.0
75%,747468.0,0.0218,0.137845,0.16848,0.0,0.022228,0.0,0.009677,0.010914,0.022661,0.020941,0.01321,0.000148,0.0,0.0,0.0,0.0,0.990072,0.0,0.0,0.000776,0.020673,0.002133,0.009444,0.000221,2.44392,0.009367,0.273303,2.350983,3.163428,3.65557,0.0,0.014632,0.0,0.000989,0.304856,0.003092,0.0,0.096217,0.923318,...,0.000172,0.0,0.0,0.0,0.0,1.007811e-06,5e-06,1.2e-05,0.000623,5e-06,0.000105,3e-06,2.3e-05,0.0,0.0,0.0,0.0,4e-06,0.0,1.8e-05,0.000288,0.0,5.2e-05,0.0,3e-06,0.0,0.0,0.0,7.267442e-07,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.032412,0.0,0.002844,0.00622
max,1105897.0,0.071511,3.221739,0.840146,0.0,1.022487,0.0,3.592297,10.515789,0.049956,0.067865,0.01321,0.011411,0.00049,0.579173,0.637185,0.0,9.008696,0.0,0.0,0.002052,0.435028,0.003257,0.016334,0.000526,19.051672,0.616257,20.207903,59.12462,170.422492,398.367347,21.170732,0.385562,4.01355,0.111644,48.0,0.221515,61.77933,1.585556,0.923318,...,0.00343,0.0,0.0,0.0,0.0,8.583691e-06,4.8e-05,1.2e-05,0.082194,0.000275,0.000988,3e-06,0.000676,0.0,0.0,0.0,0.0,0.000272,0.0,1.8e-05,0.002536,0.0,0.01491,0.0,0.000245,9.584665e-06,0.0,0.0,1.648352e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.298876,0.0,0.002844,0.020698


In [None]:
df_median = pd.read_csv(data_median, index_col=0)
del df_median['food_category_id']
df_median.describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,1323,1325,1329,1330,1331,1333,1334,1335,1404,1405,1406,1409,1411,1414,2003,2004,2005,2006,2007,2008,2009,2010,2012,2013,2014,2015,2016,2018,2019,2020,2021,2022,2023,2024,2025,2026,2028,2029,2032,2033
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
std,254401.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,1105897.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
device = torch.device("cuda" if use_cuda else "cpu")

# 20 Columns

In [None]:
data_path = 'datasets/median datasets/byColumnsWithEmbeddings/20_columns_embeddings.csv'
num_embeddings = 5
num_columns = 20

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
#dataset[dataset.notna().all(axis=1)].describe()
dataset.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,0,1,2,3,4
count,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0,5489.0
mean,473054.1,0.267373,0.439492,3.383676,1.775314,14.85225,0.05523,7.220077,0.063769,0.008607,0.01857,1.404218,0.076677,20.914739,0.037346,0.048068,0.004076,0.002324,0.003661,0.266908,0.052979,13778.458629,743.878433,0.176459,-0.126827,-0.09977
std,231230.5,3.186427,3.293347,35.794023,16.406068,130.127557,0.421648,56.057814,0.546446,0.087142,0.18702,10.151666,1.924846,739.212918,0.391203,0.531868,0.042929,0.031001,0.037562,2.846555,0.485522,231286.870011,3652.56376,1.138101,0.659163,0.594369
min,319883.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139549.827144,-7413.329674,-3.349638,-1.585647,-1.527283
25%,324889.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134465.990297,-1877.79491,-0.787643,-0.671177,-0.489657
50%,329680.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129585.45598,69.890008,0.049628,-0.181722,-0.171434
75%,747564.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288359.891527,3551.796646,1.17291,0.375246,0.319151
max,1105897.0,99.1,79.8,1020.0,354.0,2510.0,7.15,988.0,9.94,1.84,3.94,94.9,99.0,38500.0,12.8,13.4,0.908,0.982,1.63,95.4,10.9,646711.616853,7629.611107,2.647689,1.869345,1.403119


In [None]:
dataset_test.describe()

Unnamed: 0,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,0,1,2,3,4
count,5152.0,6639.0,5437.0,5435.0,5448.0,5437.0,5436.0,5441.0,5435.0,5436.0,5125.0,5993.0,6267.0,6187.0,6993.0,7020.0,7033.0,6937.0,7030.0,6506.0,7616.0,7616.0,7616.0,7616.0,7616.0
mean,0.000654,0.00177,0.077432,0.036909,0.426762,0.00055,0.050957,0.001136,0.000177,0.000265,0.116878,0.001223,0.037019,0.000305,0.000789,7.2e-05,5.6e-05,5.2e-05,0.000896,0.0,27044.743975,-232.588349,0.078801,-0.07019,-0.011533
std,0.020168,0.061385,3.221593,1.347541,14.158099,0.015376,1.524584,0.039911,0.005905,0.009933,3.1271,0.039491,1.294086,0.010231,0.024177,0.002376,0.001765,0.001732,0.039854,0.0,269788.791907,4116.285913,1.091648,0.71782,0.598418
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7412.348752,-2.816578,-1.506455,-1.327897
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133223.445583,-3030.663593,-0.800539,-0.659422,-0.473695
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127179.314878,-355.354324,-0.010516,0.062336,-0.033467
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288179.329072,2872.168462,0.883603,0.464539,0.351199
max,0.91,3.17,164.0,67.6,683.0,0.61,67.0,1.98,0.286,0.491,90.4,1.9,65.0,0.51,1.09,0.095,0.079,0.085,2.7,0.0,646710.59817,7635.496843,2.647376,1.869332,1.616787


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,0,1,2,3,4
count,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0,4391.0
mean,473272.8,0.240091,0.467051,3.370075,1.97463,16.492371,0.05806,7.575951,0.068162,0.009664,0.02097,1.353954,0.084359,23.889775,0.038032,0.052911,0.004433,0.002456,0.003858,0.280383,0.056069,13997.110479,747.34077,0.179845,-0.128129,-0.099523
std,232118.0,2.80721,3.414152,35.020266,17.673715,139.175649,0.433277,57.681247,0.571255,0.093407,0.202176,9.965624,2.141429,825.066562,0.40042,0.563096,0.045348,0.031504,0.039887,2.988386,0.50581,232174.293483,3659.134087,1.141939,0.657215,0.595458
min,319883.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139549.827144,-7413.329674,-3.349638,-1.585647,-1.481994
25%,324908.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134446.635223,-1887.851274,-0.793266,-0.66472,-0.489293
50%,329665.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129600.7363,64.985228,0.052235,-0.181708,-0.171482
75%,747586.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288382.812007,3533.158522,1.180983,0.374686,0.309593
max,1105894.0,81.2,79.8,1020.0,354.0,2510.0,7.15,988.0,9.94,1.84,3.94,94.9,99.0,38500.0,12.8,13.4,0.908,0.982,1.63,95.4,10.9,646708.560794,7629.611107,2.647689,1.869302,1.384489


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,0,1,2,3,4
count,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0,1098.0
mean,472179.6,0.376475,0.329281,3.438069,0.978233,8.29326,0.043916,5.796903,0.046202,0.004381,0.008971,1.605228,0.045956,9.017304,0.034599,0.0287,0.002648,0.001797,0.002874,0.213024,0.040619,12904.050366,730.03224,0.162916,-0.121621,-0.100755
std,227749.0,4.386727,2.756475,38.751108,9.783971,84.519973,0.37149,49.029391,0.433038,0.05524,0.106244,10.86633,0.427642,96.516985,0.352094,0.381885,0.031428,0.028912,0.026282,2.189705,0.39404,227805.899924,3627.794801,1.123032,0.667173,0.590263
min,319926.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139506.023559,-7048.414038,-2.569523,-1.504577,-1.527283
25%,324792.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134564.803033,-1834.145659,-0.764632,-0.684179,-0.489591
50%,329745.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129519.241267,81.171002,0.037505,-0.181744,-0.171039
75%,747487.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288282.216565,3646.949427,1.120113,0.375931,0.320162
max,1105897.0,99.1,47.7,704.0,179.0,1540.0,5.02,783.0,7.02,1.13,1.99,93.5,6.98,1770.0,7.63,7.07,0.746,0.906,0.424,34.5,5.7,646711.616853,7587.429999,2.64748,1.869345,1.403119


In [None]:
indices = test_data['fdc_id']
indices

457     322082
4561    748655
3486    333373
1489    325334
3037    331370
         ...  
583     322321
3150    331708
1476    325303
1073    323596
3215    331991
Name: fdc_id, Length: 1098, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
count,893.0,878.0,871.0,866.0,894.0,865.0,886.0,883.0,870.0,877.0,881.0,869.0,856.0,871.0,882.0,873.0,864.0,910.0,861.0,875.0,1098.0,1098.0,1098.0,1098.0,1098.0
mean,0.298723,0.370433,3.098737,0.784296,7.573826,0.041584,4.17833,0.050102,0.005261,0.009185,1.606844,0.047975,8.669393,0.037313,0.02779,0.002947,0.001105,0.002699,0.205575,0.036229,12904.050366,730.03224,0.162916,-0.121621,-0.100755
std,3.446613,2.950847,35.595388,8.810859,78.65119,0.357915,41.916487,0.471775,0.061803,0.110894,10.873475,0.455003,87.906645,0.37794,0.36724,0.034779,0.010386,0.026264,2.092965,0.357754,227805.899924,3627.794801,1.123032,0.667173,0.590263
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139506.023559,-7048.414038,-2.569523,-1.504577,-1.527283
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134564.803033,-1834.145659,-0.764632,-0.684179,-0.489591
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129519.241267,81.171002,0.037505,-0.181744,-0.171039
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288282.216565,3646.949427,1.120113,0.375931,0.320162
max,81.7,47.7,704.0,179.0,1540.0,5.02,783.0,7.02,1.13,1.99,93.5,6.98,1770.0,7.63,7.07,0.746,0.17,0.424,29.7,5.5,646711.616853,7587.429999,2.64748,1.869345,1.403119


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all

In [None]:
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all

In [None]:
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.00443350
epoch : 2/200, recon loss = 0.00236126
epoch : 3/200, recon loss = 0.00236263
epoch : 4/200, recon loss = 0.00235895
epoch : 5/200, recon loss = 0.00235628
epoch : 6/200, recon loss = 0.00235457
epoch : 7/200, recon loss = 0.00235172
epoch : 8/200, recon loss = 0.00235264
epoch : 9/200, recon loss = 0.00234703
epoch : 10/200, recon loss = 0.00234370
epoch : 11/200, recon loss = 0.00234270
epoch : 12/200, recon loss = 0.00234187
epoch : 13/200, recon loss = 0.00233628
epoch : 14/200, recon loss = 0.00233176
epoch : 15/200, recon loss = 0.00232819
epoch : 16/200, recon loss = 0.00232740
epoch : 17/200, recon loss = 0.00232583
epoch : 18/200, recon loss = 0.00232089
epoch : 19/200, recon loss = 0.00232122
epoch : 20/200, recon loss = 0.00231885
epoch : 21/200, recon loss = 0.00230741
epoch : 22/200, recon loss = 0.00229885
epoch : 23/200, recon loss = 0.00230183
epoch : 24/200, recon loss = 0.00229014
epoch : 25/200, recon loss = 0.00228620
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

0.5672658039561218


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.01840921
epoch : 2/200, recon loss = 0.01233796
epoch : 3/200, recon loss = 0.01230825
epoch : 4/200, recon loss = 0.01228404
epoch : 5/200, recon loss = 0.01225366
epoch : 6/200, recon loss = 0.01221165
epoch : 7/200, recon loss = 0.01215797
epoch : 8/200, recon loss = 0.01210341
epoch : 9/200, recon loss = 0.01200975
epoch : 10/200, recon loss = 0.01190232
epoch : 11/200, recon loss = 0.01172687
epoch : 12/200, recon loss = 0.01151027
epoch : 13/200, recon loss = 0.01126146
epoch : 14/200, recon loss = 0.01095855
epoch : 15/200, recon loss = 0.01065575
epoch : 16/200, recon loss = 0.01031688
epoch : 17/200, recon loss = 0.01014744
epoch : 18/200, recon loss = 0.00997001
epoch : 19/200, recon loss = 0.00977952
epoch : 20/200, recon loss = 0.00960712
epoch : 21/200, recon loss = 0.00964399
epoch : 22/200, recon loss = 0.00952381
epoch : 23/200, recon loss = 0.00942325
epoch : 24/200, recon loss = 0.00944970
epoch : 25/200, recon loss = 0.00940373
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

0.6148321373948228


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

0.7373841556424181


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

0.7537959013937711


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

0.7521543524024294


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

0.7537959013937711


# 40 Columns

In [None]:
data_path = 'datasets/median datasets/byColumnsWithEmbeddings/40_columns_embeddings.csv'
num_embeddings = 10
num_columns = 40

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,1301,0,1,2,3,4,5,6,7,8,9
count,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0,3586.0
mean,502692.7,0.408932,0.670929,5.146403,2.699136,22.511991,0.084211,11.00251,0.097103,0.013061,0.02816,2.048229,0.116676,31.987451,0.05688,0.073018,0.00611,0.003505,0.005533,0.407379,0.078639,0.008883,0.016275,0.069503,0.028624,0.004933,0.144953,0.073117,0.001364,0.039099,0.001284,0.007001,0.001121,0.000723,0.004209,0.01475,0.000789,0.000851,2.593698,0.004178,0.00048,43418.079224,970.68815,0.211216,-0.205704,-0.029754,0.000743,0.071495,0.04028,0.044713,-0.078759
std,251745.1,3.9351,4.055251,44.170931,20.229914,160.298331,0.519253,69.0406,0.673526,0.10749,0.230679,12.141032,2.380474,914.407879,0.48278,0.656473,0.052795,0.038261,0.046288,3.513607,0.594537,0.079193,0.332079,0.722163,0.28282,0.053047,1.991216,1.354982,0.015516,0.748456,0.019748,0.174599,0.018962,0.008813,0.054285,0.703503,0.00816,0.01026,52.963515,0.102283,0.013246,251806.869897,3609.424939,1.120618,0.644494,0.622728,0.478089,0.50996,0.378452,0.397628,0.371373
min,319907.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139525.378632,-7413.329674,-3.349638,-1.585647,-1.527263,-1.442507,-1.189811,-0.904551,-0.994182,-0.829681
25%,324737.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134620.830874,-1752.696315,-0.693309,-0.70065,-0.47696,-0.343007,-0.248043,-0.212606,-0.170793,-0.340061
50%,329794.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129469.325555,612.361932,0.064898,-0.215095,-0.08551,-0.030935,0.042146,0.087137,0.026665,-0.094014
75%,748307.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289117.031396,3700.902034,1.197435,0.329588,0.421723,0.260839,0.35124,0.229062,0.272532,0.127306
max,1105897.0,99.1,79.8,1020.0,354.0,2510.0,7.15,988.0,9.94,1.84,3.94,94.9,99.0,38500.0,12.8,13.4,0.908,0.982,1.63,95.4,10.9,1.58,16.6,21.5,7.46,1.16,68.6,52.4,0.432,23.8,0.782,7.68,0.726,0.251,1.6,42.0,0.168,0.225,2340.0,5.38,0.729,646711.616853,7606.068163,2.647585,1.869342,1.403106,2.020087,1.700626,1.605172,1.361318,1.348159


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,1301,0,1,2,3,4,5,6,7,8,9
count,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0,2868.0
mean,503094.1,0.445697,0.674379,5.171897,2.647036,21.700139,0.085471,10.891562,0.091161,0.012714,0.027781,1.852144,0.131569,37.596932,0.059854,0.071965,0.006182,0.003226,0.005447,0.39493,0.07493,0.009018,0.018754,0.069934,0.030073,0.004899,0.135298,0.066929,0.001421,0.037333,0.001423,0.007866,0.001078,0.000748,0.004415,0.018106,0.000867,0.000925,2.721757,0.00493,0.000511,43818.356298,1033.097786,0.211218,-0.205922,-0.0404,-0.002237,0.072376,0.043217,0.04415,-0.079882
std,253192.4,4.248371,4.168764,46.166228,20.359215,158.596009,0.532789,69.301154,0.648585,0.10722,0.235777,11.480242,2.654518,1021.36889,0.516485,0.656287,0.054074,0.035207,0.047832,3.394912,0.597923,0.082208,0.368542,0.718821,0.294002,0.050585,1.722758,1.145576,0.016049,0.704894,0.021496,0.193768,0.018646,0.008678,0.055827,0.78654,0.008703,0.010844,56.120883,0.113813,0.014479,253254.546439,3618.745177,1.126598,0.640904,0.622946,0.476165,0.513075,0.37887,0.399487,0.371086
min,319907.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139525.378632,-7413.329674,-2.956202,-1.577342,-1.458191,-1.442507,-1.189811,-0.904551,-0.99278,-0.829681
25%,324504.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134857.421164,-1730.40968,-0.710266,-0.7003,-0.477381,-0.342827,-0.248075,-0.205069,-0.17641,-0.333945
50%,329713.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129551.839277,720.757602,0.066249,-0.215436,-0.10681,-0.035451,0.046421,0.087137,0.024192,-0.095396
75%,748287.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289097.16698,3928.729171,1.290903,0.320872,0.365652,0.259482,0.343577,0.229002,0.272556,0.128136
max,1105894.0,99.1,79.8,1020.0,354.0,2510.0,7.15,988.0,9.94,1.84,3.94,94.9,99.0,38500.0,12.8,13.4,0.908,0.906,1.63,95.4,10.9,1.58,16.6,21.5,7.46,1.16,60.5,50.9,0.432,22.3,0.782,7.68,0.726,0.251,1.6,42.0,0.168,0.225,2340.0,5.38,0.729,646708.560794,7606.068163,2.647585,1.869342,1.403106,2.020087,1.700626,1.605172,1.361318,1.348159


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,1301,0,1,2,3,4,5,6,7,8,9
count,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0
mean,501089.4,0.262075,0.657145,5.044568,2.907242,25.754875,0.079178,11.445682,0.120836,0.014447,0.029674,2.831476,0.057187,9.58078,0.045,0.077223,0.005822,0.004618,0.005876,0.457103,0.093454,0.008347,0.006373,0.067783,0.022836,0.005068,0.183518,0.097837,0.001138,0.046156,0.000727,0.003547,0.001295,0.000624,0.003386,0.001347,0.000478,0.000554,2.082173,0.001175,0.000354,41819.200882,721.397291,0.211208,-0.204832,0.012771,0.012649,0.067976,0.028549,0.04696,-0.074271
std,246046.1,2.286616,3.568604,35.114568,19.717513,166.999732,0.461544,68.035836,0.765127,0.108628,0.209227,14.465653,0.390902,93.581629,0.313929,0.657658,0.047377,0.048598,0.039556,3.954546,0.580986,0.065832,0.090322,0.735864,0.232952,0.061954,2.820533,1.98282,0.01318,0.902251,0.010098,0.04771,0.020188,0.009341,0.047659,0.025286,0.005472,0.007484,37.835439,0.022419,0.006243,246106.256917,3563.555374,1.09718,0.659092,0.620469,0.485854,0.497656,0.376812,0.390381,0.37274
min,319927.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139505.004871,-7261.28149,-3.349638,-1.585647,-1.527263,-1.407771,-1.189721,-0.896693,-0.994182,-0.808921
25%,325496.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133847.391994,-1802.384714,-0.616373,-0.731912,-0.475357,-0.343517,-0.247988,-0.228328,-0.169285,-0.352771
50%,330196.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129059.812966,162.590387,0.05737,-0.18733,-0.066967,-0.018236,0.017192,0.087123,0.028252,-0.092625
75%,748362.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289173.56858,2956.111122,1.161595,0.341216,0.508456,0.264308,0.368696,0.229875,0.272372,0.122727
max,1105897.0,29.0,26.5,704.0,197.0,1620.0,3.53,600.0,7.29,1.15,2.14,90.8,6.98,1770.0,3.72,10.0,0.751,0.982,0.463,67.2,4.5,1.2,2.19,12.0,3.48,1.16,68.6,52.4,0.29,23.8,0.235,1.05,0.406,0.234,1.07,0.665,0.122,0.165,1000.0,0.59,0.155,646711.616853,7586.449043,2.64602,1.405733,1.384469,1.86167,1.592247,1.564844,1.106818,1.183382


In [None]:
indices = test_data['fdc_id']
indices

3530    1105314
3193     790550
3382    1104948
904      324809
1525     327741
         ...   
2059     332584
1201     326625
2070     332690
1008     325402
1330     327061
Name: fdc_id, Length: 718, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
count,584.0,560.0,584.0,595.0,565.0,583.0,580.0,582.0,578.0,573.0,573.0,570.0,597.0,571.0,568.0,581.0,591.0,568.0,569.0,565.0,589.0,597.0,554.0,574.0,586.0,579.0,569.0,588.0,558.0,590.0,565.0,594.0,589.0,586.0,578.0,581.0,583.0,590.0,569.0,579.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0,718.0
mean,0.312483,0.642161,5.535959,2.926891,23.19646,0.0653,10.941379,0.091237,0.014888,0.023667,2.545724,0.059632,8.914573,0.04352,0.057449,0.006618,0.00502,0.005815,0.384183,0.091504,0.006311,0.007611,0.064094,0.027092,0.004962,0.227428,0.029583,0.001374,0.055663,0.000885,0.004377,0.000879,0.000649,0.004056,0.001464,0.000375,0.000648,2.415254,0.000359,0.000408,41819.200882,721.397291,0.211208,-0.204832,0.012771,0.012649,0.067976,0.028549,0.04696,-0.074271
std,2.530198,3.541929,38.04275,20.256992,151.457969,0.413429,65.598138,0.684744,0.111386,0.184516,13.504419,0.41083,93.669495,0.309532,0.580402,0.051345,0.052162,0.037938,3.603871,0.580504,0.045974,0.099018,0.691298,0.258024,0.064055,3.139836,0.371948,0.014553,1.021024,0.011135,0.053704,0.014699,0.010134,0.052693,0.027981,0.003388,0.008289,41.65839,0.004338,0.006943,246106.256917,3563.555374,1.09718,0.659092,0.620469,0.485854,0.497656,0.376812,0.390381,0.37274
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139505.004871,-7261.28149,-3.349638,-1.585647,-1.527263,-1.407771,-1.189721,-0.896693,-0.994182,-0.808921
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133847.391994,-1802.384714,-0.616373,-0.731912,-0.475357,-0.343517,-0.247988,-0.228328,-0.169285,-0.352771
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129059.812966,162.590387,0.05737,-0.18733,-0.066967,-0.018236,0.017192,0.087123,0.028252,-0.092625
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289173.56858,2956.111122,1.161595,0.341216,0.508456,0.264308,0.368696,0.229875,0.272372,0.122727
max,29.0,26.5,704.0,197.0,1620.0,3.5,600.0,7.29,1.15,2.14,89.9,6.98,1770.0,3.72,10.0,0.751,0.982,0.463,67.2,4.5,0.395,2.19,12.0,3.48,1.16,68.6,7.68,0.29,23.8,0.235,1.05,0.351,0.234,1.07,0.665,0.038,0.165,1000.0,0.085,0.155,646711.616853,7586.449043,2.64602,1.405733,1.384469,1.86167,1.592247,1.564844,1.106818,1.183382


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.00438335
epoch : 2/200, recon loss = 0.00222015
epoch : 3/200, recon loss = 0.00221075
epoch : 4/200, recon loss = 0.00220858
epoch : 5/200, recon loss = 0.00220569
epoch : 6/200, recon loss = 0.00220007
epoch : 7/200, recon loss = 0.00219686
epoch : 8/200, recon loss = 0.00219413
epoch : 9/200, recon loss = 0.00219150
epoch : 10/200, recon loss = 0.00218875
epoch : 11/200, recon loss = 0.00218529
epoch : 12/200, recon loss = 0.00218136
epoch : 13/200, recon loss = 0.00217481
epoch : 14/200, recon loss = 0.00217557
epoch : 15/200, recon loss = 0.00216694
epoch : 16/200, recon loss = 0.00216696
epoch : 17/200, recon loss = 0.00215481
epoch : 18/200, recon loss = 0.00215674
epoch : 19/200, recon loss = 0.00215181
epoch : 20/200, recon loss = 0.00214487
epoch : 21/200, recon loss = 0.00214350
epoch : 22/200, recon loss = 0.00213536
epoch : 23/200, recon loss = 0.00213069
epoch : 24/200, recon loss = 0.00212067
epoch : 25/200, recon loss = 0.00211735
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

1.1857101329640536


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.01755902
epoch : 2/200, recon loss = 0.01008028
epoch : 3/200, recon loss = 0.01006691
epoch : 4/200, recon loss = 0.01005710
epoch : 5/200, recon loss = 0.01004560
epoch : 6/200, recon loss = 0.01004009
epoch : 7/200, recon loss = 0.01002531
epoch : 8/200, recon loss = 0.01001744
epoch : 9/200, recon loss = 0.01000099
epoch : 10/200, recon loss = 0.00999542
epoch : 11/200, recon loss = 0.00998688
epoch : 12/200, recon loss = 0.00997166
epoch : 13/200, recon loss = 0.00996486
epoch : 14/200, recon loss = 0.00995095
epoch : 15/200, recon loss = 0.00993475
epoch : 16/200, recon loss = 0.00992375
epoch : 17/200, recon loss = 0.00991056
epoch : 18/200, recon loss = 0.00989135
epoch : 19/200, recon loss = 0.00987593
epoch : 20/200, recon loss = 0.00984958
epoch : 21/200, recon loss = 0.00983564
epoch : 22/200, recon loss = 0.00980668
epoch : 23/200, recon loss = 0.00978511
epoch : 24/200, recon loss = 0.00975722
epoch : 25/200, recon loss = 0.00971403
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

1.3033921789107272


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

1.437001084026095


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

1.4981608899742396


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

1.4958297743199493


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

1.4981608899742396


# 80 Columns

In [None]:
data_path = 'datasets/median datasets/byColumnsWithEmbeddings/80_columns_embeddings.csv'
num_embeddings = 20
num_columns = 80

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,...,1211,1212,1213,1214,1215,1217,1218,1219,1220,1221,1222,1225,1226,1227,1259,1180,1194,1277,1195,1261,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
count,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,...,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0,1453.0
mean,536969.3,1.009243,1.65585,12.701308,6.661459,55.559532,0.207832,27.154164,0.239649,0.032235,0.069498,5.055024,0.287956,78.944942,0.140379,0.180209,0.015078,0.00865,0.013655,1.005409,0.194081,0.021924,0.040167,0.171534,0.070644,0.012173,0.357744,0.180454,0.003368,0.096497,0.003168,0.017279,0.002767,0.001785,0.010388,0.036404,0.001948,0.0021,6.401239,0.010311,...,0.00995,0.010702,0.020067,0.017466,0.005196,0.011637,0.009125,0.011923,0.014065,0.006139,0.011599,0.008496,0.015729,0.012605,0.005653,1.27605,0.17075,0.001481,0.681074,0.007357,77698.045425,1112.760398,0.334271,-0.341645,-0.173203,0.017877,0.055166,0.042864,0.03929,-0.064322,0.075091,0.027414,0.026837,-0.003604,0.013585,0.133238,0.109719,-0.059729,0.068823,-0.051
std,253215.9,6.134032,6.242675,68.710899,31.369207,248.203881,0.799995,106.441999,1.042042,0.167059,0.358481,18.674405,3.733854,1435.524757,0.750825,1.022109,0.082138,0.059749,0.071966,5.466196,0.922122,0.123281,0.520877,1.126996,0.441043,0.082823,3.116616,2.124538,0.024241,1.173696,0.030934,0.274024,0.029718,0.01378,0.084921,1.105062,0.012734,0.01604,83.075321,0.160521,...,0.091211,0.102629,0.188514,0.155955,0.04393,0.119117,0.087895,0.111143,0.154035,0.057805,0.117576,0.083682,0.147997,0.122226,0.071469,11.601093,1.900425,0.019376,9.27899,0.178897,253287.956831,3639.782299,1.067522,0.541473,0.603847,0.472052,0.540352,0.364085,0.409249,0.434274,0.393722,0.406193,0.359171,0.408909,0.2827,0.372869,0.299674,0.268196,0.347295,0.328849
min,319920.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139512.135695,-7413.329674,-3.349638,-1.504577,-1.527263,-1.442317,-1.188229,-0.786603,-0.994322,-0.829621,-0.933825,-0.900247,-0.784511,-0.920093,-0.989982,-0.890915,-0.765894,-0.591957,-0.670206,-0.813081
25%,323637.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-135741.387693,-1577.1351,-0.555839,-0.700182,-0.51435,-0.343191,-0.250034,-0.201571,-0.169653,-0.35735,-0.074586,-0.242787,-0.221236,-0.26433,-0.205377,-0.135246,-0.112946,-0.199005,-0.176811,-0.222227
50%,330950.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-128291.722202,71.851925,0.476832,-0.438818,-0.198504,-0.022241,0.066977,0.087915,0.04151,-0.145689,0.132186,0.076594,0.098543,-0.00697,-0.021233,0.020182,0.116625,-0.06859,0.008035,-0.066421
75%,748644.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289460.074589,4779.953807,1.21248,0.062627,0.228929,0.318,0.388351,0.240122,0.272543,0.137788,0.302162,0.237307,0.259274,0.286566,0.234687,0.341327,0.344383,0.085595,0.212455,0.101205
max,1105897.0,99.1,79.8,1020.0,354.0,2510.0,7.15,988.0,9.94,1.84,3.94,94.9,99.0,38500.0,12.8,13.4,0.908,0.982,1.63,95.4,10.9,1.58,16.6,21.5,7.46,1.16,68.6,52.4,0.432,23.8,0.782,7.68,0.726,0.251,1.6,42.0,0.168,0.225,2340.0,5.38,...,1.95,2.31,4.1,3.07,0.62,2.87,1.73,2.3,3.93,1.27,2.85,1.77,2.85,2.67,2.04,338.0,42.7,0.578,313.0,6.71,646711.616853,7593.31546,2.563649,1.589162,1.304911,1.916178,1.700616,1.447543,1.110999,1.348233,1.122923,1.044835,0.954733,1.171595,0.885691,1.031238,0.749074,0.699881,0.891233,0.877797


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,...,1211,1212,1213,1214,1215,1217,1218,1219,1220,1221,1222,1225,1226,1227,1259,1180,1194,1277,1195,1261,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
count,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,...,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0,1162.0
mean,541475.0,1.039294,1.619785,12.520654,6.893287,56.935456,0.203735,27.256454,0.237461,0.033783,0.071288,4.667065,0.316059,90.435456,0.138098,0.158858,0.013874,0.007826,0.012552,0.878193,0.190792,0.021085,0.043058,0.161424,0.064894,0.0108,0.315028,0.14248,0.002983,0.079045,0.003114,0.011509,0.00238,0.001485,0.008913,0.044046,0.001843,0.001935,6.416523,0.011596,...,0.009211,0.010075,0.018723,0.015864,0.004526,0.010682,0.008409,0.010954,0.013237,0.005607,0.010549,0.00697,0.014455,0.011807,0.005985,1.309036,0.180379,0.001324,0.678313,0.008655,82205.392285,1071.789476,0.343991,-0.354484,-0.175293,0.01695,0.062733,0.044525,0.036329,-0.070495,0.071124,0.023649,0.020781,0.002016,0.014975,0.13392,0.111977,-0.057087,0.065823,-0.050172
std,253350.3,6.536275,6.380276,67.356572,32.630668,255.632172,0.804576,109.032443,1.068464,0.173173,0.369525,18.037169,4.16122,1603.022618,0.78575,0.976477,0.080447,0.056034,0.073299,5.20933,0.934676,0.126681,0.5713,1.125883,0.433715,0.074514,2.83011,1.71583,0.022887,0.996552,0.032266,0.202623,0.027223,0.011399,0.077815,1.235282,0.012918,0.015597,87.924476,0.177257,...,0.091798,0.104493,0.189794,0.152111,0.0405,0.121827,0.086622,0.110016,0.162207,0.057767,0.12034,0.078519,0.143166,0.124292,0.07598,12.444452,2.053998,0.01893,10.053031,0.199678,253423.061995,3591.07271,1.07172,0.536381,0.612123,0.468454,0.539846,0.362324,0.407476,0.434865,0.398475,0.407328,0.357431,0.405037,0.278349,0.369242,0.29648,0.266332,0.343125,0.333203
min,319920.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139512.135695,-7413.329674,-3.349638,-1.504577,-1.482007,-1.442317,-1.188229,-0.786603,-0.994322,-0.829621,-0.933825,-0.789601,-0.784511,-0.920093,-0.989982,-0.890915,-0.749688,-0.591957,-0.670206,-0.813081
25%,323702.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-135674.663631,-1596.138577,-0.576242,-0.700208,-0.51464,-0.342474,-0.24935,-0.201564,-0.174988,-0.373918,-0.089697,-0.243151,-0.241413,-0.26394,-0.198087,-0.13516,-0.091019,-0.195363,-0.176828,-0.226791
50%,331231.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-128005.470867,-75.534942,0.48134,-0.439431,-0.198943,-0.026045,0.077581,0.087921,0.041103,-0.172227,0.132104,0.076451,0.098544,-0.006944,-0.020234,0.020053,0.123724,-0.048948,0.006579,-0.072053
75%,748719.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289536.985534,4715.701066,1.291685,0.022044,0.229673,0.31751,0.382783,0.2402,0.272711,0.136146,0.302372,0.237109,0.242591,0.306223,0.236045,0.335081,0.344227,0.095181,0.212268,0.099422
max,1105897.0,99.1,79.8,1020.0,354.0,2510.0,7.15,988.0,9.94,1.84,3.94,94.9,99.0,38500.0,12.8,13.4,0.908,0.982,1.63,95.4,10.9,1.58,16.6,21.5,7.46,1.03,68.6,50.9,0.432,22.3,0.782,6.7,0.726,0.251,1.6,42.0,0.168,0.225,2340.0,5.38,...,1.95,2.31,4.1,3.07,0.62,2.87,1.73,2.3,3.93,1.27,2.85,1.77,2.85,2.67,2.04,338.0,42.7,0.578,313.0,6.71,646711.616853,7593.31546,2.563649,1.589162,1.304911,1.85508,1.700616,1.447543,1.110999,1.348233,1.122923,1.042327,0.95334,1.171595,0.885691,1.031238,0.749074,0.699752,0.891233,0.877797


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,...,1211,1212,1213,1214,1215,1217,1218,1219,1220,1221,1222,1225,1226,1227,1259,1180,1194,1277,1195,1261,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
count,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,...,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0
mean,518977.1,0.889244,1.799863,13.42268,5.735739,50.065292,0.224192,26.745704,0.248385,0.026052,0.062351,6.604192,0.175739,33.061856,0.149485,0.265464,0.019887,0.011942,0.018058,1.513402,0.207216,0.025275,0.028625,0.211904,0.093605,0.017656,0.528313,0.332086,0.004904,0.166186,0.003381,0.04032,0.004316,0.002979,0.016278,0.005887,0.002368,0.002759,6.340206,0.005182,...,0.0129,0.013206,0.025436,0.023863,0.007873,0.015447,0.011983,0.01579,0.017375,0.008265,0.015794,0.014591,0.020814,0.01579,0.004326,1.14433,0.132302,0.002107,0.692096,0.002175,59699.636314,1276.3625,0.295456,-0.290376,-0.164854,0.021579,0.02495,0.036233,0.051114,-0.039674,0.090931,0.042449,0.051017,-0.026042,0.008033,0.130516,0.100706,-0.070276,0.080801,-0.054305
std,252310.5,4.163444,5.667945,73.99033,25.751373,216.324419,0.782582,95.572669,0.930723,0.140107,0.310965,20.990023,0.682616,166.252808,0.592227,1.18503,0.088551,0.072737,0.066314,6.374696,0.871611,0.108783,0.227182,1.132473,0.469244,0.109904,4.065866,3.28392,0.029021,1.707434,0.024968,0.459274,0.038105,0.020708,0.108733,0.064167,0.011981,0.017712,60.046859,0.056493,...,0.088921,0.094949,0.183536,0.170573,0.055559,0.107728,0.092903,0.11565,0.11597,0.058004,0.105907,0.10165,0.166063,0.113751,0.049598,7.337682,1.091308,0.021086,5.146821,0.024206,252379.624234,3830.166637,1.051521,0.559357,0.570535,0.486962,0.542245,0.371595,0.416756,0.431769,0.374381,0.401975,0.365659,0.423988,0.299884,0.387657,0.312468,0.275732,0.363838,0.311387
min,320040.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139389.911994,-7133.757178,-2.956202,-1.472188,-1.527263,-1.156725,-1.187958,-0.71106,-0.826223,-0.793025,-0.889331,-0.900247,-0.723801,-0.890852,-0.974736,-0.76869,-0.765894,-0.591726,-0.49774,-0.813007
25%,323522.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-135858.536813,-1509.692727,-0.511524,-0.700043,-0.513178,-0.34542,-0.271111,-0.200803,-0.159625,-0.311187,-0.058392,-0.241695,-0.148956,-0.273404,-0.21835,-0.148159,-0.143909,-0.21446,-0.176778,-0.191082
50%,330064.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129194.279784,242.538302,0.364432,-0.324264,-0.157514,-0.008188,0.027684,0.087851,0.068043,-0.130007,0.132643,0.077049,0.081896,-0.007015,-0.0357,0.042666,0.102032,-0.11072,0.012649,-0.06536
75%,748401.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289213.04274,4892.763841,1.189274,0.126343,0.145024,0.319763,0.407634,0.227687,0.272037,0.152054,0.301959,0.238672,0.317039,0.232187,0.227851,0.421239,0.345174,0.047149,0.226793,0.112462
max,1105430.0,30.9,27.8,892.0,182.0,1520.0,4.36,576.0,6.27,1.17,2.22,94.5,6.98,1770.0,4.24,10.0,0.751,0.884,0.404,67.2,5.8,0.96,2.94,11.1,3.48,1.16,60.5,52.4,0.29,23.8,0.308,7.68,0.538,0.234,1.07,0.87,0.122,0.195,1000.0,0.762,...,0.95,1.16,2.29,1.96,0.535,1.17,1.28,1.48,1.03,0.675,0.94,0.94,2.38,1.37,0.679,70.4,16.3,0.345,46.4,0.324,646235.908412,7476.600766,2.278418,1.316523,1.277786,1.916178,1.278692,1.432538,1.043114,1.155502,0.76083,1.044835,0.954733,0.904097,0.855834,1.03121,0.747299,0.699881,0.890798,0.803328


In [None]:
indices = test_data['fdc_id']
indices

710     330094
55      320280
1127    748827
829     746769
1070    748578
         ...  
133     321925
913     747841
155     322058
508     326913
67      320351
Name: fdc_id, Length: 291, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
count,241.0,230.0,230.0,218.0,232.0,232.0,230.0,236.0,228.0,228.0,234.0,235.0,224.0,223.0,226.0,232.0,239.0,239.0,235.0,221.0,240.0,243.0,234.0,237.0,240.0,230.0,230.0,225.0,222.0,243.0,242.0,225.0,234.0,229.0,244.0,242.0,231.0,225.0,230.0,244.0,...,232.0,232.0,223.0,217.0,230.0,226.0,230.0,229.0,238.0,240.0,218.0,223.0,240.0,230.0,243.0,232.0,246.0,246.0,227.0,228.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0,291.0
mean,0.762199,1.665783,8.565217,6.134404,39.99569,0.2425,24.8,0.285932,0.032127,0.075539,6.812393,0.160681,34.584821,0.13852,0.236438,0.017267,0.007937,0.018749,1.604681,0.195475,0.027667,0.025012,0.14291,0.096515,0.02015,0.57017,0.315822,0.005716,0.138108,0.002259,0.04807,0.005187,0.003167,0.019611,0.003402,0.002847,0.003017,7.008889,0.006396,0.001598,...,0.010828,0.011375,0.029022,0.015825,0.007287,0.017235,0.007113,0.013319,0.018239,0.008104,0.015702,0.006785,0.025046,0.019761,0.005,1.146552,0.143089,0.002354,0.725551,0.002675,59699.636314,1276.3625,0.295456,-0.290376,-0.164854,0.021579,0.02495,0.036233,0.051114,-0.039674,0.090931,0.042449,0.051017,-0.026042,0.008033,0.130516,0.100706,-0.070276,0.080801,-0.054305
std,4.029165,5.487207,36.840416,27.258497,173.970503,0.807814,90.735009,1.018941,0.157348,0.346689,21.13212,0.586131,173.235737,0.576925,1.177114,0.081661,0.041133,0.068517,6.76404,0.853377,0.115317,0.235378,0.86219,0.475818,0.120118,4.505582,3.497503,0.032469,1.601421,0.021425,0.503422,0.043119,0.022704,0.121361,0.042761,0.01309,0.018633,67.707287,0.063501,0.013911,...,0.076375,0.074661,0.200274,0.131699,0.051543,0.115621,0.056039,0.113509,0.119734,0.06063,0.108474,0.067576,0.18262,0.12767,0.054212,7.3284,1.168539,0.022904,5.510333,0.027318,252379.624234,3830.166637,1.051521,0.559357,0.570535,0.486962,0.542245,0.371595,0.416756,0.431769,0.374381,0.401975,0.365659,0.423988,0.299884,0.387657,0.312468,0.275732,0.363838,0.311387
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139389.911994,-7133.757178,-2.956202,-1.472188,-1.527263,-1.156725,-1.187958,-0.71106,-0.826223,-0.793025,-0.889331,-0.900247,-0.723801,-0.890852,-0.974736,-0.76869,-0.765894,-0.591726,-0.49774,-0.813007
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-135858.536813,-1509.692727,-0.511524,-0.700043,-0.513178,-0.34542,-0.271111,-0.200803,-0.159625,-0.311187,-0.058392,-0.241695,-0.148956,-0.273404,-0.21835,-0.148159,-0.143909,-0.21446,-0.176778,-0.191082
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-129194.279784,242.538302,0.364432,-0.324264,-0.157514,-0.008188,0.027684,0.087851,0.068043,-0.130007,0.132643,0.077049,0.081896,-0.007015,-0.0357,0.042666,0.102032,-0.11072,0.012649,-0.06536
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289213.04274,4892.763841,1.189274,0.126343,0.145024,0.319763,0.407634,0.227687,0.272037,0.152054,0.301959,0.238672,0.317039,0.232187,0.227851,0.421239,0.345174,0.047149,0.226793,0.112462
max,30.9,27.8,264.0,182.0,1490.0,4.36,576.0,6.27,1.17,2.22,94.5,3.75,1770.0,4.24,10.0,0.67,0.39,0.404,67.2,5.8,0.96,2.94,8.32,3.48,1.16,60.5,52.4,0.29,23.8,0.308,7.68,0.538,0.234,1.07,0.665,0.122,0.195,1000.0,0.762,0.155,...,0.71,0.65,2.29,1.41,0.505,1.17,0.56,1.48,1.03,0.675,0.94,0.94,2.38,1.37,0.679,70.4,16.3,0.345,46.4,0.324,646235.908412,7476.600766,2.278418,1.316523,1.277786,1.916178,1.278692,1.432538,1.043114,1.155502,0.76083,1.044835,0.954733,0.904097,0.855834,1.03121,0.747299,0.699881,0.890798,0.803328


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.00769151
epoch : 2/200, recon loss = 0.00512462
epoch : 3/200, recon loss = 0.00442063
epoch : 4/200, recon loss = 0.00419243
epoch : 5/200, recon loss = 0.00411083
epoch : 6/200, recon loss = 0.00408550
epoch : 7/200, recon loss = 0.00407591
epoch : 8/200, recon loss = 0.00406459
epoch : 9/200, recon loss = 0.00406478
epoch : 10/200, recon loss = 0.00406381
epoch : 11/200, recon loss = 0.00406205
epoch : 12/200, recon loss = 0.00405745
epoch : 13/200, recon loss = 0.00405906
epoch : 14/200, recon loss = 0.00405386
epoch : 15/200, recon loss = 0.00405228
epoch : 16/200, recon loss = 0.00404917
epoch : 17/200, recon loss = 0.00405015
epoch : 18/200, recon loss = 0.00404915
epoch : 19/200, recon loss = 0.00404418
epoch : 20/200, recon loss = 0.00404113
epoch : 21/200, recon loss = 0.00404603
epoch : 22/200, recon loss = 0.00403972
epoch : 23/200, recon loss = 0.00403707
epoch : 24/200, recon loss = 0.00403296
epoch : 25/200, recon loss = 0.00403535
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

3.2103825866236084


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.04190433
epoch : 2/200, recon loss = 0.01822679
epoch : 3/200, recon loss = 0.01242952
epoch : 4/200, recon loss = 0.01162838
epoch : 5/200, recon loss = 0.01152282
epoch : 6/200, recon loss = 0.01150328
epoch : 7/200, recon loss = 0.01149246
epoch : 8/200, recon loss = 0.01149595
epoch : 9/200, recon loss = 0.01148129
epoch : 10/200, recon loss = 0.01148172
epoch : 11/200, recon loss = 0.01148340
epoch : 12/200, recon loss = 0.01147649
epoch : 13/200, recon loss = 0.01147073
epoch : 14/200, recon loss = 0.01146298
epoch : 15/200, recon loss = 0.01146126
epoch : 16/200, recon loss = 0.01145656
epoch : 17/200, recon loss = 0.01145111
epoch : 18/200, recon loss = 0.01145069
epoch : 19/200, recon loss = 0.01144169
epoch : 20/200, recon loss = 0.01143143
epoch : 21/200, recon loss = 0.01142800
epoch : 22/200, recon loss = 0.01142187
epoch : 23/200, recon loss = 0.01142518
epoch : 24/200, recon loss = 0.01142261
epoch : 25/200, recon loss = 0.01141272
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

3.7267627947317603


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

3.9988857751070244


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.128531007730607


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.101660369906696


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.128531007730607


# 120 Columns

In [None]:
data_path = 'datasets/median datasets/byColumnsWithEmbeddings/120_columns_embeddings.csv'
num_embeddings = 30
num_columns = 120

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,...,1293,1075,1100,1013,1317,1334,2014,1082,1129,1108,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
count,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,...,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0,337.0
mean,771626.5,1.027211,3.90638,24.356083,22.15549,189.863501,0.4673,71.032641,0.782107,0.117151,0.255739,9.62724,0.78905,246.448071,0.229792,0.196641,0.013095,0.020585,0.014475,1.013828,0.640059,0.007777,0.00062,0.117496,0.030792,0.004721,0.537433,0.386142,0.00081,0.191128,0.000151,0.047754,0.004709,9e-06,0.0,0.0,6.5e-05,0.0,18.290801,0.0,...,0.434718,0.000119,32.583976,0.0,0.0,0.0,0.0,0.003264,0.043887,0.11276,312432.850279,-783.917728,0.16367,-0.328462,-0.258602,0.116251,0.363781,0.002436,0.029925,0.050554,-0.059879,-0.035958,0.001461,-0.027677,0.021421,0.176914,0.113005,-0.034515,0.040843,-0.048968,0.01786,-0.013098,0.154126,-0.018147,-0.04676,0.062921,-0.01228,-0.04012,-0.110309,-0.019184
std,213219.6,6.239342,9.746133,60.456972,56.257136,477.0423,1.167091,177.725365,1.934237,0.311297,0.66528,26.241902,7.639991,2962.211411,1.105766,0.984966,0.054277,0.106922,0.065381,5.85315,1.481542,0.041022,0.00546,1.072943,0.277453,0.055329,5.311087,4.104537,0.007462,1.739359,0.001377,0.558067,0.044122,0.000163,0.0,0.0,0.001198,0.0,169.769763,0.0,...,4.499659,0.002179,388.299498,0.0,0.0,0.0,0.0,0.041058,0.503694,0.956984,213235.605081,2565.046531,0.909753,0.435431,0.593571,0.467607,0.434828,0.467055,0.535015,0.591888,0.566249,0.251631,0.387459,0.321187,0.238679,0.409981,0.271503,0.230766,0.302564,0.412149,0.215225,0.201515,0.305158,0.311825,0.171811,0.127519,0.213181,0.231699,0.205817,0.133273
min,321360.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-138060.94032,-7413.329674,-2.4186,-1.39351,-1.527263,-0.879995,-0.62061,-0.71401,-0.810807,-0.829626,-0.933817,-0.726943,-0.784566,-0.67689,-0.697008,-0.891222,-0.529189,-0.592159,-0.491107,-0.801674,-0.432406,-0.515368,-0.631423,-0.663685,-0.41614,-0.44249,-0.52815,-0.621591,-0.51962,-0.240604
25%,748029.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288833.581454,-2276.516759,-0.498649,-0.632807,-0.542181,-0.140073,0.046914,-0.388377,-0.32321,-0.453906,-0.616111,-0.16316,-0.214346,-0.26481,-0.198246,-0.134907,-0.030266,-0.186126,-0.157757,-0.314304,-0.119965,-0.149135,-0.124148,-0.248377,-0.171209,0.005852,-0.176824,-0.194875,-0.194555,-0.084722
50%,748570.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289384.691672,-1183.277851,0.334084,-0.259017,-0.197087,0.014252,0.162557,0.056702,-0.007734,-0.017268,-0.013873,0.002181,0.064521,-0.056085,0.00811,0.05087,0.195688,-0.021982,0.084049,-0.087758,0.062746,-0.034608,0.253195,0.0583,-0.093896,0.097071,-0.010452,0.010053,-0.121303,8e-05
75%,790806.0,0.12,0.12,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,331648.736972,315.128979,0.982731,-0.011155,0.230799,0.351405,0.883851,0.447403,0.166913,0.626173,0.267278,0.113144,0.242557,0.044629,0.226877,0.377943,0.30123,0.120542,0.198932,0.293977,0.150355,0.035629,0.368802,0.164981,0.12139,0.130564,0.176848,0.155724,0.021858,0.062105
max,1105897.0,81.2,79.8,357.0,312.0,2510.0,7.15,988.0,9.94,1.51,3.94,94.9,99.0,38500.0,12.8,10.0,0.62,0.982,0.463,67.2,10.8,0.37,0.075,12.0,3.67,0.96,68.6,52.4,0.086,23.8,0.018,7.68,0.538,0.003,0.0,0.0,0.022,0.0,2340.0,0.0,...,57.6,0.04,5040.0,0.0,0.0,0.0,0.0,0.7,8.03,10.0,646711.616853,7013.590582,2.563649,1.096777,0.740316,1.383223,1.700616,0.719434,1.043124,1.348236,1.12291,0.886142,0.721888,0.908264,0.533856,1.031414,0.637376,0.611758,0.713774,0.866855,0.78019,0.614133,0.564858,0.719459,0.54037,0.501947,0.457145,0.64121,0.322989,0.389681


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,...,1293,1075,1100,1013,1317,1334,2014,1082,1129,1108,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
count,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,...,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0,269.0
mean,767588.9,0.986952,3.835576,24.304833,21.281041,183.04461,0.446283,67.360595,0.755836,0.111595,0.244071,9.744907,0.945539,304.449814,0.250595,0.218647,0.014989,0.023409,0.014338,0.924164,0.626766,0.008729,0.000613,0.094037,0.025818,0.001454,0.405892,0.449851,0.000725,0.226283,0.000167,0.057368,0.00439,1.1e-05,0.0,0.0,8.2e-05,0.0,17.67658,0.0,...,0.506952,0.000149,39.973234,0.0,0.0,0.0,0.0,0.003346,0.054981,0.141264,308393.3011,-716.994323,0.173784,-0.336296,-0.286419,0.118306,0.362688,0.002745,0.037585,0.046372,-0.040789,-0.040014,0.006509,-0.021715,0.014919,0.170805,0.118089,-0.033461,0.039829,-0.043893,0.023311,-0.005881,0.153866,-0.019001,-0.041101,0.071462,-0.020522,-0.040089,-0.101283,-0.020304
std,218182.6,6.572242,10.011393,61.692152,55.62306,470.946718,1.160966,175.998431,1.920198,0.305546,0.6544,26.661605,8.542757,3313.764102,1.208709,1.042753,0.059756,0.117465,0.062914,5.476658,1.478966,0.043599,0.005756,0.945905,0.262606,0.014499,4.238126,4.569294,0.006885,1.93411,0.001498,0.623213,0.042827,0.000183,0.0,0.0,0.001341,0.0,178.567338,0.0,...,5.009182,0.002439,434.311662,0.0,0.0,0.0,0.0,0.044344,0.563443,1.069649,218201.724493,2612.773615,0.924994,0.431466,0.593365,0.460096,0.429668,0.460022,0.542745,0.584625,0.575092,0.255556,0.399608,0.322307,0.241591,0.399307,0.271019,0.234389,0.293925,0.406676,0.211215,0.211939,0.299623,0.31759,0.172,0.125573,0.21216,0.229689,0.20311,0.130341
min,321360.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-138060.94032,-7413.329674,-2.4186,-1.39351,-1.331682,-0.879995,-0.422364,-0.71401,-0.810807,-0.829626,-0.933817,-0.726943,-0.784566,-0.591486,-0.697008,-0.891222,-0.529189,-0.592159,-0.490844,-0.801674,-0.432406,-0.515368,-0.619093,-0.663685,-0.41614,-0.44249,-0.52815,-0.621591,-0.51962,-0.240604
25%,748029.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288833.581454,-2256.897639,-0.498649,-0.632817,-0.732369,-0.139718,0.046914,-0.388285,-0.32321,-0.453359,-0.598517,-0.145094,-0.214342,-0.264403,-0.213009,-0.134663,-0.010938,-0.1862,-0.157527,-0.314304,-0.101802,-0.149138,-0.137883,-0.248389,-0.170957,0.022547,-0.193007,-0.194704,-0.194407,-0.084722
50%,748535.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289349.037591,-1157.285963,0.412631,-0.259257,-0.197494,0.015003,0.162041,0.056702,-0.007937,-0.016248,-0.013429,0.001362,0.104822,-0.056156,0.007152,0.064765,0.196325,-0.023588,0.084006,-0.089547,0.066347,-0.034599,0.276732,0.058324,-0.093715,0.100666,-0.010738,-0.009653,-0.120951,-0.035683
75%,790802.0,0.12,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,331644.66222,322.976471,0.986486,-0.011155,0.229741,0.350674,0.883438,0.413056,0.176709,0.625526,0.40632,0.108646,0.242571,0.049507,0.223283,0.36606,0.304743,0.120786,0.198559,0.294087,0.150344,0.04388,0.366574,0.173348,0.121429,0.130665,0.092749,0.15588,0.022154,0.05599
max,1105897.0,81.2,79.8,357.0,312.0,2510.0,7.15,988.0,9.94,1.51,3.94,94.9,99.0,38500.0,12.8,10.0,0.62,0.982,0.424,67.2,10.8,0.37,0.075,11.1,3.67,0.205,60.5,52.4,0.086,23.8,0.018,7.68,0.538,0.003,0.0,0.0,0.022,0.0,2340.0,0.0,...,57.6,0.04,5040.0,0.0,0.0,0.0,0.0,0.7,8.03,10.0,646711.616853,7013.590582,2.563649,1.096777,0.740316,1.383223,1.700616,0.719434,1.043124,1.348236,1.04043,0.886142,0.721888,0.908264,0.533856,1.031414,0.637376,0.611758,0.713774,0.866855,0.78019,0.614133,0.554478,0.719459,0.54037,0.501947,0.457145,0.612735,0.322989,0.389681


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1003,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1093,1002,1167,1175,1165,1166,1103,1079,1170,1264,1265,1266,1314,1315,1316,1300,1109,1299,1404,1267,1323,1304,1263,1271,1306,1253,1262,...,1293,1075,1100,1013,1317,1334,2014,1082,1129,1108,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
count,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,...,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0
mean,787598.6,1.186471,4.186471,24.558824,25.614706,216.838235,0.550441,85.558824,0.886029,0.139132,0.301897,9.161765,0.17,17.0,0.1475,0.109588,0.005603,0.009412,0.015015,1.368529,0.692647,0.004015,0.000647,0.210294,0.050471,0.017647,1.057794,0.134118,0.001147,0.052059,8.8e-05,0.009721,0.005971,0.0,0.0,0.0,0.0,0.0,20.720588,0.0,...,0.148971,0.0,3.352941,0.0,0.0,0.0,0.0,0.002941,0.0,0.0,328412.831591,-1048.658844,0.123664,-0.297467,-0.148561,0.108122,0.368108,0.001211,-0.000379,0.067097,-0.135396,-0.019913,-0.018508,-0.051259,0.047143,0.201081,0.092891,-0.038687,0.044856,-0.069042,-0.003704,-0.041645,0.155156,-0.014768,-0.069146,0.029135,0.020323,-0.040245,-0.146014,-0.014752
std,193009.6,4.734797,8.679947,55.731869,58.997033,503.145023,1.196114,185.021288,1.999955,0.334557,0.70988,24.694637,0.562553,115.665627,0.528513,0.711551,0.020486,0.044506,0.07486,7.188399,1.501573,0.028597,0.004124,1.477608,0.331224,0.119578,8.322745,0.945001,0.009459,0.429288,0.000728,0.080158,0.049235,0.0,0.0,0.0,0.0,0.0,130.331706,0.0,...,1.032182,0.0,23.237749,0.0,0.0,0.0,0.0,0.024254,0.0,0.0,193010.641334,2366.444565,0.852123,0.452722,0.585803,0.499707,0.45795,0.497465,0.505984,0.62401,0.52707,0.236587,0.337061,0.317978,0.226676,0.452111,0.274502,0.21745,0.336902,0.435677,0.230788,0.151456,0.328482,0.2901,0.17048,0.130472,0.215659,0.241235,0.214011,0.145238
min,325271.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134076.851474,-7278.938713,-2.065087,-1.254119,-1.527263,-0.734534,-0.62061,-0.655768,-0.718804,-0.808637,-0.932306,-0.636982,-0.655093,-0.67689,-0.517714,-0.679069,-0.506081,-0.591562,-0.491107,-0.801608,-0.421706,-0.348043,-0.631423,-0.589524,-0.32046,-0.358027,-0.48756,-0.621512,-0.519324,-0.240564
25%,748050.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288855.228575,-2310.850192,-0.458492,-0.632297,-0.42741,-0.141481,0.039669,-0.439265,-0.293839,-0.483817,-0.635579,-0.220142,-0.220935,-0.288907,-0.142548,-0.156655,-0.118952,-0.185471,-0.157916,-0.300201,-0.156171,-0.110279,-0.07449,-0.247981,-0.188654,-0.06121,-0.138789,-0.199832,-0.213925,-0.086468
50%,748794.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289612.877789,-1332.386738,0.151198,-0.258421,-0.123562,-0.007414,0.257967,0.056694,0.003327,-0.033545,-0.015366,0.064887,0.006416,-0.055518,0.065346,0.042427,0.134746,-0.014803,0.084778,-0.086105,0.01038,-0.039556,0.250475,0.05819,-0.113329,0.094973,0.00936,0.017874,-0.123311,0.020609
75%,790833.8,0.1075,0.5,5.25,4.7,88.25,0.0325,8.75,0.04,0.0,0.007,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,331677.005565,-329.114135,0.838129,-0.019551,0.338631,0.374881,0.897982,0.491241,0.166461,0.627517,0.168867,0.153871,0.177942,0.03608,0.230379,0.461856,0.260068,0.089903,0.203315,0.283628,0.150491,0.026312,0.372282,0.163786,0.121254,0.129167,0.191044,0.155406,0.008723,0.067994
max,1105547.0,29.0,26.5,230.0,197.0,1620.0,3.82,571.0,7.29,1.17,2.79,90.3,2.98,946.0,2.58,5.79,0.124,0.325,0.463,52.0,4.5,0.233,0.032,12.0,2.62,0.96,68.6,7.68,0.078,3.54,0.006,0.661,0.406,0.0,0.0,0.0,0.0,0.0,1000.0,0.0,...,8.35,0.0,188.0,0.0,0.0,0.0,0.0,0.2,0.0,0.0,646355.094911,3177.071394,1.504308,0.685158,0.740278,1.200174,1.035529,0.719426,1.042537,1.177679,1.12291,0.370549,0.469612,0.588025,0.53243,1.03139,0.534997,0.494331,0.71373,0.797809,0.41623,0.499171,0.564858,0.529216,0.36043,0.207432,0.457002,0.64121,0.297019,0.320163


In [None]:
indices = test_data['fdc_id']
indices

286    1104913
38      335785
198     790429
44      746771
169     748578
        ...   
82      748018
220     790584
90      748061
194     790381
172     748599
Name: fdc_id, Length: 68, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149
count,52.0,56.0,54.0,56.0,57.0,54.0,61.0,55.0,52.0,57.0,55.0,56.0,54.0,43.0,52.0,50.0,52.0,52.0,52.0,54.0,51.0,59.0,50.0,54.0,56.0,54.0,52.0,53.0,46.0,47.0,53.0,61.0,58.0,51.0,56.0,58.0,56.0,55.0,58.0,55.0,...,52.0,51.0,50.0,54.0,53.0,55.0,52.0,56.0,55.0,46.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0,68.0
mean,1.307115,4.660357,30.5,30.914286,230.877193,0.483704,68.590164,0.766727,0.118904,0.319386,8.4,0.139643,2.666667,0.094186,0.135077,0.0032,0.012077,0.019635,1.697308,0.755556,0.005353,0.000746,0.0,0.063556,0.021429,1.27037,0.027692,0.001472,0.076957,0.0,0.012472,0.006656,0.0,0.0,0.0,0.0,0.0,25.618182,0.0,0.001127,...,0.0,0.0,4.56,0.0,0.0,0.0,0.0,0.003571,0.0,0.0,328412.831591,-1048.658844,0.123664,-0.297467,-0.148561,0.108122,0.368108,0.001211,-0.000379,0.067097,-0.135396,-0.019913,-0.018508,-0.051259,0.047143,0.201081,0.092891,-0.038687,0.044856,-0.069042,-0.003704,-0.041645,0.155156,-0.014768,-0.069146,0.029135,0.020323,-0.040245,-0.146014,-0.014752
std,5.255073,9.006799,61.21652,63.844708,515.668604,1.148995,164.000851,1.937506,0.305125,0.72225,23.44441,0.543915,17.568196,0.414604,0.811765,0.013352,0.050684,0.085263,8.187086,1.568399,0.032993,0.004424,0.0,0.371271,0.131667,9.335278,0.199692,0.010714,0.521945,0.0,0.090795,0.051983,0.0,0.0,0.0,0.0,0.0,144.73379,0.0,0.00836,...,0.0,0.0,27.069193,0.0,0.0,0.0,0.0,0.026726,0.0,0.0,193010.641334,2366.444565,0.852123,0.452722,0.585803,0.499707,0.45795,0.497465,0.505984,0.62401,0.52707,0.236587,0.337061,0.317978,0.226676,0.452111,0.274502,0.21745,0.336902,0.435677,0.230788,0.151456,0.328482,0.2901,0.17048,0.130472,0.215659,0.241235,0.214011,0.145238
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-134076.851474,-7278.938713,-2.065087,-1.254119,-1.527263,-0.734534,-0.62061,-0.655768,-0.718804,-0.808637,-0.932306,-0.636982,-0.655093,-0.67689,-0.517714,-0.679069,-0.506081,-0.591562,-0.491107,-0.801608,-0.421706,-0.348043,-0.631423,-0.589524,-0.32046,-0.358027,-0.48756,-0.621512,-0.519324,-0.240564
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288855.228575,-2310.850192,-0.458492,-0.632297,-0.42741,-0.141481,0.039669,-0.439265,-0.293839,-0.483817,-0.635579,-0.220142,-0.220935,-0.288907,-0.142548,-0.156655,-0.118952,-0.185471,-0.157916,-0.300201,-0.156171,-0.110279,-0.07449,-0.247981,-0.188654,-0.06121,-0.138789,-0.199832,-0.213925,-0.086468
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,289612.877789,-1332.386738,0.151198,-0.258421,-0.123562,-0.007414,0.257967,0.056694,0.003327,-0.033545,-0.015366,0.064887,0.006416,-0.055518,0.065346,0.042427,0.134746,-0.014803,0.084778,-0.086105,0.01038,-0.039556,0.250475,0.05819,-0.113329,0.094973,0.00936,0.017874,-0.123311,0.020609
75%,0.0325,0.9625,13.5,9.7,101.0,0.015,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,331677.005565,-329.114135,0.838129,-0.019551,0.338631,0.374881,0.897982,0.491241,0.166461,0.627517,0.168867,0.153871,0.177942,0.03608,0.230379,0.461856,0.260068,0.089903,0.203315,0.283628,0.150491,0.026312,0.372282,0.163786,0.121254,0.129167,0.191044,0.155406,0.008723,0.067994
max,29.0,26.5,230.0,197.0,1620.0,3.82,571.0,7.29,1.17,2.79,90.3,2.98,129.0,2.49,5.79,0.068,0.325,0.463,52.0,4.5,0.233,0.032,0.0,2.62,0.96,68.6,1.44,0.078,3.54,0.0,0.661,0.406,0.0,0.0,0.0,0.0,0.0,1000.0,0.0,0.062,...,0.0,0.0,188.0,0.0,0.0,0.0,0.0,0.2,0.0,0.0,646355.094911,3177.071394,1.504308,0.685158,0.740278,1.200174,1.035529,0.719426,1.042537,1.177679,1.12291,0.370549,0.469612,0.588025,0.53243,1.03139,0.534997,0.494331,0.71373,0.797809,0.41623,0.499171,0.564858,0.529216,0.36043,0.207432,0.457002,0.64121,0.297019,0.320163


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()

df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings])
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.01701812
epoch : 2/200, recon loss = 0.01564335
epoch : 3/200, recon loss = 0.01460507
epoch : 4/200, recon loss = 0.01379930
epoch : 5/200, recon loss = 0.01317630
epoch : 6/200, recon loss = 0.01267984
epoch : 7/200, recon loss = 0.01227968
epoch : 8/200, recon loss = 0.01196503
epoch : 9/200, recon loss = 0.01170211
epoch : 10/200, recon loss = 0.01147030
epoch : 11/200, recon loss = 0.01131449
epoch : 12/200, recon loss = 0.01114806
epoch : 13/200, recon loss = 0.01102623
epoch : 14/200, recon loss = 0.01092571
epoch : 15/200, recon loss = 0.01082791
epoch : 16/200, recon loss = 0.01076155
epoch : 17/200, recon loss = 0.01070576
epoch : 18/200, recon loss = 0.01064671
epoch : 19/200, recon loss = 0.01060539
epoch : 20/200, recon loss = 0.01057183
epoch : 21/200, recon loss = 0.01053298
epoch : 22/200, recon loss = 0.01050878
epoch : 23/200, recon loss = 0.01048132
epoch : 24/200, recon loss = 0.01046283
epoch : 25/200, recon loss = 0.01045733
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

4.426143613050631


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.06503114
epoch : 2/200, recon loss = 0.05709851
epoch : 3/200, recon loss = 0.05018301
epoch : 4/200, recon loss = 0.04397284
epoch : 5/200, recon loss = 0.03838408
epoch : 6/200, recon loss = 0.03344401
epoch : 7/200, recon loss = 0.02929400
epoch : 8/200, recon loss = 0.02600668
epoch : 9/200, recon loss = 0.02352470
epoch : 10/200, recon loss = 0.02178479
epoch : 11/200, recon loss = 0.02060667
epoch : 12/200, recon loss = 0.01985701
epoch : 13/200, recon loss = 0.01936538
epoch : 14/200, recon loss = 0.01904659
epoch : 15/200, recon loss = 0.01885361
epoch : 16/200, recon loss = 0.01872621
epoch : 17/200, recon loss = 0.01863980
epoch : 18/200, recon loss = 0.01858143
epoch : 19/200, recon loss = 0.01854193
epoch : 20/200, recon loss = 0.01853065
epoch : 21/200, recon loss = 0.01849820
epoch : 22/200, recon loss = 0.01847674
epoch : 23/200, recon loss = 0.01846886
epoch : 24/200, recon loss = 0.01846851
epoch : 25/200, recon loss = 0.01845999
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

4.745380934728905


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

7.466544038323449


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.993338681182911


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

5.439245085477065


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.993338681182911
