Skip to content

Commit

Permalink
Merge pull request #3430 from riya-singh28/unsorted_segment_sum
Browse files Browse the repository at this point in the history
Unsorted_segment_sum and unittest
  • Loading branch information
rbharath committed Jun 19, 2023
2 parents 67b83a2 + e1dc911 commit d0761e6
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 1 deletion.
51 changes: 50 additions & 1 deletion deepchem/utils/pytorch_utils.py
@@ -1,7 +1,7 @@
"""Utility functions for working with PyTorch."""

import torch
from typing import Callable, Union
from typing import Callable, Union, List


def get_activation(fn: Union[Callable, str]):
Expand All @@ -14,3 +14,52 @@ def get_activation(fn: Union[Callable, str]):
if isinstance(fn, str):
return getattr(torch.nn.functional, fn)
return fn


def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor,
num_segments: int) -> torch.Tensor:
"""Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
Parameters
----------
data: torch.Tensor
A tensor whose segments are to be summed.
segment_ids: torch.Tensor
The segment indices tensor.
num_segments: int
The number of segments.
Returns
-------
tensor: torch.Tensor
Examples
--------
>>> segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64)
>>> data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]])
>>> num_segments = 2
>>> result = unsorted_segment_sum(data=data,
segment_ids=segment_ids,
num_segments=num_segments)
>>> result
tensor([[5., 5., 5., 5.],
[5., 6., 7., 8.]])
"""
# length of segment_ids.shape should be 1
assert len(segment_ids.shape) == 1

# Shape of segment_ids should be equal to first dimension of data
assert segment_ids.shape[-1] == data.shape[0]

s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0],
*data.shape[1:])

# data.shape and segment_ids.shape should be equal
assert data.shape == segment_ids.shape
shape: List[int] = [num_segments] + list(data.shape[1:])
tensor: torch.Tensor = torch.zeros(*shape).scatter_add(
0, segment_ids, data.float())
tensor = tensor.type(data.dtype)
return tensor
Binary file added deepchem/utils/test/assets/result_segment_sum.npy
Binary file not shown.
29 changes: 29 additions & 0 deletions deepchem/utils/test/test_pytorch_utils.py
@@ -0,0 +1,29 @@
import deepchem as dc
import numpy as np
import pytest
try:
import torch
has_torch = True
except ModuleNotFoundError:
has_torch = False


@pytest.mark.torch
def test_unsorted_segment_sum():

segment_ids = torch.Tensor([0, 1, 0]).to(torch.int64)
data = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]])
num_segments = 2

# length of segment_ids.shape should be 1
assert len(segment_ids.shape) == 1

# Shape of segment_ids should be equal to first dimension of data
assert segment_ids.shape[-1] == data.shape[0]
result = dc.utils.pytorch_utils.unsorted_segment_sum(
data=data, segment_ids=segment_ids, num_segments=num_segments)

assert np.allclose(
np.array(result),
np.load("deepchem/utils/test/assets/result_segment_sum.npy"),
atol=1e-04)
5 changes: 5 additions & 0 deletions docs/source/api_reference/utils.rst
Expand Up @@ -286,3 +286,8 @@ The utilites here are used to create an object that contains information about a
:members:

.. autofunction:: deepchem.utils.dftutils.hashstr

Pytorch Utilities
-----------------

.. autofunction:: deepchem.utils.pytorch_utils.unsorted_segment_sum

0 comments on commit d0761e6

Please sign in to comment.