In [101]:
import torch
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_xs', pretrained=True)
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True)
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_m', pretrained=True)

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [102]:
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


import torchinfo

from tqdm.notebook import tqdm
import itertools
import os
import pickle

In [103]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 16
        self.NUM_WORKERS = 8  # kinetics:8, ucf101:24
        # self.CLIP_DURATION = 16 / 25
        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 16  # 事前学習済みモデルに合わせて16→8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400
        self.MODEL_NAME = "x3d_m"

args = Args()

In [104]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [105]:
def get_kinetics(subset):
    """
    Kinetics400のデータセットを取得

    Args:
        subset (str): "train" or "val" or "test"

    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset: 取得したデータセット
    """
    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(size=256),
                # RandomShortSideScale(min_size=256, max_size=320,),
                # CenterCropVideo(crop_size=(256, 256)),
                CenterCrop(224),
                # RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x),
        ),
        RemoveKey("audio"),
    ])

    root_kinetics = '/mnt/NAS-TVS872XT/dataset/Kinetics400/'

    if subset == "test":
        dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset
    else:
        dataset = Kinetics(
            data_path=root_kinetics + subset,
            video_path_prefix=root_kinetics + subset,
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset

    return False

In [106]:
def make_loader(dataset):
    """
    データローダーを作成

    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset

    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    args = Args()
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [107]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_org = torch.hub.load(
        'facebookresearch/pytorchvideo', args.MODEL_NAME, pretrained=True)
model_org = model_org.to(device)
dataset = get_kinetics("test")

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [108]:
data = dataset.__next__()
data = data["video"]
print(data.shape)

torch.Size([3, 16, 224, 224])


In [109]:
batch_size = 1

torchinfo.summary(
    model_org,
    input_size=(batch_size,3,16,224,224),
    depth=4,
    col_names=["input_size",
               "output_size"],
    row_settings=("var_names",)
)

Layer (type (var_name))                                      Input Shape               Output Shape
Net                                                          --                        --
├─ModuleList (blocks)                                        --                        --
│    └─ResStage (1)                                          --                        --
│    │    └─ModuleList (res_blocks)                          --                        --
│    └─ResStage (2)                                          --                        --
│    │    └─ModuleList (res_blocks)                          --                        --
│    └─ResStage (3)                                          --                        --
│    │    └─ModuleList (res_blocks)                          --                        --
│    └─ResStage (4)                                          --                        --
│    │    └─ModuleList (res_blocks)                          --                        --


In [123]:
class ReconstructNet(nn.Module):
    def __init__(self):
        super().__init__()
        model = torch.hub.load(
        'facebookresearch/pytorchvideo', args.MODEL_NAME, pretrained=True)

        self.res_blocks = nn.Sequential(
            # model.blocks[0],
            # model.blocks[1]
            model.blocks
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.res_blocks(x)
        return x
    


model_new = ReconstructNet()

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [124]:
model_new = ReconstructNet()
model_new = model_new.to(device)

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [125]:
torchinfo.summary(
    model_new,
    depth=4
)

Layer (type:depth-idx)                                            Param #
ReconstructNet                                                    --
├─Sequential: 1-1                                                 --
│    └─ModuleList: 2-1                                            --
│    │    └─ResNetBasicStem: 3-1                                  --
│    │    │    └─Conv2plus1d: 4-1                                 768
│    │    │    └─BatchNorm3d: 4-2                                 48
│    │    │    └─ReLU: 4-3                                        --
│    │    └─ResStage: 3-2                                         --
│    │    │    └─ModuleList: 4-4                                  15,370
│    │    └─ResStage: 3-3                                         --
│    │    │    └─ModuleList: 4-5                                  73,248
│    │    └─ResStage: 3-4                                         --
│    │    │    └─ModuleList: 4-6                                  569,256
│    │    └─Res