In [5]:
from video_diffusion import *

In [11]:
fname = 'data/mandelbrot_zoom.mp4'
frames = load_video_frames_from_file(fname)
inspect('frames', frames)


frames:
shape: torch.Size([1459, 3, 360, 548])
dtype: torch.float32
max: tensor(1.)
min: tensor(0.)
mean: tensor(0.3186)
std: tensor(0.3247)



In [12]:
dataset = AutoregressiveFrames(frames, seq_len=5)
print('len dataset', len(dataset))
print(dataset[0].shape)

len dataset 1455
torch.Size([5, 3, 360, 548])


In [17]:
import torch
from torch.utils.data import Dataset, DataLoader

class AutoregressiveFrames(Dataset):
    def __init__(self, frames: torch.Tensor, seq_len: int):
        self.frames = frames
        self.seq_len = seq_len
        self.T = frames.shape[0]

    def __len__(self):
        return max(0, self.T - self.seq_len + 1)

    def __getitem__(self, idx):
        if idx + self.seq_len > self.T:
            raise IndexError(f"Index {idx} + seq_len {self.seq_len} exceeds tensor length {self.T}")
        sub_tensor = self.frames[idx:idx+self.seq_len]
        if sub_tensor.shape[0] < self.seq_len:  # Optional: Pad if needed
            padding = torch.zeros((self.seq_len - sub_tensor.shape[0], *sub_tensor.shape[1:]))
            sub_tensor = torch.cat((sub_tensor, padding), dim=0)
        return sub_tensor

# Example: Debugging
frames = torch.arange(96).view(96, 1, 1, 1).float()  # (T=96, C=1, H=1, W=1)
seq_len = 10
dataset = AutoregressiveFrames(frames, seq_len)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

for i, batch in enumerate(dataloader):
    print(f"Batch {i}: Shape {batch.shape}, Values: {batch[:, 0, 0, 0, 0]}")


Batch 0: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([0., 1., 2., 3.])
Batch 1: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([4., 5., 6., 7.])
Batch 2: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([ 8.,  9., 10., 11.])
Batch 3: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([12., 13., 14., 15.])
Batch 4: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([16., 17., 18., 19.])
Batch 5: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([20., 21., 22., 23.])
Batch 6: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([24., 25., 26., 27.])
Batch 7: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([28., 29., 30., 31.])
Batch 8: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([32., 33., 34., 35.])
Batch 9: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([36., 37., 38., 39.])
Batch 10: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([40., 41., 42., 43.])
Batch 11: Shape torch.Size([4, 10, 1, 1, 1]), Values: tensor([44., 45., 46., 47.])
Batch 12: Shape torch.

In [25]:
seq = torch.randn(2, 5, 10, 10)
x = torch.split(seq, 3, dim=1)
print(x[0].shape)
print(x[1].shape)

torch.Size([2, 3, 10, 10])
torch.Size([2, 2, 10, 10])


IndexError: tuple index out of range