Skip to content

Commit

Permalink
General classifier (#406)
Browse files Browse the repository at this point in the history
* Add a general tabular classifier.

* Separate Tabular Preprocessing

* Modify Predict function

* Modify Tabular Preprocessor

* Add example tabular_classification

* Add tabular examples.

* Add testing.

* Add preprocessing test and remove multiprocessing

* tabular

* update

* resolve conflicts in examples

* resolve conflicts test

* update data extraction method

* add comments

* Modify tabular tests and examples

* Delete three .pt files
  • Loading branch information
qingquansong authored and haifeng-jin committed Jan 3, 2019
1 parent 36909f2 commit a71e595
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 341 deletions.
199 changes: 87 additions & 112 deletions autokeras/tabular/tabular_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import numpy as np
from pandas import DataFrame
from scipy.stats import pearsonr
import multiprocessing as mp

LEVEL_HIGH = 32


def parallel_function(labels, first_batch_keys, task):
if task == 'label':
if min(labels) > first_batch_keys:
Expand Down Expand Up @@ -83,11 +82,10 @@ def call_parallel(tasks):
return results


class TabularPreprocessor():
class TabularPreprocessor:
def __init__(self):
"""
This constructor is supposed to initialize data members.
Use triple quotes for function documentation.
Initialization function for tabular preprocessor.
"""
self.num_cat_pair = {}

Expand All @@ -104,41 +102,32 @@ def __init__(self):

self.rest = None
self.budget = None
self.datainfo = None
self.data_info = None
self.n_time = None
self.n_num = None
self.n_cat = None

def remove_useless(self, x):
self.rest = np.where(np.max(x, 0) - np.min(x, 0) != 0)[0]
return x[:, self.rest]

def process_time(self, x):
cols = range(self.ntime)
cols = range(self.n_time)
if len(cols) > 10:
cols = cols[:10]
x_time = x[:, cols]
for i in cols:
for j in range(i+1, len(cols)):
x = np.append(x, np.expand_dims(x_time[:, i]-x_time[:, j], 1), 1)
for j in range(i + 1, len(cols)):
x = np.append(x, np.expand_dims(x_time[:, i] - x_time[:, j], 1), 1)
return x

def extract_data(self, F, ncat, nmvc):
def extract_data(self, raw_x):
# only get numerical variables
ret = np.concatenate([raw_x['TIME'], raw_x['NUM'], raw_x['CAT']], axis=1)
n_rows = ret.shape[0]
n_num_col = ret.shape[1] - self.n_cat

if type(F) == np.ndarray:
ret = F
n_rows = ret.shape[0]
n_num_col = ret.shape[1] - ncat - nmvc
else:
n_rows = F['numerical'].shape[0]
n_num_col = F['numerical'].shape[1]

data_list = [F['numerical']]
if ncat > 0:
data_list.append(F['CAT'])
if nmvc > 0:
data_list.append(F['MV'])
ret = np.concatenate(data_list, axis=1)

n_cat_col = nmvc + ncat
n_cat_col = self.n_cat
if n_cat_col <= 0:
return ret.astype(np.float64)

Expand All @@ -155,12 +144,11 @@ def extract_data(self, F, ncat, nmvc):

return ret.astype(np.float64)


def cat_to_num(self, X, ncat, nmvc, nnum, ntime, y=None):
def cat_to_num(self, x, y=None):
if y is not None:
mark = ntime + nnum
mark = self.n_time + self.n_num

for col_index in range(ntime + nnum, ntime + nnum + ncat + nmvc):
for col_index in range(self.n_time + self.n_num, self.n_time + self.n_num + self.n_cat):
if self.n_first_batch_keys[col_index] <= LEVEL_HIGH:
self.num_cat_pair[mark] = (col_index,)
mark += 1
Expand All @@ -169,20 +157,15 @@ def cat_to_num(self, X, ncat, nmvc, nnum, ntime, y=None):
mark += 1

mark_1 = 0
tasks_1 = []
tasks = []
for i, cat_col_index1 in enumerate(self.high_level_cat_keys):
for cat_col_index2 in self.high_level_cat_keys[i + 1:]:
tasks_1.append((X[:, (cat_col_index1, cat_col_index2)],
[y, cat_col_index1, cat_col_index2, mark_1],
'train_cat_cat'))
tasks.append((x[:, (cat_col_index1, cat_col_index2)],
[y, cat_col_index1, cat_col_index2, mark_1],
'train_cat_cat'))
mark_1 += 1

# pool = mp.Pool()
# results = [pool.apply_async(parallel_function, t) for t in tasks_1]
# all_results = [result.get() for result in results]
# pool.close()
# pool.join()
all_results = call_parallel(tasks_1)
all_results = call_parallel(tasks)

num_cat_pair_1 = {}
pearsonr_dict_1 = {}
Expand All @@ -195,22 +178,16 @@ def cat_to_num(self, X, ncat, nmvc, nnum, ntime, y=None):
num_cat_pair_1 = {i + mark: num_cat_pair_1[key] for i, key in enumerate(num_cat_pair_1)}
self.num_cat_pair.update(num_cat_pair_1)
mark += len(pearsonr_high_1)
print('num_cat_pair_1:', num_cat_pair_1)

mark_2 = 0
tasks_2 = []
for cat_col_index in self.high_level_cat_keys:
for num_col_index in range(ntime, ntime + nnum):
tasks_2.append((X[:, (num_col_index, cat_col_index)],
for num_col_index in range(self.n_time, self.n_time + self.n_num):
tasks_2.append((x[:, (num_col_index, cat_col_index)],
[y, num_col_index, cat_col_index, mark_2],
'train_num_cat'))
mark_2 += 1

# pool = mp.Pool()
# results = [pool.apply_async(parallel_function, t) for t in tasks_2]
# all_results = [result.get() for result in results]
# pool.close()
# pool.join()
all_results = call_parallel(tasks_2)

num_cat_pair_2 = {}
Expand All @@ -230,104 +207,95 @@ def cat_to_num(self, X, ncat, nmvc, nnum, ntime, y=None):
for key in self.order_num_cat_pair:
if len(self.num_cat_pair[key]) == 1:
(col_index,) = self.num_cat_pair[key]
tasks.append((X[:, col_index], self.n_first_batch_keys[col_index], 'label'))
tasks.append((x[:, col_index], self.n_first_batch_keys[col_index], 'label'))
if len(self.num_cat_pair[key]) == 2:
(col_index, col_index) = self.num_cat_pair[key]
tasks.append((X[:, col_index], self.n_first_batch_keys[col_index], 'frequency'))
tasks.append((x[:, col_index], self.n_first_batch_keys[col_index], 'frequency'))
if len(self.num_cat_pair[key]) == 3:
(cat_col_index1, cat_col_index2, mu) = self.num_cat_pair[key]
tasks.append((X[:, (cat_col_index1,
tasks.append((x[:, (cat_col_index1,
cat_col_index2)], self.n_first_batch_keys[cat_col_index1], 'cat_cat'))
elif len(self.num_cat_pair[key]) == 4:
(num_col_index, cat_col_index, mu, a) = self.num_cat_pair[key]
tasks.append((X[:, (num_col_index, cat_col_index)], self.n_first_batch_keys[cat_col_index], 'num_cat'))

# pool = mp.Pool()
# results = [pool.apply_async(parallel_function, t) for t in tasks]
# all_num = X.shape[1] - ncat - nmvc
# results = [X[:, :all_num]] + [result.get() for result in results]
#
# ret = np.concatenate(results, axis=1)
# pool.close()
# pool.join()
tasks.append((x[:, (num_col_index, cat_col_index)], self.n_first_batch_keys[cat_col_index], 'num_cat'))

results = call_parallel(tasks)
all_num = X.shape[1] - ncat - nmvc
results = [X[:, :all_num]] + results
all_num = x.shape[1] - self.n_cat
results = [x[:, :all_num]] + results
ret = np.concatenate(results, axis=1)

return ret #, ret.shape[1] - all_num, 0
return ret

def fit(self, F, y=None, time_limit=None, datainfo=None):
def fit(self, raw_x, y, time_limit, data_info):
"""
This function should train the model parameters.
Args:
x: A numpy.ndarray instance containing the training data.
y: Training label matrix of dim num_train_samples * num_labels.
datainfo: Meta-features of the dataset, which describe:
the number of four different features including:
time, numerical, categorical, and multi-value categorical.
Both inputs X and y are numpy arrays.
If fit is called multiple times on incremental data (train, test1, test2, etc.)
you should warm-start your training from the pre-trained model. Past data will
NOT be available for re-training.
raw_x: a numpy.ndarray instance containing the training data.
y: training label vector.
time_limit: remaining time budget.
data_info: meta-features of the dataset, which is an numpy.ndarray describing the
feature type of each column in raw_x. The feature type include:
'TIME' for temporal feature, 'NUM' for other numerical feature,
and 'CAT' for categorical feature.
"""
# Get Meta-Feature
self.budget = time_limit
self.datainfo = datainfo
[self.ntime, self.nnum, self.ncat, self.nmvc] = self.datainfo['loaded_feat_types']
self.data_info = data_info if data_info is not None else self.extract_data_info(raw_x)
print('QQ: {}'.format(self.data_info))

self.n_time = sum(self.data_info == 'TIME')
self.n_num = sum(self.data_info == 'NUM')
self.n_cat = sum(self.data_info == 'CAT')

self.total_samples = raw_x.shape[0]

print('QQ1: {}'.format(self.n_time))
print('QQ2: {}'.format(self.n_num))
print('QQ3: {}'.format(self.n_cat))
raw_x = {'TIME': raw_x[:, self.data_info == 'TIME'],
'NUM': raw_x[:, self.data_info == 'NUM'],
'CAT': raw_x[:, self.data_info == 'CAT']}

for col_index in range(self.nnum + self.ntime, self.nnum + self.ntime + self.ncat + self.nmvc):

for col_index in range(self.n_num + self.n_time, self.n_num + self.n_time + self.n_cat):
self.cat_to_int_label[col_index] = {}

x = self.extract_data(F, self.ncat, self.nmvc)
x = self.extract_data(raw_x)

d_size = x.shape[0] * x.shape[1] / self.budget
print('d_size', d_size)
if d_size > 35000:
self.feature_add_high_cat = 0
else:
self.feature_add_high_cat = 10

for col_index in range(self.nnum + self.ntime, self.nnum + self.ntime + self.ncat + self.nmvc):
# Iterate cat features
for col_index in range(self.n_num + self.n_time, self.n_num + self.n_time + self.n_cat):
self.n_first_batch_keys[col_index] = len(self.cat_to_int_label[col_index])
high_level_cat_keys_tmp = sorted(self.n_first_batch_keys, key=self.n_first_batch_keys.get, reverse=True)[
:self.feature_add_high_cat]
for i in high_level_cat_keys_tmp:
if self.n_first_batch_keys[i] > 1e2:
self.high_level_cat_keys.append(i)

print('hig_order_cat_pair:', self.high_level_cat_keys)
print('n_first_batch_keys:', self.n_first_batch_keys)


# Convert NaN to zeros
x = np.nan_to_num(x)

if datainfo is None:
self.nnum = x.shape[1]

# Encode high-order categorical data to numerical with frequency
x = self.cat_to_num(x, self.ncat, self.nmvc, self.nnum, self.ntime, y)
x = self.cat_to_num(x, y)

print('X.shape before remove_useless', x.shape)
x = self.process_time(x)
x = self.remove_useless(x)
print('X.shape after remove_useless', x.shape)

return x

def encode(self, F, time_limit=None, datainfo=None):
def encode(self, raw_x, time_limit=None):
"""
This function should train the model parameters.
Args:
x: A numpy.ndarray instance containing the training data.
y: Training label matrix of dim num_train_samples * num_labels.
datainfo: Meta-features of the dataset, which describe:
the number of four different features including:
time, numerical, categorical, and multi-value categorical.
raw_x: a numpy.ndarray instance containing the training/testing data.
time_limit: remaining time budget.
Both inputs X and y are numpy arrays.
If fit is called multiple times on incremental data (train, test1, test2, etc.)
you should warm-start your training from the pre-trained model. Past data will
Expand All @@ -341,29 +309,36 @@ def encode(self, F, time_limit=None, datainfo=None):
else:
self.budget = time_limit

if datainfo is None:
if self.datainfo is None:
datainfo = {'loaded_feat_types': [0] * 4}
self.datainfo = datainfo
[self.ntime, self.nnum, self.ncat, self.nmvc] = self.datainfo['loaded_feat_types']
else:
self.datainfo = datainfo


x = self.extract_data(F, self.ncat, self.nmvc)
raw_x = {'TIME': raw_x[:, self.data_info == 'TIME'],
'NUM': raw_x[:, self.data_info == 'NUM'],
'CAT': raw_x[:, self.data_info == 'CAT']}
x = self.extract_data(raw_x)

# Convert NaN to zeros
x = np.nan_to_num(x)
if datainfo is None:
self.nnum = x.shape[1]

# Encode high-order categorical data to numerical with frequency
x = self.cat_to_num(x, self.ncat, self.nmvc, self.nnum, self.ntime)
x = self.cat_to_num(x)

print('X.shape before remove_useless', x.shape)
x = self.process_time(x)
if self.rest is not None:
x = x[:, self.rest]
print('X.shape after remove_useless', x.shape)
return x


def extract_data_info(self, raw_x):
"""
This function extracts the data info automatically based on the type of each feature in raw_x.
return x
Args:
raw_x: a numpy.ndarray instance containing the training data.
"""
data_info = []
row_num, col_num = raw_x.shape
for col_idx in range(col_num):
try:
raw_x[:, col_idx].astype(np.float)
data_info.append('NUM')
except:
data_info.append('CAT')
return np.array(data_info)

0 comments on commit a71e595

Please sign in to comment.