In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


#import torchinfo

from tqdm.notebook import tqdm
import itertools
import os
import pickle

In [None]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 16
        self.NUM_WORKERS = 8  # kinetics:8, ucf101:24

        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 8  # 事前学習済みモデルに合わせて16→8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400

args = Args()

In [None]:
transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
            transforms.Lambda(lambda x: x / 255.),
            # Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
            ShortSideScale(size=256),
            # RandomShortSideScale(min_size=256, max_size=320,),
            # CenterCropVideo(crop_size=(256, 256)),
            CenterCrop(256),
            # RandomCrop(224),
            RandomHorizontalFlip(),
        ]),
    ),
    ApplyTransformToKey(
        key="label",
        transform=transforms.Lambda(lambda x: x),
    ),
    RemoveKey("audio"),
])

In [None]:
root_SSv2 = '/mnt/NAS-TVS872XT/dataset/something-something-v2/'

# train_dataset = SSv2(
#             label_name_file=root_SSv2+"anno/something-something-v2-labels.json",
#             video_label_file=root_SSv2+"anno/something-something-v2-train.json",
#             video_path_label_file=root_SSv2+"PySlowFast/train.csv",
#             video_path_prefix="/tmp/",
#             clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
#             video_sampler=SequentialSampler,
#             # decode_audio=False,
#             transform=transform,
#         )

val_dataset = SSv2(
            label_name_file=root_SSv2+"anno/something-something-v2-labels.json",
            video_label_file=root_SSv2+"anno/something-something-v2-validation.json",
            video_path_label_file=root_SSv2+"PySlowFast/val.csv",
            video_path_prefix="/tmp/ssv2_frame",
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            # decode_audio=False,
            transform=transform,
        )


In [None]:
# print(len(train_dataset.video_sampler))
print(len(val_dataset.video_sampler))
data = val_dataset.__next__()
print(data["video"].shape)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

video = []
frame_list = []

# for i in range(16):
#     video.append(data["video"][i:i+1].numpy())
#     video[i] = np.squeeze(video[i])
#     video[i] = video[i].transpose(1,2,3,0)

# video_id = 0  # videoを0から15で指定 
data["video"] = data["video"].numpy().transpose(1,2,3,0)
for i in range(8):
    img = data["video"][i:i+1, :, :, :]
    img = np.squeeze(img)
    frame_list.append(img)

# fig = plt.figure()
# axes = []
    

In [None]:
rows = 2
cols = 4
frame_id = 0

fig, axes = plt.subplots(rows,cols,figsize=(16,16),tight_layout=True)

for i in range(rows):
    for j in range(cols):
        img = frame_list[i*cols+j]
        subplot_title = ("frame:" + str(frame_id))
        axes[i,j].set_title(subplot_title)
        axes[i,j].imshow(img)
        frame_id = frame_id + 1
plt.show()

In [None]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )
        # self.num_videos = make_num_videos(self.dataset)
        self.num_videos = len(self.dataset.video_sampler)

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.num_videos

In [None]:
val_loader = DataLoader(LimitDataset(val_dataset),
                            batch_size=2,
                            drop_last=True
                            )

In [None]:
print(len(val_dataset.video_sampler))
print(len(val_loader))

In [None]:
16*1548

In [None]:
data_from_loader = iter(val_loader).__next__()

In [None]:

video = []
frame_list = []

for i in range(2):
    video.append(data_from_loader["video"][i:i+1].numpy())
    video[i] = np.squeeze(video[i])
    video[i] = video[i].transpose(1,2,3,0)

video_id = 0 # videoを0から1で指定 

for i in range(8):
    img = video[video_id][i:i+1, :, :, :]
    img = np.squeeze(img)
    frame_list.append(img)

# fig = plt.figure()
# axes = []

rows = 2
cols = 4
frame_id = 0

fig, axes = plt.subplots(rows,cols,figsize=(16,16),tight_layout=True)

for i in range(rows):
    for j in range(cols):
        img = frame_list[i*cols+j]
        subplot_title = ("frame:" + str(frame_id))
        axes[i,j].set_title(subplot_title)
        axes[i,j].imshow(img)
        frame_id = frame_id + 1
plt.show()
