# Data Imputation - MisGAN

In [1]:
# A bit of setups
import os, sys
from os.path import isfile, join
from tqdm import tqdm

sys.path.insert(0, os.getcwd())
# change current folder to parent folder
parent = sys.path[0].rfind('/')
parent2 = sys.path[0].rfind('\\')

sys.path[0] = sys.path[0][:max(parent, parent2)]


from Imputation import Imputation

import argparse
from misgan.modules import preprocess
from misgan.modules import train
from misgan.modules import test
from misgan.modules import evaluate
from misgan.modules import impute

%matplotlib inline
%reload_ext autoreload
%autoreload 2

### Define MisGAN Modules

In [2]:
class MisGAN(Imputation):
    def __init__(self, args):
        super(MisGAN, self).__init__()
        self.args = args

    def preprocess(self, *args, **kwargs):
        data = []
        fpaths = []

        # single preprocessing
        if self.args.input:
            fpaths.append(self.args.input)
            dt, _ = super(MisGAN, self).preprocess(self.args.input, self.args, **kwargs)
            data.append([(dt, "", None)])
        else:
            for file in tqdm(os.listdir(join(os.pardir, "data"))):
                fpath = join(os.pardir, "data", file)
                if isfile(fpath):
                    fpaths.append(file)
                    data.append(super(MisGAN, self).preprocess(fpath, self.args, **kwargs))
        return preprocess.preprocess(self.args, data, fpaths)

    def train(self, *args, **kwargs):
        super(MisGAN, self).train(self.args.fname, self.args, **kwargs)
        train.train(self.args.fname)

    def test(self, *args, **kwargs):
        super(MisGAN, self).test(self.args.model, self.args.fname, self.args, **kwargs)
        test.test(self.args.model, self.args.fname)

    def impute(self, *args, **kwargs):
        _, self.impute_data = super(MisGAN, self).impute(self.args.model, join(os.pardir, self.args.fname), self.args, **kwargs)
        impute.impute(self.args, self.args.model, self.impute_data)

    def evaluate(self, *args, **kwargs):
        _, self.eval_data = super(MisGAN, self).evaluate(self.args.model, join(os.pardir, self.args.fname), self.args, **kwargs)
        self.model = 'wdbc_imputer.pth'
        return evaluate.evaluate(self.args, self.args.model, self.eval_data)

    def load_model(self, *args, **kwargs):
        pass

    def save_model(self, *args, **kwargs):
        pass

### Define MisGAN Default Arguments

In [3]:

class Args():
    def __init__(self):
        self.fname = False
        self.model = False
        self.ratio = False
        self.split = False
        self.ims = False
        self.input = False


args = Args()
misgan = MisGAN(args)

### Preprocessing

In [4]:
misgan.args.split=0.8
misgan.preprocess()

for f in os.listdir('data'):
    print(f)

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.26it/s]


letter-recognition.csv_test.data_loader
letter-recognition.csv_train.data_loader
spambase.csv_test.data_loader
spambase.csv_train.data_loader
spambase2.csv_test.data_loader
spambase2.csv_train.data_loader
wdbc.csv_.data_loader
wdbc.csv_test.data_loader
wdbc.csv_train.data_loader


### Training

In [5]:
misgan.args.fname = "wdbc.csv_train"

if os.path.exists("checkpoint/wdbc.csv_train_data_critic.pth"):
    print(f)
else:
    misgan.train()
    
for f in os.listdir('checkpoint'):
    print(f)   

wdbc.csv_train.data_loader
letter-recognition.csv_train_data_critic.pth
letter-recognition.csv_train_data_gen.pth
letter-recognition.csv_train_imputer.pth
letter-recognition.csv_train_impute_critic.pth
letter-recognition.csv_train_mask_critic.pth
letter-recognition.csv_train_mask_gen.pth
spambase.csv_test_data_critic.pth
spambase.csv_test_data_gen.pth
spambase.csv_test_imputer.pth
spambase.csv_test_impute_critic.pth
spambase.csv_test_mask_critic.pth
spambase.csv_test_mask_gen.pth
wdbc.csv_train_data_critic.pth
wdbc.csv_train_data_gen.pth
wdbc.csv_train_imputer.pth
wdbc.csv_train_impute_critic.pth
wdbc.csv_train_mask_critic.pth
wdbc.csv_train_mask_gen.pth


### Impute

In [6]:
misgan.args.fname = "data/wdbc.csv"
misgan.args.model = "wdbc.csv_train"
misgan.impute()

for f in os.listdir('result'):
    print(f)

  cols = [int(x % y) for x, y in zip(indices, rows)]
100%|████████████████████████████████████████████████████████████████████████████████| 941/941 [01:27<00:00, 10.78it/s]


CSV Saved
29_Apr_2019_02_00_24_impute_data.csv


### Evaluation

In [7]:
misgan.args.fname = "data/wdbc.csv"
misgan.args.model = "wdbc.csv_train"
rmse = misgan.evaluate()
print("RMSE = {0}".format(rmse))

100%|████████████████████████████████████████████████████████████████████████████████| 941/941 [02:58<00:00,  4.38it/s]


Imputed data length = 18207
CSV Saved
Evaluateion RMSE: 236.03297406111258
RMSE = 236.03297406111258
