Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/dlrm/data/dlrm_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
CAT_FEATURE_COUNT,
DEFAULT_CAT_NAMES,
DEFAULT_INT_NAMES,
DAYS,
InMemoryBinaryCriteoIterDataPipe,
)
from torchrec.datasets.random import RandomRecDataset

STAGES = ["train", "val", "test"]
DAYS = 24


def _get_random_dataloader(
Expand Down Expand Up @@ -85,6 +85,7 @@ def is_final_day(s: str) -> bool:
batch_size=args.batch_size,
rank=rank,
world_size=world_size,
shuffle_batches=args.shuffle_batches,
hashes=args.num_embeddings_per_feature
if args.num_embeddings is None
else ([args.num_embeddings] * CAT_FEATURE_COUNT),
Expand Down
6 changes: 6 additions & 0 deletions examples/dlrm/dlrm_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
default=15.0,
help="Learning rate.",
)
parser.add_argument(
"--shuffle_batches",
type=bool,
default=False,
help="Shuffle each batch during training.",
)
parser.set_defaults(pin_memory=None)
return parser.parse_args(argv)

Expand Down
261 changes: 260 additions & 1 deletion torchrec/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import os
import time
from typing import (
Iterator,
Any,
Expand Down Expand Up @@ -33,9 +34,10 @@
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


FREQUENCY_THRESHOLD = 3
INT_FEATURE_COUNT = 13
CAT_FEATURE_COUNT = 26
DAYS = 24
DEFAULT_LABEL_NAME = "label"
DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)]
DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)]
Expand Down Expand Up @@ -375,6 +377,253 @@ def load_npy_range(
data = np.fromfile(fin, dtype=dtype, count=num_entries)
return data.reshape((num_rows, row_size))

@staticmethod
def sparse_to_contiguous(
in_files: List[str],
output_dir: str,
frequency_threshold: int = FREQUENCY_THRESHOLD,
columns: int = CAT_FEATURE_COUNT,
path_manager_key: str = PATH_MANAGER_KEY,
output_file_suffix: str = "_contig_freq.npy",
) -> None:
"""
Convert all sparse .npy files to have contiguous integers. Store in a separate
.npy file. All input files must be processed together because columns
can have matching IDs between files. Hence, they must be transformed
together. Also, the transformed IDs are not unique between columns. IDs
that appear less than frequency_threshold amount of times will be remapped
to have a value of 1.

Example transformation, frequenchy_threshold of 2:
day_0_sparse.npy
| col_0 | col_1 |
-----------------
| abc | xyz |
| iop | xyz |

day_1_sparse.npy
| col_0 | col_1 |
-----------------
| iop | tuv |
| lkj | xyz |

day_0_sparse_contig.npy
| col_0 | col_1 |
-----------------
| 1 | 2 |
| 2 | 2 |

day_1_sparse_contig.npy
| col_0 | col_1 |
-----------------
| 2 | 1 |
| 1 | 2 |

Args:
in_files List[str]: Input directory of npy files.
out_dir (str): Output directory of processed npy files.
frequency_threshold: IDs occuring less than this frequency will be remapped to a value of 1.
path_manager_key (str): Path manager key used to load from different filesystems.

Returns:
None.
"""

# Load each .npy file of sparse features. Transformations are made along the columns.
# Thereby, transpose the input to ease operations.
# E.g. file_to_features = {"day_0_sparse": [array([[3,6,7],[7,9,3]]}
file_to_features: Dict[str, np.ndarray] = {}
for f in in_files:
name = os.path.basename(f).split(".")[0]
file_to_features[name] = np.load(f).transpose()
print(f"Successfully loaded file: {f}")

# Iterate through each column in each file and map the sparse ids to contiguous ids.
for col in range(columns):
print(f"Processing column: {col}")

# Iterate through each row in each file for the current column and determine the
# frequency of each sparse id.
sparse_to_frequency: Dict[int, int] = {}
if frequency_threshold > 1:
for f in file_to_features:
for _, sparse in enumerate(file_to_features[f][col]):
if sparse in sparse_to_frequency:
sparse_to_frequency[sparse] += 1
else:
sparse_to_frequency[sparse] = 1

# Iterate through each row in each file for the current column and remap each
# sparse id to a contiguous id. The contiguous ints start at a value of 2 so that
# infrequenct IDs (determined by the frequency_threshold) can be remapped to 1.
running_sum = 2
sparse_to_contiguous_int: Dict[int, int] = {}

for f in file_to_features:
print(f"Processing file: {f}")

for i, sparse in enumerate(file_to_features[f][col]):
if sparse not in sparse_to_contiguous_int:
# If the ID appears less than frequency_threshold amount of times
# remap the value to 1.
if (
frequency_threshold > 1
and sparse_to_frequency[sparse] < frequency_threshold
):
sparse_to_contiguous_int[sparse] = 1
else:
sparse_to_contiguous_int[sparse] = running_sum
running_sum += 1

# Re-map sparse value to contiguous in place.
file_to_features[f][col][i] = sparse_to_contiguous_int[sparse]

path_manager = PathManagerFactory().get(path_manager_key)
for f, features in file_to_features.items():
output_file = os.path.join(output_dir, f + output_file_suffix)
with path_manager.open(output_file, "wb") as fout:
print(f"Writing file: {output_file}")
# Transpose back the features when saving, as they were transposed when loading.
np.save(fout, features.transpose())

@staticmethod
def shuffle(
input_dir_labels_and_dense: str,
input_dir_sparse: str,
output_dir_shuffled: str,
rows_per_day: Dict[int, int],
output_dir_full_set: Optional[str] = None,
days: int = DAYS,
int_columns: int = INT_FEATURE_COUNT,
sparse_columns: int = CAT_FEATURE_COUNT,
path_manager_key: str = PATH_MANAGER_KEY,
) -> None:
"""
Shuffle the dataset. Expects the files to be in .npy format and the data
to be split by day and by dense, sparse and label data.
Dense data must be in: day_x_dense.npy
Sparse data must be in: day_x_sparse.npy
Labels data must be in: day_x_labels.npy

The dataset will be reconstructed, shuffled and then split back into
separate dense, sparse and labels files.

Args:
input_dir_labels_and_dense (str): Input directory of labels and dense npy files.
input_dir_sparse (str): Input directory of sparse npy files.
output_dir_shuffled (str): Output directory for shuffled labels, dense and sparse npy files.
rows_per_day Dict[int, int]: Number of rows in each file.
output_dir_full_set (str): Output directory of the full dataset, if desired.
days (int): Number of day files.
int_columns (int): Number of columns with dense features.
columns (int): Total number of columns.
path_manager_key (str): Path manager key used to load from different filesystems.
"""

total_rows = sum(rows_per_day.values())
columns = int_columns + sparse_columns + 1 # add 1 for label column
full_dataset = np.zeros((total_rows, columns), dtype=np.float32)
curr_first_row = 0
curr_last_row = 0
for d in range(0, days):
curr_last_row += rows_per_day[d]

# dense
path_to_file = os.path.join(
input_dir_labels_and_dense, f"day_{d}_dense.npy"
)
data = np.load(path_to_file)
print(
f"Day {d} dense- {curr_first_row}-{curr_last_row} loaded files - {time.time()} - {path_to_file}"
)

full_dataset[curr_first_row:curr_last_row, 0:int_columns] = data
del data

# sparse
path_to_file = os.path.join(input_dir_sparse, f"day_{d}_sparse.npy")
data = np.load(path_to_file)
print(
f"Day {d} sparse- {curr_first_row}-{curr_last_row} loaded files - {time.time()} - {path_to_file}"
)

full_dataset[curr_first_row:curr_last_row, int_columns : columns - 1] = data
del data

# labels
path_to_file = os.path.join(
input_dir_labels_and_dense, f"day_{d}_labels.npy"
)
data = np.load(path_to_file)
print(
f"Day {d} labels- {curr_first_row}-{curr_last_row} loaded files - {time.time()} - {path_to_file}"
)

full_dataset[curr_first_row:curr_last_row, columns - 1 :] = data
del data

curr_first_row = curr_last_row

path_manager = PathManagerFactory().get(path_manager_key)

# Save the full dataset
if output_dir_full_set is not None:
full_output_file = os.path.join(output_dir_full_set, "full.npy")
with path_manager.open(full_output_file, "wb") as fout:
print(f"Writing full set file: {full_output_file}")
np.save(fout, full_dataset)

print("Shuffling dataset")
np.random.shuffle(full_dataset)

# Slice and save each portion into dense, sparse and labels
curr_first_row = 0
curr_last_row = 0
for d in range(0, days):
curr_last_row += rows_per_day[d]

# write dense columns
shuffled_dense_file = os.path.join(
output_dir_shuffled, f"day_{d}_dense.npy"
)
with path_manager.open(shuffled_dense_file, "wb") as fout:
print(
f"Writing rows {curr_first_row}-{curr_last_row-1} dense file: {shuffled_dense_file}"
)
np.save(fout, full_dataset[curr_first_row:curr_last_row, 0:int_columns])

# write sparse columns
shuffled_sparse_file = os.path.join(
output_dir_shuffled, f"day_{d}_sparse.npy"
)
with path_manager.open(shuffled_sparse_file, "wb") as fout:
print(
f"Writing rows {curr_first_row}-{curr_last_row-1} sparse file: {shuffled_sparse_file}"
)
np.save(
fout,
full_dataset[
curr_first_row:curr_last_row, int_columns : columns - 1
].astype(np.int32),
)

# write labels columns
shuffled_labels_file = os.path.join(
output_dir_shuffled, f"day_{d}_labels.npy"
)
with path_manager.open(shuffled_labels_file, "wb") as fout:
print(
f"Writing rows {curr_first_row}-{curr_last_row-1} labels file: {shuffled_labels_file}"
)
np.save(
fout,
full_dataset[curr_first_row:curr_last_row, columns - 1 :].astype(
np.int32
),
)

curr_first_row = curr_last_row


class InMemoryBinaryCriteoIterDataPipe(IterableDataset):
"""
Expand All @@ -392,6 +641,7 @@ class InMemoryBinaryCriteoIterDataPipe(IterableDataset):
batch_size (int): batch size.
rank (int): rank.
world_size (int): world size.
shuffle_batches (bool): Whether to shuffle batches
hashes (Optional[int]): List of max categorical feature value for each feature.
Length of this list should be CAT_FEATURE_COUNT.
path_manager_key (str): Path manager key used to load from different
Expand All @@ -418,6 +668,7 @@ def __init__(
batch_size: int,
rank: int,
world_size: int,
shuffle_batches: bool = False,
hashes: Optional[List[int]] = None,
path_manager_key: str = PATH_MANAGER_KEY,
) -> None:
Expand All @@ -427,6 +678,7 @@ def __init__(
self.batch_size = batch_size
self.rank = rank
self.world_size = world_size
self.shuffle_batches = shuffle_batches
self.hashes = hashes
self.path_manager_key = path_manager_key
self.path_manager: PathManager = PathManagerFactory().get(path_manager_key)
Expand Down Expand Up @@ -490,6 +742,13 @@ def _load_data_for_rank(self) -> None:
def _np_arrays_to_batch(
self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray
) -> Batch:
if self.shuffle_batches:
# Shuffle all 3 in unison
shuffler = np.random.permutation(len(dense))
dense = dense[shuffler]
sparse = sparse[shuffler]
labels = labels[shuffler]

return Batch(
dense_features=torch.from_numpy(dense),
sparse_features=KeyedJaggedTensor(
Expand Down
Loading