In [1]:
import torch
from collections import defaultdict
import timeit

items = [torch.randint(0, 10, (10, 10))]
keys = ['input_ids']
pad_tokens = {'input_ids': 0}
max_length = 10

# Method 1: Using defaultdict and torch.full
def method1():
    padded_tensors = defaultdict(lambda: torch.zeros((len(items), max_length), dtype=torch.int64))
    for key in keys:
        padded_tensors[key] = torch.full((len(items), max_length), pad_tokens.get(key, 0), dtype=torch.int64)
    return padded_tensors

# Method 2: Creating torch.full directly
def method2():
    padded_tensors = {key: torch.full((len(items), max_length), pad_tokens.get(key, 0), dtype=torch.int64) for key in keys}
    return padded_tensors

# Time comparison
time_method1 = timeit.timeit(method1, number=1000)
time_method2 = timeit.timeit(method2, number=1000)

print(f"Method 1 (defaultdict) time: {time_method1:.6f} seconds")
print(f"Method 2 (direct creation) time: {time_method2:.6f} seconds")

# Use the faster method
padded_tensors = method2() if time_method2 < time_method1 else method1()

Method 1 (defaultdict) time: 0.001292 seconds
Method 2 (direct creation) time: 0.001142 seconds
