In [1]:
import torch
import nestedtensor
from IPython.display import Markdown, display

def print_eval(s):
    colorS = "<span style='color:darkred'>$ {}</span>".format(s)
    display(Markdown('**{}**'.format(colorS))) 
    print('{}\n'.format(str(eval(s))))

## Custom nn.functionals

By default all nn.functionals are implemented as a tensorwise function. However, in some cases we want to support custom semantics that come about by slight modifications to the lifted function. Take nn.functional.conv2d as an example.



In [2]:
nt = nestedtensor.nested_tensor([
    torch.rand(3, 10, 30),
    torch.rand(3, 20, 40),
    torch.rand(3, 30, 50)
])
nt1 = nestedtensor.nested_tensor([
    torch.rand(1, 3, 10, 30),
    torch.rand(1, 3, 20, 40),
    torch.rand(1, 3, 30, 50)
])
weight = torch.rand(64, 3, 7, 7)
print_eval("nt.size()")

**<span style='color:darkred'>$ nt.size()</span>**

(3, 3, None, None)



By default this function fails, because the components do not have a batch dimension.

However, NestedTensors implement a version of conv2d that doesn't require a batch dimension for ease of use and for efficiency (more on that later).

In [3]:
print_eval("torch.nn.functional.conv2d(nt, weight).size()")

**<span style='color:darkred'>$ torch.nn.functional.conv2d(nt, weight).size()</span>**

(3, 64, None, None)



We have a similar story for nn.functional.embedding_bag. The lifted version only works on elements of batch size 1, unless given an offset, which is an unnecessary annoyance. We extend the lifted embedding_bag to support inputs of dimension 1, if offset is set to None.

In [10]:
nt3 = nestedtensor.nested_tensor([
    torch.rand(30) * 10,
    torch.rand(40) * 10,
    torch.rand(50) * 10
], dtype=torch.int64)
nt4 = nestedtensor.nested_tensor([
    [
        torch.rand(1, 30),
    ],
    [
        torch.rand(1, 40),
        torch.rand(1, 50)
    ]
], dtype=torch.int64) * 10


In [11]:
weight = torch.rand(100, 256)
print_eval("torch.nn.functional.embedding_bag(nt3, weight).nested_size()")
print_eval("torch.nn.functional.embedding_bag(nt4, weight).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt2).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt3).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt4).nested_size()")

**<span style='color:darkred'>$ torch.nn.functional.embedding_bag(nt3, weight).nested_size()</span>**

RuntimeError: step must be nonzero

In [None]:
nt3 = nt3.float()
print_eval("nt3")
print_eval("nt3.size()")
print_eval("nt3.nested_size()")
print_eval("nestedtensor.nested_tensor(nt3.nested_size(1))")
nt4 = nt3 / nestedtensor.nested_tensor(nt3.nested_size(1))
print_eval("nt4")
print_eval("nt4.size()")

In [None]:
nt5 = nestedtensor.nested_tensor([
    torch.rand(30, 10),
    torch.rand(40, 10),
    torch.rand(50, 10)
])
print_eval("nt5.nested_size()")
print_eval("torch.mm(nt5, torch.rand(10, 5)).nested_size()")

In [None]:
print_eval("nt5.argmax(1)")
print_eval("nt5.argmax(1).size()")
print_eval("nt5.argmax(1).to_tensor()")

In [None]:
# THIS IS TEMOPORARILY DISABLED
# print_eval("nt5.nested_size()")
# print_eval("nt5.argmax(2).nested_size()")
# print_eval("torch.nn.functional.cross_entropy(nt5, nt5.argmax(2))")

In [None]:
nt6 = nestedtensor.nested_tensor([torch.rand(10, 10), torch.rand(20, 20), torch.rand(30, 30)])
print_eval("nt6.lu()[0].size()")
print_eval("nt6.lu()[1].size()")

In [None]:
nt7 = nestedtensor.nested_tensor([[torch.rand(1, 10), torch.rand(2, 20)], [torch.rand(3, 30)]])
nt8 = nestedtensor.nested_tensor([[torch.rand(10, 1), torch.rand(20, 2)], [torch.rand(30, 3)]])
print_eval("torch.mm(nt7, nt8)")