Skip to content

Commit

Permalink
Add TabNet Datasets (#1153)
Browse files Browse the repository at this point in the history
* add poker hand dataset

* set split=True for goemotions datatset

* add sarcos dataset

* fix kaggle mixin + minor fixes

* add mushroom_edibility dataset + fix imports

Co-authored-by: Piero Molino <w4nderlust@gmail.com>
  • Loading branch information
kanishk16 and w4nderlust committed Apr 23, 2021
1 parent 45dc644 commit 6e14a6e
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 22 deletions.
8 changes: 4 additions & 4 deletions ludwig/datasets/ames_housing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin


def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False, kaggle_username=None, kaggle_api_key=None):
def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False, kaggle_username=None, kaggle_key=None):
dataset = AmesHousing(
cache_dir=cache_dir,
kaggle_username=kaggle_username,
kaggle_api_key=kaggle_api_key
kaggle_key=kaggle_key
)
return dataset.load(split=split)

Expand All @@ -42,8 +42,8 @@ class AmesHousing(CSVLoadMixin, MultifileJoinProcessMixin, KaggleDownloadMixin,
def __init__(self,
cache_dir=DEFAULT_CACHE_LOCATION,
kaggle_username=None,
kaggle_api_key=None):
kaggle_key=None):
self.kaggle_username = kaggle_username
self.kaggle_api_key = kaggle_api_key
self.kaggle_key = kaggle_key
self.is_kaggle_competition = True
super().__init__(dataset_name='ames_housing', cache_dir=cache_dir)
3 changes: 0 additions & 3 deletions ludwig/datasets/dbpedia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,3 @@ def process_downloaded_dataset(self):
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)



8 changes: 4 additions & 4 deletions ludwig/datasets/goemotions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import *
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin


def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False):
def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True):
dataset = GoEmotions(cache_dir=cache_dir)
return dataset.load(split=split)

Expand All @@ -37,9 +37,9 @@ class GoEmotions(UncompressedFileDownloadMixin, MultifileJoinProcessMixin,
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="goemotions", cache_dir=cache_dir)

def read_file(self, filetype, filename):
def read_file(self, filetype, filename, header=None):
file_df = pd.read_table(os.path.join(self.raw_dataset_path, filename),
header=None)
header=header)
return file_df

def process_downloaded_dataset(self):
Expand Down
5 changes: 2 additions & 3 deletions ludwig/datasets/mixins/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class KaggleDownloadMixin:
raw_temp_path: str
name: str
kaggle_username: str
kaggle_api_key: str
kaggle_key: str
is_kaggle_competition: bool

def download_raw_dataset(self):
Expand All @@ -29,7 +29,7 @@ def download_raw_dataset(self):
kaggle.json file we lookup the passed in username and the api key and
perform authentication.
"""
with self.update_env(KAGGLE_USERNAME=self.kaggle_username, KAGGLE_API_KEY=self.kaggle_api_key):
with self.update_env(KAGGLE_USERNAME=self.kaggle_username, KAGGLE_KEY=self.kaggle_key):
# Call authenticate explicitly to pick up new credentials if necessary
api = create_kaggle_client()
api.authenticate()
Expand Down Expand Up @@ -64,4 +64,3 @@ def competition_name(self):
@property
def archive_filename(self):
return self.config["archive_filename"]

1 change: 0 additions & 1 deletion ludwig/datasets/mixins/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,3 @@ def load_processed_dataset(self, split) -> Union[pd.DataFrame,
@property
def parquet_filename(self):
return self.config["parquet_filename"]

6 changes: 3 additions & 3 deletions ludwig/datasets/mixins/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def read_file(self, filetype, filename, header=0):
os.path.join(self.raw_dataset_path, filename), lines=True)
elif filetype == 'tsv':
file_df = pd.read_table(
os.path.join(self.raw_dataset_path, filename))
elif filetype == 'csv':
os.path.join(self.raw_dataset_path, filename))
elif filetype == 'csv' or filetype == 'data':
file_df = pd.read_csv(
os.path.join(self.raw_dataset_path, filename), header=header)
else:
Expand All @@ -56,7 +56,7 @@ def read_file(self, filetype, filename, header=0):
def process_downloaded_dataset(self, header=0):
"""Processes dataset
:param header: indicates whether raw data files contain headers
:param header: indicates whether raw data files contain headers
"""
downloaded_files = self.download_filenames
filetype = self.download_file_type
Expand Down
58 changes: 58 additions & 0 deletions ludwig/datasets/mushroom_edibility/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2021 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import pandas as pd

from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin

def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False):
dataset = MushroomEdibility(cache_dir=cache_dir)
return dataset.load(split=split)

class MushroomEdibility(UncompressedFileDownloadMixin, MultifileJoinProcessMixin,
CSVLoadMixin, BaseDataset):
"""
The Mushroom Edibility dataset
Additional Details:
http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.names
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="mushroom_edibility", cache_dir=cache_dir)

def process_downloaded_dataset(self):
super().process_downloaded_dataset(header=None)
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename))
columns = [
"class", "cap-shape", "cap-surface", "cap-color", "bruises?", "odor",
"gill-attachment", "gill-spacing", "gill-size", "gill-color",
"stalk-shape", "stalk-root", "stalk-surface-above-ring",
"stalk-surface-below-ring", "stalk-color-above-ring",
"stalk-color-below-ring", "veil-type", "veil-color", "ring-number",
"ring-type", "spore-print-color", "population", "habitat", "split"
]
processed_df.columns = columns
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)

7 changes: 7 additions & 0 deletions ludwig/datasets/mushroom_edibility/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version: 1.0
download_urls:
- "http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"
split_filenames:
train_file: agaricus-lepiota.data
download_file_type: data
csv_filename: mushroom_edibility.csv
54 changes: 54 additions & 0 deletions ludwig/datasets/poker_hand/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2021 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import pandas as pd

from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin


def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True):
dataset = PokerHand(cache_dir=cache_dir)
return dataset.load(split=split)

class PokerHand(UncompressedFileDownloadMixin, MultifileJoinProcessMixin,
CSVLoadMixin, BaseDataset):
"""
The Poker Hand dataset
Additional Details:
http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand.names
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="poker_hand", cache_dir=cache_dir)

def process_downloaded_dataset(self):
super().process_downloaded_dataset(header=None)
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename))
columns = [
"S1", "C1", "S2", "C2", "S3", "C3", "S4", "C4", "S5", "C5", "hand", "split"
]
processed_df.columns = columns
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)

9 changes: 9 additions & 0 deletions ludwig/datasets/poker_hand/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: 1.0
download_urls:
- "http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-training-true.data"
- "http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-testing.data"
split_filenames:
train_file: poker-hand-training-true.data
test_file: poker-hand-testing.data
download_file_type: data
csv_filename: poker_hand.csv
75 changes: 75 additions & 0 deletions ludwig/datasets/sarcos/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2021 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from scipy.io import loadmat
import pandas as pd

from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin

def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True):
dataset = Sarcos(cache_dir=cache_dir)
return dataset.load(split=split)

class Sarcos(UncompressedFileDownloadMixin, MultifileJoinProcessMixin,
CSVLoadMixin, BaseDataset):
"""
The Sarcos dataset
Details:
The data relates to an inverse dynamics problem for a seven
degrees-of-freedom SARCOS anthropomorphic robot arm. The
task is to map from a 21-dimensional input space (7 joint
positions, 7 joint velocities, 7 joint accelerations) to the
corresponding 7 joint torques. There are 44,484 training
examples and 4,449 test examples. The first 21 columns are
the input variables, and the 22nd column is used as the target
variable.
Dataset source:
Locally Weighted Projection RegressionL: An O(n) Algorithm for
Incremental Real Time Learning in High Dimensional Space,
S. Vijayakumar and S. Schaal, Proc ICML 2000.
http://www.gaussianprocess.org/gpml/data/
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="sarcos", cache_dir=cache_dir)

def read_file(self, filetype, filename, header=0):
mat = loadmat(os.path.join(self.raw_dataset_path, filename))
file_df = pd.DataFrame(mat[filename.split('.')[0]])
return file_df

def process_downloaded_dataset(self):
super().process_downloaded_dataset()
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename))
columns = []
columns += [f'position_{i}' for i in range(1, 8)]
columns += [f'velocity_{i}' for i in range(1, 8)]
columns += [f'acceleration_{i}' for i in range(1, 8)]
columns += [f'torque_{i}' for i in range(1, 8)]
columns += ['split']

processed_df.columns = columns
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)

9 changes: 9 additions & 0 deletions ludwig/datasets/sarcos/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: 1.0
download_urls:
- "http://www.gaussianprocess.org/gpml/data/sarcos_inv.mat"
- "http://www.gaussianprocess.org/gpml/data/sarcos_inv_test.mat"
split_filenames:
train_file: sarcos_inv.mat
test_file: sarcos_inv_test.mat
download_file_type: mat
csv_filename: sarcos.csv
8 changes: 4 additions & 4 deletions ludwig/datasets/titanic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from ludwig.datasets.mixins.load import CSVLoadMixin


def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False, kaggle_username=None, kaggle_api_key=None):
def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False, kaggle_username=None, kaggle_key=None):
dataset = Titanic(
cache_dir=cache_dir,
kaggle_username=kaggle_username,
kaggle_api_key=kaggle_api_key
kaggle_key=kaggle_key
)
return dataset.load(split=split)

Expand All @@ -41,9 +41,9 @@ class Titanic(CSVLoadMixin, KaggleDownloadMixin, BaseDataset):
def __init__(self,
cache_dir=DEFAULT_CACHE_LOCATION,
kaggle_username=None,
kaggle_api_key=None):
kaggle_key=None):
self.kaggle_username = kaggle_username
self.kaggle_api_key = kaggle_api_key
self.kaggle_key = kaggle_key
self.is_kaggle_competition = True
super().__init__(dataset_name='titanic', cache_dir=cache_dir)

Expand Down

0 comments on commit 6e14a6e

Please sign in to comment.