Skip to content

Commit

Permalink
make batches.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 10, 2021
1 parent b3e776d commit 33f7072
Showing 1 changed file with 40 additions and 28 deletions.
68 changes: 40 additions & 28 deletions demo/guide-python/external_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,22 @@
from typing import Callable, List, Tuple
import tempfile
import numpy as np
import cupy as cp
from time import time


def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Generate random batches."""
X = []
y = []
rng = np.random.RandomState(1994)
for i in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch)
X.append(_X)
y.append(_y)
return X, y
# from sklearn.datasets import make_regression
from cuml.datasets import make_regression


class Iterator(xgboost.DataIter):
"""A custom iterator for loading files in batches."""

def __init__(self, file_paths: List[Tuple[str, str]]):
self._file_paths = file_paths
self._it = 0
# XGBoost will generate some cache files under current directory with the prefix
# "cache"
super().__init__(cache_prefix=os.path.join(".", "cache"))
super().__init__(cache_prefix=os.path.join("/home/fis/Others/", "cache"))

def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
X_path, y_path = self._file_paths[self._it]
Expand All @@ -58,27 +48,42 @@ def next(self, input_data: Callable) -> int:
X, y = self.load_file()
input_data(data=X, label=y)
self._it += 1
print("It:", self._it)
return 1

def reset(self) -> None:
"""Reset the iterator to its beginning"""
self._it = 0


def main(tmpdir: str, external: int) -> xgboost.Booster:
# generate some random data for demo
batches = make_batches(2048, 32, 64)
def make_batches(
tmpdir: str, n_samples_per_batch: int, n_features: int, n_batches: int
) -> List[Tuple[str, str]]:
# rng = cp.random.RandomState(1994)
rng = 1994
files = []
for i, (X, y) in enumerate(zip(*batches)):
for i in range(n_batches):
X_path = os.path.join(tmpdir, "X-" + str(i) + ".npy")
np.save(X_path, X)
y_path = os.path.join(tmpdir, "y-" + str(i) + ".npy")
np.save(y_path, y)
files.append((X_path, y_path))
if os.path.exists(X_path) and os.path.exists(y_path):
continue
X, y = make_regression(
n_samples=n_samples_per_batch, n_features=n_features, random_state=rng
)
cp.save(X_path, X)
cp.save(y_path, y)
print(i)
return files


def main(tmpdir: str, external: int) -> xgboost.Booster:
# generate some random data for demo
files = make_batches(tmpdir, int(2 ** 20), 64, 64)
print("Finish data generation.")
missing = np.NaN
parameters = {"max_depth": 8, "tree_method": "approx", "nthread": 16}
missing = cp.NaN
parameters = {"max_depth": 8, "updater": "grow_histmaker,prune", "nthread": 16}
# parameters = {"max_depth": 8, "tree_method": "gpu_hist", "nthread": 16}

rounds = 16

Expand All @@ -87,6 +92,7 @@ def main(tmpdir: str, external: int) -> xgboost.Booster:
it = Iterator(files)
# For non-data arguments, specify it here once instead of passing them by the
# `next` method.
# Xy = xgboost.DeviceQuantileDMatrix(it, missing=missing, enable_categorical=False)
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
end = time()
print("Duration::DMatrix:", end - start)
Expand All @@ -104,8 +110,8 @@ def main(tmpdir: str, external: int) -> xgboost.Booster:
print("Duration::Train:", end - start)
else:
X_l, y_l = batches
X = np.concatenate(X_l, axis=0)
y = np.concatenate(y_l, axis=0)
X = cp.concatenate(X_l, axis=0)
y = cp.concatenate(y_l, axis=0)
start = time()
Xy = xgboost.DMatrix(X, y, missing=missing, enable_categorical=False)
end = time()
Expand All @@ -125,8 +131,14 @@ def main(tmpdir: str, external: int) -> xgboost.Booster:

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--external", type=int, choices=[0, 1], required=False, default=1)
parser.add_argument(
"--external", type=int, choices=[0, 1], required=False, default=1
)
args = parser.parse_args()
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir, args.external)
if not os.path.exists("./cache"):
os.mkdir("./cache")

main("./cache", args.external)
# with tempfile.TemporaryDirectory() as tmpdir:

0 comments on commit 33f7072

Please sign in to comment.