Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
12 contributors

Users who have contributed to this file

@fchollet @taehoonlee @YuriyGuts @rvinas @reidsanders @pavlin99th @ozabluda @nzw0301 @jfsantos @hevensun @BoboTiG @bdwyer2
106 lines (87 sloc) 3.82 KB
# -*- coding: utf-8 -*-
"""Reuters topic classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ..utils.data_utils import get_file
from ..preprocessing.sequence import _remove_long_seq
import numpy as np
import json
import warnings
def load_data(path='reuters.npz', num_words=None, skip_top=0,
maxlen=None, test_split=0.2, seed=113,
start_char=1, oov_char=2, index_from=3, **kwargs):
"""Loads the Reuters newswire classification dataset.
# Arguments
path: where to cache the data (relative to `~/.keras/dataset`).
num_words: max number of words to include. Words are ranked
by how often they occur (in the training set) and only
the most frequent words are kept
skip_top: skip the top N most frequently occurring words
(which may not be informative).
maxlen: truncate sequences after this length.
test_split: Fraction of the dataset to be used as test data.
seed: random seed for sample shuffling.
start_char: The start of a sequence will be marked with this character.
Set to 1 because 0 is usually the padding character.
oov_char: words that were cut out because of the `num_words`
or `skip_top` limit will be replaced with this character.
index_from: index actual words with this index and higher.
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
Note that the 'out of vocabulary' character is only used for
words that were present in the training set but are not included
because they're not making the `num_words` cut here.
Words that were not seen in the training set but are in the test set
have simply been skipped.
"""
# Legacy support
if 'nb_words' in kwargs:
warnings.warn('The `nb_words` argument in `load_data` '
'has been renamed `num_words`.')
num_words = kwargs.pop('nb_words')
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
path = get_file(path,
origin='https://s3.amazonaws.com/text-datasets/reuters.npz',
file_hash='87aedbeb0cb229e378797a632c1997b6')
with np.load(path, allow_pickle=True) as f:
xs, labels = f['x'], f['y']
rng = np.random.RandomState(seed)
indices = np.arange(len(xs))
rng.shuffle(indices)
xs = xs[indices]
labels = labels[indices]
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
elif index_from:
xs = [[w + index_from for w in x] for x in xs]
if maxlen:
xs, labels = _remove_long_seq(maxlen, xs, labels)
if not num_words:
num_words = max([max(x) for x in xs])
# by convention, use 2 as OOV word
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs]
else:
xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
idx = int(len(xs) * (1 - test_split))
x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
def get_word_index(path='reuters_word_index.json'):
"""Retrieves the dictionary mapping words to word indices.
# Arguments
path: where to cache the data (relative to `~/.keras/dataset`).
# Returns
The word index dictionary.
"""
path = get_file(
path,
origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json',
file_hash='4d44cc38712099c9e383dc6e5f11a921')
with open(path) as f:
return json.load(f)
You can’t perform that action at this time.