# Missing data imputation using optimal transport

In this notebook, we will show how to use optimal transport to impute missing values in an incomplete dataset.

The methods we will use are described in the following paper:

B. Muzellec, J. Josse, C. Boyer, M. Cuturi, [Missing Data Imputation using Optimal Transport](https://arxiv.org/pdf/2002.03860.pdf).

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import numpy as np
import pandas as pd

from sklearn.preprocessing import scale
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer

import os

from geomloss import SamplesLoss

from imputers import OTimputer, RRimputer

from utils import *
from data_loaders import dataset_loader
from softimpute import softimpute, cv_softimpute

import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

torch.set_default_tensor_type('torch.DoubleTensor')

In [13]:
def get_data(name):
    df_ori = pd.read_csv('autodl-tmp/6/' + name + '_std.csv')
    df_mis = pd.read_csv('autodl-tmp/6/' + name + '_clean.csv')
    mask = pd.read_csv('autodl-tmp/6/' + name + '_sign.csv')
#     df_ori[df_ori > 100 * np.nanmax(df_mis)] = 1
    return df_ori, df_mis, mask

In [12]:
df_ori, df_mis, mask = get_data('XY6')

X_true = torch.from_numpy(np.array(df_ori))
X_miss = torch.from_numpy(np.array(df_mis))

n, d = X_miss.shape
batchsize = 128 
lr = 1e-2
epsilon = pick_epsilon(X_miss)

sk_imputer = OTimputer(eps=epsilon, batchsize=batchsize, lr=lr, niter=1001)
sk_imp, sk_maes, sk_rmses = sk_imputer.fit_transform(X_miss, verbose=True, report_interval=500, X_true=X_true)

INFO:root:batchsize = 128, epsilon = 0.0095
INFO:root:Iteration 0:	 Loss: 0.0381	 Validation MAE: 0.1780	RMSE: 0.2337
INFO:root:Iteration 500:	 Loss: 0.0214	 Validation MAE: 0.0895	RMSE: 0.1624
INFO:root:Iteration 1000:	 Loss: 0.0250	 Validation MAE: 0.0844	RMSE: 0.1593


In [10]:
output = pd.DataFrame(sk_imp.detach().numpy(), columns = df_ori.columns)
output.to_csv('autodl-tmp/Sink_50.csv', index = False)