In [106]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
# %matplotlib inline

Getting Started with Nested Tensors
===================================

Nested tensors generalize the shape of regular dense tensors, allowing
for representation of ragged-sized data.

-   for a regular tensor, each dimension is regular and has a size
-   for a nested tensor, not all dimensions have regular sizes; some of
    them are ragged

Nested tensors are a natural solution for representing sequential data
within various domains:

-   in NLP, sentences can have variable lengths, so a batch of sentences
    forms a nested tensor
-   in CV, images can have variable shapes, so a batch of images forms a
    nested tensor

In this tutorial, we will demonstrate basic usage of nested tensors and
motivate their usefulness for operating on sequential data of varying
lengths with a real-world example. In particular, they are invaluable
for building transformers that can efficiently operate on ragged
sequential inputs. Below, we present an implementation of multi-head
attention using nested tensors that, combined usage of `torch.compile`,
out-performs operating naively on tensors with padding.

Nested tensors are currently a prototype feature and are subject to
change.


In [107]:
import numpy as np
import timeit
import torch
import torch.nn.functional as F

from torch import nn

torch.manual_seed(1)
np.random.seed(1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Nested tensor initialization
============================

From the Python frontend, a nested tensor can be created from a list of
tensors. We denote nt\[i\] as the ith tensor component of a
nestedtensor.


In [108]:
nt = torch.nested.nested_tensor([torch.arange(12).reshape(
    2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
print(f"{nt=}")

nt=nested_tensor([
  tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
          [ 6.,  7.,  8.,  9., 10., 11.]], device='cuda:0'),
  tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
          [ 6.,  7.,  8.,  9., 10., 11.],
          [12., 13., 14., 15., 16., 17.]], device='cuda:0')
], device='cuda:0')


By padding every underlying tensor to the same shape, a nestedtensor can
be converted to a regular tensor.


In [109]:
padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)
print(f"{padded_out_tensor=}")

padded_out_tensor=tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [ 0.,  0.,  0.,  0.,  0.,  0.]],

        [[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]]], device='cuda:0')


All tensors posses an attribute for determining if they are nested;


In [110]:
print(f"nt is nested: {nt.is_nested}")
print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")

nt is nested: True
padded_out_tensor is nested: False


It is common to construct nestedtensors from batches of irregularly
shaped tensors. i.e. dimension 0 is assumed to be the batch dimension.
Indexing dimension 0 gives back the first underlying tensor component.


In [111]:
print("First underlying tensor component:", nt[0], sep='\n')
print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')

# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.
print(f"First underlying tensor component is nested: {nt[0].is_nested}")

First underlying tensor component:
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.]], device='cuda:0')
last column of 2nd underlying tensor component:
tensor([ 5., 11., 17.], device='cuda:0')
First underlying tensor component is nested: False


An important note is that slicing in dimension 0 has not been supported
yet. Which means it not currently possible to construct a view that
combines the underlying tensor components.


Nested Tensor Operations
========================

As each operation must be explicitly implemented for nestedtensors,
operation coverage for nestedtensors is currently narrower than that of
regular tensors. For now, only basic operations such as index, dropout,
softmax, transpose, reshape, linear, bmm are covered. However, coverage
is being expanded. If you need certain operations, please file an
[issue](https://github.com/pytorch/pytorch) to help us prioritize
coverage.

**reshape**

The reshape op is for changing the shape of a tensor. Its full semantics
for regular tensors can be found
[here](https://pytorch.org/docs/stable/generated/torch.reshape.html).
For regular tensors, when specifying the new shape, a single dimension
may be -1, in which case it is inferred from the remaining dimensions
and the number of elements.

The semantics for nestedtensors are similar, except that -1 no longer
infers. Instead, it inherits the old size (here 2 for `nt[0]` and 3 for
`nt[1]`). -1 is the only legal size to specify for a jagged dimension.


In [112]:
nt_reshaped = nt.reshape(2, -1, 2, 3)
print(f"{nt_reshaped=}")

nt_reshaped=nested_tensor([
  tensor([[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.]],
  
          [[ 6.,  7.,  8.],
           [ 9., 10., 11.]]], device='cuda:0'),
  tensor([[[ 0.,  1.,  2.],
           [ 3.,  4.,  5.]],
  
          [[ 6.,  7.,  8.],
           [ 9., 10., 11.]],
  
          [[12., 13., 14.],
           [15., 16., 17.]]], device='cuda:0')
], device='cuda:0')


**transpose**

The transpose op is for swapping two dimensions of a tensor. Its full
semantics can be found
[here](https://pytorch.org/docs/stable/generated/torch.transpose.html).
Note that for nestedtensors dimension 0 is special; it is assumed to be
the batch dimension, so transposes involving nestedtensor dimension 0
are not supported.


In [113]:
nt_transposed = nt_reshaped.transpose(1, 2)
print(f"{nt_transposed=}")

nt_transposed=nested_tensor([
  tensor([[[ 0.,  1.,  2.],
           [ 6.,  7.,  8.]],
  
          [[ 3.,  4.,  5.],
           [ 9., 10., 11.]]], device='cuda:0'),
  tensor([[[ 0.,  1.,  2.],
           [ 6.,  7.,  8.],
           [12., 13., 14.]],
  
          [[ 3.,  4.,  5.],
           [ 9., 10., 11.],
           [15., 16., 17.]]], device='cuda:0')
], device='cuda:0')


**others**

Other operations have the same semantics as for regular tensors.
Applying the operation on a nestedtensor is equivalent to applying the
operation to the underlying tensor components, with the result being a
nestedtensor as well.


In [114]:
nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
nt3 = torch.matmul(nt_transposed, nt_mm)
print(f"Result of Matmul:\n {nt3}")

nt4 = F.dropout(nt3, 0.1)
print(f"Result of Dropout:\n {nt4}")

nt5 = F.softmax(nt4, -1)
print(f"Result of Softmax:\n {nt5}")

Result of Matmul:
 nested_tensor([
  tensor([[[  0.7781,   1.7332,   2.5551,  -1.7998],
           [ -6.3416,   0.6039,   3.3571, -21.6835]],
  
          [[ -3.0563,   1.1609,  -6.8225,  19.4126],
           [ -7.3476,  -0.8315, -15.4485,  44.0489]]], device='cuda:0'),
  tensor([[[ -0.7215,   3.0998,  -0.2846,   4.7335,   3.6254],
           [-17.8239,   9.9335,  14.5221,  25.6358,  15.9261],
           [-34.9263,  16.7672,  29.3289,  46.5381,  28.2268]],
  
          [[  5.9445,   3.1823,   7.7202, -15.5639,   9.8096],
           [ 13.5947,   9.8521,  19.5695, -38.9003,  20.3403],
           [ 21.2450,  16.5219,  31.4188, -62.2367,  30.8710]]], device='cuda:0')
], device='cuda:0')
Result of Dropout:
 nested_tensor([
  tensor([[[  0.8646,   1.9258,   2.8390,  -1.9998],
           [ -0.0000,   0.6710,   3.7301, -24.0928]],
  
          [[ -3.3959,   1.2899,  -0.0000,  21.5696],
           [ -8.1640,  -0.9239, -17.1650,  48.9432]]], device='cuda:0'),
  tensor([[[ -0.8017,   3.4442,  -0.

Why Nested Tensor
=================


When data is sequential, it is often the case that each sample has a
different length. For example, in a batch of sentences, each sentence
has a different number of words. A common technique for handling varying
sequences is to manually pad each data tensor to the same shape in order
to form a batch. For example, we have 2 sentences with different lengths
and a vocabulary In order to represent his as single tensor we pad with
0 to the max length in the batch.


In [115]:
sentences = [["goodbye", "padding"],
             ["embrace", "nested", "tensor"]]
vocabulary = {"goodbye": 1.0, "padding": 2.0,
              "embrace": 3.0, "nested": 4.0, "tensor": 5.0}
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
                                 [3.0, 4.0, 5.0]])
nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),
                                               torch.tensor([3.0, 4.0, 5.0])])
print(f"{padded_sentences=}")
print(f"{nested_sentences=}")

padded_sentences=tensor([[1., 2., 0.],
        [3., 4., 5.]])
nested_sentences=nested_tensor([
  tensor([1., 2.]),
  tensor([3., 4., 5.])
])


This technique of padding a batch of data to its max length is not
optimal. The padded data is not needed for computation and wastes memory
by allocating larger tensors than necessary. Further, not all operations
have the same semnatics when applied to padded data. For matrix
multiplications in order to ignore the padded entries, one needs to pad
with 0 while for softmax one has to pad with -inf to ignore specific
entries. The primary objective of nested tensor is to facilitate
operations on ragged data using the standard PyTorch tensor UX, thereby
eliminating the need for inefficient and complex padding and masking.


In [116]:
padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
                                             [3.0, 4.0, 5.0]])
print(F.softmax(padded_sentences_for_softmax, -1))
print(F.softmax(nested_sentences, -1))

tensor([[0.2689, 0.7311, 0.0000],
        [0.0900, 0.2447, 0.6652]])
nested_tensor([
  tensor([0.2689, 0.7311]),
  tensor([0.0900, 0.2447, 0.6652])
])


Let us take a look at a practical example: the multi-head attention
component utilized in
[Transformers](https://arxiv.org/pdf/1706.03762.pdf). We can implement
this in such a way that it can operate on either padded or nested
tensors.


In [117]:
class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout_p (float, optional): Dropout probability. Default: 0.0
    """
    def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,
                 nheads: int, dropout_p: float = 0.0):
        super().__init__()
        self.nheads = nheads
        self.dropout_p = dropout_p
        self.query_proj = nn.Linear(E_q, E_total)
        self.key_proj = nn.Linear(E_k, E_total)
        self.value_proj = nn.Linear(E_v, E_total)
        self.qkv_proj = nn.Linear(E_q, 3 * E_total)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (N, L_t, E_q)
            key (torch.Tensor): key of shape (N, L_s, E_k)
            value (torch.Tensor): value of shape (N, L_s, E_v)

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # # Step 1. Apply input projection
        # # TODO: demonstrate packed projection
        # query = self.query_proj(query)
        # key = self.key_proj(key)
        # value = self.value_proj(value)

        query, key, value = self.qkv_proj(query).chunk(3, dim=-1)
        print("hello")

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=dropout_p, is_causal=True)
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

set hyperparameters following [the Transformer
paper](https://arxiv.org/pdf/1706.03762.pdf)


In [118]:
N = 512
E_q, E_k, E_v, E_total = 512, 512, 512, 512
E_out = E_q
nheads = 8

except for dropout probability: set to 0 for correctness check


In [119]:
dropout_p = 0.0

Let us generate some realistic fake data from Zipf\'s law.


In [120]:
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)

Create nested tensor batch inputs


In [121]:
def gen_batch(N, E_q, E_k, E_v, device):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    query = torch.nested.nested_tensor([
        torch.randn(l.item(), E_q, device=device)
        for l in sentence_lengths
    ] , layout=torch.jagged, requires_grad=True)

    key = torch.nested.nested_tensor([
        torch.randn(s.item(), E_k, device=device)
        for s in sentence_lengths
    ], layout=torch.jagged)

    value = torch.nested.nested_tensor([
        torch.randn(s.item(), E_v, device=device)
        for s in sentence_lengths
    ], layout=torch.jagged)

    return query, key, value, sentence_lengths

query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)

Generate padded forms of query, key, value for comparison


In [122]:
def jagged_to_padded(jt, padding_val):
    # TODO: do jagged -> padded directly when this is supported
    return torch.nested.to_padded_tensor(
        torch.nested.nested_tensor(list(jt.unbind())),
        padding_val)

padded_query, padded_key, padded_value = (
    jagged_to_padded(t, 0.0) for t in (query, key, value)
)

Construct the model


In [123]:
mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)

Check correctness and performance


In [124]:
def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin)

output_nested, time_nested = benchmark(mha, query, key, value)
output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)

# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
    output_padded[i, entry_length:] = 0.0

print("=== without torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())
print("nested tensor multi-head attention takes", time_nested, "seconds")
print("padded tensor multi-head attention takes", time_padded, "seconds")

# warm up compile first...
compiled_mha = torch.compile(mha)
compiled_mha(query, key, value)
# ...now benchmark
compiled_output_nested, compiled_time_nested = benchmark(
    compiled_mha, query, key, value)

# warm up compile first...
compiled_mha(padded_query, padded_key, padded_value)
# ...now benchmark
compiled_output_padded, compiled_time_padded = benchmark(
    compiled_mha, padded_query, padded_key, padded_value)

# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
    compiled_output_padded[i, entry_length:] = 0.0

print("=== with torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())
print("nested tensor multi-head attention takes", compiled_time_nested, "seconds")
print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")

hello
hello
=== without torch.compile ===
nested and padded calculations differ by 1.1920928955078125e-06
nested tensor multi-head attention takes 0.059736272000009194 seconds
padded tensor multi-head attention takes 0.09523707599146292 seconds
hello
hello
hello
hello
=== with torch.compile ===
nested and padded calculations differ by 1.1920928955078125e-06
nested tensor multi-head attention takes 0.006561717003933154 seconds
padded tensor multi-head attention takes 0.025924878995283507 seconds


Note that without `torch.compile`, the overhead of the python subclass
nested tensor can make it slower than the equivalent computation on
padded tensors. However, once `torch.compile` is enabled, operating on
nested tensors gives a multiple x speedup. Avoiding wasted computation
on padding becomes only more valuable as the percentage of padding in
the batch increases.


In [125]:
print(f"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}")

Nested speedup: 3.951


Conclusion
==========

In this tutorial, we have learned how to perform basic operations with
nested tensors and how implement multi-head attention for transformers
in a way that avoids computation on padding. For more information, check
out the docs for the
[torch.nested](https://pytorch.org/docs/stable/nested.html) namespace.


In [126]:
dim = 64
heads = 8
temp_qkv = torch.cat([query, query, query],dim = -1)
qkv = temp_qkv.unflatten(-1, [3, heads, dim])
print(qkv.shape)
query_list = list(qkv)

max_length = 128
print("Max length:", max_length)

# Print the shapes of the tensors in the query list
for t in query_list:
    print("Tensor shape:", t.shape)



torch.Size([512, j19, 3, 8, 64])
Max length: 128
Tensor shape: torch.Size([18, 3, 8, 64])
Tensor shape: torch.Size([37, 3, 8, 64])
Tensor shape: torch.Size([5, 3, 8, 64])
Tensor shape: torch.Size([37, 3, 8, 64])
Tensor shape: torch.Size([14, 3, 8, 64])
Tensor shape: torch.Size([5, 3, 8, 64])
Tensor shape: torch.Size([26, 3, 8, 64])
Tensor shape: torch.Size([70, 3, 8, 64])
Tensor shape: torch.Size([2, 3, 8, 64])
Tensor shape: torch.Size([20, 3, 8, 64])
Tensor shape: torch.Size([8, 3, 8, 64])
Tensor shape: torch.Size([16, 3, 8, 64])
Tensor shape: torch.Size([64, 3, 8, 64])
Tensor shape: torch.Size([34, 3, 8, 64])
Tensor shape: torch.Size([36, 3, 8, 64])
Tensor shape: torch.Size([1, 3, 8, 64])
Tensor shape: torch.Size([29, 3, 8, 64])
Tensor shape: torch.Size([1, 3, 8, 64])
Tensor shape: torch.Size([2, 3, 8, 64])
Tensor shape: torch.Size([8, 3, 8, 64])
Tensor shape: torch.Size([12, 3, 8, 64])
Tensor shape: torch.Size([17, 3, 8, 64])
Tensor shape: torch.Size([14, 3, 8, 64])
Tensor shape: to

In [147]:
import torch

@torch.enable_grad
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


qkv = torch.cat([query, query, query],dim = -1)
qkv = qkv.unflatten(-1, [3, heads, dim])

print(qkv.shape)

print("qkv0 requires_grad", qkv.requires_grad)

# Example cos and sin tensors
max_length = 128
dim = 64
cos = torch.ones(1, max_length, 3, 1, dim, device=device)
sin = torch.zeros(1, max_length, 3, 1, dim, device=device)
cos = cos.squeeze(0)
sin = sin.squeeze(0)

# Unbind the nested tensor into a list of regular tensors
qkv_list = list(qkv.unbind())

time1 = timeit.default_timer()
# Apply the pointwise multiplication to each tensor in the list

result_list = []
for t in qkv_list:
    length = t.shape[0]
    cos_slice = cos[:length]
    sin_slice = sin[:length]
    result = (t * cos_slice) + (rotate_half(t) * sin_slice)
    result_list.append(result)

time2 = timeit.default_timer()

[(t * cos[:t.shape[0]]) + (rotate_half(t) * sin[:t.shape[0]]) for t in qkv_list]

time3 = timeit.default_timer()

# Reassemble the list of tensors back into a nested tensor
result_nested_tensor = torch.nested.as_nested_tensor(result_list)
time4 = timeit.default_timer()

print("Time taken for pointwise multiplication:", time2 - time1)
print("Time taken for pointwise multiplication with list comprehension:", time3 - time2)
print("Time taken for reassembly:", time4 - time3)

time1 = timeit.default_timer()
# Apply the pointwise multiplication to each tensor in the list

result_list = []
for t in qkv_list:
    length = t.shape[0]
    cos_slice = cos[:length]
    sin_slice = sin[:length]
    result = torch.einsum('...ij,...ij->...ij', t, cos_slice) + torch.einsum('...ij,...ij->...ij', rotate_half(t), sin_slice)
    result_list.append(result)

time2 = timeit.default_timer()

[(torch.einsum('...ij,...ij->...ij', t, cos[:t.shape[0]]) + torch.einsum('...ij,...ij->...ij', rotate_half(t), sin[:t.shape[0]])) for t in qkv_list]

time3 = timeit.default_timer()

# Reassemble the list of tensors back into a nested tensor
result_nested_tensor = torch.nested.as_nested_tensor(result_list)
time4 = timeit.default_timer()

print("Time taken for einsum:", time2 - time1)
print("Time taken for einsum with list comprehension:", time3 - time2)
print("Time taken for reassembly:", time4 - time3)

time1 = timeit.default_timer()
# Apply the pointwise multiplication to each tensor in the list

print(cos.shape)
cos = torch.nested.nested_tensor([cos[:length] for length in qkv.offsets().diff().tolist()], layout=torch.jagged, requires_grad=True)
sin = torch.nested.nested_tensor([sin[:length] for length in qkv.offsets().diff().tolist()], layout=torch.jagged, requires_grad=True)
print(cos.shape)

time2 = timeit.default_timer()

print("qkv1 requires grad: ", qkv.requires_grad)

qkv_temp = (qkv._values * cos._values) + (rotate_half(qkv)._values * sin._values)

time3 = timeit.default_timer()

print("qkv_temp requires grad: ", qkv_temp.requires_grad)
# Reassemble the list of tensors back into a nested tensor
qkv = torch.nested.nested_tensor_from_jagged(qkv_temp, qkv.offsets())
print(qkv.shape)

time4 = timeit.default_timer()

print("Time taken for cos and sin reshaping:", time2 - time1)
print("Time taken for pointwise multiplication with reshaped cos and sin:", time3 - time2)
print("Time taken for reassembly:", time4 - time3)


# Check if gradients are preserved
loss = qkv.flatten(-3, -1).sum(dim=-1)
loss = loss.to_padded_tensor(0.0)
print("Loss:", loss.shape)
loss.backward()
print("Gradients:")
for t in qkv.unbind():
    print(t.grad)


torch.Size([512, j19, 3, 8, 64])
qkv0 requires_grad True
Time taken for pointwise multiplication: 0.0513258810096886
Time taken for pointwise multiplication with list comprehension: 0.06395369699748699
Time taken for reassembly: 0.0026697169960243627
Time taken for einsum: 0.08307680799043737
Time taken for einsum with list comprehension: 0.06599850799830165
Time taken for reassembly: 0.003352902000187896
torch.Size([128, 3, 1, 64])
torch.Size([512, j60, 3, 1, 64])
qkv1 requires grad:  True
qkv_temp requires grad:  False
torch.Size([512, j19, 3, 8, 64])
Time taken for cos and sin reshaping: 0.005554736009798944
Time taken for pointwise multiplication with reshaped cos and sin: 0.0014531909982906654
Time taken for reassembly: 0.00040532198909204453


NotImplementedError: aten.to_padded_tensor.default

In [None]:
def packed_tensor_from_jagged(tensor):
    offsets = tensor.offsets()
    return torch.cat([t for t in tensor.unbind()], dim = 0), offsets

def jagged_from_packed_tensor(tensor, offsets):
    return torch.nested.nested_tensor_from_jagged(tensor, offsets)

def coerce_offsets(src, tgt):
    assert torch.eq(src.offsets(), tgt.offsets()).all().item()
    assert src._ragged_idx == tgt._ragged_idx

    def mb_get_size(t):
        return t.shape[0] if t is not None else None

    return torch.nested.nested_tensor_from_jagged(
        src.values(),
        tgt.offsets(),
        None,
        src._ragged_idx,
        mb_get_size(src._max_seqlen_tensor) if tgt._max_seqlen_tensor is None else mb_get_size(src._max_seqlen_tensor),
        mb_get_size(src._min_seqlen_tensor) if tgt._min_seqlen_tensor is None else mb_get_size(src._min_seqlen_tensor),
    )

qkv = torch.nested.nested_tensor([torch.randn(64, 1536), torch.randn(128, 1536), torch.randn(256, 1536), torch.randn(512, 1536)], device=device, requires_grad=True, layout=torch.jagged)
qkv = qkv.unflatten(-1, [3, heads, dim])
# qkv_padded = qkv.to_padded_tensor(0.0)

print(qkv.shape)
print("qkv0 requires_grad", qkv.requires_grad)

qkv, offsets = packed_tensor_from_jagged(qkv)
print(qkv.shape)
print("qkv requires_grad", qkv.requires_grad)

qkv = jagged_from_packed_tensor(qkv, offsets)
print(qkv.shape)
print("qkv requires_grad", qkv.requires_grad)

NameError: name 'torch' is not defined

In [1]:
import torch

seq_len = 1024
batch_size = 512

cu_seqlens = torch.arange(
                0, (batch_size + 1) * seq_len, step=seq_len,
                dtype=torch.int32
)

print(cu_seqlens)

tensor([     0,   1024,   2048,   3072,   4096,   5120,   6144,   7168,   8192,
          9216,  10240,  11264,  12288,  13312,  14336,  15360,  16384,  17408,
         18432,  19456,  20480,  21504,  22528,  23552,  24576,  25600,  26624,
         27648,  28672,  29696,  30720,  31744,  32768,  33792,  34816,  35840,
         36864,  37888,  38912,  39936,  40960,  41984,  43008,  44032,  45056,
         46080,  47104,  48128,  49152,  50176,  51200,  52224,  53248,  54272,
         55296,  56320,  57344,  58368,  59392,  60416,  61440,  62464,  63488,
         64512,  65536,  66560,  67584,  68608,  69632,  70656,  71680,  72704,
         73728,  74752,  75776,  76800,  77824,  78848,  79872,  80896,  81920,
         82944,  83968,  84992,  86016,  87040,  88064,  89088,  90112,  91136,
         92160,  93184,  94208,  95232,  96256,  97280,  98304,  99328, 100352,
        101376, 102400, 103424, 104448, 105472, 106496, 107520, 108544, 109568,
        110592, 111616, 112640, 113664, 

In [10]:
import torch
import torch.nn as nn

from torchviz import make_dot


from model import SEDD_nested

# Example usage
model = SEDD_nested(
input_tensor = torch.randn(8, 1, 768, requires_grad=True)

def forward_hook(module, input, output):
    print(f"Forward hook for {module.__class__.__name__}: input shape {input[0].shape}, output shape {output.shape}")

def backward_hook(module, grad_input, grad_output):
    print(f"Backward hook for {module.__class__.__name__}: grad_input shape {grad_input[0].shape}, grad_output shape {grad_output[0].shape}")


# Register hooks
for name, module in model.named_modules():
    module.register_forward_hook(forward_hook)
    module.register_backward_hook(backward_hook)

# Perform forward and backward pass
output = model(input_tensor, None, None)
loss = output.sum()

make_dot(output.sum(), params=dict(model.named_parameters()))

loss.backward()

Input shape: torch.Size([8, 1, 768])
Forward hook for LayerNorm: input shape torch.Size([8, 1, 768]), output shape torch.Size([8, 1, 768])
Forward hook for LayerNorm: input shape torch.Size([8, 1, 768]), output shape torch.Size([8, 1, 768])
After norm1 shape: torch.Size([8, 1, 768])
Forward hook for Linear: input shape torch.Size([8, 1, 768]), output shape torch.Size([8, 1, 2304])
After attn_qkv shape: torch.Size([8, 1, 2304])
Forward hook for DDiTBlock: input shape torch.Size([8, 1, 768]), output shape torch.Size([8, 1, 768])
Backward hook for LayerNorm: grad_input shape torch.Size([8, 1, 768]), grad_output shape torch.Size([8, 1, 768])
Backward hook for LayerNorm: grad_input shape torch.Size([8, 1, 768]), grad_output shape torch.Size([8, 1, 768])
Backward hook for DDiTBlock: grad_input shape torch.Size([8, 1, 768]), grad_output shape torch.Size([8, 1, 768])


In [16]:
import torch
import torch.nn as nn

def packed_tensor_from_jagged(tensor):
    offsets = tensor.offsets()
    return torch.cat([t for t in tensor.unbind()], dim = 0), offsets

def modulate(x, shift, scale):

    if scale is not None:
        x = x * (1 + scale)
    if shift is not None:
        x = x + shift
    return x

class test_model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.modulation = nn.Linear(dim, 2 * dim)

    def forward(self, x, c): # x is a ragged tensor (batch_size=4, j, dim=64), c is a regular tensor (batch_size=4, dim=64)
        shift, scale = self.modulation(c).chunk(2, dim=-1)
        shift, scale = shift.unsqueeze(1), scale.unsqueeze(1) # I think it has something to do with this unsqueeze

        return modulate(x, shift, scale)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = test_model(64).to(device)

### This seems to work fine
batch =torch.randn(4, 512, 64, device=device, requires_grad=True) # batch_size=4, j=512, dim=64
c = torch.randn(4, 64, device=device, requires_grad=True) # batch_size=4, dim=64

output = model(batch, c)
loss = output.sum(dim=-1).mean()
loss.backward()
###

### Bug here, when using nested tensors
batch = torch.nested.nested_tensor([torch.randn(64, 64), torch.randn(128, 64), torch.randn(256, 64), torch.randn(512, 64)], device=device, requires_grad=True, layout=torch.jagged) # batch_size=4, j=jagged, dim=64
c = torch.randn(4, 64, device=device, requires_grad=True) # batch_size=4, dim=64

output = model(batch, c)
output, offsets = packed_tensor_from_jagged(output)
loss = output.sum(dim=-1).mean()
loss.backward() # This line throws an error (Function AddBackward0 returned an invalid gradient at index 0 - got [1, 4, 64] but expected shape compatible with [4, 1, 64])
###

RuntimeError: Function AddBackward0 returned an invalid gradient at index 0 - got [1, 4, 64] but expected shape compatible with [4, 1, 64]