In [101]:
import torch
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_xs', pretrained=True)
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True)
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_m', pretrained=True)

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [102]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


import torchinfo

from tqdm.notebook import tqdm
import itertools
import os
import pickle

In [103]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 8
        self.NUM_WORKERS = 8  # kinetics:8, ucf101:24
        # self.CLIP_DURATION = 16 / 25
        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 16  # 事前学習済みモデルに合わせて16→8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400
        self.MODEL_NAME = "x3d_m"

args = Args()

In [104]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [105]:
def get_kinetics(subset):
    """
    Kinetics400のデータセットを取得

    Args:
        subset (str): "train" or "val" or "test"

    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset: 取得したデータセット
    """
    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(size=256),
                # RandomShortSideScale(min_size=256, max_size=320,),
                # CenterCropVideo(crop_size=(256, 256)),
                CenterCrop(224),
                # RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x),
        ),
        RemoveKey("audio"),
    ])

    root_kinetics = '/mnt/NAS-TVS872XT/dataset/Kinetics400/'

    if subset == "test":
        dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset
    else:
        dataset = Kinetics(
            data_path=root_kinetics + subset,
            video_path_prefix=root_kinetics + subset,
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset

    return False

In [106]:
def make_loader(dataset):
    """
    データローダーを作成

    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset

    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    args = Args()
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [107]:
# dataset = get_kinetics("val")

In [108]:
# data = dataset.__next__()
# data = data["video"]
# print(data.shape)

In [109]:
# dataloader = make_loader(dataset)
# data_from_loader = iter(dataloader).__next__()

In [110]:
# batch = data_from_loader["video"]

# print(batch.shape)

In [111]:
# batch = batch.permute(0,2,1,3,4)
# print(batch.shape)

In [112]:
# video_data_list = []
# for i in range(args.BATCH_SIZE):
#     video_data_list.append(batch[i])

In [113]:
# print(video_data_list[0].shape)

In [114]:
# new_batch = torch.cat(video_data_list, dim=0)
# print(new_batch.shape)

In [115]:
def change_dim(list):
    new_batch = torch.cat(list, dim=0)
    return new_batch

In [116]:
# test_list = []
# input_0 = torch.zeros(16,3,224,224)
# input_1 = torch.ones(16,3,224,224)
# test_list.append(input_0)
# test_list.append(input_1)
# # print(test_list[1])
# test_batch = change_dim(test_list)
# print(test_batch.shape)
# # print(test_batch[16:32, :])

5次元を4次元に変更完了

In [117]:
# # ダミーデータ用意
# dummy_data_list = []

# for i in range(args.BATCH_SIZE):
#     dummy_data = torch.linspace(i, i, 3*16*224*224)
#     dummy_data = dummy_data.view(16,3,224,224)
#     dummy_data_list.append(dummy_data)

# print(dummy_data_list[0].shape)
# # print(dummy_data_list[5])  # 全ての要素が5（後でシフトが上手くできたか確認できるように）

In [118]:
# dummy_batch = change_dim(dummy_data_list)
# print(dummy_batch.shape)
# # print(dummy_batch[16*5:16*6,:])

In [119]:
# dummy_data_list_from_batch = []
# for i in range(args.BATCH_SIZE):
#     dummy_data_from_batch = dummy_batch[i*16:(i+1)*16, :, :, :]
#     # if i == 5:
#     #     print(dummy_data_from_batch.shape)
#     #     print(dummy_data_from_batch)
#     dummy_data_list_from_batch.append(dummy_data_from_batch)

# # print(dummy_data_list_from_batch[5])
    

### ビデオデータをフレームに分割し画像認識モデルに流すための実験

In [127]:
dataset = get_kinetics("val")
dataset.video_sampler._num_samples = 100

dataloader = make_loader(dataset)

In [128]:
def make_new_batch(x: torch.Tensor) -> torch.Tensor:
    x = x.permute(0,2,1,3,4)
    video_data_list = []
    for i in range(x.size()[0]):
        video_data_list.append(x[i])
    new_batch = torch.cat(video_data_list, dim=0)
    return new_batch

In [129]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_epoch = 1

with tqdm(range(num_epoch)) as pbar_epoch:
    for epoch in pbar_epoch:
        pbar_epoch.set_description("[Epoch {}]".format(epoch))

        with tqdm(enumerate(dataloader),
                  total=len(dataloader),
                  leave=True) as pbar_batch:
            
            for batch_idx, batch in pbar_batch:
                inputs = batch['video'].to(device)
                targets = batch['label'].to(device)
                if batch_idx == 0:
                    # new_batch = make_new_batch(inputs)
                    # print(inputs.shape)
                    # print(new_batch.shape)
                    targets_list = []
                    num_frame = inputs.size()[2]
                    for i in range(targets.size()[0]):
                        # print(targets[i])
                        target_id = targets[i].item()
                        # print(target_id)
                        label = torch.full((1,8), target_id)
                        # print(label.shape)
                        # print(label)
                        targets_list.append(label)
                    new_targets = torch.cat(targets_list, dim=1)
                    new_targets = torch.squeeze(new_targets)
                    print(targets)
                    print(new_targets)
                    print(targets.shape)
                    print(new_targets.shape)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12.0), HTML(value='')))

tensor([223,  36, 179, 392,  71, 350, 308, 163], device='cuda:0')
tensor([223, 223, 223, 223, 223, 223, 223, 223,  36,  36,  36,  36,  36,  36,
         36,  36, 179, 179, 179, 179, 179, 179, 179, 179, 392, 392, 392, 392,
        392, 392, 392, 392,  71,  71,  71,  71,  71,  71,  71,  71, 350, 350,
        350, 350, 350, 350, 350, 350, 308, 308, 308, 308, 308, 308, 308, 308,
        163, 163, 163, 163, 163, 163, 163, 163])
torch.Size([8])
torch.Size([64])


