Skip to content

Commit

Permalink
Add distributions wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Jan 7, 2019
1 parent a9e97ce commit bb00592
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
8 changes: 8 additions & 0 deletions namedtensor/test_core.py
Expand Up @@ -3,6 +3,7 @@
import torch
from collections import OrderedDict
import pytest
import torch.nn.functional as F


def make_tensors(sizes):
Expand Down Expand Up @@ -40,6 +41,13 @@ def test_apply():
assert (ntorch.abs(ntensor.sum("alpha") - 1.0) < 1e-5).all()


def test_apply2():
base = torch.zeros([10, 2, 50])
ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
ntensor = ntensor.op(F.softmax, dim="alpha")
assert (ntorch.abs(ntensor.sum("alpha") - 1.0) < 1e-5).all()


@pytest.mark.xfail
def test_fail():
for base1, base2 in zip(
Expand Down
2 changes: 1 addition & 1 deletion namedtensor/test_module.py
Expand Up @@ -9,7 +9,7 @@ def __init__(self):
self.linear = nn.Linear(10, 20)

def forward(self, inp):
return inp.op(self.linear, inhid="outhid")
return inp.op(self.linear, outhid="inhid")


class NTModule(nn.Module):
Expand Down
19 changes: 16 additions & 3 deletions namedtensor/torch_helpers.py
Expand Up @@ -56,10 +56,23 @@ def access(self, dims):
return self.transpose(*term)._tensor

def op(self, axis_op, dim=None, **kwargs):
kwargs = {}
"Apply an op that may change dimensions sizes "
func_args = {}
if dim is not None:
kwargs["dim"] = self._schema.get(dim)
return self._new(axis_op(self._tensor, **kwargs), updates=kwargs)
func_args["dim"] = self._schema.get(dim)
out = self._new(
axis_op(self._tensor, **func_args),
updates={
(v[0] if isinstance(v, tuple) else v): k
for k, v in kwargs.items()
},
)

for k, v in self.shape.items():
assert (
k not in out.shape or v == out.shape[k]
), "name needs to change for updated dimensions"
return out

def __add__(self, b):
return self.add(b)
Expand Down

0 comments on commit bb00592

Please sign in to comment.