In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler

from torchvision import transforms

from pytorchvideo.models import x3d
from pytorchvideo.data import Ucf101, RandomClipSampler, UniformClipSampler, Kinetics

# from torchvision.transforms._transforms_video import (
#     CenterCropVideo,
#     NormalizeVideo,
# )

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)

from tqdm.notebook import tqdm
from collections import OrderedDict
import itertools
import os
import pickle

In [12]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 16
        self.NUM_WORKERS = 8  # kinetics:8, ucf101:24

        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 8  # 事前学習済みモデルに合わせて16→8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400

args = Args()

In [13]:
transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
            transforms.Lambda(lambda x: x / 255.),
            Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
            ShortSideScale(size=256),
            # RandomShortSideScale(min_size=256, max_size=320,),
            # CenterCropVideo(crop_size=(256, 256)),
            CenterCrop(256),
            # RandomCrop(224),
            RandomHorizontalFlip(),
        ]),
    ),
    ApplyTransformToKey(
        key="label",
        transform=transforms.Lambda(lambda x: x),
    ),
    RemoveKey("audio"),
])

In [14]:
root_kinetics = '/mnt/NAS-TVS872XT/dataset/Kinetics400/'

train_dataset = Kinetics(
            data_path=root_kinetics + "train",
            video_path_prefix=root_kinetics + "train",
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )

test_dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )

val_dataset = Kinetics(
            data_path=root_kinetics + "val",
            video_path_prefix=root_kinetics + "val",
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )

In [15]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [18]:
train_loader = DataLoader(LimitDataset(train_dataset),
                            batch_size=args.BATCH_SIZE,
                            drop_last=True,
                            num_workers=args.NUM_WORKERS)
val_loader = DataLoader(LimitDataset(val_dataset),
                            batch_size=args.BATCH_SIZE,
                            drop_last=True,
                            num_workers=args.NUM_WORKERS)
test_loader = DataLoader(LimitDataset(test_dataset),
                            batch_size=args.BATCH_SIZE,
                            drop_last=True,
                            num_workers=args.NUM_WORKERS)

In [20]:
data = iter(test_loader).__next__()



torch.Size([16])

In [23]:
data["video"][:1].shape

torch.Size([1, 3, 8, 256, 256])

torch.Size([16])