diff --git a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py index b7bc604ee83..18633c553dc 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py @@ -25,7 +25,7 @@ from bigdl.orca.learn.pytorch.pytorch_pyspark_worker import PytorchPysparkWorker from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \ convert_predict_xshards_to_dataframe, make_data_creator, update_predict_xshards, \ - process_xshards_of_pandas_dataframe + reload_dataloader_creator from bigdl.orca.data import SparkXShards from bigdl.orca import OrcaContext from bigdl.orca.learn.base_estimator import BaseEstimator @@ -287,8 +287,8 @@ def transform_func(iter, init_params, param): "data should be either an instance of SparkXShards or a " "callable function, but got type: {}".format(type(data))) - params["data_creator"] = data - params["validation_data_creator"] = validation_data + params["data_creator"] = reload_dataloader_creator(data) + params["validation_data_creator"] = reload_dataloader_creator(validation_data) def transform_func(iter, init_param, param): return PytorchPysparkWorker(**init_param).train_epochs(**param) @@ -474,7 +474,7 @@ def transform_func(iter, init_param, param): res = data.rdd.repartition(self.num_workers).barrier() \ .mapPartitions(lambda iter: transform_func(iter, init_params, params)).collect() else: - params["data_creator"] = data + params["data_creator"] = reload_dataloader_creator(data) def transform_func(iter, init_param, param): return PytorchPysparkWorker(**init_param).validate(**param) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py index fa643e06c11..80336f4c594 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_ray_estimator.py @@ -27,7 +27,7 @@ from bigdl.orca.learn.pytorch.pytorch_ray_worker import PytorchRayWorker from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \ convert_predict_xshards_to_dataframe, update_predict_xshards, \ - process_xshards_of_pandas_dataframe + process_xshards_of_pandas_dataframe, reload_dataloader_creator from bigdl.orca.ray import OrcaRayContext from bigdl.orca.learn.ray_estimator import Estimator as OrcaRayEstimator from bigdl.dllib.utils.file_utils import enable_multi_fs_load, enable_multi_fs_save @@ -360,8 +360,8 @@ def data_creator(config, batch_size): " Ray Dataset or a callable function, but" " got type: {}".format(type(data))) - params["data_creator"] = data - params["validation_data_creator"] = validation_data + params["data_creator"] = reload_dataloader_creator(data) + params["validation_data_creator"] = reload_dataloader_creator(validation_data) success, worker_stats = self._train_epochs(**params) epoch_stats = list(map(list, zip(*worker_stats))) @@ -504,7 +504,8 @@ def data_creator(config, batch_size): "data should be either an instance of SparkXShards or a callable" " function, but got type: {}".format(type(data))) - params = dict(data_creator=data, batch_size=batch_size, num_steps=num_steps, + params = dict(data_creator=reload_dataloader_creator(data), + batch_size=batch_size, num_steps=num_steps, profile=profile, info=info) worker_stats = ray.get([w.validate.remote(**params) for w in self.remote_workers]) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/training_operator.py b/python/orca/src/bigdl/orca/learn/pytorch/training_operator.py index 5b1cdf77b5f..c3a3898c51d 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/training_operator.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/training_operator.py @@ -37,7 +37,7 @@ from bigdl.orca.learn.metrics import Metric from bigdl.orca.learn.pytorch.utils import (TimerCollection, AverageMeterCollection, - NUM_SAMPLES) + NUM_SAMPLES, get_batchsize) from bigdl.orca.learn.pytorch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, SCHEDULER_STEP_BATCH, SCHEDULER_STEP) from torch.nn.parallel import DistributedDataParallel as DDP @@ -268,15 +268,21 @@ def train_batch(self, batch, batch_info): """ # unpack features into list to support multiple inputs model - *features, target = batch - # If features is already a tuple, we don't give it an extra list dimension. - already_list = (isinstance(features[0], tuple) or isinstance(features[0], list)) - if len(features) == 1 and already_list: - features = features[0] + features, target = batch # Compute output. with self.timers.record("fwd"): - output = self.model(*features) + if torch.is_tensor(features): + output = self.model(features) + elif isinstance(features, dict): + output = self.model(**features) + elif isinstance(features, (tuple, list)): + output = self.model(*features) + else: + invalidInputError(False, + "Features should be tensor, list/tuple or dict, " + "but got {}".format(type(features))) + if isinstance(output, tuple) or isinstance(output, list): # Then target is also assumed to be a tuple or list. loss = self.criterion(*output, *target) @@ -292,7 +298,7 @@ def train_batch(self, batch, batch_info): with self.timers.record("apply"): self.optimizer.step() - return {"train_loss": loss.item(), NUM_SAMPLES: features[0].size(0)} + return {"train_loss": loss.item(), NUM_SAMPLES: get_batchsize(features)} def validate(self, val_iterator, info, metrics, num_steps=None): """Runs one standard validation pass over the val_iterator. @@ -329,7 +335,7 @@ def validate(self, val_iterator, info, metrics, num_steps=None): batch_info = {"batch_idx": batch_idx} batch_info.update(info) output, target, loss = self.forward_batch(batch, batch_info) - num_samples = target.size(0) + num_samples = get_batchsize(target) total_samples += num_samples losses.append(loss.item() * num_samples) for metric in metrics.values(): @@ -394,15 +400,21 @@ def forward_batch(self, batch, batch_info): calculate averages. """ # unpack features into list to support multiple inputs model - *features, target = batch - # If features is already a tuple, we don't give it an extra list dimension. - already_list = (isinstance(features[0], tuple) or isinstance(features[0], list)) - if len(features) == 1 and already_list: - features = features[0] + features, target = batch # compute output with self.timers.record("eval_fwd"): - output = self.model(*features) + if torch.is_tensor(features): + output = self.model(features) + elif isinstance(features, dict): + output = self.model(**features) + elif isinstance(features, (tuple, list)): + output = self.model(*features) + else: + invalidInputError(False, + "Features should be tensor, list/tuple or dict, " + "but got {}".format(type(features))) + loss = self.criterion(output, target) return output, target, loss diff --git a/python/orca/src/bigdl/orca/learn/pytorch/utils.py b/python/orca/src/bigdl/orca/learn/pytorch/utils.py index 6100dad35b2..6903fd6be77 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/utils.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/utils.py @@ -280,3 +280,12 @@ def get_filesystem(filepath): from fsspec.core import url_to_fs fs, _ = url_to_fs(str(filepath)) return fs + + +def get_batchsize(input): + if isinstance(input, (list, tuple)): + return get_batchsize(input[0]) + elif isinstance(input, dict): + return get_batchsize(list(input.values())[0]) + else: + return input.size(0) diff --git a/python/orca/src/bigdl/orca/learn/utils.py b/python/orca/src/bigdl/orca/learn/utils.py index deb59d7d2ce..325a8ab7bcb 100644 --- a/python/orca/src/bigdl/orca/learn/utils.py +++ b/python/orca/src/bigdl/orca/learn/utils.py @@ -459,6 +459,29 @@ def data_creator(config, batch_size): return data_creator +def make_dataloader_list_wrapper(func): + import torch + + def make_feature_list(batch): + if func is not None: + batch = func(batch) + *features, target = batch + if len(features) == 1 and torch.is_tensor(features[0]): + features = features[0] + return features, target + + return make_feature_list + + +def reload_dataloader_creator(dataloader_func): + def reload_dataloader(config, batch_size): + dataloader = dataloader_func(config, batch_size) + dataloader.collate_fn = make_dataloader_list_wrapper(dataloader.collate_fn) + return dataloader + + return reload_dataloader if dataloader_func else None + + def data_length(data): x = data["x"] if isinstance(x, np.ndarray): diff --git a/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pyspark_backend.py b/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pyspark_backend.py index 52ee2c1b147..a5ed581a8cd 100644 --- a/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pyspark_backend.py +++ b/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_pyspark_backend.py @@ -222,7 +222,27 @@ def test_spark_xshards(self): x_rdd = sc.parallelize(np.random.rand(4000, 1, 50).astype(np.float32)) # torch 1.7.1+ requires target size same as output size, which is (batch, 1) y_rdd = sc.parallelize(np.random.randint(0, 2, size=(4000, 1, 1)).astype(np.float32)) - rdd = x_rdd.zip(y_rdd).map(lambda x_y: {'x': x_y[0], 'y': x_y[1]}) + rdd = x_rdd.zip(y_rdd).map(lambda x_y: {'x': {"input_":x_y[0]}, 'y': x_y[1]}) + train_rdd, val_rdd = rdd.randomSplit([0.9, 0.1]) + train_xshards = SparkXShards(train_rdd) + val_xshards = SparkXShards(val_rdd) + train_stats = estimator.fit(train_xshards, validation_data=val_xshards, + batch_size=256, epochs=2) + print(train_stats) + val_stats = estimator.evaluate(val_xshards, batch_size=128) + print(val_stats) + + def test_spark_xshards_of_dict(self): + from bigdl.dllib.nncontext import init_nncontext + from bigdl.orca.data import SparkXShards + estimator = get_estimator(workers_per_node=1, + model_fn=lambda config: MultiInputNet()) + sc = init_nncontext() + x1_rdd = sc.parallelize(np.random.rand(4000, 1, 25).astype(np.float32)) + x2_rdd = sc.parallelize(np.random.rand(4000, 1, 25).astype(np.float32)) + # torch 1.7.1+ requires target size same as output size, which is (batch, 1) + y_rdd = sc.parallelize(np.random.randint(0, 2, size=(4000, 1, 1)).astype(np.float32)) + rdd = x1_rdd.zip(x2_rdd).zip(y_rdd).map(lambda x_y: {'x': {"input1":x_y[0][0], "input2":x_y[0][1]}, 'y': x_y[1]}) train_rdd, val_rdd = rdd.randomSplit([0.9, 0.1]) train_xshards = SparkXShards(train_rdd) val_xshards = SparkXShards(val_rdd) diff --git a/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_ray_runtime.py b/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_ray_runtime.py index 5380692f7b7..ea1973b7d40 100644 --- a/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_ray_runtime.py +++ b/python/orca/test/bigdl/orca/learn/ray/pytorch/test_estimator_ray_runtime.py @@ -26,7 +26,8 @@ class LinearDataset(torch.utils.data.Dataset): - def __init__(self, size=1000): + def __init__(self, size=1000, nested_input=False): + self.nested_input = nested_input X1 = torch.randn(size // 2, 50) X2 = torch.randn(size // 2, 50) + 1.5 self.x = torch.cat([X1, X2], dim=0) @@ -35,13 +36,77 @@ def __init__(self, size=1000): self.y = torch.cat([Y1, Y2], dim=0) def __getitem__(self, index): - return self.x[index, None], self.y[index, None] + if self.nested_input: + return {'x':self.x[index, None]}, self.y[index, None] + else: + return self.x[index, None], self.y[index, None] def __len__(self): return len(self.x) +class SingleListDataset(torch.utils.data.Dataset): + def __init__(self, size=1000, nested_input=True) -> None: + super().__init__() + self.size = size + self.nested_input = nested_input + X1_1 = torch.rand(self.size // 2, 1) + X1_2 = torch.rand(self.size // 2, 1) + 1.5 + self.X1 = torch.cat([X1_1, X1_2], dim=0) + + X2_1 = torch.rand(self.size // 2, 1) + 1.5 + X2_2 = torch.rand(self.size // 2, 1) + 3.0 + self.X2 = torch.cat([X2_1, X2_2], dim=0) + + Y1 = torch.zeros(self.size // 2, 1) + Y2 = torch.ones(self.size // 2, 1) + self.Y = torch.cat([Y1, Y2], dim=0) + + def __getitem__(self, index): + if self.nested_input: + return [self.X1[index], self.X2[index]], self.Y[index] + else: + return self.X1[index], self.X2[index], self.Y[index] + + def __len__(self): + return self.size + +class ComplicatedInputDataset(torch.utils.data.Dataset): + def __init__(self, size=1000, nested_input=True) -> None: + super().__init__() + self.size = size + X1_1 = torch.rand(self.size // 2, 1) + X1_2 = torch.rand(self.size // 2, 1) + 1.5 + self.X1 = torch.cat([X1_1, X1_2], dim=0) + + X2_1 = torch.rand(self.size // 2, 1) + 1.5 + X2_2 = torch.rand(self.size // 2, 1) + 3.0 + self.X2 = torch.cat([X2_1, X2_2], dim=0) + + X3_1 = torch.rand(self.size // 2, 1) + 3.0 + X3_2 = torch.rand(self.size // 2, 1) + 4.5 + self.X3 = torch.cat([X3_1, X3_2], dim=0) + + X4_1 = torch.rand(self.size // 2, 1) + 4.5 + X4_2 = torch.rand(self.size // 2, 1) + 6.0 + self.X4 = torch.cat([X4_1, X4_2], dim=0) + + Y1 = torch.zeros(self.size // 2, 1) + Y2 = torch.ones(self.size // 2, 1) + self.Y = torch.cat([Y1, Y2], dim=0) + + def __getitem__(self, index): + return (self.X1[index], self.X2[index]), {'x3': self.X3[index]}, self.X4[index], self.Y[index] + + def __len__(self): + return self.size + +DataSetMap = {"LinearDataset": LinearDataset, + "SingleListDataset": SingleListDataset, + "ComplicatedInputDataset": ComplicatedInputDataset} + def train_data_loader(config, batch_size): - train_dataset = LinearDataset(size=config.get("data_size", 1000)) + train_dataset = DataSetMap[config.get("dataset", "LinearDataset")](size=config.get("data_size", 1000), + nested_input=config.get("nested_input", False)) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size @@ -49,7 +114,8 @@ def train_data_loader(config, batch_size): return train_loader def val_data_loader(config, batch_size): - val_dataset = LinearDataset(size=config.get("val_size", 400)) + val_dataset = DataSetMap[config.get("dataset", "LinearDataset")](size=config.get("val_size", 400), + nested_input=config.get("nested_input", False)) validation_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size @@ -78,9 +144,73 @@ def forward(self, input_): y = self.out_act(a3) return y +class DictInputNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(50, 50) + self.relu1 = nn.ReLU() + self.dout = nn.Dropout(0.2) + self.fc2 = nn.Linear(50, 100) + self.prelu = nn.PReLU(1) + self.out = nn.Linear(100, 1) + self.out_act = nn.Sigmoid() + + def forward(self, input_): + a1 = self.fc1(input_['x']) + h1 = self.relu1(a1) + dout = self.dout(h1) + a2 = self.fc2(dout) + h2 = self.prelu(a2) + a3 = self.out(h2) + y = self.out_act(a3) + return y + +class SingleListInputModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(2, 1) + self.out_act = nn.Sigmoid() + + def forward(self, input_list): + x = torch.cat(input_list, dim=1) + x = self.fc(x) + x = self.out_act(x) + return x + +class MultiInputModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(2, 1) + self.out_act = nn.Sigmoid() + + def forward(self, x1, x2): + x = torch.cat((x1, x2), dim=1) + x = self.fc(x) + x = self.out_act(x) + return x + +class ComplicatedInputModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(4, 1) + self.out_act = nn.Sigmoid() + + def forward(self, x1_x2, x3_dict, x4): + x = torch.cat((x1_x2[0], x1_x2[1], x3_dict['x3'], x4), dim=1) + x = self.fc(x) + x = self.out_act(x) + return x + + +ModelMap = {"Net": Net, + "SingleListInputModel": SingleListInputModel, + "MultiInputModel": MultiInputModel, + "DictInputNet": DictInputNet, + "ComplicatedInputModel": ComplicatedInputModel} + def get_model(config): torch.manual_seed(0) - return Net() + return ModelMap[config.get("model", "Net")]() def get_optimizer(model, config): return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-2)) @@ -119,6 +249,126 @@ def test_train(self): dacc = (end_val_stats["Accuracy"] - start_val_stats["Accuracy"]) print(f"dLoss: {dloss}, dAcc: {dacc}") assert dloss < 0 < dacc, "training sanity check failed. loss increased!" + + def test_singlelist_input(self): + estimator = Estimator.from_torch(model=get_model, + optimizer=get_optimizer, + loss=nn.BCELoss(), + metrics=Accuracy(), + config={"lr": 1e-2, + "model": "SingleListInputModel", + "dataset": "SingleListDataset", + "nested_input": True}, + workers_per_node=2, + backend="ray", + sync_stats=True) + start_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(start_val_stats) + + train_stats = estimator.fit(train_data_loader, epochs=1, batch_size=32) + print(train_stats) + + end_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(end_val_stats) + + assert 0 < end_val_stats["Accuracy"] < 1 + assert estimator.get_model() + + # sanity check that training worked + dloss = end_val_stats["val_loss"] - start_val_stats["val_loss"] + dacc = (end_val_stats["Accuracy"] - start_val_stats["Accuracy"]) + print(f"dLoss: {dloss}, dAcc: {dacc}") + assert dloss < 0 < dacc, "training sanity check failed. loss increased!" + + def test_multi_input(self): + estimator = Estimator.from_torch(model=get_model, + optimizer=get_optimizer, + loss=nn.BCELoss(), + metrics=Accuracy(), + config={"lr": 1e-2, + "model": "MultiInputModel", + "dataset": "SingleListDataset", + "nested_input": False}, + workers_per_node=2, + backend="ray", + sync_stats=True) + start_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(start_val_stats) + + train_stats = estimator.fit(train_data_loader, epochs=1, batch_size=32) + print(train_stats) + + end_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(end_val_stats) + + assert 0 < end_val_stats["Accuracy"] < 1 + assert estimator.get_model() + + # sanity check that training worked + dloss = end_val_stats["val_loss"] - start_val_stats["val_loss"] + dacc = (end_val_stats["Accuracy"] - start_val_stats["Accuracy"]) + print(f"dLoss: {dloss}, dAcc: {dacc}") + assert dloss < 0 < dacc, "training sanity check failed. loss increased!" + + def test_dict_input(self): + estimator = Estimator.from_torch(model=get_model, + optimizer=get_optimizer, + loss=nn.BCELoss(), + metrics=Accuracy(), + config={"lr": 1e-2, + "model": "DictInputNet", + "dataset": "LinearDataset", + "nested_input": True}, + workers_per_node=2, + backend="ray", + sync_stats=True) + + start_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(start_val_stats) + + train_stats = estimator.fit(train_data_loader, epochs=1, batch_size=32) + print(train_stats) + + end_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(end_val_stats) + + assert 0 < end_val_stats["Accuracy"] < 1 + assert estimator.get_model() + + # sanity check that training worked + dloss = end_val_stats["val_loss"] - start_val_stats["val_loss"] + dacc = (end_val_stats["Accuracy"] - start_val_stats["Accuracy"]) + print(f"dLoss: {dloss}, dAcc: {dacc}") + assert dloss < 0 < dacc, "training sanity check failed. loss increased!" + + def test_complicated_input(self): + estimator = Estimator.from_torch(model=get_model, + optimizer=get_optimizer, + loss=nn.BCELoss(), + metrics=Accuracy(), + config={"lr": 1e-2, + "model": "ComplicatedInputModel", + "dataset": "ComplicatedInputDataset"}, + workers_per_node=2, + backend="ray", + sync_stats=True) + start_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(start_val_stats) + + train_stats = estimator.fit(train_data_loader, epochs=1, batch_size=32) + print(train_stats) + + end_val_stats = estimator.evaluate(val_data_loader, batch_size=32) + print(end_val_stats) + + assert 0 < end_val_stats["Accuracy"] < 1 + assert estimator.get_model() + + # sanity check that training worked + dloss = end_val_stats["val_loss"] - start_val_stats["val_loss"] + dacc = (end_val_stats["Accuracy"] - start_val_stats["Accuracy"]) + print(f"dLoss: {dloss}, dAcc: {dacc}") + assert dloss < 0 < dacc, "training sanity check failed. loss increased!" if __name__ == "__main__":