diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index ac44ea1da7..7a9f3b597b 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -1,7 +1,7 @@ """Utility functions for working with PyTorch.""" import torch -from typing import Callable, Union +from typing import Callable, Union, List def get_activation(fn: Union[Callable, str]): @@ -14,3 +14,52 @@ def get_activation(fn: Union[Callable, str]): if isinstance(fn, str): return getattr(torch.nn.functional, fn) return fn + + +def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, + num_segments: int) -> torch.Tensor: + """Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum. + + Parameters + ---------- + data: torch.Tensor + A tensor whose segments are to be summed. + segment_ids: torch.Tensor + The segment indices tensor. + num_segments: int + The number of segments. + + Returns + ------- + tensor: torch.Tensor + + Examples + -------- + >>> segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64) + >>> data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]]) + >>> num_segments = 2 + >>> result = unsorted_segment_sum(data=data, + segment_ids=segment_ids, + num_segments=num_segments) + >>> result + tensor([[5., 5., 5., 5.], + [5., 6., 7., 8.]]) + + """ + # length of segment_ids.shape should be 1 + assert len(segment_ids.shape) == 1 + + # Shape of segment_ids should be equal to first dimension of data + assert segment_ids.shape[-1] == data.shape[0] + + s = torch.prod(torch.tensor(data.shape[1:])).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], + *data.shape[1:]) + + # data.shape and segment_ids.shape should be equal + assert data.shape == segment_ids.shape + shape: List[int] = [num_segments] + list(data.shape[1:]) + tensor: torch.Tensor = torch.zeros(*shape).scatter_add( + 0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor diff --git a/deepchem/utils/test/assets/result_segment_sum.npy b/deepchem/utils/test/assets/result_segment_sum.npy new file mode 100644 index 0000000000..4acce6c147 Binary files /dev/null and b/deepchem/utils/test/assets/result_segment_sum.npy differ diff --git a/deepchem/utils/test/test_pytorch_utils.py b/deepchem/utils/test/test_pytorch_utils.py new file mode 100644 index 0000000000..6bddb4c0f2 --- /dev/null +++ b/deepchem/utils/test/test_pytorch_utils.py @@ -0,0 +1,29 @@ +import deepchem as dc +import numpy as np +import pytest +try: + import torch + has_torch = True +except ModuleNotFoundError: + has_torch = False + + +@pytest.mark.torch +def test_unsorted_segment_sum(): + + segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64) + data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]]) + num_segments = 2 + + # length of segment_ids.shape should be 1 + assert len(segment_ids.shape) == 1 + + # Shape of segment_ids should be equal to first dimension of data + assert segment_ids.shape[-1] == data.shape[0] + result = dc.utils.pytorch_utils.unsorted_segment_sum( + data=data, segment_ids=segment_ids, num_segments=num_segments) + + assert np.allclose( + np.array(result), + np.load("deepchem/utils/test/assets/result_segment_sum.npy"), + atol=1e-04) diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index f6c29c8088..12d4d0e048 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -286,3 +286,8 @@ The utilites here are used to create an object that contains information about a :members: .. autofunction:: deepchem.utils.dftutils.hashstr + +Pytorch Utilities +----------------- + +.. autofunction:: deepchem.utils.pytorch_utils.unsorted_segment_sum