# Import libraries

In [None]:
import os
import random
import numpy as np
import pandas as pd
from math import sqrt
import math
import os
import random

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim

In [None]:
!pip install autoimpute
from autoimpute.imputations import MultipleImputer

# Missingness method

In [None]:
def missing_method(test_data, train_data, num_embeddings) :
    
    # test data
    test_data = test_data.copy()
    test_rows, test_cols = test_data.shape
    test_cols -= num_embeddings

    # train data
    train_data = train_data.copy()
    train_rows, train_cols = train_data.shape
    train_cols -= num_embeddings

    # missingness threshold
    t = 0.2

    # uniform random vector, missing values where v<=t
    # test data corruption
    # embedding columns do not have any missing values
    v = np.random.uniform(size=(test_rows, test_cols))
    embeddings_mask = np.zeros((test_rows, num_embeddings), dtype=bool)
    mask = (v<=t)
    mask = np.c_[mask, embeddings_mask]
    test_data[mask] = np.NAN

    # train data corruption - this is used for training MultipleImputer for imputing mean/median values in the dataset
    v_train = np.random.uniform(size=(train_rows, train_cols))
    embeddings_mask = np.zeros((train_rows, num_embeddings), dtype=bool)
    mask_train = (v_train <= t)
    mask_train = np.c_[mask_train, embeddings_mask]
    train_data[mask_train] = np.NAN
        
    return test_data, train_data, mask

# Imputation Methods

## Autoencoder

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, dim):
        super(Autoencoder, self).__init__()
        self.dim = dim
        
        self.drop_out = nn.Dropout(p=0.2)
        
        # encoder architecture
        self.encoder = nn.Sequential(
            nn.Linear(dim, int(dim*0.7)),
            nn.Tanh(),
            nn.Linear(int(dim*0.7), int(dim*0.5)),
            nn.Tanh(),
            nn.Linear(int(dim*0.5), int(dim*0.2))
        )
            
        # decoder architecture
        self.decoder = nn.Sequential(
            nn.Linear(int(dim*0.2), int(dim*0.5)),
            nn.Tanh(),
            nn.Linear(int(dim*0.5), int(dim*0.7)),
            nn.Tanh(),
            nn.Linear(int(dim*0.7), dim)
        )
        
    def forward(self, x):
        x = x.view(-1, self.dim)

        # adding dropout to introduce input corruption during training
        x_missed = self.drop_out(x)
        
        z = self.encoder(x_missed)
        out = self.decoder(z)
        
        out = out.view(-1, self.dim)
        
        return out

In [None]:
def training(num_epochs, model, train_loader, criterion, optimizer):
  for epoch in range(num_epochs):
      loss = 0
      for i, batch_features in enumerate(train_loader):
          # load it to the active device
          batch_features = batch_features.to(device)
        
          # reset the gradients back to zero
          optimizer.zero_grad()
        
          # compute reconstructions
          outputs = model(batch_features)
        
          # compute training reconstruction loss
          train_loss = criterion(outputs, batch_features)
        
          # compute accumulated gradients
          train_loss.backward()
        
          # perform parameter update based on current gradients
          optimizer.step()
        
          # add the mini-batch training loss to epoch loss
          loss += train_loss.item()
    
      # compute the epoch training loss
      loss = loss / len(train_loader)
    
      #  display the epoch training loss
      print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, num_epochs, loss))

## Mean and Median Imputation

In [None]:
# takes the test dataset and the dataset containing the average/median values by category
# imputes the values found in the dataset by catogory in place of the NaN values in the test dataset
# returns pandas DataFrame

def impute_traditional_by_category(X_test, df, indices):
  for i in indices:
    temp1 = X_test.loc[[i]]
    temp2 = df[df['fdc_id'] == i]
    del temp2['fdc_id']
    for col in temp1.columns:
      if math.isnan(temp1[col].values[0]):
        temp1[col] = temp2[col].values[0]
    X_test.loc[[i]] = temp1

  return X_test

In [None]:
# uses a multiple imputer to fill the NaN values in the test dataset, calculates mean/median by column
# returns pandas DataFrame

def impute_traditional_all_foods(train, test, method):
  imputer = MultipleImputer(1, strategy=method, return_list=True)
  imputer.fit(train)
  data = imputer.transform(test)
  return data[0][1]

# Evaluation

In [None]:
def rmse_error(test_data, imputed_data, num_cols, mask):
  rmse_sum = 0

  for i in range(num_cols):
    y_actual = test_data[:,i][mask[:,i]]
    y_predicted = imputed_data[:,i][mask[:,i]]

    rmse = sqrt(mean_squared_error(y_actual, y_predicted))
    rmse_sum += rmse

  return rmse_sum

# Data Preparation

In [None]:
num_epochs = 200
test_size = 0.2
use_cuda = False
batch_size  = 1

data_mean = 'datasets/average datasets/traditionalMethods/food_mean.csv'
data_median = 'datasets/average datasets/traditionalMethods/food_median.csv'

In [None]:
df_mean = pd.read_csv(data_mean)
del df_mean['food_category_id']
del df_mean['Unnamed: 0']
df_mean.describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,1323,1325,1329,1330,1331,1333,1334,1335,1404,1405,1406,1409,1411,1414,2003,2004,2005,2006,2007,2008,2009,2010,2012,2013,2014,2015,2016,2018,2019,2020,2021,2022,2023,2024,2025,2026,2028,2029,2032,2033
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,0.243665,1.942243,1.824697,0.164999,0.472742,1.999714,2.561062,0.147034,0.118187,0.120046,0.046356,0.008754,0.000476,0.681114,0.728577,0.049253,8.349769,8.36934,0.055411,0.002459,0.350355,0.005555,0.021892,0.106772,27.659061,0.476038,14.595597,55.746662,125.003892,133.394887,14.730179,0.384163,2.983579,0.077504,12.187135,0.149979,42.944411,0.840258,2.086684,...,0.00248,0.0,0.001246,0.000236,7.630675e-07,2e-06,6.5e-05,1.7e-05,0.016719,0.000279,0.001158,2.1e-05,0.000652,0.0,0.0,0.0,0.0,9.3e-05,0.0,0.000153,0.002787,0.000194,0.005167,0.0,0.00017,1.8e-05,0.020393,0.00191,2.1e-05,9e-05,4e-06,0.000217,2.5e-05,4.4e-05,3.9e-05,0.00011,0.298283,0.004426,0.005723,0.02257
std,254401.8,0.220034,5.131636,1.626531,0.445364,1.407237,2.322706,7.750389,1.676538,0.170909,0.170768,0.066607,0.022466,0.001082,1.547442,1.655274,0.088366,6.015826,9.725673,0.336683,0.003077,0.866652,0.009656,0.039604,0.268325,36.719285,1.102398,36.788753,104.198028,312.532122,574.034787,46.850594,0.685919,9.489527,0.216968,77.76698,0.408316,135.856952,0.72705,2.920551,...,0.004219,0.0,0.002653,0.000311,2.094377e-06,6e-06,0.000102,2.4e-05,0.062646,0.000978,0.001561,3e-05,0.001585,0.0,0.0,0.0,0.0,0.00041,0.0,0.000214,0.00398,0.000502,0.017287,0.0,0.000669,9.7e-05,0.073219,0.008556,2.9e-05,0.00018,1.8e-05,0.000493,6.2e-05,9.3e-05,0.000208,0.000393,0.359319,0.006378,0.008247,0.037044
min,319877.0,0.0,0.0,0.111598,0.0,0.000945,0.259574,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000787,1.080851,0.0,0.0,0.0,0.0,0.0,0.0,0.019685,0.001024,0.0,0.0,0.0,0.0,0.0,0.000315,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,325648.0,0.039103,0.032765,0.213767,0.029723,0.126017,0.852374,0.0,0.014362,0.002979,0.00151,0.0,0.0,0.0,0.0,0.0,0.002864,7.083392,3.566526,0.009194,0.0,0.000399,0.0,0.0,0.011009,2.582981,0.08935,1.771225,5.056801,19.483948,11.090738,0.0,0.034903,0.0,0.002225,0.0,0.001896,0.0,0.101719,0.0,...,0.000156,0.0,7e-06,2.4e-05,0.0,0.0,0.0,0.0,0.004215,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.6e-05,0.000145,0.0,0.0,0.0,0.001804,0.000368,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2e-06,0.0,0.0,0.0,0.0
50%,331410.0,0.22813,0.153873,2.072731,0.050521,0.268274,1.537129,0.0,0.014362,0.002979,0.005394,6.6e-05,0.0,0.0,0.0,0.0,0.002864,7.640763,6.428218,0.016087,0.001852,0.000399,0.0,0.0,0.054283,6.659193,0.113892,2.494296,27.334123,19.621362,66.594918,0.0,0.246153,0.0,0.006796,0.0,0.008731,0.0,1.210188,0.0,...,0.001174,0.0,0.00029,6.2e-05,0.0,0.0,5.7e-05,0.0,0.006797,7e-05,0.000115,0.0,0.000152,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7e-06,5.4e-05,0.000988,0.0,8e-06,0.0,0.002971,0.000368,0.0,8.7e-05,0.0,0.000125,1.7e-05,4.7e-05,3e-06,1.9e-05,0.143865,0.0,0.0,0.0
75%,747468.0,0.347477,0.170609,2.739955,0.197184,0.377012,1.609859,0.031965,0.13,0.28,0.344147,0.142326,0.001274,0.0,0.0,0.0,0.047227,8.114625,6.732629,0.048356,0.005681,0.138035,0.01435,0.045217,0.077488,40.1777,0.230601,3.248511,37.944601,32.73991,71.179343,0.0,0.246153,0.0,0.01087,5.087418,0.025827,0.0,1.210188,6.291549,...,0.00298,0.0,0.001881,0.000432,1.173709e-06,4e-06,0.0001,5.1e-05,0.006797,8.9e-05,0.001902,6.5e-05,0.000408,0.0,0.0,0.0,0.0,0.000103,0.0,0.00046,0.006659,0.000159,0.001459,0.0,5e-05,1e-06,0.009273,0.000636,5.8e-05,0.000101,1e-06,0.000125,2.8e-05,5e-05,3e-06,1.9e-05,0.476995,0.013615,0.017606,0.048557
max,1105897.0,1.041708,17.350339,7.7,7.377778,10.745404,28.518519,27.170017,36.97037,0.509611,0.424066,0.142326,0.0993,0.002934,4.196521,4.488952,0.442857,39.162234,119.259259,7.392593,0.015404,3.078965,0.0252,0.106441,1.839764,131.759966,3.967752,130.819593,381.446989,1116.573367,4352.914894,163.731128,2.514563,33.163528,0.765211,587.404255,1.439618,475.002375,2.564158,6.291549,...,0.018017,0.0,0.017298,0.001777,1.184834e-05,4.3e-05,0.000545,5.1e-05,0.456394,0.005574,0.007629,6.5e-05,0.008614,0.0,0.0,0.0,0.0,0.002992,0.0,0.00046,0.016773,0.002872,0.109142,0.0,0.003732,0.000559,0.514488,0.06213,7.6e-05,0.000894,0.000104,0.00284,0.000416,0.000628,0.001441,0.002182,1.174141,0.013615,0.017606,0.123579


In [None]:
df_median = pd.read_csv(data_median, index_col=0)
del df_median['food_category_id']
df_median.describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,1323,1325,1329,1330,1331,1333,1334,1335,1404,1405,1406,1409,1411,1414,2003,2004,2005,2006,2007,2008,2009,2010,2012,2013,2014,2015,2016,2018,2019,2020,2021,2022,2023,2024,2025,2026,2028,2029,2032,2033
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,0.0,2.417543,0.142532,0.0,0.013915,0.0,3.310736,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.865044,0.0,0.0,0.0,0.368859,0.0,0.0,0.0,13.674781,0.424638,14.9343,43.813277,133.14918,0.0,19.522549,0.28879,3.508661,0.085287,0.0,0.15744,34.63678,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
std,254401.8,0.0,7.085115,0.438829,0.0,0.115348,0.0,10.53008,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.170618,0.0,0.0,0.0,1.173188,0.0,0.0,0.0,43.493809,1.350597,47.499817,139.351873,423.492346,0.0,62.093135,0.918521,11.159596,0.271264,0.0,0.500751,110.165239,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,1105897.0,0.0,23.3,2.41,0.0,0.97,0.0,36.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,60.3,0.0,0.0,0.0,4.1,0.0,0.0,0.0,152.0,4.72,166.0,487.0,1480.0,0.0,217.0,3.21,39.0,0.948,0.0,1.75,385.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
device = torch.device("cuda" if use_cuda else "cpu")

# 20 Columns

In [None]:
data_path = 'datasets/average datasets/byColumnsWithEmbeddings/20_columns_embeddings.csv'
num_embeddings = 5
num_columns = 20

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,0,1,2,3,4
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,1.824697,27.659061,14.595597,125.003892,0.384163,55.746662,0.476038,0.077504,0.149979,8.349769,0.472742,0.243665,133.394887,0.350355,1.942243,2.561062,42.944411,0.308332,0.116491,0.065003,21488.189968,176.402583,0.119704,-0.093912,-0.048493
std,254401.8,7.838466,99.924214,45.754175,382.322293,1.05534,153.43266,1.48336,0.264892,0.515558,23.549114,4.416751,0.999601,1736.846193,1.186155,6.583455,9.446549,187.200801,1.605577,0.597369,0.559044,254447.573253,3958.010969,1.112342,0.694392,0.598292
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.527263
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133692.806088,-2328.507401,-0.796369,-0.667107,-0.477074
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127823.125712,-153.274226,-0.002647,-0.011945,-0.071942
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288262.097477,3032.870912,0.983044,0.44114,0.323268
max,1105897.0,99.4,1390.0,386.0,2520.0,7.66,997.0,12.7,1.92,7.87,96.7,99.6,13.2,40700.0,11.5,79.9,43.9,1790.0,22.8,8.61,17.0,646711.616853,7635.496843,2.647689,1.869347,1.616847


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,0,1,2,3,4
count,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0
mean,482254.0,1.755201,27.743609,14.484309,124.078024,0.380677,55.346433,0.474776,0.077031,0.148934,8.414429,0.476104,0.23448,139.155284,0.355341,1.915318,2.559018,42.233441,0.308766,0.118179,0.063822,22990.22972,203.642129,0.120423,-0.094587,-0.045482
std,255957.8,7.687655,101.585144,45.636608,381.855342,1.053683,153.490134,1.485049,0.263929,0.514145,23.687789,4.515724,0.967373,1822.974657,1.197866,6.542063,9.447906,184.659671,1.595443,0.603438,0.523031,256004.025403,3967.478183,1.109581,0.694324,0.598217
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.481988
25%,325622.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133718.782631,-2326.205329,-0.788413,-0.671057,-0.476737
50%,331415.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127818.032271,-111.096255,-0.002088,-0.011191,-0.071399
75%,747515.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288310.230486,3047.83048,0.982963,0.441077,0.330094
max,1105897.0,99.4,1390.0,386.0,2510.0,7.66,997.0,10.8,1.92,7.87,96.7,99.6,13.0,40700.0,11.4,79.9,43.9,1790.0,22.8,8.61,16.8,646711.616853,7635.496843,2.647689,1.869347,1.597387


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,0,1,2,3,4
count,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,474742.6,2.102678,27.32087,15.040748,128.707364,0.398108,57.347577,0.481087,0.079392,0.154159,8.091129,0.459294,0.280404,110.3533,0.330408,2.049947,2.569241,45.788287,0.306598,0.109738,0.069724,15480.03096,67.4444,0.116832,-0.091213,-0.060537
std,248036.9,8.410532,93.001944,46.227551,384.235236,1.062026,153.221347,1.476856,0.26875,0.52125,22.988673,3.997199,1.118717,1338.044656,1.138109,6.74668,9.442917,197.048894,1.645802,0.572513,0.684504,248081.146906,3918.771433,1.123528,0.694789,0.598559
min,319882.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139550.845832,-7411.3678,-2.816578,-1.537302,-1.527263
25%,325731.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133608.254984,-2345.183653,-0.820487,-0.646966,-0.482865
50%,331403.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127830.256528,-330.340365,-0.006454,-0.014591,-0.076192
75%,747294.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288084.845758,2974.013474,0.984178,0.441246,0.321993
max,1105887.0,82.6,938.0,369.0,2520.0,7.05,986.0,12.7,1.84,4.01,96.0,99.4,13.2,40100.0,11.5,48.1,42.5,1610.0,22.2,7.42,17.0,646701.429977,7630.592063,2.647168,1.869343,1.616847


In [None]:
indices = test_data['fdc_id']
indices

767      321893
11351    790103
7753     333420
3071     325328
1309     322529
          ...  
3525     326013
8207     334668
11660    790480
5439     329186
2603     324520
Name: fdc_id, Length: 2621, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
count,2128.0,2103.0,2081.0,2088.0,2107.0,2127.0,2096.0,2105.0,2098.0,2116.0,2100.0,2067.0,2093.0,2115.0,2080.0,2085.0,2138.0,2111.0,2069.0,2088.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,2.127702,28.920114,14.914512,126.25,0.408519,58.964269,0.489957,0.079634,0.159527,8.198341,0.446043,0.306473,86.846632,0.330875,2.03124,2.564427,45.905753,0.324708,0.111738,0.067591,15480.03096,67.4444,0.116832,-0.091213,-0.060537
std,8.600823,95.757893,45.769414,380.203401,1.074592,155.364526,1.495226,0.270365,0.526523,23.213875,3.877333,1.18457,879.867413,1.127328,6.722095,9.429371,198.052557,1.712762,0.574624,0.700947,248081.146906,3918.771433,1.123528,0.694789,0.598559
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139550.845832,-7411.3678,-2.816578,-1.537302,-1.527263
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133608.254984,-2345.183653,-0.820487,-0.646966,-0.482865
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127830.256528,-330.340365,-0.006454,-0.014591,-0.076192
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288084.845758,2974.013474,0.984178,0.441246,0.321993
max,82.6,938.0,369.0,2030.0,6.63,986.0,12.7,1.84,4.01,96.0,99.4,13.2,38200.0,11.0,48.1,42.5,1610.0,22.2,7.42,17.0,646701.429977,7630.592063,2.647168,1.869343,1.616847


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.01529659
epoch : 2/200, recon loss = 0.01418361
epoch : 3/200, recon loss = 0.01256301
epoch : 4/200, recon loss = 0.00741972
epoch : 5/200, recon loss = 0.00631517
epoch : 6/200, recon loss = 0.00621053
epoch : 7/200, recon loss = 0.00613695
epoch : 8/200, recon loss = 0.00609543
epoch : 9/200, recon loss = 0.00604837
epoch : 10/200, recon loss = 0.00601578
epoch : 11/200, recon loss = 0.00595585
epoch : 12/200, recon loss = 0.00587455
epoch : 13/200, recon loss = 0.00579482
epoch : 14/200, recon loss = 0.00552471
epoch : 15/200, recon loss = 0.00499643
epoch : 16/200, recon loss = 0.00423957
epoch : 17/200, recon loss = 0.00383543
epoch : 18/200, recon loss = 0.00374121
epoch : 19/200, recon loss = 0.00365074
epoch : 20/200, recon loss = 0.00368156
epoch : 21/200, recon loss = 0.00358289
epoch : 22/200, recon loss = 0.00359027
epoch : 23/200, recon loss = 0.00357586
epoch : 24/200, recon loss = 0.00363234
epoch : 25/200, recon loss = 0.00358037
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

1.1560837637956913


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.02462699
epoch : 2/200, recon loss = 0.02235821
epoch : 3/200, recon loss = 0.01862748
epoch : 4/200, recon loss = 0.01543054
epoch : 5/200, recon loss = 0.01503805
epoch : 6/200, recon loss = 0.01475612
epoch : 7/200, recon loss = 0.01449390
epoch : 8/200, recon loss = 0.01418221
epoch : 9/200, recon loss = 0.01377359
epoch : 10/200, recon loss = 0.01337756
epoch : 11/200, recon loss = 0.01296752
epoch : 12/200, recon loss = 0.01257062
epoch : 13/200, recon loss = 0.01242249
epoch : 14/200, recon loss = 0.01220449
epoch : 15/200, recon loss = 0.01187324
epoch : 16/200, recon loss = 0.01163460
epoch : 17/200, recon loss = 0.01161640
epoch : 18/200, recon loss = 0.01143047
epoch : 19/200, recon loss = 0.01141854
epoch : 20/200, recon loss = 0.01130841
epoch : 21/200, recon loss = 0.01125204
epoch : 22/200, recon loss = 0.01104177
epoch : 23/200, recon loss = 0.01103405
epoch : 24/200, recon loss = 0.01080708
epoch : 25/200, recon loss = 0.01060850
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

1.1978401246034787


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

1.5219354232992517


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

1.6189507261319402


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

2.0916638036635202


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

2.178197307340046


# 40 Columns

In [None]:
data_path = 'datasets/average datasets/byColumnsWithEmbeddings/40_columns_embeddings.csv'
num_embeddings = 10
num_columns = 40

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,1165,0,1,2,3,4,5,6,7,8,9
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,1.824697,27.659061,14.595597,125.003892,0.384163,55.746662,0.476038,0.077504,0.149979,8.349769,0.472742,0.243665,133.394887,0.350355,1.942243,2.561062,42.944411,0.308332,0.116491,0.065003,0.60573,0.211003,0.023788,14.730179,2.983579,78.72934,12.225029,0.005267,0.005886,0.016719,0.004499,0.00387,0.00248,0.118839,0.002989,0.015767,0.012258,0.008632,0.038133,0.006928,21488.189968,176.402583,0.119704,-0.093912,-0.048493,-0.007831,0.051473,0.02311,0.06316,0.043647
std,254401.8,7.838466,99.924214,45.754175,382.322293,1.05534,153.43266,1.48336,0.264892,0.515558,23.549114,4.416751,0.999601,1736.846193,1.186155,6.583455,9.446549,187.200801,1.605577,0.597369,0.559044,4.440086,1.941342,0.148635,56.459805,11.848737,302.68343,49.0407,0.030796,0.0431,0.223785,0.039142,0.065227,0.015745,0.828604,0.01705,0.10619,0.071481,0.054902,0.974266,0.056734,254447.573253,3958.010969,1.112342,0.694392,0.598293,0.511537,0.453723,0.451006,0.41745,0.426074
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.527264,-1.442526,-1.189873,-1.064457,-1.076825,-0.829557
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133692.806088,-2328.507401,-0.796369,-0.667107,-0.477074,-0.343775,-0.23611,-0.286783,-0.177237,-0.292813
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127823.125712,-153.274226,-0.002647,-0.011945,-0.071966,-0.022448,0.01203,0.040891,0.040572,-0.032707
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288262.097477,3032.870912,0.983044,0.44114,0.32327,0.289041,0.257717,0.227563,0.309656,0.274336
max,1105897.0,99.4,1390.0,386.0,2520.0,7.66,997.0,12.7,1.92,7.87,96.7,99.6,13.2,40700.0,11.5,79.9,43.9,1790.0,22.8,8.61,17.0,73.2,52.6,3.4,311.0,93.0,1710.0,449.0,0.485,0.85,7.95,0.774,1.83,0.303,14.0,0.222,1.93,1.86,0.989,42.4,1.92,646711.616853,7635.496843,2.647689,1.869347,1.616847,2.041569,1.70062,1.764974,1.547869,1.348304


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,1165,0,1,2,3,4,5,6,7,8,9
count,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0
mean,483705.8,1.887024,27.318962,14.395221,123.668447,0.380351,55.097482,0.471182,0.076515,0.148187,8.491958,0.488262,0.25033,122.514594,0.342884,1.927489,2.517378,42.45703,0.315906,0.118237,0.069707,0.614797,0.199738,0.02407,14.494945,2.935969,77.261351,11.97563,0.005308,0.006134,0.016559,0.004545,0.003895,0.002414,0.122711,0.003012,0.015654,0.012506,0.008658,0.044606,0.007359,24442.898852,170.58904,0.119443,-0.093123,-0.05268,-0.01159,0.052038,0.022672,0.063588,0.045477
std,255333.3,8.019547,99.862708,45.531646,380.806938,1.051379,152.479658,1.48038,0.263713,0.515968,23.679751,4.518943,1.014003,1618.739582,1.172874,6.563284,9.370949,186.008101,1.641078,0.605666,0.602119,4.534074,1.77633,0.150727,56.09916,11.764307,299.832117,48.511157,0.030991,0.044737,0.216067,0.039463,0.066432,0.015003,0.842898,0.017028,0.104826,0.073419,0.055229,1.087993,0.059635,255379.454754,3933.834962,1.111404,0.692907,0.599123,0.506658,0.454648,0.452795,0.421476,0.425888
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.577342,-1.527264,-1.442526,-1.189873,-1.006986,-1.076825,-0.829557
25%,325691.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133648.238488,-2307.662167,-0.798023,-0.663338,-0.477425,-0.344402,-0.236082,-0.28688,-0.177357,-0.292446
50%,331468.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127763.532461,-159.406808,0.006781,-0.011148,-0.072708,-0.023806,0.011936,0.040795,0.041052,-0.032263
75%,747603.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288399.875032,3016.23688,0.981714,0.441011,0.3228,0.288232,0.261384,0.227468,0.311777,0.276468
max,1105896.0,99.2,1380.0,386.0,2520.0,7.66,997.0,12.7,1.92,7.87,96.0,99.6,13.2,40700.0,11.5,79.9,43.9,1730.0,22.8,8.61,17.0,73.2,52.4,3.4,311.0,93.0,1710.0,449.0,0.485,0.85,7.95,0.774,1.83,0.303,14.0,0.222,1.93,1.86,0.989,42.4,1.92,646710.59817,7635.496843,2.647689,1.869347,1.616847,2.02881,1.70062,1.764974,1.547869,1.228624


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,1165,0,1,2,3,4,5,6,7,8,9
count,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,468935.5,1.575387,29.019458,15.3971,130.34567,0.399412,58.34338,0.495464,0.081459,0.157148,7.781011,0.41066,0.217005,176.916063,0.380237,2.001263,2.735799,44.893934,0.278035,0.109507,0.046184,0.569463,0.256064,0.022656,15.671118,3.174018,84.601297,13.222625,0.005103,0.004891,0.017362,0.004317,0.003768,0.002741,0.103353,0.002899,0.016219,0.011265,0.008528,0.012241,0.005205,9669.354429,199.656755,0.120749,-0.097069,-0.031746,0.007202,0.049214,0.02486,0.061447,0.036329
std,250340.7,7.063825,100.1774,46.633877,388.352177,1.071102,157.189722,1.495347,0.269571,0.513949,23.014814,3.982042,0.939498,2145.02653,1.237634,6.664468,9.742956,191.921278,1.454837,0.563017,0.334787,4.042945,2.494442,0.139977,57.881263,12.181039,313.819887,51.101416,0.030006,0.035801,0.252354,0.037836,0.060177,0.01842,0.76873,0.017142,0.111496,0.063141,0.053586,0.101575,0.043194,250385.040634,4053.954334,1.1163,0.700426,0.594781,0.530437,0.450082,0.443857,0.401017,0.426822
min,319882.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139550.845832,-7412.348752,-2.571016,-1.585647,-1.327981,-1.442505,-1.189847,-1.064457,-1.061308,-0.808859
25%,325420.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133925.066958,-2397.174257,-0.788091,-0.671897,-0.475166,-0.33434,-0.238444,-0.265503,-0.17075,-0.294423
50%,330904.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-128338.58185,-98.834426,-0.018865,-0.019108,-0.066464,-0.015501,0.012356,0.048059,0.02843,-0.03413
75%,746994.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,287779.239352,3128.023695,0.990449,0.441551,0.339117,0.305718,0.247664,0.227888,0.300948,0.264765
max,1105897.0,99.4,1390.0,358.0,1810.0,7.46,997.0,9.89,1.83,3.9,96.7,99.3,13.0,40300.0,11.2,48.1,43.2,1790.0,22.6,6.86,7.1,69.9,52.6,3.32,284.0,79.0,1650.0,370.0,0.435,0.765,7.68,0.726,1.79,0.264,13.4,0.194,1.77,1.59,0.787,2.14,1.02,646711.616853,7630.592063,2.647168,1.869328,1.597387,2.041569,1.592385,1.742486,1.484113,1.348304


In [None]:
indices = test_data['fdc_id']
indices

1309      322529
3496      325966
8951      335679
12777    1105515
9704      747274
          ...   
4254      327175
4402      327370
8410      335104
5694      329579
5764      329677
Name: fdc_id, Length: 2621, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
count,2081.0,2098.0,2065.0,2083.0,2080.0,2086.0,2057.0,2087.0,2109.0,2121.0,2092.0,2113.0,2095.0,2079.0,2109.0,2108.0,2129.0,2090.0,2097.0,2129.0,2044.0,2104.0,2123.0,2054.0,2100.0,2085.0,2086.0,2074.0,2104.0,2121.0,2099.0,2097.0,2103.0,2109.0,2086.0,2109.0,2110.0,2130.0,2122.0,2092.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,1.625132,29.615825,15.899952,124.521843,0.382135,57.97651,0.516393,0.080287,0.156195,7.633182,0.44055,0.235542,185.649642,0.373641,1.932248,2.718221,44.560263,0.288953,0.101189,0.04857,0.589414,0.248271,0.022463,16.201558,3.297667,88.138609,12.980585,0.005081,0.004599,0.01913,0.004288,0.003748,0.00278,0.106002,0.002992,0.01584,0.0114,0.007671,0.013134,0.005845,9669.354429,199.656755,0.120749,-0.097069,-0.031746,0.007202,0.049214,0.02486,0.061447,0.036329
std,7.297463,101.869652,47.24955,380.291582,1.058164,155.631801,1.532044,0.268996,0.515979,22.701211,4.427868,0.988084,2235.922331,1.221313,6.602178,9.71453,192.1643,1.452063,0.540959,0.345796,4.141566,2.465606,0.140614,58.688585,12.455273,319.828725,50.17165,0.030617,0.034833,0.27946,0.03797,0.05805,0.018405,0.792091,0.017595,0.113064,0.064075,0.048936,0.105973,0.04695,250385.040634,4053.954334,1.1163,0.700426,0.594781,0.530437,0.450082,0.443857,0.401017,0.426822
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139550.845832,-7412.348752,-2.571016,-1.585647,-1.327981,-1.442505,-1.189847,-1.064457,-1.061308,-0.808859
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133925.066958,-2397.174257,-0.788091,-0.671897,-0.475166,-0.33434,-0.238444,-0.265503,-0.17075,-0.294423
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-128338.58185,-98.834426,-0.018865,-0.019108,-0.066464,-0.015501,0.012356,0.048059,0.02843,-0.03413
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,287779.239352,3128.023695,0.990449,0.441551,0.339117,0.305718,0.247664,0.227888,0.300948,0.264765
max,99.4,1390.0,358.0,1810.0,7.46,995.0,9.89,1.83,3.9,96.7,99.3,13.0,40300.0,11.2,48.1,42.4,1790.0,14.0,6.86,7.1,65.0,52.6,3.32,279.0,79.0,1640.0,331.0,0.435,0.765,7.68,0.726,1.79,0.252,13.4,0.194,1.77,1.59,0.787,2.14,1.02,646711.616853,7630.592063,2.647168,1.869328,1.597387,2.041569,1.592385,1.742486,1.484113,1.348304


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.01119547
epoch : 2/200, recon loss = 0.00909315
epoch : 3/200, recon loss = 0.00602033
epoch : 4/200, recon loss = 0.00438604
epoch : 5/200, recon loss = 0.00417261
epoch : 6/200, recon loss = 0.00411483
epoch : 7/200, recon loss = 0.00406816
epoch : 8/200, recon loss = 0.00399762
epoch : 9/200, recon loss = 0.00392510
epoch : 10/200, recon loss = 0.00380531
epoch : 11/200, recon loss = 0.00364834
epoch : 12/200, recon loss = 0.00346005
epoch : 13/200, recon loss = 0.00331112
epoch : 14/200, recon loss = 0.00316740
epoch : 15/200, recon loss = 0.00311514
epoch : 16/200, recon loss = 0.00300763
epoch : 17/200, recon loss = 0.00295820
epoch : 18/200, recon loss = 0.00293936
epoch : 19/200, recon loss = 0.00289737
epoch : 20/200, recon loss = 0.00286893
epoch : 21/200, recon loss = 0.00285394
epoch : 22/200, recon loss = 0.00283548
epoch : 23/200, recon loss = 0.00280714
epoch : 24/200, recon loss = 0.00280493
epoch : 25/200, recon loss = 0.00275761
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

1.8480004416035005


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.01978195
epoch : 2/200, recon loss = 0.01637038
epoch : 3/200, recon loss = 0.01545645
epoch : 4/200, recon loss = 0.01315527
epoch : 5/200, recon loss = 0.01141412
epoch : 6/200, recon loss = 0.01078181
epoch : 7/200, recon loss = 0.01042307
epoch : 8/200, recon loss = 0.01011216
epoch : 9/200, recon loss = 0.00977376
epoch : 10/200, recon loss = 0.00956904
epoch : 11/200, recon loss = 0.00940188
epoch : 12/200, recon loss = 0.00927159
epoch : 13/200, recon loss = 0.00917862
epoch : 14/200, recon loss = 0.00906874
epoch : 15/200, recon loss = 0.00896763
epoch : 16/200, recon loss = 0.00881151
epoch : 17/200, recon loss = 0.00871635
epoch : 18/200, recon loss = 0.00862613
epoch : 19/200, recon loss = 0.00843092
epoch : 20/200, recon loss = 0.00834418
epoch : 21/200, recon loss = 0.00828966
epoch : 22/200, recon loss = 0.00818139
epoch : 23/200, recon loss = 0.00812029
epoch : 24/200, recon loss = 0.00805194
epoch : 25/200, recon loss = 0.00792620
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

1.8551008859606322


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

2.531217748333226


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

2.680654619816815


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

3.350671086407284


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

3.4719214819705706


# 80 Columns

In [None]:
data_path = 'datasets/avarage datasets/byColumnsWithEmbeddings/80_columns_embeddings.csv'
num_embeddings = 20
num_columns = 80

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,...,1277,1292,1258,1293,1405,1010,1170,1109,1177,1178,1312,2009,1334,1257,1100,1107,1223,1224,1210,1211,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,1.824697,27.659061,14.595597,125.003892,0.384163,55.746662,0.476038,0.077504,0.149979,8.349769,0.472742,0.243665,133.394887,0.350355,1.942243,2.561062,42.944411,0.308332,0.116491,0.065003,0.60573,0.211003,0.023788,14.730179,2.983579,78.72934,12.225029,0.005267,0.005886,0.016719,0.004499,0.00387,0.00248,0.118839,0.002989,0.015767,0.012258,0.008632,0.038133,...,0.002201,0.434562,0.319591,0.150495,0.000279,0.147034,0.011803,0.073772,1.055017,0.027458,9.3e-05,0.002787,6.5e-05,0.008114,12.187135,8.910568,0.016636,0.027917,0.002017,0.007339,21488.189968,176.402583,0.119704,-0.093912,-0.048494,-0.007831,0.051473,0.023113,0.063163,0.04367,-0.005261,-0.053342,0.032862,0.017149,0.053913,-0.050025,-0.006816,-0.001294,0.04543,-0.002778
std,254401.8,7.838466,99.924214,45.754175,382.322293,1.05534,153.43266,1.48336,0.264892,0.515558,23.549114,4.416751,0.999601,1736.846193,1.186155,6.583455,9.446549,187.200801,1.605577,0.597369,0.559044,4.440086,1.941342,0.148635,56.459805,11.848737,302.68343,49.0407,0.030796,0.0431,0.223785,0.039142,0.065227,0.015745,0.828604,0.01705,0.10619,0.071481,0.054902,0.974266,...,0.025626,4.338678,2.546196,1.960438,0.002974,2.828219,0.085062,1.032461,10.123472,0.206489,0.001497,0.029047,0.000516,0.083699,237.511065,175.426202,0.181747,0.318716,0.022631,0.07859,254447.573253,3958.010969,1.112342,0.694392,0.598293,0.511537,0.453723,0.451007,0.417454,0.426071,0.400175,0.345475,0.349559,0.332055,0.321275,0.292891,0.270026,0.284946,0.26171,0.269164
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.527263,-1.442526,-1.189856,-1.064453,-1.077035,-0.829624,-0.933852,-0.969973,-1.097073,-0.920129,-1.006213,-1.000462,-0.798443,-0.685798,-0.669546,-0.809746
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133692.806088,-2328.507401,-0.796369,-0.667107,-0.477073,-0.343775,-0.236112,-0.286794,-0.177182,-0.292727,-0.303204,-0.332177,-0.196668,-0.219358,-0.16023,-0.217182,-0.194962,-0.175535,-0.1318,-0.159046
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127823.125712,-153.274226,-0.002647,-0.011945,-0.071964,-0.022449,0.012034,0.040864,0.040584,-0.032775,-0.006456,0.045001,0.064457,-0.007091,0.051132,-0.054635,-0.020075,-0.031549,0.015217,-0.036377
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288262.097477,3032.870912,0.983044,0.44114,0.323271,0.289042,0.257716,0.227569,0.309666,0.274537,0.26712,0.16633,0.248007,0.188549,0.251005,0.108285,0.190913,0.178225,0.197493,0.175122
max,1105897.0,99.4,1390.0,386.0,2520.0,7.66,997.0,12.7,1.92,7.87,96.7,99.6,13.2,40700.0,11.5,79.9,43.9,1790.0,22.8,8.61,17.0,73.2,52.6,3.4,311.0,93.0,1710.0,449.0,0.485,0.85,7.95,0.774,1.83,0.303,14.0,0.222,1.93,1.86,0.989,42.4,...,0.612,74.2,82.5,58.1,0.171,100.0,1.64,28.3,348.0,4.4,0.086,0.712,0.022,1.59,7430.0,8750.0,5.32,9.7,0.63,2.1,646711.616853,7635.496843,2.647689,1.869347,1.616849,2.041568,1.700615,1.764944,1.547783,1.34823,1.122909,1.235692,1.230902,1.301706,1.225573,1.03207,0.770407,0.843408,1.053437,0.875425


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,...,1277,1292,1258,1293,1405,1010,1170,1109,1177,1178,1312,2009,1334,1257,1100,1107,1223,1224,1210,1211,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
count,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,...,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0
mean,480840.6,1.788952,27.818199,14.610349,125.796356,0.381077,55.589088,0.473794,0.077637,0.150204,8.280277,0.466951,0.235234,130.758585,0.355857,1.938891,2.572338,42.944601,0.312186,0.116781,0.064787,0.635888,0.215604,0.024394,14.746471,2.996175,78.773655,12.361551,0.005233,0.005803,0.015959,0.004814,0.00439,0.002434,0.120633,0.002984,0.015861,0.012295,0.008884,0.039484,...,0.002451,0.420075,0.306654,0.147405,0.000279,0.165155,0.011789,0.078353,0.988936,0.027461,9.9e-05,0.002724,6.6e-05,0.008259,11.711017,8.692865,0.016277,0.027779,0.001996,0.007305,21577.236256,166.880759,0.116692,-0.094338,-0.047668,-0.012898,0.056047,0.022257,0.061427,0.043563,-0.005482,-0.054468,0.031648,0.018106,0.054145,-0.051153,-0.005925,5.8e-05,0.046189,-0.002405
std,254587.4,7.742546,100.368856,45.772164,383.926355,1.048002,153.150946,1.47435,0.264757,0.516615,23.518922,4.413584,0.98138,1700.110237,1.19887,6.577799,9.472052,187.192828,1.609271,0.593455,0.563034,4.639686,1.932338,0.155189,56.453297,11.869546,302.28416,49.438662,0.030464,0.042131,0.208596,0.041514,0.071241,0.015491,0.85113,0.017098,0.106631,0.070982,0.055063,1.007735,...,0.027812,4.227437,2.382598,1.910046,0.003049,3.143208,0.085169,1.058159,9.861488,0.208564,0.001648,0.028177,0.000529,0.084488,230.035809,177.214654,0.177415,0.311426,0.022198,0.077407,254633.088377,3965.821318,1.112802,0.695036,0.59731,0.509655,0.454499,0.449308,0.418755,0.426148,0.401053,0.343192,0.350124,0.332365,0.321075,0.293762,0.26905,0.286023,0.261,0.270059
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.527263,-1.442526,-1.189856,-1.064453,-1.077035,-0.829624,-0.933852,-0.969973,-1.097073,-0.920129,-1.006163,-1.000462,-0.798443,-0.685798,-0.669546,-0.809746
25%,325613.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133727.696151,-2357.200364,-0.795388,-0.666598,-0.477031,-0.343848,-0.235564,-0.286745,-0.177279,-0.293252,-0.303703,-0.332336,-0.199342,-0.221195,-0.159572,-0.217351,-0.192035,-0.174844,-0.132597,-0.158719
50%,331442.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127790.01835,-144.692521,-0.006754,-0.011181,-0.071705,-0.02687,0.013067,0.040808,0.038015,-0.032916,-0.006045,0.043369,0.064454,-0.00721,0.050987,-0.055414,-0.019909,-0.031226,0.01551,-0.034718
75%,747464.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288258.022725,3048.413899,0.982095,0.441213,0.323066,0.282833,0.26406,0.227521,0.305305,0.274315,0.267013,0.165323,0.246141,0.187702,0.251093,0.10823,0.19122,0.178754,0.197991,0.176335
max,1105897.0,99.4,1390.0,369.0,2520.0,7.66,997.0,12.7,1.85,7.87,96.7,99.6,13.2,40300.0,11.5,79.9,43.4,1650.0,22.8,8.61,17.0,73.2,52.6,3.4,311.0,91.0,1710.0,410.0,0.474,0.806,7.95,0.774,1.83,0.303,14.0,0.222,1.77,1.86,0.989,42.4,...,0.612,73.2,48.4,57.6,0.171,100.0,1.64,28.3,348.0,4.4,0.086,0.712,0.022,1.59,7260.0,8750.0,5.02,9.22,0.63,2.0,646711.616853,7635.496843,2.647689,1.869347,1.616849,2.041568,1.700615,1.764944,1.547783,1.34823,1.122909,1.167115,1.222555,1.289735,1.225573,1.03207,0.770407,0.843408,1.053437,0.875414


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,...,1277,1292,1258,1293,1405,1010,1170,1109,1177,1178,1312,2009,1334,1257,1100,1107,1223,1224,1210,1211,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
count,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,...,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,480396.3,1.967676,27.02251,14.536589,121.834033,0.396509,56.376955,0.485013,0.076968,0.14908,8.627738,0.495906,0.27739,143.940099,0.328348,1.955654,2.51596,42.943647,0.292919,0.11533,0.065867,0.4851,0.192602,0.021364,14.665013,2.933193,78.552079,11.678939,0.005403,0.006216,0.01976,0.00324,0.001788,0.002662,0.111663,0.00301,0.015391,0.01211,0.007625,0.032726,...,0.001202,0.492513,0.37134,0.162857,0.000279,0.074552,0.011859,0.055448,1.319344,0.027448,6.8e-05,0.00304,6.3e-05,0.007536,14.091606,9.781381,0.018074,0.028465,0.002101,0.007475,21132.004815,214.48988,0.131754,-0.09221,-0.051795,0.012438,0.033177,0.026538,0.07011,0.044096,-0.004377,-0.048841,0.037719,0.013325,0.052985,-0.045513,-0.01038,-0.006703,0.042393,-0.004267
std,253705.9,8.210923,98.141815,45.690826,375.893218,1.084308,154.58216,1.519119,0.265481,0.511403,23.671963,4.430164,1.068915,1876.920647,1.133821,6.607275,9.345496,187.268411,1.590927,0.612892,0.542897,3.528785,1.977214,0.118853,56.496558,11.767255,304.333119,47.420738,0.032094,0.04678,0.276343,0.027657,0.031095,0.016725,0.731701,0.016862,0.104423,0.073456,0.054255,0.827105,...,0.013715,4.758129,3.11585,2.150577,0.002654,0.685144,0.084649,0.9225,11.107897,0.198013,0.000584,0.032298,0.000462,0.080481,265.353791,168.112176,0.198158,0.346407,0.024288,0.083169,253752.326858,3927.135566,1.110629,0.691941,0.602314,0.518598,0.450228,0.457808,0.412216,0.425847,0.39672,0.35449,0.347314,0.330849,0.322132,0.289392,0.273918,0.280586,0.264559,0.2656
min,319878.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139554.920584,-7410.386832,-2.489668,-1.577342,-1.433074,-1.4424,-1.189727,-0.922758,-1.021707,-0.808912,-0.933827,-0.946624,-0.961907,-0.89094,-1.006213,-0.894695,-0.797834,-0.601732,-0.616257,-0.809725
25%,325784.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133554.264517,-2265.726243,-0.797956,-0.669478,-0.477206,-0.343567,-0.248446,-0.28681,-0.176744,-0.28696,-0.293944,-0.331787,-0.178625,-0.218948,-0.163446,-0.216294,-0.19915,-0.181633,-0.126746,-0.159625
50%,331177.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-128060.480019,-165.539435,0.024403,-0.016796,-0.072881,-0.007731,-0.004707,0.041421,0.041124,-0.031619,-0.011872,0.047214,0.06447,-0.006434,0.051343,-0.050258,-0.021204,-0.041344,0.011042,-0.038586
75%,747474.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288268.209605,2978.84736,0.989184,0.440794,0.338293,0.321224,0.23786,0.227693,0.328289,0.275653,0.267565,0.169586,0.254462,0.189796,0.250228,0.110422,0.190131,0.174171,0.193967,0.170483
max,1105891.0,99.1,1070.0,386.0,1800.0,7.44,997.0,12.3,1.92,3.87,95.2,99.3,13.2,40700.0,11.4,34.2,43.9,1790.0,22.2,7.47,16.7,70.5,52.3,1.45,284.0,93.0,1650.0,449.0,0.485,0.85,7.72,0.663,1.46,0.285,13.7,0.192,1.93,1.69,0.87,41.8,...,0.576,74.2,82.5,58.1,0.043,18.8,1.3,27.6,172.0,3.9,0.011,0.687,0.008,1.18,7430.0,5260.0,5.32,9.7,0.62,2.1,646705.504729,7634.515887,2.64748,1.869342,1.560331,1.99033,1.592393,1.764828,1.361376,1.224556,1.068824,1.235692,1.230902,1.301706,1.182738,1.032062,0.74623,0.702585,1.053217,0.875425


In [None]:
indices = test_data['fdc_id']
indices

9313      336041
2440      324243
2257      323937
1041      322215
12468    1105140
          ...   
11687     790512
3464      325917
11720     790553
2688      324702
5900      329898
Name: fdc_id, Length: 2621, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
count,2109.0,2078.0,2092.0,2100.0,2123.0,2085.0,2137.0,2119.0,2110.0,2086.0,2088.0,2079.0,2117.0,2126.0,2125.0,2110.0,2120.0,2138.0,2114.0,2104.0,2088.0,2118.0,2109.0,2123.0,2083.0,2091.0,2100.0,2140.0,2087.0,2087.0,2097.0,2130.0,2107.0,2058.0,2122.0,2120.0,2065.0,2095.0,2091.0,2093.0,...,2076.0,2107.0,2101.0,2092.0,2125.0,2102.0,2086.0,2085.0,2105.0,2094.0,2108.0,2133.0,2068.0,2096.0,2046.0,2067.0,2111.0,2099.0,2094.0,2092.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,1.967435,26.70308,13.71305,120.066667,0.39764,58.003357,0.493776,0.080866,0.152113,8.641716,0.500747,0.254367,161.042985,0.323612,1.978193,2.486318,43.015943,0.297366,0.109253,0.05802,0.43734,0.183017,0.021095,14.1691,3.165098,73.178384,11.415,0.005707,0.00572,0.021977,0.003471,0.002084,0.002667,0.103992,0.003213,0.014764,0.011882,0.008267,0.037143,0.006607,...,0.001267,0.508391,0.367177,0.180243,0.000321,0.071613,0.011859,0.054974,1.31639,0.025907,8e-05,0.003531,5.6e-05,0.009091,15.702639,7.275278,0.017606,0.033414,0.001969,0.007199,21132.004815,214.48988,0.131754,-0.09221,-0.051795,0.012438,0.033177,0.026538,0.07011,0.044096,-0.004377,-0.048841,0.037719,0.013325,0.052985,-0.045513,-0.01038,-0.006703,0.042393,-0.004267
std,8.333341,97.186923,43.918,372.767341,1.09067,155.870192,1.52722,0.272309,0.514505,23.644255,4.442314,0.983221,2083.637258,1.131511,6.652319,9.294418,185.051257,1.64744,0.594519,0.426314,3.169038,1.860204,0.115439,55.54696,12.216444,294.649725,45.966284,0.032902,0.045211,0.307572,0.029392,0.034465,0.016328,0.67384,0.017801,0.097919,0.069045,0.057652,0.923418,0.05845,...,0.015056,4.994738,3.059325,2.373638,0.002881,0.651726,0.083499,0.908443,11.203809,0.186068,0.00064,0.035284,0.000442,0.089476,285.237165,158.288625,0.207159,0.383715,0.021935,0.081271,253752.326858,3927.135566,1.110629,0.691941,0.602314,0.518598,0.450228,0.457808,0.412216,0.425847,0.39672,0.35449,0.347314,0.330849,0.322132,0.289392,0.273918,0.280586,0.264559,0.2656
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139554.920584,-7410.386832,-2.489668,-1.577342,-1.433074,-1.4424,-1.189727,-0.922758,-1.021707,-0.808912,-0.933827,-0.946624,-0.961907,-0.89094,-1.006213,-0.894695,-0.797834,-0.601732,-0.616257,-0.809725
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133554.264517,-2265.726243,-0.797956,-0.669478,-0.477206,-0.343567,-0.248446,-0.28681,-0.176744,-0.28696,-0.293944,-0.331787,-0.178625,-0.218948,-0.163446,-0.216294,-0.19915,-0.181633,-0.126746,-0.159625
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-128060.480019,-165.539435,0.024403,-0.016796,-0.072881,-0.007731,-0.004707,0.041421,0.041124,-0.031619,-0.011872,0.047214,0.06447,-0.006434,0.051343,-0.050258,-0.021204,-0.041344,0.011042,-0.038586
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288268.209605,2978.84736,0.989184,0.440794,0.338293,0.321224,0.23786,0.227693,0.328289,0.275653,0.267565,0.169586,0.254462,0.189796,0.250228,0.110422,0.190131,0.174171,0.193967,0.170483
max,99.1,1070.0,350.0,1800.0,7.44,997.0,12.3,1.92,3.87,94.7,99.3,13.2,40700.0,11.4,34.2,43.9,1790.0,22.2,7.38,7.74,67.8,52.3,1.34,284.0,93.0,1650.0,294.0,0.485,0.85,7.72,0.663,1.46,0.285,7.5,0.192,1.39,1.69,0.87,41.8,1.09,...,0.576,74.2,82.5,58.1,0.043,18.8,1.3,27.6,172.0,3.9,0.011,0.687,0.008,1.18,7430.0,5260.0,5.32,9.7,0.6,2.1,646705.504729,7634.515887,2.64748,1.869342,1.560331,1.99033,1.592393,1.764828,1.361376,1.224556,1.068824,1.235692,1.230902,1.301706,1.182738,1.032062,0.74623,0.702585,1.053217,0.875425


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.00723645
epoch : 2/200, recon loss = 0.00623200
epoch : 3/200, recon loss = 0.00608631
epoch : 4/200, recon loss = 0.00575263
epoch : 5/200, recon loss = 0.00497927
epoch : 6/200, recon loss = 0.00397156
epoch : 7/200, recon loss = 0.00338542
epoch : 8/200, recon loss = 0.00317286
epoch : 9/200, recon loss = 0.00310097
epoch : 10/200, recon loss = 0.00307003
epoch : 11/200, recon loss = 0.00305066
epoch : 12/200, recon loss = 0.00301768
epoch : 13/200, recon loss = 0.00299938
epoch : 14/200, recon loss = 0.00296950
epoch : 15/200, recon loss = 0.00292809
epoch : 16/200, recon loss = 0.00288915
epoch : 17/200, recon loss = 0.00283167
epoch : 18/200, recon loss = 0.00278632
epoch : 19/200, recon loss = 0.00272362
epoch : 20/200, recon loss = 0.00268940
epoch : 21/200, recon loss = 0.00263934
epoch : 22/200, recon loss = 0.00258300
epoch : 23/200, recon loss = 0.00256760
epoch : 24/200, recon loss = 0.00254809
epoch : 25/200, recon loss = 0.00252104
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

3.1464000386280997


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.01493986
epoch : 2/200, recon loss = 0.01184460
epoch : 3/200, recon loss = 0.01175992
epoch : 4/200, recon loss = 0.01161368
epoch : 5/200, recon loss = 0.01129655
epoch : 6/200, recon loss = 0.01056569
epoch : 7/200, recon loss = 0.00956796
epoch : 8/200, recon loss = 0.00903610
epoch : 9/200, recon loss = 0.00883364
epoch : 10/200, recon loss = 0.00874422
epoch : 11/200, recon loss = 0.00869116
epoch : 12/200, recon loss = 0.00863454
epoch : 13/200, recon loss = 0.00858350
epoch : 14/200, recon loss = 0.00852487
epoch : 15/200, recon loss = 0.00844616
epoch : 16/200, recon loss = 0.00835073
epoch : 17/200, recon loss = 0.00825278
epoch : 18/200, recon loss = 0.00815671
epoch : 19/200, recon loss = 0.00804786
epoch : 20/200, recon loss = 0.00797374
epoch : 21/200, recon loss = 0.00788826
epoch : 22/200, recon loss = 0.00784343
epoch : 23/200, recon loss = 0.00779354
epoch : 24/200, recon loss = 0.00773754
epoch : 25/200, recon loss = 0.00769605
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

3.1930650168502


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.240093997588478


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

4.456391764474138


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

5.005264371707832


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

5.1325066538984725


# 120 Columns

In [None]:
data_path = 'datasets/average datasets/byColumnsWithEmbeddings/120_columns_embeddings.csv'
num_embeddings = 30
num_columns = 120

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,...,1180,1194,1128,1195,1005,1075,1196,1197,1198,1084,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,1.824697,27.659061,14.595597,125.003892,0.384163,55.746662,0.476038,0.077504,0.149979,8.349769,0.472742,0.243665,133.394887,0.350355,1.942243,2.561062,42.944411,0.308332,0.116491,0.065003,0.60573,0.211003,0.023788,14.730179,2.983579,78.72934,12.225029,0.005267,0.005886,0.016719,0.004499,0.00387,0.00248,0.118839,0.002989,0.015767,0.012258,0.008632,0.038133,...,0.928287,0.039206,0.002494,0.726326,0.164999,0.002459,0.04541,0.042976,0.046929,0.021892,21488.189968,176.402583,0.119704,-0.093912,-0.048494,-0.007831,0.051473,0.023113,0.063163,0.04367,-0.005261,-0.053342,0.032862,0.017148,0.053914,-0.050016,-0.006804,-0.001309,0.045397,-0.002817,-0.00487,0.015158,-0.002832,-0.008018,0.007639,-0.005292,0.013278,0.02641,-0.008193,-0.003744
std,254401.8,7.838466,99.924214,45.754175,382.322293,1.05534,153.43266,1.48336,0.264892,0.515558,23.549114,4.416751,0.999601,1736.846193,1.186155,6.583455,9.446549,187.200801,1.605577,0.597369,0.559044,4.440086,1.941342,0.148635,56.459805,11.848737,302.68343,49.0407,0.030796,0.0431,0.223785,0.039142,0.065227,0.015745,0.828604,0.01705,0.10619,0.071481,0.054902,0.974266,...,15.719211,0.824407,0.055645,14.277867,2.905073,0.034625,1.333663,0.624933,1.671006,0.319893,254447.573253,3958.010969,1.112342,0.694392,0.598293,0.511537,0.453723,0.451007,0.417454,0.426072,0.400175,0.345475,0.349559,0.332055,0.321282,0.292925,0.270047,0.284938,0.261697,0.269206,0.240664,0.243863,0.245151,0.247991,0.235566,0.205593,0.224944,0.212094,0.209662,0.2061
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.527263,-1.442526,-1.189856,-1.064453,-1.077035,-0.829626,-0.933856,-0.969972,-1.097135,-0.920029,-1.006142,-1.000203,-0.798836,-0.687075,-0.668804,-0.801703,-0.726147,-0.698734,-0.73337,-0.848113,-0.673804,-0.609317,-0.622568,-0.618201,-0.605566,-0.802281
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133692.806088,-2328.507401,-0.796369,-0.667107,-0.477073,-0.343775,-0.236112,-0.286794,-0.177179,-0.292722,-0.303209,-0.332181,-0.196565,-0.219471,-0.160211,-0.217184,-0.194592,-0.17529,-0.131532,-0.159,-0.129686,-0.142597,-0.128993,-0.177006,-0.164276,-0.14966,-0.136069,-0.112281,-0.131025,-0.103372
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127823.125712,-153.274226,-0.002647,-0.011945,-0.071964,-0.022449,0.012034,0.040864,0.040584,-0.032776,-0.006464,0.045016,0.064438,-0.007061,0.051129,-0.054894,-0.020014,-0.030517,0.015432,-0.034379,-0.022279,0.000418,-0.023101,0.012883,-0.022581,-0.015932,0.015818,0.012556,-0.039363,0.00318
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288262.097477,3032.870912,0.983044,0.44114,0.323271,0.289041,0.257716,0.227569,0.309666,0.274538,0.267124,0.166322,0.24799,0.188549,0.251011,0.109037,0.190487,0.17819,0.198249,0.177465,0.098791,0.180564,0.128409,0.157268,0.136539,0.119725,0.154657,0.164012,0.105063,0.111023
max,1105897.0,99.4,1390.0,386.0,2520.0,7.66,997.0,12.7,1.92,7.87,96.7,99.6,13.2,40700.0,11.5,79.9,43.9,1790.0,22.8,8.61,17.0,73.2,52.6,3.4,311.0,93.0,1710.0,449.0,0.485,0.85,7.95,0.774,1.83,0.303,14.0,0.222,1.93,1.86,0.989,42.4,...,404.0,42.7,2.51,376.0,99.6,1.1,65.5,13.1,126.0,9.2,646711.616853,7635.496843,2.647689,1.869347,1.616849,2.041568,1.700616,1.764945,1.547788,1.348236,1.122911,1.235694,1.23092,1.30151,1.225628,1.031464,0.770631,0.843875,1.049396,0.867503,1.478596,0.921429,0.756656,0.908209,0.921199,0.965138,0.774711,0.770132,0.968808,0.881655


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,...,1180,1194,1128,1195,1005,1075,1196,1197,1198,1084,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
count,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,...,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0
mean,480730.7,1.838933,27.191816,14.418333,124.497043,0.37742,55.049886,0.472891,0.076689,0.147316,8.377778,0.465995,0.247112,126.185712,0.347301,1.912456,2.568743,42.08534,0.315506,0.119174,0.067715,0.623412,0.223398,0.024106,14.656143,2.980733,78.750095,12.213277,0.00536,0.005941,0.017104,0.004821,0.004366,0.002521,0.119267,0.003027,0.015955,0.011941,0.008555,0.043831,...,0.865958,0.035702,0.002479,0.663859,0.16813,0.002428,0.054016,0.043018,0.052289,0.022568,21467.008441,182.062106,0.119423,-0.095347,-0.048735,-0.007134,0.054944,0.023179,0.063675,0.043259,-0.004838,-0.0513,0.030691,0.017925,0.05288,-0.050042,-0.004916,-0.001487,0.045954,-0.004037,-0.004376,0.015646,-0.001045,-0.00797,0.006989,-0.004153,0.012878,0.027332,-0.007871,-0.002815
std,254394.7,7.873509,98.818185,45.429656,382.049756,1.047533,152.629159,1.481276,0.263418,0.507523,23.608917,4.309402,1.01059,1662.616839,1.178584,6.533157,9.457161,184.404872,1.624915,0.606916,0.59361,4.49624,2.050528,0.151031,56.189047,11.840536,302.770335,49.054157,0.031136,0.043461,0.226249,0.041357,0.070544,0.015894,0.83607,0.017135,0.10689,0.071769,0.054803,1.087551,...,15.102529,0.693033,0.056443,13.714635,2.924708,0.035324,1.488693,0.619556,1.854747,0.329757,254440.586435,3952.715216,1.109124,0.693789,0.599498,0.508734,0.453486,0.45005,0.417313,0.425748,0.400349,0.344332,0.351216,0.333003,0.321272,0.293894,0.270204,0.283863,0.263103,0.268481,0.241821,0.243146,0.245953,0.249616,0.235283,0.205396,0.225057,0.212131,0.210375,0.207533
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139555.939272,-7413.329674,-3.349638,-1.585647,-1.482007,-1.442526,-1.189856,-1.064453,-1.077035,-0.829626,-0.933843,-0.969972,-1.097135,-0.920029,-1.006142,-1.000139,-0.798836,-0.661197,-0.668804,-0.801681,-0.726147,-0.698734,-0.73337,-0.848113,-0.673804,-0.609317,-0.622568,-0.618201,-0.605566,-0.802281
25%,325636.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133704.266328,-2325.696385,-0.788362,-0.664546,-0.477113,-0.343361,-0.235777,-0.28633,-0.177309,-0.293599,-0.303225,-0.331717,-0.19922,-0.219429,-0.162129,-0.217265,-0.194685,-0.175304,-0.133387,-0.16027,-0.129654,-0.139515,-0.1285,-0.177134,-0.166202,-0.149504,-0.137358,-0.112711,-0.130938,-0.103349
50%,331391.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127842.480782,-115.754324,-0.006292,-0.018057,-0.07186,-0.021788,0.012815,0.040862,0.041067,-0.032651,-0.004219,0.045433,0.064433,-0.006994,0.050675,-0.055315,-0.016151,-0.0308,0.015893,-0.0345,-0.022281,0.005188,-0.019857,0.013883,-0.022115,-0.013533,0.015496,0.015006,-0.039344,0.003225
75%,747484.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288278.651157,3037.817807,0.982128,0.441242,0.32482,0.293505,0.26827,0.227682,0.314777,0.274099,0.267319,0.16643,0.242591,0.186729,0.250783,0.109022,0.19073,0.179107,0.199225,0.175066,0.098883,0.180564,0.12901,0.157526,0.135655,0.122485,0.154603,0.166079,0.10526,0.11341
max,1105895.0,99.4,1380.0,369.0,2520.0,7.46,997.0,12.7,1.85,4.17,96.7,99.6,13.2,40700.0,11.5,79.9,43.9,1790.0,22.8,8.61,17.0,73.2,52.6,3.4,286.0,93.0,1650.0,449.0,0.485,0.85,7.72,0.774,1.83,0.303,14.0,0.222,1.93,1.86,0.989,42.4,...,404.0,42.7,2.51,376.0,99.6,1.1,65.5,12.2,126.0,9.2,646709.579482,7635.496843,2.647689,1.869347,1.616849,2.041568,1.700616,1.764945,1.547652,1.228567,1.122911,1.166864,1.23092,1.30151,1.225628,1.031456,0.770631,0.843875,1.048969,0.867503,1.478596,0.881955,0.756656,0.908209,0.921199,0.965138,0.774711,0.770132,0.968808,0.881655


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1004,1087,1090,1092,1095,1091,1089,1098,1101,1051,1007,1002,1093,1079,1003,1009,1102,1265,1266,1264,1315,1316,1314,1094,1097,1137,1146,1300,1299,1404,1267,1273,1323,1167,1271,1304,1166,1175,1263,...,1180,1194,1128,1195,1005,1075,1196,1197,1198,1084,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
count,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,...,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,480836.0,1.767753,29.528043,15.304655,127.031286,0.411137,58.533766,0.488626,0.080763,0.160633,8.237734,0.499729,0.229878,162.231591,0.362572,2.061393,2.53034,46.380694,0.279637,0.105758,0.054155,0.535003,0.161425,0.022514,15.026326,2.994964,78.646318,12.272034,0.004897,0.005666,0.015179,0.003214,0.001885,0.002314,0.11713,0.002836,0.015014,0.013524,0.008942,0.015341,...,1.177604,0.053224,0.002552,0.976192,0.152476,0.002579,0.010988,0.042808,0.025486,0.019191,21572.916075,153.76449,0.120828,-0.088173,-0.047529,-0.01062,0.037588,0.022852,0.061117,0.045312,-0.006952,-0.06151,0.041546,0.014043,0.058051,-0.049911,-0.014357,-0.000596,0.043168,0.002061,-0.006847,0.013206,-0.00998,-0.008207,0.010237,-0.009849,0.014877,0.022725,-0.009483,-0.00746
std,254478.5,7.697916,104.229568,47.032037,383.476861,1.085789,156.604018,1.491882,0.270735,0.546488,23.312503,4.823097,0.954442,2006.587891,1.216119,6.780884,9.405717,197.989572,1.525738,0.557532,0.391267,4.208065,1.422409,0.138656,57.54002,11.883744,302.393274,48.996164,0.029397,0.04163,0.213682,0.028586,0.036916,0.015139,0.798193,0.01671,0.103358,0.070314,0.055309,0.119789,...,17.976672,1.215442,0.052341,16.338307,2.82568,0.03168,0.164785,0.646115,0.447997,0.276983,254524.057706,3979.800284,1.125335,0.696901,0.593561,0.522688,0.454492,0.454899,0.418094,0.427441,0.399548,0.349958,0.34278,0.328279,0.321352,0.289075,0.269336,0.289248,0.25603,0.272083,0.236014,0.246748,0.241832,0.241427,0.236721,0.206354,0.224524,0.211945,0.206822,0.200258
min,319878.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139554.920584,-7401.55824,-2.962405,-1.577342,-1.527263,-1.442463,-1.189817,-1.006999,-0.950851,-0.808687,-0.933856,-0.969439,-0.961805,-0.889532,-1.006043,-1.000203,-0.798091,-0.687075,-0.613786,-0.801703,-0.678342,-0.648339,-0.733317,-0.847795,-0.570389,-0.609083,-0.548479,-0.617991,-0.605527,-0.723781
25%,325730.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133609.273672,-2363.368248,-0.820904,-0.671543,-0.476913,-0.345315,-0.246092,-0.288083,-0.170787,-0.283309,-0.302658,-0.334026,-0.184878,-0.222849,-0.152735,-0.216593,-0.193624,-0.174951,-0.125512,-0.157261,-0.129686,-0.151529,-0.136937,-0.176882,-0.154057,-0.150147,-0.13336,-0.110359,-0.132679,-0.104974
50%,331483.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127748.761484,-252.350686,0.026713,-0.010729,-0.072314,-0.025157,-0.004068,0.040868,0.028471,-0.03344,-0.013315,0.040263,0.064911,-0.008833,0.054967,-0.051972,-0.027597,-0.03028,0.0132,-0.03403,-0.022273,-0.013266,-0.032071,0.007956,-0.022689,-0.022648,0.01677,0.003926,-0.039613,-0.002187
75%,747425.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288218.293889,3014.232766,0.990554,0.434246,0.323139,0.277699,0.234861,0.226925,0.279577,0.287347,0.262978,0.165174,0.267027,0.190587,0.253358,0.109219,0.189096,0.171073,0.188646,0.185413,0.098563,0.180564,0.124176,0.156376,0.141215,0.110624,0.155096,0.154093,0.100122,0.109099
max,1105897.0,88.1,1390.0,386.0,2480.0,7.66,995.0,12.3,1.92,7.87,95.4,99.6,12.9,40100.0,11.2,51.1,42.7,1460.0,22.6,7.46,7.78,71.7,51.3,2.69,311.0,91.0,1710.0,338.0,0.469,0.806,7.95,0.721,1.76,0.264,13.7,0.196,1.68,1.57,0.866,2.37,...,360.0,42.7,1.64,334.0,80.8,0.73,6.0,13.1,10.8,8.1,646711.616853,7634.515887,2.647585,1.869309,1.433469,2.02881,1.592535,1.758867,1.547788,1.348236,1.068685,1.235694,1.200106,1.246412,1.140985,1.031464,0.746734,0.806215,1.049396,0.867459,1.478362,0.921429,0.756642,0.764919,0.832335,0.949775,0.774605,0.769898,0.961489,0.790081


In [None]:
indices = test_data['fdc_id']
indices

10512     748595
12618    1105328
5891      329882
9742      747334
8399      335084
          ...   
12133    1104708
7710      333356
5455      329214
6016      330136
8229      334695
Name: fdc_id, Length: 2621, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149
count,2079.0,2131.0,2122.0,2116.0,2114.0,2112.0,2101.0,2086.0,2104.0,2071.0,2086.0,2115.0,2104.0,2096.0,2098.0,2087.0,2097.0,2102.0,2109.0,2098.0,2086.0,2053.0,2091.0,2081.0,2129.0,2076.0,2088.0,2074.0,2119.0,2121.0,2093.0,2129.0,2108.0,2122.0,2104.0,2108.0,2109.0,2072.0,2064.0,2115.0,...,2095.0,2063.0,2100.0,2082.0,2096.0,2084.0,2116.0,2135.0,2073.0,2099.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,1.822968,29.281089,15.401979,127.583176,0.413018,58.667614,0.482922,0.082035,0.162047,8.10099,0.582287,0.235206,111.011882,0.379628,2.164971,2.341284,49.029375,0.243778,0.105295,0.046996,0.452872,0.150144,0.024538,13.952427,3.262001,75.279865,12.383621,0.005118,0.00544,0.013107,0.003535,0.001947,0.002442,0.121552,0.00249,0.015886,0.014331,0.008495,0.014514,0.007035,...,1.207351,0.056714,0.002152,0.855187,0.133941,0.002893,0.010491,0.040094,0.018186,0.016532,21572.916075,153.76449,0.120828,-0.088173,-0.047529,-0.01062,0.037588,0.022852,0.061117,0.045312,-0.006952,-0.06151,0.041546,0.014043,0.058051,-0.049911,-0.014357,-0.000596,0.043168,0.002061,-0.006847,0.013206,-0.00998,-0.008207,0.010237,-0.009849,0.014877,0.022725,-0.009483,-0.00746
std,7.688367,105.779651,46.881167,386.194247,1.093257,156.451379,1.484678,0.272904,0.555224,23.085199,5.388674,0.954813,1455.138966,1.238272,6.922343,9.048164,203.083355,1.361023,0.54989,0.371828,3.720943,1.443742,0.150145,55.487474,12.358336,295.218949,48.990889,0.030959,0.041524,0.18333,0.030463,0.040624,0.015608,0.825986,0.015069,0.107239,0.072712,0.054448,0.119956,0.052296,...,18.558856,1.34506,0.045922,15.254636,2.57118,0.034639,0.128727,0.61766,0.376355,0.2354,254524.057706,3979.800284,1.125335,0.696901,0.593561,0.522688,0.454492,0.454899,0.418094,0.427441,0.399548,0.349958,0.34278,0.328279,0.321352,0.289075,0.269336,0.289248,0.25603,0.272083,0.236014,0.246748,0.241832,0.241427,0.236721,0.206354,0.224524,0.211945,0.206822,0.200258
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-139554.920584,-7401.55824,-2.962405,-1.577342,-1.527263,-1.442463,-1.189817,-1.006999,-0.950851,-0.808687,-0.933856,-0.969439,-0.961805,-0.889532,-1.006043,-1.000203,-0.798091,-0.687075,-0.613786,-0.801703,-0.678342,-0.648339,-0.733317,-0.847795,-0.570389,-0.609083,-0.548479,-0.617991,-0.605527,-0.723781
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-133609.273672,-2363.368248,-0.820904,-0.671543,-0.476913,-0.345315,-0.246092,-0.288083,-0.170787,-0.283309,-0.302658,-0.334026,-0.184878,-0.222849,-0.152735,-0.216593,-0.193624,-0.174951,-0.125512,-0.157261,-0.129686,-0.151529,-0.136937,-0.176882,-0.154057,-0.150147,-0.13336,-0.110359,-0.132679,-0.104974
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-127748.761484,-252.350686,0.026713,-0.010729,-0.072314,-0.025157,-0.004068,0.040868,0.028471,-0.03344,-0.013315,0.040263,0.064911,-0.008833,0.054967,-0.051972,-0.027597,-0.03028,0.0132,-0.03403,-0.022273,-0.013266,-0.032071,0.007956,-0.022689,-0.022648,0.01677,0.003926,-0.039613,-0.002187
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,288218.293889,3014.232766,0.990554,0.434246,0.323139,0.277699,0.234861,0.226925,0.279577,0.287347,0.262978,0.165174,0.267027,0.190587,0.253358,0.109219,0.189096,0.171073,0.188646,0.185413,0.098563,0.180564,0.124176,0.156376,0.141215,0.110624,0.155096,0.154093,0.100122,0.109099
max,88.1,1390.0,386.0,2480.0,7.66,995.0,12.3,1.92,7.87,95.4,99.6,12.9,38300.0,11.2,51.1,42.7,1460.0,15.8,6.53,7.78,68.6,51.3,2.69,283.0,91.0,1600.0,338.0,0.469,0.806,7.95,0.721,1.76,0.264,13.7,0.193,1.68,1.57,0.866,2.37,0.939,...,360.0,42.7,1.58,334.0,80.8,0.73,2.3,13.1,10.8,8.1,646711.616853,7634.515887,2.647585,1.869309,1.433469,2.02881,1.592535,1.758867,1.547788,1.348236,1.068685,1.235694,1.200106,1.246412,1.140985,1.031464,0.746734,0.806215,1.049396,0.867459,1.478362,0.921429,0.756642,0.764919,0.832335,0.949775,0.774605,0.769898,0.961489,0.790081


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.00581219
epoch : 2/200, recon loss = 0.00472384
epoch : 3/200, recon loss = 0.00469371
epoch : 4/200, recon loss = 0.00465589
epoch : 5/200, recon loss = 0.00459831
epoch : 6/200, recon loss = 0.00449245
epoch : 7/200, recon loss = 0.00428894
epoch : 8/200, recon loss = 0.00391550
epoch : 9/200, recon loss = 0.00337470
epoch : 10/200, recon loss = 0.00294085
epoch : 11/200, recon loss = 0.00269202
epoch : 12/200, recon loss = 0.00257523
epoch : 13/200, recon loss = 0.00252215
epoch : 14/200, recon loss = 0.00249634
epoch : 15/200, recon loss = 0.00247268
epoch : 16/200, recon loss = 0.00246725
epoch : 17/200, recon loss = 0.00245641
epoch : 18/200, recon loss = 0.00244636
epoch : 19/200, recon loss = 0.00243947
epoch : 20/200, recon loss = 0.00242779
epoch : 21/200, recon loss = 0.00241965
epoch : 22/200, recon loss = 0.00240268
epoch : 23/200, recon loss = 0.00238925
epoch : 24/200, recon loss = 0.00236624
epoch : 25/200, recon loss = 0.00235028
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

4.011076875370579


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.01557843
epoch : 2/200, recon loss = 0.00973100
epoch : 3/200, recon loss = 0.00969905
epoch : 4/200, recon loss = 0.00966056
epoch : 5/200, recon loss = 0.00960115
epoch : 6/200, recon loss = 0.00950456
epoch : 7/200, recon loss = 0.00932529
epoch : 8/200, recon loss = 0.00900926
epoch : 9/200, recon loss = 0.00853238
epoch : 10/200, recon loss = 0.00807931
epoch : 11/200, recon loss = 0.00779773
epoch : 12/200, recon loss = 0.00765350
epoch : 13/200, recon loss = 0.00758035
epoch : 14/200, recon loss = 0.00753677
epoch : 15/200, recon loss = 0.00750988
epoch : 16/200, recon loss = 0.00749236
epoch : 17/200, recon loss = 0.00746782
epoch : 18/200, recon loss = 0.00745426
epoch : 19/200, recon loss = 0.00743363
epoch : 20/200, recon loss = 0.00741140
epoch : 21/200, recon loss = 0.00739428
epoch : 22/200, recon loss = 0.00736950
epoch : 23/200, recon loss = 0.00733880
epoch : 24/200, recon loss = 0.00731051
epoch : 25/200, recon loss = 0.00726884
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

4.268289142753645


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

5.519246675443418


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

5.730640468576924


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

6.3815988582216585


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

6.5167976604056905


# 198 Columns

In [None]:
data_path = 'datasets/average datasets/byColumnsWithEmbeddings/198_columns_embeddings.csv'
num_embeddings = 50
num_columns = 198

In [None]:
dataset = pd.read_csv(data_path, index_col=0)
dataset[dataset.notna().all(axis=1)].describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
count,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,...,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0,13105.0
mean,480751.7,0.243665,1.942243,1.824697,0.164999,0.472742,1.999714,2.561062,0.147034,0.118187,0.120046,0.046356,0.008754,0.000476,0.681114,0.728577,0.049253,8.349769,8.36934,0.055411,0.002459,0.350355,0.005555,0.021892,0.106772,27.659061,0.476038,14.595597,55.746662,125.003892,133.394887,14.730179,0.384163,2.983579,0.077504,12.187135,0.149979,42.944411,0.840258,2.086684,...,-0.005261,-0.053342,0.032862,0.017149,0.053914,-0.050015,-0.006802,-0.001312,0.045397,-0.002815,-0.004867,0.015156,-0.002852,-0.00799,0.007587,-0.005105,0.013259,0.026135,-0.009025,-0.004237,-0.00899,-0.00029,0.009895,-0.005317,0.001269,0.006334,0.018435,0.004424,0.009392,-0.00317,-0.001646,-0.00683,0.003214,-0.000149,0.004693,-0.004719,0.000873,-0.001455,0.006576,0.003596
std,254401.8,0.999601,6.583455,7.838466,2.905073,4.416751,25.783702,9.446549,2.828219,1.055735,1.016245,0.441966,0.129818,0.022249,14.844129,15.911796,1.303063,23.549114,107.911712,1.491246,0.034625,1.186155,0.082114,0.319893,2.306569,99.924214,1.48336,45.754175,153.43266,382.322293,1736.846193,56.459805,1.05534,11.848737,0.264892,237.511065,0.515558,187.200801,5.232485,29.37536,...,0.400175,0.345475,0.349559,0.332054,0.321282,0.292924,0.270047,0.284935,0.261685,0.269203,0.240661,0.243893,0.245143,0.248005,0.235594,0.205819,0.225068,0.212354,0.209393,0.205605,0.208641,0.187673,0.193694,0.181126,0.188895,0.180591,0.17262,0.162121,0.180927,0.171087,0.170579,0.162898,0.166174,0.156805,0.165867,0.149952,0.151873,0.145215,0.1513,0.148431
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.933855,-0.969972,-1.097137,-0.920027,-1.006138,-1.000187,-0.798815,-0.687112,-0.668615,-0.801637,-0.726105,-0.698135,-0.734302,-0.848955,-0.669846,-0.604,-0.624276,-0.615452,-0.597283,-0.806154,-0.673437,-0.584369,-0.800978,-0.701795,-0.66374,-0.69752,-0.542536,-0.560396,-0.75379,-0.567381,-0.571239,-0.533034,-0.571749,-0.680382,-0.648647,-0.538644,-0.550678,-0.580334,-0.568859,-0.518081
25%,325648.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.303209,-0.332181,-0.196565,-0.21947,-0.160205,-0.21719,-0.194588,-0.175356,-0.131521,-0.159043,-0.129569,-0.142713,-0.128844,-0.176975,-0.164039,-0.152442,-0.136194,-0.11314,-0.133148,-0.100405,-0.112391,-0.121311,-0.102899,-0.127997,-0.112148,-0.103513,-0.087577,-0.118496,-0.090855,-0.124071,-0.098138,-0.108403,-0.089859,-0.061942,-0.089926,-0.083567,-0.099206,-0.084546,-0.092055,-0.094561
50%,331410.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.006464,0.045016,0.064438,-0.007061,0.051118,-0.054914,-0.020087,-0.030643,0.015376,-0.034391,-0.022269,0.000457,-0.022652,0.013189,-0.022747,-0.014965,0.015073,0.016768,-0.04464,-0.002055,-0.023092,-0.00499,0.001245,-0.016104,0.015313,0.010249,0.006607,0.005072,0.007793,-0.01332,-0.003852,-0.013888,0.001045,-0.004789,0.016949,-0.004967,0.005749,-0.009788,-0.003402,0.014599
75%,747468.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.267125,0.166323,0.24799,0.188549,0.251014,0.109046,0.190527,0.178192,0.198413,0.177397,0.099106,0.180667,0.128173,0.157417,0.136158,0.119462,0.155412,0.157847,0.104137,0.111113,0.074621,0.119803,0.117055,0.095796,0.125334,0.100478,0.106755,0.110569,0.081504,0.121099,0.097413,0.104568,0.101486,0.103675,0.091954,0.081006,0.087104,0.080598,0.092365,0.091704
max,1105897.0,13.2,79.9,99.4,99.6,99.6,833.0,43.9,100.0,25.9,23.6,6.34,3.83,1.04,377.0,429.0,67.9,96.7,3480.0,99.8,1.1,11.5,2.9,9.2,94.6,1390.0,12.7,386.0,997.0,2520.0,40700.0,311.0,7.66,93.0,1.92,7430.0,7.87,1790.0,97.7,912.0,...,1.122911,1.235694,1.23092,1.301511,1.225611,1.031472,0.77072,0.843955,1.049345,0.867486,1.478752,0.921318,0.756673,0.907832,0.918719,0.97021,0.778253,0.770962,0.961512,0.903499,0.642813,0.90378,0.722286,0.874324,0.835497,0.680991,0.846139,0.705718,0.985508,0.545425,0.757822,0.626882,0.575618,0.746988,0.949548,0.502315,0.644586,0.563422,0.705493,0.757814


In [None]:
data = dataset
train_data, test_data = train_test_split(data, test_size=test_size)

In [None]:
train_data.describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
count,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,...,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0,10484.0
mean,479456.2,0.24107,1.92887,1.821707,0.161388,0.461399,2.026159,2.518903,0.137158,0.118248,0.119258,0.046907,0.009217,0.000595,0.679321,0.732259,0.045408,8.340591,8.480275,0.046689,0.002474,0.344878,0.00579,0.02253,0.109062,27.787963,0.470211,14.459805,55.374952,123.405857,120.136112,14.461274,0.381541,2.938544,0.076159,12.661293,0.148089,42.606524,0.864411,2.053987,...,-0.004485,-0.053181,0.03519,0.016537,0.055051,-0.049332,-0.005589,-0.001322,0.047027,-0.003175,-0.003616,0.015436,-0.002492,-0.00639,0.007534,-0.00621,0.014451,0.026644,-0.009088,-0.003639,-0.008473,0.000794,0.010315,-0.003527,0.001452,0.006197,0.019294,0.003951,0.009452,-0.00393,-0.001654,-0.007448,0.003339,-0.001735,0.004656,-0.004393,0.000464,-5.8e-05,0.00596,0.003356
std,253272.3,0.993474,6.563354,7.902323,2.84605,4.304038,25.760127,9.375699,2.657333,1.057639,1.015281,0.442991,0.135252,0.024874,14.94027,16.128424,1.207659,23.558816,107.815927,1.235632,0.03539,1.176192,0.082493,0.321813,2.326975,101.368809,1.478138,45.646184,153.060188,379.104934,1575.954298,56.017479,1.053653,11.787381,0.262443,244.012777,0.516502,187.724232,5.320714,29.167232,...,0.399808,0.345308,0.347456,0.332473,0.321406,0.292574,0.270887,0.285244,0.262761,0.269514,0.241773,0.244525,0.246485,0.24756,0.236307,0.205018,0.225697,0.213409,0.209506,0.205961,0.209026,0.187867,0.1947,0.181905,0.189907,0.180628,0.172939,0.162417,0.180656,0.171054,0.170816,0.163941,0.166224,0.15738,0.166602,0.150019,0.152127,0.145858,0.151741,0.148118
min,319877.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.933843,-0.969439,-1.097137,-0.920027,-1.006138,-1.000187,-0.798815,-0.687112,-0.668615,-0.801637,-0.726105,-0.698135,-0.734302,-0.831774,-0.661366,-0.604,-0.624276,-0.615452,-0.597283,-0.806154,-0.673437,-0.584369,-0.787943,-0.701795,-0.66374,-0.69752,-0.529472,-0.536528,-0.75379,-0.557471,-0.571239,-0.533034,-0.571749,-0.680382,-0.648647,-0.538644,-0.550678,-0.580334,-0.559276,-0.518081
25%,325638.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.303365,-0.331833,-0.193712,-0.220025,-0.159651,-0.216723,-0.190823,-0.176393,-0.130259,-0.160182,-0.1311,-0.143404,-0.129463,-0.173109,-0.164184,-0.152609,-0.134107,-0.113856,-0.133126,-0.099676,-0.112175,-0.121307,-0.104513,-0.124913,-0.11212,-0.103636,-0.085785,-0.118501,-0.089202,-0.126307,-0.097939,-0.108631,-0.090715,-0.062085,-0.092483,-0.083392,-0.097852,-0.084365,-0.092098,-0.094591
50%,331398.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.003707,0.045433,0.064924,-0.007322,0.051658,-0.053177,-0.019007,-0.030674,0.017497,-0.034493,-0.02216,0.002732,-0.020873,0.013576,-0.022006,-0.015323,0.015488,0.016824,-0.044433,0.000929,-0.021161,-0.003651,0.001622,-0.012053,0.015342,0.010372,0.00714,0.004978,0.007067,-0.013731,-0.00392,-0.013905,0.000959,-0.005007,0.014839,-0.004766,0.005744,-0.009697,-0.0044,0.013266
75%,747459.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.267167,0.165864,0.247223,0.186875,0.252065,0.109015,0.191541,0.181806,0.199292,0.178718,0.099278,0.180669,0.128743,0.157482,0.13659,0.117972,0.156187,0.158928,0.103986,0.115242,0.073736,0.124609,0.121336,0.095865,0.125497,0.101867,0.106555,0.110043,0.081471,0.121044,0.097395,0.10524,0.099313,0.103633,0.092808,0.081048,0.086873,0.081748,0.092386,0.092551
max,1105896.0,13.2,79.9,99.4,99.6,99.6,833.0,43.9,100.0,25.9,23.6,6.34,3.62,1.04,377.0,429.0,67.9,96.0,3480.0,99.8,1.1,11.5,2.3,9.2,94.6,1390.0,12.7,386.0,997.0,2520.0,40300.0,311.0,7.66,93.0,1.92,7430.0,7.87,1790.0,97.7,903.0,...,1.122911,1.235694,1.222554,1.301511,1.225611,1.031472,0.77072,0.843955,1.048918,0.867486,1.478667,0.921318,0.756673,0.907832,0.918719,0.955452,0.778253,0.770962,0.961512,0.903499,0.642813,0.90378,0.722286,0.874324,0.835497,0.652537,0.83418,0.705718,0.985508,0.545425,0.757822,0.626882,0.575618,0.746988,0.949548,0.502315,0.607102,0.563422,0.705493,0.757814


In [None]:
test_data.describe()

Unnamed: 0,fdc_id,1002,1003,1004,1005,1007,1008,1009,1010,1011,1012,1013,1014,1024,1032,1039,1050,1051,1062,1063,1075,1079,1082,1084,1085,1087,1089,1090,1091,1092,1093,1094,1095,1097,1098,1100,1101,1102,1103,1105,...,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
count,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,...,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,485933.8,0.254044,1.995738,1.836654,0.179443,0.518111,1.893934,2.729699,0.186539,0.117944,0.123197,0.044151,0.006902,0.0,0.688287,0.71385,0.064632,8.386478,7.925601,0.090298,0.002396,0.372262,0.004617,0.019344,0.097612,27.143457,0.499348,15.138764,57.233499,131.396032,186.429989,15.805799,0.394651,3.163716,0.082882,10.2905,0.157542,44.295956,0.743647,2.217474,...,-0.008367,-0.053989,0.023549,0.019596,0.049367,-0.052746,-0.011657,-0.001276,0.038876,-0.001373,-0.009874,0.014033,-0.004292,-0.014387,0.007798,-0.000686,0.008489,0.024095,-0.008769,-0.006628,-0.01106,-0.004628,0.008216,-0.012476,0.000541,0.006882,0.014999,0.006314,0.009149,-0.000132,-0.001614,-0.00436,0.002712,0.006195,0.004842,-0.006025,0.00251,-0.007046,0.009037,0.004557
std,258854.5,1.02387,6.664248,7.579117,3.130588,4.842051,25.882445,9.72479,3.427916,1.048284,1.020284,0.437916,0.105316,0.0,14.455965,15.017005,1.629949,23.51472,108.313495,2.23876,0.031385,1.22518,0.080585,0.312139,2.223482,93.940655,1.504124,46.188392,154.934092,394.939205,2268.675857,58.19401,1.062197,12.091657,0.274454,209.52813,0.511793,185.121625,4.863353,30.198858,...,0.4017,0.346211,0.357762,0.330427,0.320807,0.294364,0.266655,0.283749,0.257284,0.268003,0.236137,0.241392,0.239741,0.249726,0.232766,0.208973,0.222512,0.208108,0.208984,0.204199,0.20712,0.186865,0.189641,0.177829,0.184825,0.180475,0.171326,0.160948,0.182043,0.171217,0.16966,0.158665,0.166007,0.15435,0.162924,0.149703,0.15087,0.142503,0.149526,0.149701
min,319908.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.933855,-0.969972,-0.973889,-0.889484,-0.844214,-0.975344,-0.795263,-0.683452,-0.613335,-0.801561,-0.678256,-0.625027,-0.734182,-0.848955,-0.669846,-0.603996,-0.549889,-0.615429,-0.597251,-0.639653,-0.673297,-0.525369,-0.800978,-0.603322,-0.65687,-0.595664,-0.542536,-0.560396,-0.5283,-0.567381,-0.546243,-0.532969,-0.571732,-0.606464,-0.621206,-0.53861,-0.515188,-0.532556,-0.568859,-0.486389
25%,325666.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.300802,-0.333182,-0.211713,-0.219276,-0.162027,-0.220277,-0.199176,-0.173755,-0.138855,-0.157256,-0.125563,-0.140941,-0.12602,-0.188958,-0.163733,-0.151088,-0.139832,-0.108629,-0.133923,-0.110559,-0.113465,-0.121356,-0.10051,-0.132482,-0.112254,-0.100508,-0.102012,-0.118485,-0.091904,-0.118641,-0.098746,-0.105721,-0.087408,-0.061454,-0.078814,-0.084274,-0.099628,-0.087041,-0.09185,-0.094435
50%,331467.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.014279,0.039952,0.050841,-0.003235,0.049275,-0.058812,-0.026589,-0.030488,0.002919,-0.033887,-0.022722,-0.007847,-0.028247,0.007928,-0.02412,-0.014729,0.013604,0.01652,-0.045117,-0.003871,-0.029207,-0.008502,-0.001047,-0.020061,0.01525,0.009938,0.004786,0.007806,0.011801,-0.011018,-0.00367,-0.013831,0.003276,-0.003831,0.025561,-0.009699,0.005768,-0.00995,-0.002146,0.015221
75%,747544.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.26616,0.166972,0.254764,0.193289,0.241754,0.109078,0.174927,0.169065,0.193289,0.17314,0.0981,0.180645,0.125638,0.157089,0.134833,0.12675,0.150886,0.149021,0.104876,0.096763,0.076813,0.115245,0.098988,0.09281,0.118545,0.095405,0.108273,0.110751,0.081606,0.127006,0.09746,0.102299,0.105051,0.107226,0.08918,0.080794,0.09344,0.075128,0.09228,0.085729
max,1105897.0,13.2,40.9,98.9,99.6,99.6,654.0,42.7,99.9,24.8,22.9,5.73,3.83,0.0,341.0,369.0,62.8,96.7,2740.0,99.8,0.82,11.3,2.9,8.3,94.0,1260.0,10.0,357.0,997.0,2460.0,40700.0,277.0,7.44,81.0,1.85,6170.0,3.44,1640.0,81.3,912.0,...,1.06885,1.16737,1.23092,1.266974,1.128774,1.031444,0.752794,0.805872,1.049345,0.867322,1.478752,0.881879,0.756606,0.763748,0.878603,0.97021,0.777881,0.770679,0.961082,0.849681,0.576198,0.609812,0.676092,0.818344,0.675452,0.680991,0.846139,0.700253,0.928878,0.511358,0.702707,0.620772,0.572546,0.701231,0.621155,0.483569,0.644586,0.563413,0.606385,0.7166


In [None]:
indices = test_data['fdc_id']
indices

7626     333186
7678     333305
6030     330174
10633    748750
2944     325099
          ...  
3560     326060
7279     332543
4019     326847
1607     322877
2520     324378
Name: fdc_id, Length: 2621, dtype: int64

In [None]:
del train_data['fdc_id']
del test_data['fdc_id']
columns = train_data.columns

In [None]:
train_data = train_data.to_numpy()
test_data.fillna(0, inplace=True)
test_data = test_data.to_numpy()
data = dataset.values
rows, cols = data.shape
cols -= 1

y_test = test_data.copy()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  downcast=downcast,


In [None]:
missed_data, missed_data_train, mask = missing_method(test_data, train_data, num_embeddings)
pd.DataFrame(missed_data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247
count,2094.0,2057.0,2110.0,2112.0,2128.0,2100.0,2096.0,2105.0,2111.0,2090.0,2084.0,2131.0,2083.0,2117.0,2081.0,2127.0,2111.0,2074.0,2093.0,2095.0,2075.0,2098.0,2067.0,2142.0,2113.0,2087.0,2087.0,2151.0,2086.0,2085.0,2079.0,2103.0,2122.0,2118.0,2119.0,2125.0,2092.0,2124.0,2122.0,2103.0,...,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0,2621.0
mean,0.272364,1.939626,1.76163,0.173873,0.527603,1.359048,2.784098,0.168546,0.123302,0.108555,0.045763,0.007405,0.0,0.601323,0.590581,0.07118,8.551644,5.432015,0.074233,0.002869,0.364964,0.005434,0.017078,0.098735,27.946522,0.507834,14.693531,57.714551,133.066155,199.305995,16.016835,0.389544,2.962253,0.082634,9.752006,0.162012,45.254828,0.72726,2.500471,0.415597,...,-0.008367,-0.053989,0.023549,0.019596,0.049367,-0.052746,-0.011657,-0.001276,0.038876,-0.001373,-0.009874,0.014033,-0.004292,-0.014387,0.007798,-0.000686,0.008489,0.024095,-0.008769,-0.006628,-0.01106,-0.004628,0.008216,-0.012476,0.000541,0.006882,0.014999,0.006314,0.009149,-0.000132,-0.001614,-0.00436,0.002712,0.006195,0.004842,-0.006025,0.00251,-0.007046,0.009037,0.004557
std,1.096187,6.569316,7.18062,2.980519,4.902768,20.176795,9.807325,3.137919,1.126675,0.908371,0.437763,0.113087,0.0,13.862775,13.57747,1.766844,23.868884,80.292984,2.237338,0.034784,1.21451,0.088752,0.289027,2.368098,98.178407,1.521939,45.325499,154.395943,398.289975,2395.089148,58.614368,1.049673,11.754408,0.273274,190.699412,0.522117,187.593139,4.862275,33.054742,10.61256,...,0.4017,0.346211,0.357762,0.330427,0.320807,0.294364,0.266655,0.283749,0.257284,0.268003,0.236137,0.241392,0.239741,0.249726,0.232766,0.208973,0.222512,0.208108,0.208984,0.204199,0.20712,0.186865,0.189641,0.177829,0.184825,0.180475,0.171326,0.160948,0.182043,0.171217,0.16966,0.158665,0.166007,0.15435,0.162924,0.149703,0.15087,0.142503,0.149526,0.149701
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.933855,-0.969972,-0.973889,-0.889484,-0.844214,-0.975344,-0.795263,-0.683452,-0.613335,-0.801561,-0.678256,-0.625027,-0.734182,-0.848955,-0.669846,-0.603996,-0.549889,-0.615429,-0.597251,-0.639653,-0.673297,-0.525369,-0.800978,-0.603322,-0.65687,-0.595664,-0.542536,-0.560396,-0.5283,-0.567381,-0.546243,-0.532969,-0.571732,-0.606464,-0.621206,-0.53861,-0.515188,-0.532556,-0.568859,-0.486389
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.300802,-0.333182,-0.211713,-0.219276,-0.162027,-0.220277,-0.199176,-0.173755,-0.138855,-0.157256,-0.125563,-0.140941,-0.12602,-0.188958,-0.163733,-0.151088,-0.139832,-0.108629,-0.133923,-0.110559,-0.113465,-0.121356,-0.10051,-0.132482,-0.112254,-0.100508,-0.102012,-0.118485,-0.091904,-0.118641,-0.098746,-0.105721,-0.087408,-0.061454,-0.078814,-0.084274,-0.099628,-0.087041,-0.09185,-0.094435
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.014279,0.039952,0.050841,-0.003235,0.049275,-0.058812,-0.026589,-0.030488,0.002919,-0.033887,-0.022722,-0.007847,-0.028247,0.007928,-0.02412,-0.014729,0.013604,0.01652,-0.045117,-0.003871,-0.029207,-0.008502,-0.001047,-0.020061,0.01525,0.009938,0.004786,0.007806,0.011801,-0.011018,-0.00367,-0.013831,0.003276,-0.003831,0.025561,-0.009699,0.005768,-0.00995,-0.002146,0.015221
75%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.26616,0.166972,0.254764,0.193289,0.241754,0.109078,0.174927,0.169065,0.193289,0.17314,0.0981,0.180645,0.125638,0.157089,0.134833,0.12675,0.150886,0.149021,0.104876,0.096763,0.076813,0.115245,0.098988,0.09281,0.118545,0.095405,0.108273,0.110751,0.081606,0.127006,0.09746,0.102299,0.105051,0.107226,0.08918,0.080794,0.09344,0.075128,0.09228,0.085729
max,13.2,40.9,82.0,99.6,99.6,408.0,42.0,99.8,24.8,22.9,5.05,3.83,0.0,341.0,369.0,62.8,96.7,1640.0,99.8,0.82,11.3,2.9,8.3,94.0,1260.0,9.28,357.0,982.0,2460.0,40700.0,277.0,7.44,81.0,1.85,5080.0,3.44,1640.0,81.3,912.0,316.0,...,1.06885,1.16737,1.23092,1.266974,1.128774,1.031444,0.752794,0.805872,1.049345,0.867322,1.478752,0.881879,0.756606,0.763748,0.878603,0.97021,0.777881,0.770679,0.961082,0.849681,0.576198,0.609812,0.676092,0.818344,0.675452,0.680991,0.846139,0.700253,0.928878,0.511358,0.702707,0.620772,0.572546,0.701231,0.621155,0.483569,0.644586,0.563413,0.606385,0.7166


In [None]:
# traditional methods test sets - impute across all foods
X_train = pd.DataFrame(missed_data_train, columns = columns)
X_test_mean_all = pd.DataFrame(missed_data, columns = columns)
X_test_median_all = pd.DataFrame(missed_data, columns = columns)

X_test_mean_all = impute_traditional_all_foods(X_train, X_test_mean_all, 'mean')
X_test_median_all = impute_traditional_all_foods(X_train, X_test_median_all, 'median')

In [None]:
# Traditional methods test sets - impute by food category
X_test_mean = pd.DataFrame(missed_data, columns = columns, index=indices)
X_test_median = X_test_mean.copy()
df_mean_cat = df_mean[df_mean['fdc_id'].isin(indices)]
X_test_mean = impute_traditional_by_category(X_test_mean, df_mean_cat, indices)
df_median_cat = df_median[df_median['fdc_id'].isin(indices)]
X_test_median = impute_traditional_by_category(X_test_median, df_median_cat, indices)

In [None]:
scaler = MinMaxScaler()
scaler.fit(train_data)
train_data = scaler.transform(train_data)
test_data = scaler.transform(test_data)
X_test_mean = scaler.transform(X_test_mean)
X_test_median = scaler.transform(X_test_median)
X_test_mean_all = scaler.transform(X_test_mean_all)
X_test_median_all = scaler.transform(X_test_median_all)

y_test = scaler.transform(y_test)

In [None]:
# datasets without embeddings for autoencoder
train_data_noEmbeddings = pd.DataFrame(train_data, columns = columns)
train_data_noEmbeddings = np.array(train_data_noEmbeddings.iloc[:, :-num_embeddings]) # df = df.iloc[: , :-1]
test_mean_noEmbeddings = np.array(pd.DataFrame(X_test_mean, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings = np.array(pd.DataFrame(X_test_median, columns = columns).iloc[:, :-num_embeddings])
test_mean_noEmbeddings_all = np.array(pd.DataFrame(X_test_mean_all, columns = columns).iloc[:, :-num_embeddings])
test_median_noEmbeddings_all = np.array(pd.DataFrame(X_test_median_all, columns = columns).iloc[:, :-num_embeddings])

In [None]:
missed_data = X_test_mean_all
missed_data = torch.from_numpy(missed_data).float()

train_data = torch.from_numpy(train_data).float()

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)

In [None]:
missed_data_noEm = test_mean_noEmbeddings_all
missed_data_noEm = torch.from_numpy(missed_data_noEm).float()

train_data_noEmbeddings = torch.from_numpy(train_data_noEmbeddings).float()

train_loader_noEm = torch.utils.data.DataLoader(dataset=train_data_noEmbeddings,
                                           batch_size=batch_size,
                                           shuffle=True)

## Autoencoder - without embeddings

In [None]:
model = Autoencoder(dim=cols-num_embeddings).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader_noEm, criterion, optimizer)

epoch : 1/200, recon loss = 0.00376239
epoch : 2/200, recon loss = 0.00293869
epoch : 3/200, recon loss = 0.00290759
epoch : 4/200, recon loss = 0.00288642
epoch : 5/200, recon loss = 0.00285750
epoch : 6/200, recon loss = 0.00281565
epoch : 7/200, recon loss = 0.00275599
epoch : 8/200, recon loss = 0.00266495
epoch : 9/200, recon loss = 0.00252990
epoch : 10/200, recon loss = 0.00233611
epoch : 11/200, recon loss = 0.00214065
epoch : 12/200, recon loss = 0.00196707
epoch : 13/200, recon loss = 0.00185716
epoch : 14/200, recon loss = 0.00178805
epoch : 15/200, recon loss = 0.00174686
epoch : 16/200, recon loss = 0.00172417
epoch : 17/200, recon loss = 0.00171114
epoch : 18/200, recon loss = 0.00169873
epoch : 19/200, recon loss = 0.00168953
epoch : 20/200, recon loss = 0.00168483
epoch : 21/200, recon loss = 0.00168238
epoch : 22/200, recon loss = 0.00167891
epoch : 23/200, recon loss = 0.00167378
epoch : 24/200, recon loss = 0.00167387
epoch : 25/200, recon loss = 0.00166987
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data_noEm.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols-num_embeddings, mask)

print(rmse_sum)

5.215949090468736


## Autoencoder - with embeddings

In [None]:
model = Autoencoder(dim=cols).to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
training(num_epochs, model, train_loader, criterion, optimizer)

epoch : 1/200, recon loss = 0.01425203
epoch : 2/200, recon loss = 0.00729474
epoch : 3/200, recon loss = 0.00727854
epoch : 4/200, recon loss = 0.00726253
epoch : 5/200, recon loss = 0.00724003
epoch : 6/200, recon loss = 0.00721195
epoch : 7/200, recon loss = 0.00717784
epoch : 8/200, recon loss = 0.00713054
epoch : 9/200, recon loss = 0.00706220
epoch : 10/200, recon loss = 0.00696389
epoch : 11/200, recon loss = 0.00682904
epoch : 12/200, recon loss = 0.00666035
epoch : 13/200, recon loss = 0.00648561
epoch : 14/200, recon loss = 0.00634780
epoch : 15/200, recon loss = 0.00624832
epoch : 16/200, recon loss = 0.00618158
epoch : 17/200, recon loss = 0.00613688
epoch : 18/200, recon loss = 0.00610990
epoch : 19/200, recon loss = 0.00608614
epoch : 20/200, recon loss = 0.00607367
epoch : 21/200, recon loss = 0.00606013
epoch : 22/200, recon loss = 0.00605006
epoch : 23/200, recon loss = 0.00604083
epoch : 24/200, recon loss = 0.00603483
epoch : 25/200, recon loss = 0.00602478
epoch : 2

In [None]:
model.eval()
rmse_sum = 0

filled_data = model(missed_data.to(device))
filled_data = filled_data.cpu().detach().numpy()

rmse_sum = rmse_error(test_data, filled_data, cols, mask)

print(rmse_sum)

5.712720665341518


## Mean and Median Imputation

### By category

In [None]:
imputed = X_test_mean
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

6.562400490537268


In [None]:
imputed = X_test_median
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

6.729803902928729


### Across all foods

In [None]:
imputed = X_test_mean_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

7.4467183945414215


In [None]:
imputed = X_test_median_all
error = rmse_error(y_test, imputed, cols-num_embeddings, mask)
print(error)

7.573344537892341
