In [None]:
import torch
from torch.nested import nested_tensor
import torch.nn as nn
import time
from tqdm import tqdm
import numpy as np

torch.manual_seed(42)

batch_size = 64
max_seq_len = 100
embed_dim = 768
num_batches = 200


seq_lengths = torch.randint(10, max_seq_len + 1, (num_batches, batch_size))
base_batches = [
    [torch.randn(seq_len, embed_dim) for seq_len in batch_seq_lengths]
    for batch_seq_lengths in seq_lengths
]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
nested_batches = [nested_tensor(batch) for batch in base_batches]


class TransformerEncoderWithNested(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super(TransformerEncoderWithNested, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, nested_input):
        outputs = []
        for tensor in nested_input:
            attn_output, _ = self.attention(tensor, tensor, tensor)
            tensor = self.norm1(tensor + attn_output)
            ff_output = self.ff(tensor)
            tensor = self.norm2(tensor + ff_output)
            outputs.append(tensor)
        return nested_tensor(outputs)

    
    
model_nested = TransformerEncoderWithNested(embed_dim=embed_dim, num_heads=8, ff_dim=256).to(device)

forward_times_nested = []
for nested_batch in tqdm(nested_batches, total=num_batches):
    nested_batch = nested_batch.to(device)
    start_time = time.time()
    output = model_nested(nested_batch)
    forward_times_nested.append(time.time() - start_time)

print(f"Forward pass time with NestedTensors: {sum(forward_times_nested)/len(forward_times_nested):.6f} seconds/batch")

In [None]:
np.save("nested_time.npy", np.array(forward_times_nested))