In [1]:
import os
os.chdir("..")

In [7]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

In [3]:
def frame_separator(frame_path):
    return int(frame_path.split("/")[-1].split(".")[0].split("_")[-1])

class SingleVideoDataset(ImageFolder):
    def __init__(self, root, index_slice=None, transform=None):
        # Call parent class constructor with root directory and transform
        super().__init__(
            root=root,
            transform=transform,
            target_transform=None,
        )
        if index_slice is not None:
            samples = sorted(self.samples, key= lambda x: frame_separator(x[0]))
            self.samples = samples[index_slice[0]:index_slice[1]]
        
    def __getitem__(self, idx):
        # Get the image path from parent class samples
        path = self.samples[idx][0]
        image = self.loader(path)
        frame_id = frame_separator(path)
        
        if self.transform is not None:
            image = self.transform(image)
            
        return image, frame_id

In [8]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = SingleVideoDataset("./data/Sintel_frames", index_slice=(314, 424), transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
samples, frame_ids = next(iter(dataloader))

print(samples.shape)
print(frame_ids)

torch.Size([10, 3, 224, 224])
tensor([317, 330, 356, 368, 386, 385, 334, 375, 335, 395])
