In [1]:
# import torch
# model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_xs', pretrained=True)
# model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True)
# model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_m', pretrained=True)

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


import torchinfo

from tqdm.notebook import tqdm
import itertools
import os
import pickle

In [27]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 8
        self.NUM_WORKERS = 8  # kinetics:8, ucf101:24
        # self.CLIP_DURATION = 16 / 25
        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 16  # 事前学習済みモデルに合わせて16→8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400
        self.MODEL_NAME = "x3d_m"

args = Args()

In [28]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [29]:
def get_kinetics(subset):
    """
    Kinetics400のデータセットを取得

    Args:
        subset (str): "train" or "val" or "test"

    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset: 取得したデータセット
    """
    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(size=256),
                # RandomShortSideScale(min_size=256, max_size=320,),
                # CenterCropVideo(crop_size=(256, 256)),
                CenterCrop(224),
                # RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x),
        ),
        RemoveKey("audio"),
    ])

    root_kinetics = '/mnt/NAS-TVS872XT/dataset/Kinetics400/'

    if subset == "test":
        dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset
    else:
        dataset = Kinetics(
            data_path=root_kinetics + subset,
            video_path_prefix=root_kinetics + subset,
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset

    return False

In [30]:
def make_loader(dataset):
    """
    データローダーを作成

    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset

    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    args = Args()
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [31]:
dataset = get_kinetics("val")

In [32]:
data = dataset.__next__()
data = data["video"]
print(data.shape)

torch.Size([3, 16, 224, 224])


In [33]:
dataloader = make_loader(dataset)
data_from_loader = iter(dataloader).__next__()

In [34]:
batch = data_from_loader["video"]

print(batch.shape)

torch.Size([8, 3, 16, 224, 224])


In [35]:
batch = batch.permute(0,2,1,3,4)
print(batch.shape)

torch.Size([8, 16, 3, 224, 224])


In [37]:
video_data_list = []
for i in range(args.BATCH_SIZE):
    video_data_list.append(batch[i])

In [38]:
print(video_data_list[0].shape)

torch.Size([16, 3, 224, 224])


In [39]:
new_batch = torch.cat(video_data_list, dim=0)
print(new_batch.shape)

torch.Size([128, 3, 224, 224])


In [48]:
def change_dim(list):
    args = Args()
    data_list = []
    for i in range(args.BATCH_SIZE):
        data_list.append(batch[i])
        new_batch = torch.cat(video_data_list, dim=0)
    return new_batch

5次元を4次元に変更完了

以下ではデータのシフトをテストする．

In [60]:
# ダミーデータ用意
dummy_data_list = []

for i in range(args.BATCH_SIZE):
    dummy_data = torch.linspace(i, i, 3*16*224*224)
    dummy_data = dummy_data.view(16,3,224,224)
    dummy_data_list.append(dummy_data)

print(dummy_data_list[0].shape)
# print(dummy_data_list[5])  # 全ての要素が5（後でシフトが上手くできたか確認できるように）

torch.Size([16, 3, 224, 224])


In [56]:
dummy_batch = change_dim(dummy_data_list)
print(dummy_batch.shape)
print(dummy_batch[0:8,:,:,:])

torch.Size([128, 3, 224, 224])
tensor([[[[ 0.5891,  0.5949,  0.6040,  ...,  2.3747,  2.3747,  2.3747],
          [ 0.5768,  0.5884,  0.6169,  ...,  2.4096,  2.4096,  2.4096],
          [ 0.5844,  0.5996,  0.6554,  ...,  2.4096,  2.4096,  2.4096],
          ...,
          [-0.8436,  0.4482,  1.4101,  ..., -1.6779, -1.6891, -1.7404],
          [ 1.7025,  2.0984,  2.0718,  ..., -1.6863, -1.6922, -1.7427],
          [ 1.9739,  1.9707,  1.9812,  ..., -1.6923, -1.6956, -1.7557]],

         [[ 0.6762,  0.6820,  0.6912,  ...,  2.4096,  2.4096,  2.4096],
          [ 0.6988,  0.7015,  0.7040,  ...,  2.4096,  2.4096,  2.4096],
          [ 0.7064,  0.7128,  0.7426,  ...,  2.4096,  2.4096,  2.4096],
          ...,
          [-0.6344,  0.6573,  1.6081,  ..., -1.5733, -1.5846, -1.6359],
          [ 1.9089,  2.3076,  2.2784,  ..., -1.5817, -1.5876, -1.6382],
          [ 2.1830,  2.1799,  2.1904,  ..., -1.5877, -1.5910, -1.6512]],

         [[ 0.6065,  0.6123,  0.6215,  ...,  2.4096,  2.4096,  2.4096],

In [53]:
dummy_data_list_from_batch = []
for i in range(args.BATCH_SIZE):
    dummy_data_from_batch = dummy_batch[i:i+16, :, :, :]
    print(dummy_data_from_batch.shape)
    dummy_data_list_from_batch.append(dummy_data_from_batch)

print(dummy_data_list_from_batch[2])
    

torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
torch.Size([16, 3, 224, 224])
tensor([[[[ 6.4924e-01,  6.4661e-01,  6.2268e-01,  ...,  2.3747e+00,
            2.3747e+00,  2.3747e+00],
          [ 6.2703e-01,  6.3671e-01,  6.5114e-01,  ...,  2.4096e+00,
            2.4096e+00,  2.4096e+00],
          [ 6.4080e-01,  6.4964e-01,  6.8160e-01,  ...,  2.4096e+00,
            2.4096e+00,  2.4096e+00],
          ...,
          [-8.5465e-01, -9.3751e-01, -9.1507e-01,  ..., -1.2194e+00,
           -5.6206e-01, -1.0906e+00],
          [-8.5382e-01, -1.1270e+00, -1.1599e+00,  ..., -5.0026e-01,
           -4.1224e-01, -7.7708e-01],
          [-1.2133e+00, -1.5635e+00, -1.2646e+00,  ..., -5.1982e-01,
           -1.0190e+00, -1.4233e+00]],

         [[ 7.3638e-01,  7.3376e-01,  7.0983e-01,  ...,  2.4096e+00,
            2.4096e+00,  2.4096e