In [None]:
import torch
from torch.nested import nested_tensor
import torch.nn as nn
import time
from tqdm import tqdm
import numpy as np

torch.manual_seed(42)

batch_size = 64
max_seq_len = 100
embed_dim = 768
num_batches = 200


seq_lengths = torch.randint(10, max_seq_len + 1, (num_batches, batch_size))
base_batches = [
    [torch.randn(seq_len, embed_dim) for seq_len in batch_seq_lengths]
    for batch_seq_lengths in seq_lengths
]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
padded_batches = []
attention_masks = []

for batch in base_batches:
    batch_padded = []
    mask = []
    for sequence in batch:
        seq_len = sequence.size(0)
        padded_sequence = torch.cat([sequence, torch.zeros(max_seq_len - seq_len, embed_dim)], dim=0)
        batch_padded.append(padded_sequence)
        mask.append([0] * seq_len + [1] * (max_seq_len - seq_len))
    padded_batches.append(torch.stack(batch_padded))
    attention_masks.append(torch.tensor(mask, dtype=torch.bool))

    
class TransformerEncoderWithPadding(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super(TransformerEncoderWithPadding, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x, attention_mask):
        x = x.transpose(0, 1)
        attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)
        return x.transpose(0, 1)

    
    
model_padded = TransformerEncoderWithPadding(embed_dim=embed_dim, num_heads=8, ff_dim=256).to(device)

forward_times_padded = []
for padded_batch, mask in tqdm(zip(padded_batches, attention_masks), total=num_batches):  
    padded_batch = padded_batch.to(device)
    mask = mask.to(device)
    start_time = time.time()
    output = model_padded(padded_batch, mask)
    forward_times_padded.append(time.time() - start_time)

print(f"Forward pass time with padding: {sum(forward_times_padded)/len(forward_times_padded):.6f} seconds/batch")

In [None]:
np.save("padding_time.npy", np.array(forward_times_padded))