Skip to content

Commit

Permalink
remove if statement from unsorted_segment_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
riya-singh28 committed Jun 16, 2023
1 parent 2d939ea commit e1dc911
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
7 changes: 3 additions & 4 deletions deepchem/utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor,
# Shape of segment_ids should be equal to first dimension of data
assert segment_ids.shape[-1] == data.shape[0]

if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(
segment_ids.shape[0], *data.shape[1:])
s = torch.prod(torch.tensor(data.shape[1:])).long()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0],
*data.shape[1:])

# data.shape and segment_ids.shape should be equal
assert data.shape == segment_ids.shape
Expand Down
8 changes: 3 additions & 5 deletions deepchem/utils/test/test_pytorch_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import deepchem as dc
import numpy as np
import pytest
try:
Expand All @@ -6,8 +7,6 @@
except ModuleNotFoundError:
has_torch = False

from deepchem.utils.pytorch_utils import unsorted_segment_sum


@pytest.mark.torch
def test_unsorted_segment_sum():
Expand All @@ -21,9 +20,8 @@ def test_unsorted_segment_sum():

# Shape of segment_ids should be equal to first dimension of data
assert segment_ids.shape[-1] == data.shape[0]
result = unsorted_segment_sum(data=data,
segment_ids=segment_ids,
num_segments=num_segments)
result = dc.utils.pytorch_utils.unsorted_segment_sum(
data=data, segment_ids=segment_ids, num_segments=num_segments)

assert np.allclose(
np.array(result),
Expand Down

0 comments on commit e1dc911

Please sign in to comment.