In [1]:
from comet_ml import Experiment

import torch
import torch.nn as nn
import torchvision
import torchinfo
import pytorchvideo
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler

from torchvision import transforms

from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101,
    RandomClipSampler,
    UniformClipSampler,
    Kinetics
)

# from torchvision.transforms._transforms_video import (
#     CenterCropVideo,
#     NormalizeVideo,
# )

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)

from tqdm import tqdm
from collections import OrderedDict
import itertools
import os
import argparse

In [2]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.NUM_EPOCH = 4
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 32
        self.NUM_WORKERS = 32
        # self.CLIP_DURATION = 16 / 25
        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400


In [3]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )
        # self.num_videos = make_num_videos(self.dataset)
        self.num_videos = len(self.dataset.video_sampler)

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.num_videos

In [41]:
class WrapperDataset(pytorchvideo.data.LabeledVideoDataset):
    def __init__(self, dataset, labeld_video_paths, clip_sampler):
        super().__init__()
        self.dataset = dataset
        self.labeled_video_paths = dataset.labeled_video_paths
        self.clip_sampler = dataset.clip_sampler
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )
        # self.num_videos = make_num_videos(self.dataset)
        self.num_videos = len(self.dataset.video_sampler)

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.num_videos

In [42]:
# class WrapperDataset(pytorchvideo.data.LabeledVideoDataset):
#     def __init__(self, dataset):
#         super().__init__()
#         self.dataset = dataset
#         self.labeled_video_paths = dataset.labeled_video_paths
#         self.clip_sampler = dataset.clip_sampler
#         self.dataset_iter = itertools.chain.from_iterable(
#             itertools.repeat(iter(dataset), 2)
#         )
#         # self.num_videos = make_num_videos(self.dataset)
#         self.num_videos = len(self.dataset.video_sampler)

#     def __getitem__(self, index):
#         return next(self.dataset_iter)

#     def __len__(self):
#         return self.num_videos

In [43]:
def get_ucf101(subset):
    """
    ucf101のデータセットを取得

    Args:
        subset (str): "train" or "test"

    Returns:
        pytorchvideo.data.LabeledVideoDataset: 取得したデータセット
    """
    subset_root_Ucf101 = 'ucfTrainTestlist/trainlist01.txt'
    if subset == "test":
        subset_root_Ucf101 = 'ucfTrainTestlist/testlist.txt'

    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                RandomShortSideScale(min_size=256, max_size=320,),
                RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x - 1),
        ),
        RemoveKey("audio"),
    ])

    # root_ucf101 = '/mnt/dataset/UCF101/'
    root_ucf101 = '/mnt/dataset/UCF101/'

    dataset = Ucf101(
        data_path=root_ucf101 + subset_root_Ucf101,
        video_path_prefix=root_ucf101 + 'video/',
        clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
        video_sampler=RandomSampler,
        decode_audio=False,
        transform=transform,
    )

    return dataset

In [44]:
def get_kinetics(subset):
    """
    Kinetics400のデータセットを取得

    Args:
        subset (str): "train" or "val" or "test"

    Returns:
        pytorchvideo.data.LabeledVideoDataset: 取得したデータセット
    """
    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(size=256),
                # RandomShortSideScale(min_size=256, max_size=320,),
                # CenterCropVideo(crop_size=(256, 256)),
                CenterCrop(256),
                # RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x),
        ),
        RemoveKey("audio"),
    ])

    root_kinetics = '/mnt/dataset/Kinetics400/'

    if subset == "test":
        dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset
    else:
        dataset = Kinetics(
            data_path=root_kinetics + subset,
            video_path_prefix=root_kinetics + subset,
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset

    return False

In [45]:
def make_loader(dataset):
    """
    データローダーを作成

    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset

    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    args = Args()
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [48]:
def make_multi_loader(dataset):
    args = Args()
    loader = DataLoader(dataset,
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [46]:
ucf_dataset = get_ucf101("test")
kinetics_dataset = get_kinetics("test")
print(ucf_dataset.num_videos)
print(kinetics_dataset.num_videos)

3783
35357


https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py

```python
class ChainDataset(IterableDataset):
    r"""Dataset for chaining multiple :class:`IterableDataset` s.
    This class is useful to assemble different existing dataset streams. The
    chaining operation is done on-the-fly, so concatenating large-scale
    datasets with this class will be efficient.
    Args:
        datasets (iterable of IterableDataset): datasets to be chained together
    """
    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ChainDataset, self).__init__()
        self.datasets = datasets

    def __iter__(self):
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            for x in d:
                yield x

    def __len__(self):
        total = 0
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            total += len(d)
        return total
```

In [47]:
chain_datasets = ucf_dataset + kinetics_dataset
print(type(chain_datasets))
concat_datasets = LimitDataset(ucf_dataset) + LimitDataset(kinetics_dataset)
print(type(concat_datasets))
wrapper_datasets = WrapperDataset(ucf_dataset, labeld_video_paths="a", clip_sampler="b") + WrapperDataset(kinetics_dataset, labeld_video_paths="a", clip_sampler="b")
print(type(wrapper_datasets))


<class 'torch.utils.data.dataset.ChainDataset'>
<class 'torch.utils.data.dataset.ConcatDataset'>


TypeError: __init__() missing 2 required positional arguments: 'labeled_video_paths' and 'clip_sampler'

In [14]:
print(type(chain_datasets.datasets))
print(type(chain_datasets.datasets[0]))
print(chain_datasets.__len__())

<class 'list'>
<class 'pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset'>


TypeError: object of type 'LabeledVideoDataset' has no len()

In [39]:
print(type(concat_datasets.datasets))
print(type(concat_datasets.datasets[0]))

<class 'list'>
<class '__main__.LimitDataset'>


In [49]:
chain_loader = make_multi_loader(chain_datasets)

In [50]:
concat_loader = make_multi_loader(concat_datasets)

In [65]:
chain_batch = iter(chain_loader).__next__()
print(type(chain_batch))
chain_inputs = chain_batch["video"]
# print(type(chain_inputs))
print(chain_inputs.shape)
chain_labels = chain_batch["label"]
# print(type(chain_labels))
print(chain_labels.shape)

<class 'dict'>
<class 'torch.Tensor'>
torch.Size([32])


In [66]:
concat_batch = iter(concat_loader).__next__()
print(type(concat_batch))
concat_inputs = concat_batch["video"]
# print(type(concat_inputs))
print(concat_inputs.shape)
concat_labels = concat_batch["label"]
# print(type(concat_labels))
print(concat_labels.shape)

<class 'dict'>
<class 'torch.Tensor'>
torch.Size([32])


In [67]:
class MyUcf101(pytorchvideo.data.LabeledVideoDataset):

    def __next__(self):
        dict = super().__next__()
        dict["dataset"] = "ucf101"
        return dict

class MyKinetics(pytorchvideo.data.LabeledVideoDataset):

    def __next__(self):
        dict = super().__next__()
        dict["dataset"] = "kinetics"
        return dict

## torchvision.dataset.DatasetNameとpytorchvicdo.data.LaveldVideoDatasetの違い
- torchvisionはクラスを呼び出しているのでtorchvisionのクラスを継承したクラスを定義して新たに定義したクラスにtorchvisionと同じ引数を与えて呼び出せる
- pytorchvideoはメソッドを呼び出して返り値がクラスなのでwrapperのようなことをするしかない？
- LimitDatasetの___getitem__()でデータセットの名前も同時に返すとか（dictに入らないしデータセットごとにLimitDatasetクラスを用意しなければならない）

In [68]:
my_ucf_dataset = MyUcf101(ucf_dataset)
my_kinetics_dataset = MyKinetics(kinetics_dataset)

TypeError: __init__() missing 1 required positional argument: 'clip_sampler'