Skip to content

Commit

Permalink
Additional Text Classification Datasets (#1121)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANarayan committed Mar 17, 2021
1 parent 6a09fdd commit ed5ae80
Show file tree
Hide file tree
Showing 17 changed files with 461 additions and 8 deletions.
53 changes: 53 additions & 0 deletions ludwig/datasets/amazon_review_polarity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import TarDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import *

def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True):
dataset = AmazonPolarity(cache_dir=cache_dir)
return dataset.load(split=split)

class AmazonPolarity(TarDownloadMixin, MultifileJoinProcessMixin,
CSVLoadMixin, BaseDataset):
"""
The Amazon Reviews Polarity dataset
Details:
34,686,770 Amazon reviews from 6,643,669 users on 2,441,053 \
products, from the Stanford Network Analysis Project (SNAP). \
This dataset contains 600,000 training samples and 130,000 \
testing samples in each class.
Dataset source:
Character-level Convolutional Networks for Text Classification
Xiang Zhang et al., 2015
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="amazon_review_polarity", cache_dir=cache_dir)

def process_downloaded_dataset(self):
super(AmazonPolarity, self).process_downloaded_dataset(header=None)
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename))
processed_df.columns = ['label', 'review_tile', 'review_text', 'split']
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)



8 changes: 8 additions & 0 deletions ludwig/datasets/amazon_review_polarity/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
version: 1.0
download_urls:
- "https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz"
split_filenames:
train_file: train.csv
test_file: test.csv
download_file_type: csv
csv_filename: amazon_review_polarity.csv
53 changes: 53 additions & 0 deletions ludwig/datasets/amazon_reviews/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import TarDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import *

def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True):
dataset = AmazonReviews(cache_dir=cache_dir)
return dataset.load(split=split)

class AmazonReviews(TarDownloadMixin, MultifileJoinProcessMixin,
CSVLoadMixin, BaseDataset):
"""
The Amazon Reviews dataset
Details:
34,686,770 Amazon reviews from 6,643,669 users on 2,441,053 \
products, from the Stanford Network Analysis Project (SNAP). \
This dataset contains 600,000 training samples and 130,000 \
testing samples in each class.
Dataset source:
Character-level Convolutional Networks for Text Classification
Xiang Zhang et al., 2015
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="amazon_reviews", cache_dir=cache_dir)

def process_downloaded_dataset(self):
super(AmazonReviews, self).process_downloaded_dataset(header=None)
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename))
processed_df.columns = ['label', 'review_tile', 'review_text', 'split']
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)



8 changes: 8 additions & 0 deletions ludwig/datasets/amazon_reviews/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
version: 1.0
download_urls:
- "https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz"
split_filenames:
train_file: train.csv
test_file: test.csv
download_file_type: csv
csv_filename: amazon_reviews.csv
51 changes: 51 additions & 0 deletions ludwig/datasets/dbpedia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import TarDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import *

def load(cache_dir=DEFAULT_CACHE_LOCATION, split=True):
dataset = DBPedia(cache_dir=cache_dir)
return dataset.load(split=split)

class DBPedia(TarDownloadMixin, MultifileJoinProcessMixin,
CSVLoadMixin, BaseDataset):
"""
The DBPedia Ontology dataset.
Details:
40,000 training samples and 5,000 testing samples from 14 \
nonoverlapping classes from DBpedia 2014.
Dataset source:
Character-level Convolutional Networks for Text Classification
Xiang Zhang et al., 2015
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="dbpedia", cache_dir=cache_dir)

def process_downloaded_dataset(self):
super(DBPedia, self).process_downloaded_dataset(header=None)
processed_df = pd.read_csv(os.path.join(self.processed_dataset_path,
self.csv_filename))
processed_df.columns = ['label', 'title', 'content', 'split']
processed_df.to_csv(
os.path.join(self.processed_dataset_path, self.csv_filename),
index=False
)



8 changes: 8 additions & 0 deletions ludwig/datasets/dbpedia/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
version: 1.0
download_urls:
- "https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz"
split_filenames:
train_file: train.csv
test_file: test.csv
download_file_type: csv
csv_filename: dbpedia.csv
42 changes: 42 additions & 0 deletions ludwig/datasets/ethos_binary/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#! /usr/bin/env python
# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ludwig.datasets.base_dataset import BaseDataset, DEFAULT_CACHE_LOCATION
from ludwig.datasets.mixins.download import UncompressedFileDownloadMixin
from ludwig.datasets.mixins.load import CSVLoadMixin
from ludwig.datasets.mixins.process import *

def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False):
dataset = EthosBinary(cache_dir=cache_dir)
return dataset.load(split=split)

class EthosBinary(UncompressedFileDownloadMixin, IdentityProcessMixin,
CSVLoadMixin, BaseDataset):
"""The Ethos Hate Speech Dataset.
Source Paper:
ETHOS: an Online Hate Speech Detection Dataset
Ioannis Mollas and Zoe Chrysopoulou and Stamatis Karlos and
Grigorios Tsoumakas
"""
def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION):
super().__init__(dataset_name="ethos_binary", cache_dir=cache_dir)

def load_processed_dataset(self, split):
dataset_csv = os.path.join(self.processed_dataset_path,
self.csv_filename)
data_df = pd.read_csv(dataset_csv, sep=';')
return data_df
5 changes: 5 additions & 0 deletions ludwig/datasets/ethos_binary/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
version: 1.0
download_urls:
- https://raw.githubusercontent.com/intelligence-csd-auth-gr/Ethos-Hate-Speech-Dataset/master/ethos/ethos_data/Ethos_Dataset_Binary.csv
download_file_type: csv
csv_filename: Ethos_Dataset_Binary.csv
39 changes: 38 additions & 1 deletion ludwig/datasets/mixins/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import gzip
import os
import shutil
import tarfile
import urllib.request
from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile

from tqdm import tqdm


class TqdmUpTo(tqdm):
"""Provides progress bar for `urlretrieve`.
Expand Down Expand Up @@ -68,6 +68,43 @@ def download_raw_dataset(self):
def download_urls(self):
return self.config["download_urls"]

class TarDownloadMixin:
"""Downloads the compressed tar file containing the training data and extracts the contents."""

config: dict
raw_dataset_path: str
raw_temp_path: str

def download_raw_dataset(self):
"""
Download the raw dataset and extract the contents of the tar file and
store that in the cache location.
"""
os.makedirs(self.raw_temp_path, exist_ok=True)
for url in self.download_urls:
filename = url.split('/')[-1]
with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024,
miniters=1, desc=filename) as t:
urllib.request.urlretrieve(
url,
os.path.join(self.raw_temp_path, filename),
t.update_to
)

download_folder_name = url.split('/')[-1].split('.')[0]
file_path = os.path.join(self.raw_temp_path, filename)
with tarfile.open(file_path) as tar_file:
tar_file.extractall(path=self.raw_temp_path)

for f in os.scandir(os.path.join(self.raw_temp_path,
download_folder_name)):
shutil.copyfile(f, os.path.join(self.raw_temp_path, f.name))

os.rename(self.raw_temp_path, self.raw_dataset_path)

@property
def download_urls(self):
return self.config['download_urls']

class GZipDownloadMixin:
"""Downloads the gzip archive file containing the training data and extracts the contents."""
Expand Down
12 changes: 8 additions & 4 deletions ludwig/datasets/mixins/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MultifileJoinProcessMixin:
raw_dataset_path: str
processed_dataset_path: str

def read_file(self, filetype, filename):
def read_file(self, filetype, filename, header):
if filetype == 'json':
file_df = pd.read_json(
os.path.join(self.raw_dataset_path, filename))
Expand All @@ -48,17 +48,21 @@ def read_file(self, filetype, filename):
os.path.join(self.raw_dataset_path, filename))
elif filetype == 'csv':
file_df = pd.read_csv(
os.path.join(self.raw_dataset_path, filename))
os.path.join(self.raw_dataset_path, filename), header=header)
else:
raise ValueError(f'Unsupported file type: {filetype}')
return file_df

def process_downloaded_dataset(self):
def process_downloaded_dataset(self, header=0):
"""Processes dataset
:param header: indicates whether raw data files contain headers
"""
downloaded_files = self.download_filenames
filetype = self.download_file_type
all_files = []
for split_name, filename in downloaded_files.items():
file_df = self.read_file(filetype, filename)
file_df = self.read_file(filetype, filename, header)
if split_name == 'train_file':
file_df['split'] = 0
elif split_name == 'val_file':
Expand Down
9 changes: 6 additions & 3 deletions ludwig/datasets/sst2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@


def load(cache_dir=DEFAULT_CACHE_LOCATION, split=False,
include_subtrees=False, convert_parentheses=True):
include_subtrees=False, convert_parentheses=True,
remove_duplicates=False):
print("loaddding")
dataset = SST2(cache_dir=cache_dir, include_subtrees=include_subtrees,
convert_parentheses=convert_parentheses)
convert_parentheses=onvert_parentheses,
remove_duplicates=remove_duplicates)
return dataset.load(split=split)


Expand Down Expand Up @@ -54,7 +57,7 @@ def __init__(self, cache_dir=DEFAULT_CACHE_LOCATION,
include_subtrees=include_subtrees,
discard_neutral=True,
convert_parentheses=convert_parentheses,
remove_duplicates=False)
remove_duplicates=remove_duplicates)

def get_sentiment_label(self, id2sent, phrase_id):
sentiment = id2sent[phrase_id]
Expand Down
Loading

0 comments on commit ed5ae80

Please sign in to comment.