From 4508961d9d1181e47c3e00cd3958661893e50e68 Mon Sep 17 00:00:00 2001 From: riya-singh28 Date: Wed, 14 Jun 2023 19:58:35 +0530 Subject: [PATCH 1/5] add unsorted_segment_sum and unittest --- deepchem/utils/pytorch_utils.py | 40 ++++++++++++++++++ .../utils/test/assets/result_segment_sum.npy | Bin 0 -> 160 bytes deepchem/utils/test/test_pytorch_utils.py | 25 +++++++++++ 3 files changed, 65 insertions(+) create mode 100644 deepchem/utils/test/assets/result_segment_sum.npy create mode 100644 deepchem/utils/test/test_pytorch_utils.py diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index ac44ea1da7..c7c372ad1c 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -14,3 +14,43 @@ def get_activation(fn: Union[Callable, str]): if isinstance(fn, str): return getattr(torch.nn.functional, fn) return fn + + +def unsorted_segment_sum(data, segment_ids, num_segments): + """Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum. + + Parameters + ---------- + data: torch.Tensor + A tensor whose segments are to be summed. + segment_ids: torch.Tensor + The segment indices tensor. + num_segments: int + The number of segments. + + Returns + ------- + tensor: torch.Tensor + + Examples + -------- + >>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) + >>> tf.math.segment_sum(c, tf.constant([0, 0, 1])).numpy() + array([[5, 5, 5, 5], + [5, 6, 7, 8]], dtype=int32) + + """ + # segment_ids.shape should be a prefix of data.shape + assert all([i in data.shape for i in segment_ids.shape]) + + if len(segment_ids.shape) == 1: + s = torch.prod(torch.tensor(data.shape[1:])).long() + segment_ids = segment_ids.repeat_interleave(s).view( + segment_ids.shape[0], *data.shape[1:]) + + # data.shape and segment_ids.shape should be equal + assert data.shape == segment_ids.shape + shape = [num_segments] + list(data.shape[1:]) + tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor diff --git a/deepchem/utils/test/assets/result_segment_sum.npy b/deepchem/utils/test/assets/result_segment_sum.npy new file mode 100644 index 0000000000000000000000000000000000000000..4acce6c14772ca77b39e1f3a7ba55f7eee8f744e GIT binary patch literal 160 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC%^qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= lXCxM+0{I$7ItnJ5ItsN4WCJc%1_lOn%mx%=2VxE&1^~$v8=e3F literal 0 HcmV?d00001 diff --git a/deepchem/utils/test/test_pytorch_utils.py b/deepchem/utils/test/test_pytorch_utils.py new file mode 100644 index 0000000000..f5f35bee28 --- /dev/null +++ b/deepchem/utils/test/test_pytorch_utils.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest +try: + import torch + has_torch = True +except ModuleNotFoundError: + has_torch = False + +from deepchem.utils.pytorch_utils import unsorted_segment_sum + + +@pytest.mark.torch +def test_unsorted_segment_sum(): + + segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64) + data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]]) + num_segments = 2 + + result = unsorted_segment_sum(data=data, + segment_ids=segment_ids, + num_segments=num_segments) + + assert np.allclose(np.array(result), + np.load("deepchem/utils/test/assets/result_segment_sum.npy"), + atol=1e-04) From 0352b09cb322908317347c8c38f7fb507ca444c2 Mon Sep 17 00:00:00 2001 From: riya-singh28 Date: Wed, 14 Jun 2023 21:12:06 +0530 Subject: [PATCH 2/5] correct example for unsorted_segment_sum --- deepchem/utils/pytorch_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index c7c372ad1c..7de4930ffa 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -34,10 +34,15 @@ def unsorted_segment_sum(data, segment_ids, num_segments): Examples -------- - >>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) - >>> tf.math.segment_sum(c, tf.constant([0, 0, 1])).numpy() - array([[5, 5, 5, 5], - [5, 6, 7, 8]], dtype=int32) + >>> segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64) + >>> data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]]) + >>> num_segments = 2 + >>> result = unsorted_segment_sum(data=data, + segment_ids=segment_ids, + num_segments=num_segments) + >>> result + tensor([[5., 5., 5., 5.], + [5., 6., 7., 8.]]) """ # segment_ids.shape should be a prefix of data.shape From e6dfc9e28f58bb1ae457fabe8ee7dc21caa673eb Mon Sep 17 00:00:00 2001 From: riya-singh28 Date: Fri, 16 Jun 2023 18:52:57 +0530 Subject: [PATCH 3/5] add some minor fixes --- deepchem/utils/pytorch_utils.py | 15 ++++++++++----- deepchem/utils/test/test_pytorch_utils.py | 8 ++++---- docs/source/api_reference/utils.rst | 5 +++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index 7de4930ffa..dc05efe9e7 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -1,7 +1,7 @@ """Utility functions for working with PyTorch.""" import torch -from typing import Callable, Union +from typing import Callable, Union, List def get_activation(fn: Union[Callable, str]): @@ -16,7 +16,8 @@ def get_activation(fn: Union[Callable, str]): return fn -def unsorted_segment_sum(data, segment_ids, num_segments): +def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, + num_segments: int) -> torch.Tensor: """Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum. Parameters @@ -45,8 +46,11 @@ def unsorted_segment_sum(data, segment_ids, num_segments): [5., 6., 7., 8.]]) """ + # length of segment_ids.shape should be 1 + assert len(segment_ids.shape) == 1 + # segment_ids.shape should be a prefix of data.shape - assert all([i in data.shape for i in segment_ids.shape]) + assert segment_ids.shape[-1] == data.shape[0] if len(segment_ids.shape) == 1: s = torch.prod(torch.tensor(data.shape[1:])).long() @@ -55,7 +59,8 @@ def unsorted_segment_sum(data, segment_ids, num_segments): # data.shape and segment_ids.shape should be equal assert data.shape == segment_ids.shape - shape = [num_segments] + list(data.shape[1:]) - tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float()) + shape: List[int] = [num_segments] + list(data.shape[1:]) + tensor: torch.Tensor = torch.zeros(*shape).scatter_add( + 0, segment_ids, data.float()) tensor = tensor.type(data.dtype) return tensor diff --git a/deepchem/utils/test/test_pytorch_utils.py b/deepchem/utils/test/test_pytorch_utils.py index f5f35bee28..cdacfabe01 100644 --- a/deepchem/utils/test/test_pytorch_utils.py +++ b/deepchem/utils/test/test_pytorch_utils.py @@ -15,11 +15,11 @@ def test_unsorted_segment_sum(): segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64) data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]]) num_segments = 2 - result = unsorted_segment_sum(data=data, segment_ids=segment_ids, num_segments=num_segments) - assert np.allclose(np.array(result), - np.load("deepchem/utils/test/assets/result_segment_sum.npy"), - atol=1e-04) + assert np.allclose( + np.array(result), + np.load("deepchem/utils/test/assets/result_segment_sum.npy"), + atol=1e-04) diff --git a/docs/source/api_reference/utils.rst b/docs/source/api_reference/utils.rst index f6c29c8088..12d4d0e048 100644 --- a/docs/source/api_reference/utils.rst +++ b/docs/source/api_reference/utils.rst @@ -286,3 +286,8 @@ The utilites here are used to create an object that contains information about a :members: .. autofunction:: deepchem.utils.dftutils.hashstr + +Pytorch Utilities +----------------- + +.. autofunction:: deepchem.utils.pytorch_utils.unsorted_segment_sum From 2d939ea6eba7397e7c864daf12398d8a5f38b65a Mon Sep 17 00:00:00 2001 From: riya-singh28 Date: Fri, 16 Jun 2023 19:35:27 +0530 Subject: [PATCH 4/5] update test_unsorted_segment_sum --- deepchem/utils/pytorch_utils.py | 2 +- deepchem/utils/test/test_pytorch_utils.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index dc05efe9e7..bb73f202a0 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -49,7 +49,7 @@ def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, # length of segment_ids.shape should be 1 assert len(segment_ids.shape) == 1 - # segment_ids.shape should be a prefix of data.shape + # Shape of segment_ids should be equal to first dimension of data assert segment_ids.shape[-1] == data.shape[0] if len(segment_ids.shape) == 1: diff --git a/deepchem/utils/test/test_pytorch_utils.py b/deepchem/utils/test/test_pytorch_utils.py index cdacfabe01..40563a7a5c 100644 --- a/deepchem/utils/test/test_pytorch_utils.py +++ b/deepchem/utils/test/test_pytorch_utils.py @@ -15,6 +15,12 @@ def test_unsorted_segment_sum(): segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64) data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]]) num_segments = 2 + + # length of segment_ids.shape should be 1 + assert len(segment_ids.shape) == 1 + + # Shape of segment_ids should be equal to first dimension of data + assert segment_ids.shape[-1] == data.shape[0] result = unsorted_segment_sum(data=data, segment_ids=segment_ids, num_segments=num_segments) From e1dc911511b4262b0e05d03af3f01e72701ddea0 Mon Sep 17 00:00:00 2001 From: riya-singh28 Date: Fri, 16 Jun 2023 23:20:26 +0530 Subject: [PATCH 5/5] remove if statement from unsorted_segment_sum --- deepchem/utils/pytorch_utils.py | 7 +++---- deepchem/utils/test/test_pytorch_utils.py | 8 +++----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/deepchem/utils/pytorch_utils.py b/deepchem/utils/pytorch_utils.py index bb73f202a0..7a9f3b597b 100644 --- a/deepchem/utils/pytorch_utils.py +++ b/deepchem/utils/pytorch_utils.py @@ -52,10 +52,9 @@ def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, # Shape of segment_ids should be equal to first dimension of data assert segment_ids.shape[-1] == data.shape[0] - if len(segment_ids.shape) == 1: - s = torch.prod(torch.tensor(data.shape[1:])).long() - segment_ids = segment_ids.repeat_interleave(s).view( - segment_ids.shape[0], *data.shape[1:]) + s = torch.prod(torch.tensor(data.shape[1:])).long() + segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], + *data.shape[1:]) # data.shape and segment_ids.shape should be equal assert data.shape == segment_ids.shape diff --git a/deepchem/utils/test/test_pytorch_utils.py b/deepchem/utils/test/test_pytorch_utils.py index 40563a7a5c..6bddb4c0f2 100644 --- a/deepchem/utils/test/test_pytorch_utils.py +++ b/deepchem/utils/test/test_pytorch_utils.py @@ -1,3 +1,4 @@ +import deepchem as dc import numpy as np import pytest try: @@ -6,8 +7,6 @@ except ModuleNotFoundError: has_torch = False -from deepchem.utils.pytorch_utils import unsorted_segment_sum - @pytest.mark.torch def test_unsorted_segment_sum(): @@ -21,9 +20,8 @@ def test_unsorted_segment_sum(): # Shape of segment_ids should be equal to first dimension of data assert segment_ids.shape[-1] == data.shape[0] - result = unsorted_segment_sum(data=data, - segment_ids=segment_ids, - num_segments=num_segments) + result = dc.utils.pytorch_utils.unsorted_segment_sum( + data=data, segment_ids=segment_ids, num_segments=num_segments) assert np.allclose( np.array(result),