Skip to content

Commit

Permalink
Fix validate_input_col for nn.Module or Callable (#96213)
Browse files Browse the repository at this point in the history
Forward fix the problem introduced in pytorch/pytorch#95067

Not all `Callable` objects have `__name__` implemented. Using `repr` as the backup solution to get function name or reference.
Pull Request resolved: pytorch/pytorch#96213
Approved by: https://github.com/NivekT
  • Loading branch information
ejguan authored and cyyever committed Mar 12, 2023
1 parent 06d6fe1 commit dc5d538
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
15 changes: 15 additions & 0 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import numpy as np

import torch
import torch.nn as nn
import torch.utils.data.datapipes as dp
import torch.utils.data.graph
import torch.utils.data.graph_settings
Expand Down Expand Up @@ -663,6 +664,16 @@ def _mod_3_test(x):
lambda_fn3 = lambda x: x >= 5 # noqa: E731


class Add1Module(nn.Module):
def forward(self, x):
return x + 1


class Add1Callable:
def __call__(self, x):
return x + 1


class TestFunctionalIterDataPipe(TestCase):

def _serialization_test_helper(self, datapipe, use_dill):
Expand Down Expand Up @@ -1326,6 +1337,10 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
_helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
_helper(lambda data: (data[0], data[1], int(data[2])), int, 2)

# Handle nn.Module and Callable (without __name__ implemented)
_helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0)
_helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0)

@suppress_warnings # Suppress warning for lambda fn
def test_map_dict_with_col_iterdatapipe(self):
def fn_11(d):
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]
continue

if isinstance(fn, functools.partial):
fn_name = fn.func.__name__
fn_name = getattr(fn.func, "__name__", repr(fn.func))
else:
fn_name = fn.__name__
fn_name = getattr(fn, "__name__", repr(fn))

if len(non_default_kw_only) > 0:
raise ValueError(
Expand Down

0 comments on commit dc5d538

Please sign in to comment.