-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add poker hand dataset * set split=True for goemotions datatset * add sarcos dataset * fix kaggle mixin + minor fixes * add mushroom_edibility dataset + fix imports Co-authored-by: Piero Molino <w4nderlust@gmail.com>
- Loading branch information
1 parent
45dc644
commit 6e14a6e
Showing
13 changed files
with
229 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2021 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import os | ||
import pandas as pd | ||
|
||
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION | ||
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin | ||
from ludwig.datasets.mixins.load import CSVLoadMixin | ||
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin | ||
|
||
def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False): | ||
dataset = MushroomEdibility(cache_dir=cache_dir) | ||
return dataset.load(split=split) | ||
|
||
class MushroomEdibility(UncompressedFileDownloadMixin, MultifileJoinProcessMixin, | ||
CSVLoadMixin, BaseDataset): | ||
""" | ||
The Mushroom Edibility dataset | ||
Additional Details: | ||
http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.names | ||
""" | ||
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION): | ||
super().__init__(dataset_name="mushroom_edibility", cache_dir=cache_dir) | ||
|
||
def process_downloaded_dataset(self): | ||
super().process_downloaded_dataset(header=None) | ||
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path, | ||
self.csv_filename)) | ||
columns = [ | ||
"class", "cap-shape", "cap-surface", "cap-color", "bruises?", "odor", | ||
"gill-attachment", "gill-spacing", "gill-size", "gill-color", | ||
"stalk-shape", "stalk-root", "stalk-surface-above-ring", | ||
"stalk-surface-below-ring", "stalk-color-above-ring", | ||
"stalk-color-below-ring", "veil-type", "veil-color", "ring-number", | ||
"ring-type", "spore-print-color", "population", "habitat", "split" | ||
] | ||
processed_df.columns = columns | ||
processed_df.to_csv( | ||
os.path.join(self.processed_dataset_path, self.csv_filename), | ||
index=False | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
version: 1.0 | ||
download_urls: | ||
- "http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data" | ||
split_filenames: | ||
train_file: agaricus-lepiota.data | ||
download_file_type: data | ||
csv_filename: mushroom_edibility.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2021 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import os | ||
import pandas as pd | ||
|
||
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION | ||
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin | ||
from ludwig.datasets.mixins.load import CSVLoadMixin | ||
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin | ||
|
||
|
||
def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True): | ||
dataset = PokerHand(cache_dir=cache_dir) | ||
return dataset.load(split=split) | ||
|
||
class PokerHand(UncompressedFileDownloadMixin, MultifileJoinProcessMixin, | ||
CSVLoadMixin, BaseDataset): | ||
""" | ||
The Poker Hand dataset | ||
Additional Details: | ||
http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand.names | ||
""" | ||
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION): | ||
super().__init__(dataset_name="poker_hand", cache_dir=cache_dir) | ||
|
||
def process_downloaded_dataset(self): | ||
super().process_downloaded_dataset(header=None) | ||
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path, | ||
self.csv_filename)) | ||
columns = [ | ||
"S1", "C1", "S2", "C2", "S3", "C3", "S4", "C4", "S5", "C5", "hand", "split" | ||
] | ||
processed_df.columns = columns | ||
processed_df.to_csv( | ||
os.path.join(self.processed_dataset_path, self.csv_filename), | ||
index=False | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
version: 1.0 | ||
download_urls: | ||
- "http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-training-true.data" | ||
- "http://archive.ics.uci.edu/ml/machine-learning-databases/poker/poker-hand-testing.data" | ||
split_filenames: | ||
train_file: poker-hand-training-true.data | ||
test_file: poker-hand-testing.data | ||
download_file_type: data | ||
csv_filename: poker_hand.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#! /usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright (c) 2021 Uber Technologies, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import os | ||
from scipy.io import loadmat | ||
import pandas as pd | ||
|
||
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION | ||
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin | ||
from ludwig.datasets.mixins.load import CSVLoadMixin | ||
from ludwig.datasets.mixins.process import MultifileJoinProcessMixin | ||
|
||
def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True): | ||
dataset = Sarcos(cache_dir=cache_dir) | ||
return dataset.load(split=split) | ||
|
||
class Sarcos(UncompressedFileDownloadMixin, MultifileJoinProcessMixin, | ||
CSVLoadMixin, BaseDataset): | ||
""" | ||
The Sarcos dataset | ||
Details: | ||
The data relates to an inverse dynamics problem for a seven | ||
degrees-of-freedom SARCOS anthropomorphic robot arm. The | ||
task is to map from a 21-dimensional input space (7 joint | ||
positions, 7 joint velocities, 7 joint accelerations) to the | ||
corresponding 7 joint torques. There are 44,484 training | ||
examples and 4,449 test examples. The first 21 columns are | ||
the input variables, and the 22nd column is used as the target | ||
variable. | ||
Dataset source: | ||
Locally Weighted Projection RegressionL: An O(n) Algorithm for | ||
Incremental Real Time Learning in High Dimensional Space, | ||
S. Vijayakumar and S. Schaal, Proc ICML 2000. | ||
http://www.gaussianprocess.org/gpml/data/ | ||
""" | ||
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION): | ||
super().__init__(dataset_name="sarcos", cache_dir=cache_dir) | ||
|
||
def read_file(self, filetype, filename, header=0): | ||
mat = loadmat(os.path.join(self.raw_dataset_path, filename)) | ||
file_df = pd.DataFrame(mat[filename.split('.')[0]]) | ||
return file_df | ||
|
||
def process_downloaded_dataset(self): | ||
super().process_downloaded_dataset() | ||
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path, | ||
self.csv_filename)) | ||
columns = [] | ||
columns += [f'position_{i}' for i in range(1, 8)] | ||
columns += [f'velocity_{i}' for i in range(1, 8)] | ||
columns += [f'acceleration_{i}' for i in range(1, 8)] | ||
columns += [f'torque_{i}' for i in range(1, 8)] | ||
columns += ['split'] | ||
|
||
processed_df.columns = columns | ||
processed_df.to_csv( | ||
os.path.join(self.processed_dataset_path, self.csv_filename), | ||
index=False | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
version: 1.0 | ||
download_urls: | ||
- "http://www.gaussianprocess.org/gpml/data/sarcos_inv.mat" | ||
- "http://www.gaussianprocess.org/gpml/data/sarcos_inv_test.mat" | ||
split_filenames: | ||
train_file: sarcos_inv.mat | ||
test_file: sarcos_inv_test.mat | ||
download_file_type: mat | ||
csv_filename: sarcos.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters