From bcc3d40544fc606288adbaa5aa51723023e790f2 Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Tue, 8 Feb 2022 14:31:32 -0800 Subject: [PATCH 1/3] Preproc: re-map sparse features to contiguous ids and freq thresholding [1/n] Summary: **Preproc for dlrm inspired by NVIDIA DLRM Preproc: ** https://catalog.ngc.nvidia.com/orgs/nvidia/resources/dlrm_for_pytorch/advanced (under dataset guidelines) - Re-map sparse ids to contiguous integers (`with this you can have an embedding table of size num_categories x emb_dim`) - Frequency thresholding; if an id shows up less than T times, remap it to a value of 1 (`Fit model on particular GPU`, `Capture all rarely occurring categories into one because otherwise for these categories you would overfit`) full details of benefits of this preprocessing: - https://github.com/NVIDIA/DeepLearningExamples/issues/1062#issuecomment-1015172876 Differential Revision: D33998505 fbshipit-source-id: a7a2fb6bcbfbb4ffa347cd3663f3f1a87a56b9aa --- torchrec/datasets/criteo.py | 111 +++++++++++++++++- .../scripts/contiguous_preproc_criteo.py | 75 ++++++++++++ ...rocess_criteo.py => npy_preproc_criteo.py} | 0 ...s_criteo.py => test_npy_preproc_criteo.py} | 2 +- torchrec/datasets/tests/test_criteo.py | 58 +++++++++ 5 files changed, 244 insertions(+), 2 deletions(-) create mode 100644 torchrec/datasets/scripts/contiguous_preproc_criteo.py rename torchrec/datasets/scripts/{preprocess_criteo.py => npy_preproc_criteo.py} (100%) rename torchrec/datasets/scripts/tests/{test_preprocess_criteo.py => test_npy_preproc_criteo.py} (95%) diff --git a/torchrec/datasets/criteo.py b/torchrec/datasets/criteo.py index 97abe5f80..288d537e0 100644 --- a/torchrec/datasets/criteo.py +++ b/torchrec/datasets/criteo.py @@ -33,7 +33,7 @@ ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor - +FREQUENCY_THRESHOLD = 3 INT_FEATURE_COUNT = 13 CAT_FEATURE_COUNT = 26 DEFAULT_LABEL_NAME = "label" @@ -375,6 +375,115 @@ def load_npy_range( data = np.fromfile(fin, dtype=dtype, count=num_entries) return data.reshape((num_rows, row_size)) + @staticmethod + def sparse_to_contiguous( + in_files: List[str], + output_dir: str, + frequency_threshold: int = FREQUENCY_THRESHOLD, + columns: int = CAT_FEATURE_COUNT, + path_manager_key: str = PATH_MANAGER_KEY, + output_file_suffix: str = "_contig_freq.npy", + ) -> None: + """ + Convert all sparse .npy files to have contiguous integers. Store in a separate + .npy file. All input files must be processed together because columns + can have matching IDs between files. Hence, they must be transformed + together. Also, the transformed IDs are not unique between columns. IDs + that appear less than frequency_threshold amount of times will be remapped + to have a value of 1. + + Example transformation, frequenchy_threshold of 2: + day_0_sparse.npy + | col_0 | col_1 | + ----------------- + | abc | xyz | + | iop | xyz | + + day_1_sparse.npy + | col_0 | col_1 | + ----------------- + | iop | tuv | + | lkj | xyz | + + day_0_sparse_contig.npy + | col_0 | col_1 | + ----------------- + | 1 | 2 | + | 2 | 2 | + + day_1_sparse_contig.npy + | col_0 | col_1 | + ----------------- + | 2 | 1 | + | 1 | 2 | + + Args: + in_files List[str]: Input directory of npy files. + out_dir (str): Output directory of processed npy files. + frequency_threshold: IDs occuring less than this frequency will be remapped to a value of 1. + path_manager_key (str): Path manager key used to load from different filesystems. + + Returns: + None. + """ + + # Load each .npy file of sparse features. Transformations are made along the columns. + # Thereby, transpose the input to ease operations. + # E.g. file_to_features = {"day_0_sparse": [array([[3,6,7],[7,9,3]]} + file_to_features: Dict[str, np.ndarray] = {} + for f in in_files: + name = os.path.basename(f).split(".")[0] + file_to_features[name] = np.load(f).transpose() + print(f"Successfully loaded file: {f}") + + # Iterate through each column in each file and map the sparse ids to contiguous ids. + for col in range(columns): + print(f"Processing column: {col}") + + # Iterate through each row in each file for the current column and determine the + # frequency of each sparse id. + sparse_to_frequency: Dict[int, int] = {} + if frequency_threshold > 1: + for f in file_to_features: + for _, sparse in enumerate(file_to_features[f][col]): + if sparse in sparse_to_frequency: + sparse_to_frequency[sparse] += 1 + else: + sparse_to_frequency[sparse] = 1 + + # Iterate through each row in each file for the current column and remap each + # sparse id to a contiguous id. The contiguous ints start at a value of 2 so that + # infrequenct IDs (determined by the frequency_threshold) can be remapped to 1. + running_sum = 2 + sparse_to_contiguous_int: Dict[int, int] = {} + + for f in file_to_features: + print(f"Processing file: {f}") + + for i, sparse in enumerate(file_to_features[f][col]): + if sparse not in sparse_to_contiguous_int: + # If the ID appears less than frequency_threshold amount of times + # remap the value to 1. + if ( + frequency_threshold > 1 + and sparse_to_frequency[sparse] < frequency_threshold + ): + sparse_to_contiguous_int[sparse] = 1 + else: + sparse_to_contiguous_int[sparse] = running_sum + running_sum += 1 + + # Re-map sparse value to contiguous in place. + file_to_features[f][col][i] = sparse_to_contiguous_int[sparse] + + path_manager = PathManagerFactory().get(path_manager_key) + for f, features in file_to_features.items(): + output_file = os.path.join(output_dir, f + output_file_suffix) + with path_manager.open(output_file, "wb") as fout: + print(f"Writing file: {output_file}") + # Transpose back the features when saving, as they were transposed when loading. + np.save(fout, features.transpose()) + class InMemoryBinaryCriteoIterDataPipe(IterableDataset): """ diff --git a/torchrec/datasets/scripts/contiguous_preproc_criteo.py b/torchrec/datasets/scripts/contiguous_preproc_criteo.py new file mode 100644 index 000000000..6133e7b1c --- /dev/null +++ b/torchrec/datasets/scripts/contiguous_preproc_criteo.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This script preprocesses the sparse feature files (binary npy) to such that +# the IDs become contiguous (with frequency thresholding applied). +# The results are saved in new binary (npy) files. + +import argparse +import os +import sys +from typing import List + +from torchrec.datasets.criteo import BinaryCriteoUtils + + +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Criteo sparse -> contiguous preprocessing script. " + ) + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Input directory containing the sparse features in numpy format (.npy). Files in the directory " + "should be named day_{0-23}_sparse.npy.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory to store npy files.", + ) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> None: + """ + This function processes the sparse features (.npy) to be contiguous + and saves the result in a separate (.npy) file. + + Args: + argv (List[str]): Command line args. + + Returns: + None. + """ + + args = parse_args(argv) + input_dir = args.input_dir + output_dir = args.output_dir + + # Look for files that end in "_sparse.npy" since this processing is + # only applied to sparse data. + input_files = list( + map( + lambda f: os.path.join(input_dir, f), + list(filter(lambda f: f.endswith("_sparse.npy"), os.listdir(input_dir))), + ) + ) + if not input_files: + raise ValueError( + f"There are no files that end with '_sparse.npy' in this directory: {input_dir}" + ) + + print(f"Processing files in: {input_files}. Outputs will be saved to {output_dir}.") + BinaryCriteoUtils.sparse_to_contiguous(input_files, output_dir) + print("Done processing.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/datasets/scripts/preprocess_criteo.py b/torchrec/datasets/scripts/npy_preproc_criteo.py similarity index 100% rename from torchrec/datasets/scripts/preprocess_criteo.py rename to torchrec/datasets/scripts/npy_preproc_criteo.py diff --git a/torchrec/datasets/scripts/tests/test_preprocess_criteo.py b/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py similarity index 95% rename from torchrec/datasets/scripts/tests/test_preprocess_criteo.py rename to torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py index c3a8b2945..e1fc73a9e 100644 --- a/torchrec/datasets/scripts/tests/test_preprocess_criteo.py +++ b/torchrec/datasets/scripts/tests/test_npy_preproc_criteo.py @@ -11,7 +11,7 @@ import numpy as np from torchrec.datasets.criteo import INT_FEATURE_COUNT, CAT_FEATURE_COUNT -from torchrec.datasets.scripts.preprocess_criteo import main +from torchrec.datasets.scripts.npy_preproc_criteo import main from torchrec.datasets.test_utils.criteo_test_utils import CriteoTest diff --git a/torchrec/datasets/tests/test_criteo.py b/torchrec/datasets/tests/test_criteo.py index 3969fd2ae..5888b590b 100644 --- a/torchrec/datasets/tests/test_criteo.py +++ b/torchrec/datasets/tests/test_criteo.py @@ -191,6 +191,64 @@ def test_load_npy_range(self) -> None: full[start_row : start_row + num_rows_to_select], partial ) + def test_sparse_to_contiguous_ids(self) -> None: + # Build the day .npy files. 3 days, 3 columns, 9 rows. + unprocessed_data = [ + np.array([[10, 70, 10], [20, 80, 20], [30, 90, 30]]), # day 0 + np.array([[20, 70, 40], [30, 80, 50], [40, 90, 60]]), # day 1 + np.array([[20, 70, 70], [20, 80, 80], [30, 90, 90]]), # day 2 + ] + + expected_data_no_freq_threshold = [ + np.array([[2, 2, 2], [3, 3, 3], [4, 4, 4]]), # day 0 + np.array([[3, 2, 5], [4, 3, 6], [5, 4, 7]]), # day 1 + np.array([[3, 2, 8], [3, 3, 9], [4, 4, 10]]), # day 2 + ] + self._validate_sparse_to_contiguous_preproc( + unprocessed_data, expected_data_no_freq_threshold, 0, 3 + ) + + expected_data_freq_threshold_2 = [ + np.array([[1, 2, 1], [2, 3, 1], [3, 4, 1]]), # day 0 + np.array([[2, 2, 1], [3, 3, 1], [1, 4, 1]]), # day 1 + np.array([[2, 2, 1], [2, 3, 1], [3, 4, 1]]), # day 2 + ] + self._validate_sparse_to_contiguous_preproc( + unprocessed_data, expected_data_freq_threshold_2, 2, 3 + ) + + def _validate_sparse_to_contiguous_preproc( + self, + unprocessed_data: List[np.ndarray], + expected_data: List[np.ndarray], + freq_threshold: int, + columns: int, + ) -> None: + # Save the unprocessed data to temporary directory. + temp_input_dir: str + temp_output_dir: str + with tempfile.TemporaryDirectory() as temp_input_dir, tempfile.TemporaryDirectory() as temp_output_dir: + input_files = [] + for i, data in enumerate(unprocessed_data): + file = os.path.join(temp_input_dir, f"day_{i}_sparse.npy") + input_files.append(file) + np.save(file, data) + + BinaryCriteoUtils.sparse_to_contiguous( + input_files, temp_output_dir, freq_threshold, columns + ) + + output_files = list( + map( + lambda f: os.path.join(temp_output_dir, f), + os.listdir(temp_output_dir), + ) + ) + output_files.sort() + for day, file in enumerate(output_files): + processed_data = np.load(file) + self.assertTrue(np.array_equal(expected_data[day], processed_data)) + class TestInMemoryBinaryCriteoIterDataPipe(CriteoTest): def _validate_batch( From 83a80dac3da0f04cef77a9afae923eeaec14bb18 Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Tue, 8 Feb 2022 14:31:32 -0800 Subject: [PATCH 2/3] Preproc: shuffle dataset [2/n] Summary: Shuffles the dataset by creating the full dataset from the split .npy files. Outputs the shuffled dataset in split format (labels, sparse, dense) .npy files. Differential Revision: D34007646 fbshipit-source-id: e6c265bf5572619e7b6c9647dc095919c70f76d1 --- examples/dlrm/data/dlrm_dataloader.py | 2 +- torchrec/datasets/criteo.py | 140 +++++++++++++++++++++++++ torchrec/datasets/tests/test_criteo.py | 71 +++++++++++++ 3 files changed, 212 insertions(+), 1 deletion(-) diff --git a/examples/dlrm/data/dlrm_dataloader.py b/examples/dlrm/data/dlrm_dataloader.py index f6872346a..24e66a56c 100644 --- a/examples/dlrm/data/dlrm_dataloader.py +++ b/examples/dlrm/data/dlrm_dataloader.py @@ -15,12 +15,12 @@ CAT_FEATURE_COUNT, DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES, + DAYS, InMemoryBinaryCriteoIterDataPipe, ) from torchrec.datasets.random import RandomRecDataset STAGES = ["train", "val", "test"] -DAYS = 24 def _get_random_dataloader( diff --git a/torchrec/datasets/criteo.py b/torchrec/datasets/criteo.py index 288d537e0..dbe918594 100644 --- a/torchrec/datasets/criteo.py +++ b/torchrec/datasets/criteo.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import os +import time from typing import ( Iterator, Any, @@ -36,6 +37,7 @@ FREQUENCY_THRESHOLD = 3 INT_FEATURE_COUNT = 13 CAT_FEATURE_COUNT = 26 +DAYS = 24 DEFAULT_LABEL_NAME = "label" DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)] DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)] @@ -484,6 +486,144 @@ def sparse_to_contiguous( # Transpose back the features when saving, as they were transposed when loading. np.save(fout, features.transpose()) + @staticmethod + def shuffle( + input_dir_labels_and_dense: str, + input_dir_sparse: str, + output_dir_shuffled: str, + rows_per_day: Dict[int, int], + output_dir_full_set: Optional[str] = None, + days: int = DAYS, + int_columns: int = INT_FEATURE_COUNT, + sparse_columns: int = CAT_FEATURE_COUNT, + path_manager_key: str = PATH_MANAGER_KEY, + ) -> None: + """ + Shuffle the dataset. Expects the files to be in .npy format and the data + to be split by day and by dense, sparse and label data. + Dense data must be in: day_x_dense.npy + Sparse data must be in: day_x_sparse.npy + Labels data must be in: day_x_labels.npy + + The dataset will be reconstructed, shuffled and then split back into + separate dense, sparse and labels files. + + Args: + input_dir_labels_and_dense (str): Input directory of labels and dense npy files. + input_dir_sparse (str): Input directory of sparse npy files. + output_dir_shuffled (str): Output directory for shuffled labels, dense and sparse npy files. + rows_per_day Dict[int, int]: Number of rows in each file. + output_dir_full_set (str): Output directory of the full dataset, if desired. + days (int): Number of day files. + int_columns (int): Number of columns with dense features. + columns (int): Total number of columns. + path_manager_key (str): Path manager key used to load from different filesystems. + """ + + total_rows = sum(rows_per_day.values()) + columns = int_columns + sparse_columns + 1 # add 1 for label column + full_dataset = np.zeros((total_rows, columns), dtype=np.float32) + curr_first_row = 0 + curr_last_row = 0 + for d in range(0, days): + curr_last_row += rows_per_day[d] + + # dense + path_to_file = os.path.join( + input_dir_labels_and_dense, f"day_{d}_dense.npy" + ) + data = np.load(path_to_file) + print( + f"Day {d} dense- {curr_first_row}-{curr_last_row} loaded files - {time.time()} - {path_to_file}" + ) + + full_dataset[curr_first_row:curr_last_row, 0:int_columns] = data + del data + + # sparse + path_to_file = os.path.join(input_dir_sparse, f"day_{d}_sparse.npy") + data = np.load(path_to_file) + print( + f"Day {d} sparse- {curr_first_row}-{curr_last_row} loaded files - {time.time()} - {path_to_file}" + ) + + full_dataset[curr_first_row:curr_last_row, int_columns : columns - 1] = data + del data + + # labels + path_to_file = os.path.join( + input_dir_labels_and_dense, f"day_{d}_labels.npy" + ) + data = np.load(path_to_file) + print( + f"Day {d} labels- {curr_first_row}-{curr_last_row} loaded files - {time.time()} - {path_to_file}" + ) + + full_dataset[curr_first_row:curr_last_row, columns - 1 :] = data + del data + + curr_first_row = curr_last_row + + path_manager = PathManagerFactory().get(path_manager_key) + + # Save the full dataset + if output_dir_full_set is not None: + full_output_file = os.path.join(output_dir_full_set, "full.npy") + with path_manager.open(full_output_file, "wb") as fout: + print(f"Writing full set file: {full_output_file}") + np.save(fout, full_dataset) + + print("Shuffling dataset") + np.random.shuffle(full_dataset) + + # Slice and save each portion into dense, sparse and labels + curr_first_row = 0 + curr_last_row = 0 + for d in range(0, days): + curr_last_row += rows_per_day[d] + + # write dense columns + shuffled_dense_file = os.path.join( + output_dir_shuffled, f"day_{d}_dense.npy" + ) + with path_manager.open(shuffled_dense_file, "wb") as fout: + print( + f"Writing rows {curr_first_row}-{curr_last_row-1} dense file: {shuffled_dense_file}" + ) + np.save(fout, full_dataset[curr_first_row:curr_last_row, 0:int_columns]) + + # write sparse columns + shuffled_sparse_file = os.path.join( + output_dir_shuffled, f"day_{d}_sparse.npy" + ) + with path_manager.open(shuffled_sparse_file, "wb") as fout: + print( + f"Writing rows {curr_first_row}-{curr_last_row-1} sparse file: {shuffled_sparse_file}" + ) + np.save( + fout, + full_dataset[ + curr_first_row:curr_last_row, int_columns : columns - 1 + ].astype(np.int32), + ) + + # write labels columns + shuffled_labels_file = os.path.join( + output_dir_shuffled, f"day_{d}_labels.npy" + ) + with path_manager.open(shuffled_labels_file, "wb") as fout: + print( + f"Writing rows {curr_first_row}-{curr_last_row-1} labels file: {shuffled_labels_file}" + ) + np.save( + fout, + full_dataset[curr_first_row:curr_last_row, columns - 1 :].astype( + np.int32 + ), + ) + + curr_first_row = curr_last_row + class InMemoryBinaryCriteoIterDataPipe(IterableDataset): """ diff --git a/torchrec/datasets/tests/test_criteo.py b/torchrec/datasets/tests/test_criteo.py index 5888b590b..ce5601d59 100644 --- a/torchrec/datasets/tests/test_criteo.py +++ b/torchrec/datasets/tests/test_criteo.py @@ -249,6 +249,77 @@ def _validate_sparse_to_contiguous_preproc( processed_data = np.load(file) self.assertTrue(np.array_equal(expected_data[day], processed_data)) + def test_shuffle(self) -> None: + """ + To ensure that the shuffle preserves the sanity of the input (no missing values), each row will + be uniquely identifiable by the value in the labels column. Each row will have a unique sequence. + The row ID will map to this sequence. The output map of row IDs to sequences must be the same as + the input map of row IDs to sequences. + """ + + days: int = 3 # need type annotation to be captured in local function + int_columns = 3 + cat_columns = 3 + + temp_input_dir: str + temp_output_dir: str + with tempfile.TemporaryDirectory() as temp_input_dir, tempfile.TemporaryDirectory() as temp_output_dir: + dense_data = [ # 3 columns, 3 rows per day + np.array( + [[i, i + 1, i + 2], [i + 3, i + 4, i + 5], [i + 6, i + 7, i + 8]] + ) + for i in range(days) + ] + sparse_data = [ + np.array( + [[i, i + 1, i + 2], [i + 3, i + 4, i + 5], [i + 6, i + 7, i + 8]] + ) + for i in range(days) + ] + labels_data = [np.array([[i], [i + 3], [i + 6]]) for i in range(3)] + + def save_data_list(data: List[np.ndarray], data_type: str) -> None: + for day, data in enumerate(data): + file = os.path.join(temp_input_dir, f"day_{day}_{data_type}.npy") + np.save(file, data) + + save_data_list(dense_data, "dense") + save_data_list(sparse_data, "sparse") + save_data_list(labels_data, "labels") + + rows_per_day = {0: 3, 1: 3, 2: 3} + BinaryCriteoUtils.shuffle( + temp_input_dir, + temp_input_dir, + temp_output_dir, + rows_per_day, + None, + days, + int_columns, + cat_columns, + ) + + # The label is the row id in this test. + def row_id_to_sequence(data_dir: str) -> Dict[int, List[int]]: + id_to_sequence = {} + for d in range(days): + label_data = np.load(os.path.join(data_dir, f"day_{d}_labels.npy")) + dense_data = np.load(os.path.join(data_dir, f"day_{d}_dense.npy")) + sparse_data = np.load(os.path.join(data_dir, f"day_{d}_sparse.npy")) + + for row in range(len(label_data)): + label = label_data[row][0] + id_to_sequence[label] = [label] + id_to_sequence[label].extend(dense_data[row]) + id_to_sequence[label].extend(sparse_data[row]) + + return id_to_sequence + + self.assertEqual( + row_id_to_sequence(temp_input_dir), + row_id_to_sequence(temp_output_dir), + ) + class TestInMemoryBinaryCriteoIterDataPipe(CriteoTest): def _validate_batch( From e00bc56fcdf0a374f79d07a7da609caf5cbf49cf Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Tue, 8 Feb 2022 14:31:51 -0800 Subject: [PATCH 3/3] Shuffle training batches [3/n] (#10) Summary: Pull Request resolved: https://github.com/facebookresearch/torchrec/pull/10 The original fb dlrm implementation shuffled batches to get their final results. Reviewed By: colin2328 Differential Revision: D34008000 fbshipit-source-id: b008c79841d75590f709150156455d1e9a68805a --- examples/dlrm/data/dlrm_dataloader.py | 1 + examples/dlrm/dlrm_main.py | 6 ++++++ torchrec/datasets/criteo.py | 10 ++++++++++ 3 files changed, 17 insertions(+) diff --git a/examples/dlrm/data/dlrm_dataloader.py b/examples/dlrm/data/dlrm_dataloader.py index 24e66a56c..809a3cf2e 100644 --- a/examples/dlrm/data/dlrm_dataloader.py +++ b/examples/dlrm/data/dlrm_dataloader.py @@ -85,6 +85,7 @@ def is_final_day(s: str) -> bool: batch_size=args.batch_size, rank=rank, world_size=world_size, + shuffle_batches=args.shuffle_batches, hashes=args.num_embeddings_per_feature if args.num_embeddings is None else ([args.num_embeddings] * CAT_FEATURE_COUNT), diff --git a/examples/dlrm/dlrm_main.py b/examples/dlrm/dlrm_main.py index d70be948a..99da2e2ee 100644 --- a/examples/dlrm/dlrm_main.py +++ b/examples/dlrm/dlrm_main.py @@ -133,6 +133,12 @@ def parse_args(argv: List[str]) -> argparse.Namespace: default=15.0, help="Learning rate.", ) + parser.add_argument( + "--shuffle_batches", + type=bool, + default=False, + help="Shuffle each batch during training.", + ) parser.set_defaults(pin_memory=None) return parser.parse_args(argv) diff --git a/torchrec/datasets/criteo.py b/torchrec/datasets/criteo.py index dbe918594..9595f2e5c 100644 --- a/torchrec/datasets/criteo.py +++ b/torchrec/datasets/criteo.py @@ -641,6 +641,7 @@ class InMemoryBinaryCriteoIterDataPipe(IterableDataset): batch_size (int): batch size. rank (int): rank. world_size (int): world size. + shuffle_batches (bool): Whether to shuffle batches hashes (Optional[int]): List of max categorical feature value for each feature. Length of this list should be CAT_FEATURE_COUNT. path_manager_key (str): Path manager key used to load from different @@ -667,6 +668,7 @@ def __init__( batch_size: int, rank: int, world_size: int, + shuffle_batches: bool = False, hashes: Optional[List[int]] = None, path_manager_key: str = PATH_MANAGER_KEY, ) -> None: @@ -676,6 +678,7 @@ def __init__( self.batch_size = batch_size self.rank = rank self.world_size = world_size + self.shuffle_batches = shuffle_batches self.hashes = hashes self.path_manager_key = path_manager_key self.path_manager: PathManager = PathManagerFactory().get(path_manager_key) @@ -739,6 +742,13 @@ def _load_data_for_rank(self) -> None: def _np_arrays_to_batch( self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray ) -> Batch: + if self.shuffle_batches: + # Shuffle all 3 in unison + shuffler = np.random.permutation(len(dense)) + dense = dense[shuffler] + sparse = sparse[shuffler] + labels = labels[shuffler] + return Batch( dense_features=torch.from_numpy(dense), sparse_features=KeyedJaggedTensor(