In [41]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


#import torchinfo

from tqdm.notebook import tqdm
import itertools
import os
import pickle

In [42]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 16
        self.NUM_WORKERS = 8  # kinetics:8, ucf101:24

        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 8  # 事前学習済みモデルに合わせて16→8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400

args = Args()

In [43]:
transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
            transforms.Lambda(lambda x: x / 255.),
            # Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
            ShortSideScale(size=256),
            # RandomShortSideScale(min_size=256, max_size=320,),
            # CenterCropVideo(crop_size=(256, 256)),
            CenterCrop(256),
            # RandomCrop(224),
            RandomHorizontalFlip(),
        ]),
    ),
    ApplyTransformToKey(
        key="label",
        transform=transforms.Lambda(lambda x: x),
    ),
    RemoveKey("audio"),
])

In [59]:
root_SSv2 = '/mnt/NAS-TVS872XT/dataset/something-something-v2/'

# train_dataset = SSv2(
#             label_name_file=root_SSv2+"something-something-v2-labels.json",
#             video_label_file=root_SSv2+"something-something-v2-train.json",
#             video_path_label_file=root_SSv2+"PySlowFast/train.csv",
#             video_path_prefix="/tmp/",
#             clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
#             video_sampler=RandomSampler,
#             # decode_audio=False,
#             transform=transform,
#         )

val_dataset = SSv2(
            label_name_file=root_SSv2+"something-something-v2-labels.json",
            video_label_file=root_SSv2+"something-something-v2-validation.json",
            video_path_label_file=root_SSv2+"PySlowFast/val.csv",
            video_path_prefix="/tmp/",
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            # decode_audio=False,
            transform=transform,
        )


In [60]:
# print(len(train_dataset.video_sampler))
print(len(val_dataset.video_sampler))

24777


In [61]:
data = val_dataset.__next__()

FileNotFoundError: [Errno 2] No such file or directory: '/tmp/54222/54222_000001.jpg'

In [34]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [35]:
train_loader = DataLoader(LimitDataset(train_dataset),
                            batch_size=args.BATCH_SIZE,
                            drop_last=True,
                            num_workers=args.NUM_WORKERS)

In [36]:
len(train_loader)

AttributeError: 

In [37]:
data = iter(train_loader).__next__()

AttributeError: 