Skip to content

Commit

Permalink
Orca: Align the data analysis method of dataloader and dataframe (#5763)
Browse files Browse the repository at this point in the history
* wrapper

* ray backend

* pyspark

* more uts
  • Loading branch information
leonardozcm committed Sep 23, 2022
1 parent 4d72bd7 commit 8d0f17e
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 29 deletions.
Expand Up @@ -25,7 +25,7 @@
from bigdl.orca.learn.pytorch.pytorch_pyspark_worker import PytorchPysparkWorker
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
convert_predict_xshards_to_dataframe, make_data_creator, update_predict_xshards, \
process_xshards_of_pandas_dataframe
reload_dataloader_creator
from bigdl.orca.data import SparkXShards
from bigdl.orca import OrcaContext
from bigdl.orca.learn.base_estimator import BaseEstimator
Expand Down Expand Up @@ -287,8 +287,8 @@ def transform_func(iter, init_params, param):
"data should be either an instance of SparkXShards or a "
"callable function, but got type: {}".format(type(data)))

params["data_creator"] = data
params["validation_data_creator"] = validation_data
params["data_creator"] = reload_dataloader_creator(data)
params["validation_data_creator"] = reload_dataloader_creator(validation_data)

def transform_func(iter, init_param, param):
return PytorchPysparkWorker(**init_param).train_epochs(**param)
Expand Down Expand Up @@ -474,7 +474,7 @@ def transform_func(iter, init_param, param):
res = data.rdd.repartition(self.num_workers).barrier() \
.mapPartitions(lambda iter: transform_func(iter, init_params, params)).collect()
else:
params["data_creator"] = data
params["data_creator"] = reload_dataloader_creator(data)

def transform_func(iter, init_param, param):
return PytorchPysparkWorker(**init_param).validate(**param)
Expand Down
Expand Up @@ -27,7 +27,7 @@
from bigdl.orca.learn.pytorch.pytorch_ray_worker import PytorchRayWorker
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
convert_predict_xshards_to_dataframe, update_predict_xshards, \
process_xshards_of_pandas_dataframe
process_xshards_of_pandas_dataframe, reload_dataloader_creator
from bigdl.orca.ray import OrcaRayContext
from bigdl.orca.learn.ray_estimator import Estimator as OrcaRayEstimator
from bigdl.dllib.utils.file_utils import enable_multi_fs_load, enable_multi_fs_save
Expand Down Expand Up @@ -360,8 +360,8 @@ def data_creator(config, batch_size):
" Ray Dataset or a callable function, but"
" got type: {}".format(type(data)))

params["data_creator"] = data
params["validation_data_creator"] = validation_data
params["data_creator"] = reload_dataloader_creator(data)
params["validation_data_creator"] = reload_dataloader_creator(validation_data)
success, worker_stats = self._train_epochs(**params)

epoch_stats = list(map(list, zip(*worker_stats)))
Expand Down Expand Up @@ -504,7 +504,8 @@ def data_creator(config, batch_size):
"data should be either an instance of SparkXShards or a callable"
" function, but got type: {}".format(type(data)))

params = dict(data_creator=data, batch_size=batch_size, num_steps=num_steps,
params = dict(data_creator=reload_dataloader_creator(data),
batch_size=batch_size, num_steps=num_steps,
profile=profile, info=info)

worker_stats = ray.get([w.validate.remote(**params) for w in self.remote_workers])
Expand Down
42 changes: 27 additions & 15 deletions python/orca/src/bigdl/orca/learn/pytorch/training_operator.py
Expand Up @@ -37,7 +37,7 @@

from bigdl.orca.learn.metrics import Metric
from bigdl.orca.learn.pytorch.utils import (TimerCollection, AverageMeterCollection,
NUM_SAMPLES)
NUM_SAMPLES, get_batchsize)
from bigdl.orca.learn.pytorch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
SCHEDULER_STEP_BATCH, SCHEDULER_STEP)
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -268,15 +268,21 @@ def train_batch(self, batch, batch_info):
"""
# unpack features into list to support multiple inputs model
*features, target = batch
# If features is already a tuple, we don't give it an extra list dimension.
already_list = (isinstance(features[0], tuple) or isinstance(features[0], list))
if len(features) == 1 and already_list:
features = features[0]
features, target = batch

# Compute output.
with self.timers.record("fwd"):
output = self.model(*features)
if torch.is_tensor(features):
output = self.model(features)
elif isinstance(features, dict):
output = self.model(**features)
elif isinstance(features, (tuple, list)):
output = self.model(*features)
else:
invalidInputError(False,
"Features should be tensor, list/tuple or dict, "
"but got {}".format(type(features)))

if isinstance(output, tuple) or isinstance(output, list):
# Then target is also assumed to be a tuple or list.
loss = self.criterion(*output, *target)
Expand All @@ -292,7 +298,7 @@ def train_batch(self, batch, batch_info):
with self.timers.record("apply"):
self.optimizer.step()

return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)}
return {"train_loss": loss.item(), NUM_SAMPLES: get_batchsize(features)}

def validate(self, val_iterator, info, metrics, num_steps=None):
"""Runs one standard validation pass over the val_iterator.
Expand Down Expand Up @@ -329,7 +335,7 @@ def validate(self, val_iterator, info, metrics, num_steps=None):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
output, target, loss = self.forward_batch(batch, batch_info)
num_samples = target.size(0)
num_samples = get_batchsize(target)
total_samples += num_samples
losses.append(loss.item() * num_samples)
for metric in metrics.values():
Expand Down Expand Up @@ -394,15 +400,21 @@ def forward_batch(self, batch, batch_info):
calculate averages.
"""
# unpack features into list to support multiple inputs model
*features, target = batch
# If features is already a tuple, we don't give it an extra list dimension.
already_list = (isinstance(features[0], tuple) or isinstance(features[0], list))
if len(features) == 1 and already_list:
features = features[0]
features, target = batch

# compute output
with self.timers.record("eval_fwd"):
output = self.model(*features)
if torch.is_tensor(features):
output = self.model(features)
elif isinstance(features, dict):
output = self.model(**features)
elif isinstance(features, (tuple, list)):
output = self.model(*features)
else:
invalidInputError(False,
"Features should be tensor, list/tuple or dict, "
"but got {}".format(type(features)))

loss = self.criterion(output, target)

return output, target, loss
Expand Down
9 changes: 9 additions & 0 deletions python/orca/src/bigdl/orca/learn/pytorch/utils.py
Expand Up @@ -280,3 +280,12 @@ def get_filesystem(filepath):
from fsspec.core import url_to_fs
fs, _ = url_to_fs(str(filepath))
return fs


def get_batchsize(input):
if isinstance(input, (list, tuple)):
return get_batchsize(input[0])
elif isinstance(input, dict):
return get_batchsize(list(input.values())[0])
else:
return input.size(0)
23 changes: 23 additions & 0 deletions python/orca/src/bigdl/orca/learn/utils.py
Expand Up @@ -459,6 +459,29 @@ def data_creator(config, batch_size):
return data_creator


def make_dataloader_list_wrapper(func):
import torch

def make_feature_list(batch):
if func is not None:
batch = func(batch)
*features, target = batch
if len(features) == 1 and torch.is_tensor(features[0]):
features = features[0]
return features, target

return make_feature_list


def reload_dataloader_creator(dataloader_func):
def reload_dataloader(config, batch_size):
dataloader = dataloader_func(config, batch_size)
dataloader.collate_fn = make_dataloader_list_wrapper(dataloader.collate_fn)
return dataloader

return reload_dataloader if dataloader_func else None


def data_length(data):
x = data["x"]
if isinstance(x, np.ndarray):
Expand Down
Expand Up @@ -222,7 +222,27 @@ def test_spark_xshards(self):
x_rdd = sc.parallelize(np.random.rand(4000, 1, 50).astype(np.float32))
# torch 1.7.1+ requires target size same as output size, which is (batch, 1)
y_rdd = sc.parallelize(np.random.randint(0, 2, size=(4000, 1, 1)).astype(np.float32))
rdd = x_rdd.zip(y_rdd).map(lambda x_y: {'x': x_y[0], 'y': x_y[1]})
rdd = x_rdd.zip(y_rdd).map(lambda x_y: {'x': {"input_":x_y[0]}, 'y': x_y[1]})
train_rdd, val_rdd = rdd.randomSplit([0.9, 0.1])
train_xshards = SparkXShards(train_rdd)
val_xshards = SparkXShards(val_rdd)
train_stats = estimator.fit(train_xshards, validation_data=val_xshards,
batch_size=256, epochs=2)
print(train_stats)
val_stats = estimator.evaluate(val_xshards, batch_size=128)
print(val_stats)

def test_spark_xshards_of_dict(self):
from bigdl.dllib.nncontext import init_nncontext
from bigdl.orca.data import SparkXShards
estimator = get_estimator(workers_per_node=1,
model_fn=lambda config: MultiInputNet())
sc = init_nncontext()
x1_rdd = sc.parallelize(np.random.rand(4000, 1, 25).astype(np.float32))
x2_rdd = sc.parallelize(np.random.rand(4000, 1, 25).astype(np.float32))
# torch 1.7.1+ requires target size same as output size, which is (batch, 1)
y_rdd = sc.parallelize(np.random.randint(0, 2, size=(4000, 1, 1)).astype(np.float32))
rdd = x1_rdd.zip(x2_rdd).zip(y_rdd).map(lambda x_y: {'x': {"input1":x_y[0][0], "input2":x_y[0][1]}, 'y': x_y[1]})
train_rdd, val_rdd = rdd.randomSplit([0.9, 0.1])
train_xshards = SparkXShards(train_rdd)
val_xshards = SparkXShards(val_rdd)
Expand Down

0 comments on commit 8d0f17e

Please sign in to comment.