Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsorted_segment_sum and unittest #3430

Merged
merged 5 commits into from Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 51 additions & 1 deletion deepchem/utils/pytorch_utils.py
@@ -1,7 +1,7 @@
"""Utility functions for working with PyTorch."""

import torch
from typing import Callable, Union
from typing import Callable, Union, List


def get_activation(fn: Union[Callable, str]):
Expand All @@ -14,3 +14,53 @@ def get_activation(fn: Union[Callable, str]):
if isinstance(fn, str):
return getattr(torch.nn.functional, fn)
return fn


def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor,
num_segments: int) -> torch.Tensor:
"""Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.

Parameters
----------
data: torch.Tensor
A tensor whose segments are to be summed.
segment_ids: torch.Tensor
The segment indices tensor.
num_segments: int
The number of segments.

Returns
-------
tensor: torch.Tensor

Examples
--------
>>> segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64)
>>> data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]])
>>> num_segments = 2
>>> result = unsorted_segment_sum(data=data,
segment_ids=segment_ids,
num_segments=num_segments)
>>> result
tensor([[5., 5., 5., 5.],
[5., 6., 7., 8.]])

"""
# length of segment_ids.shape should be 1
assert len(segment_ids.shape) == 1

# Shape of segment_ids should be equal to first dimension of data
assert segment_ids.shape[-1] == data.shape[0]

if len(segment_ids.shape) == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this if statement because of the earlier assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(
segment_ids.shape[0], *data.shape[1:])

# data.shape and segment_ids.shape should be equal
assert data.shape == segment_ids.shape
shape: List[int] = [num_segments] + list(data.shape[1:])
tensor: torch.Tensor = torch.zeros(*shape).scatter_add(
0, segment_ids, data.float())
tensor = tensor.type(data.dtype)
return tensor
Binary file added deepchem/utils/test/assets/result_segment_sum.npy
Binary file not shown.
31 changes: 31 additions & 0 deletions deepchem/utils/test/test_pytorch_utils.py
@@ -0,0 +1,31 @@
import numpy as np
import pytest
try:
import torch
has_torch = True
except ModuleNotFoundError:
has_torch = False

from deepchem.utils.pytorch_utils import unsorted_segment_sum


@pytest.mark.torch
def test_unsorted_segment_sum():
rbharath marked this conversation as resolved.
Show resolved Hide resolved

segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64)
data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]])
num_segments = 2

# length of segment_ids.shape should be 1
assert len(segment_ids.shape) == 1

# Shape of segment_ids should be equal to first dimension of data
assert segment_ids.shape[-1] == data.shape[0]
result = unsorted_segment_sum(data=data,
segment_ids=segment_ids,
num_segments=num_segments)

assert np.allclose(
np.array(result),
np.load("deepchem/utils/test/assets/result_segment_sum.npy"),
atol=1e-04)
5 changes: 5 additions & 0 deletions docs/source/api_reference/utils.rst
Expand Up @@ -286,3 +286,8 @@ The utilites here are used to create an object that contains information about a
:members:

.. autofunction:: deepchem.utils.dftutils.hashstr

Pytorch Utilities
-----------------

.. autofunction:: deepchem.utils.pytorch_utils.unsorted_segment_sum