In [1]:
from comet_ml import Experiment

import torch
import torch.nn as nn
import torchvision
import torchinfo
import pytorchvideo
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler

from torchvision import transforms

from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101,
    RandomClipSampler,
    UniformClipSampler,
    Kinetics
)

# from torchvision.transforms._transforms_video import (
#     CenterCropVideo,
#     NormalizeVideo,
# )

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)

from tqdm import tqdm
from collections import OrderedDict
import itertools
import os
import argparse

In [2]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/NAS-TVS872XT/dataset/'
        self.root = self.metadata_path
        self.annotation_path = self.metadata_path
        self.NUM_EPOCH = 4
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 32
        self.NUM_WORKERS = 32
        # self.CLIP_DURATION = 16 / 25
        self.CLIP_DURATION = (8 * 8) / 30  # (num_frames * sampling_rate)/fps
        self.VIDEO_NUM_SUBSAMPLED = 8
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400


In [3]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )
        # self.num_videos = make_num_videos(self.dataset)
        self.num_videos = len(self.dataset.video_sampler)

    def __getitem__(self, index):
        # return next(self.dataset_iter), "kinetics"
        return next(self.dataset_iter)

    def __len__(self):
        return self.num_videos

In [4]:
# class KineticsLimitDataset(torch.utils.data.Dataset):
#     def __init__(self, dataset):
#         super().__init__()
#         self.dataset = dataset
#         self.dataset_iter = itertools.chain.from_iterable(
#             itertools.repeat(iter(dataset), 2)
#         )
#         # self.num_videos = make_num_videos(self.dataset)
#         self.num_videos = len(self.dataset.video_sampler)

#     def __getitem__(self, index):
#         return next(self.dataset_iter), "kinetics"
    

#     def __len__(self):
#         return self.num_videos

In [5]:
class KineticsLimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )
        # self.num_videos = make_num_videos(self.dataset)
        self.num_videos = len(self.dataset.video_sampler)

    def __getitem__(self, index):
        dict = next(self.dataset_iter)
        dict["dataset_name"] = "kinetics"
        return dict
    

    def __len__(self):
        return self.num_videos

In [6]:
# class Ucf101LimitDataset(torch.utils.data.Dataset):
#     def __init__(self, dataset):
#         super().__init__()
#         self.dataset = dataset
#         self.dataset_iter = itertools.chain.from_iterable(
#             itertools.repeat(iter(dataset), 2)
#         )
#         # self.num_videos = make_num_videos(self.dataset)
#         self.num_videos = len(self.dataset.video_sampler)

#     def __getitem__(self, index):
#         return next(self.dataset_iter), "Ucf101"
    

#     def __len__(self):
#         return self.num_videos

In [7]:
class Ucf101LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )
        # self.num_videos = make_num_videos(self.dataset)
        self.num_videos = len(self.dataset.video_sampler)

    def __getitem__(self, index):
        dict = next(self.dataset_iter)
        dict["dataset_name"] = "ucf101"
        return dict
    

    def __len__(self):
        return self.num_videos

In [8]:
# class WrapperDataset(pytorchvideo.data.LabeledVideoDataset):
#     def __init__(self, dataset, labeld_video_paths, clip_sampler):
#         super().__init__()
#         self.dataset = dataset
#         self.labeled_video_paths = dataset.labeled_video_paths
#         self.clip_sampler = dataset.clip_sampler
#         self.dataset_iter = itertools.chain.from_iterable(
#             itertools.repeat(iter(dataset), 2)
#         )
#         # self.num_videos = make_num_videos(self.dataset)
#         self.num_videos = len(self.dataset.video_sampler)

#     def __getitem__(self, index):
#         return next(self.dataset_iter)

#     def __len__(self):
#         return self.num_videos

In [9]:
# class WrapperDataset(pytorchvideo.data.LabeledVideoDataset):
#     def __init__(self, dataset):
#         super().__init__()
#         self.dataset = dataset
#         self.labeled_video_paths = dataset.labeled_video_paths
#         self.clip_sampler = dataset.clip_sampler
#         self.dataset_iter = itertools.chain.from_iterable(
#             itertools.repeat(iter(dataset), 2)
#         )
#         # self.num_videos = make_num_videos(self.dataset)
#         self.num_videos = len(self.dataset.video_sampler)

#     def __getitem__(self, index):
#         return next(self.dataset_iter)

#     def __len__(self):
#         return self.num_videos

In [10]:
def get_ucf101(subset):
    """
    ucf101のデータセットを取得

    Args:
        subset (str): "train" or "test"

    Returns:
        pytorchvideo.data.LabeledVideoDataset: 取得したデータセット
    """
    subset_root_Ucf101 = 'ucfTrainTestlist/trainlist01.txt'
    if subset == "test":
        subset_root_Ucf101 = 'ucfTrainTestlist/testlist.txt'

    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(size=256),
                # RandomShortSideScale(min_size=256, max_size=320,),
                # CenterCropVideo(crop_size=(256, 256)),
                CenterCrop(256),
                # RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x - 1),
        ),
        RemoveKey("audio"),
    ])

    # root_ucf101 = '/mnt/dataset/UCF101/'
    root_ucf101 = '/mnt/dataset/UCF101/'

    dataset = Ucf101(
        data_path=root_ucf101 + subset_root_Ucf101,
        video_path_prefix=root_ucf101 + 'video/',
        clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
        video_sampler=RandomSampler,
        decode_audio=False,
        transform=transform,
    )

    return dataset

In [11]:
def get_kinetics(subset):
    """
    Kinetics400のデータセットを取得

    Args:
        subset (str): "train" or "val" or "test"

    Returns:
        pytorchvideo.data.LabeledVideoDataset: 取得したデータセット
    """
    args = Args()
    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(size=256),
                # RandomShortSideScale(min_size=256, max_size=320,),
                # CenterCropVideo(crop_size=(256, 256)),
                CenterCrop(256),
                # RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x),
        ),
        RemoveKey("audio"),
    ])

    root_kinetics = '/mnt/dataset/Kinetics400/'

    if subset == "test":
        dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset
    else:
        dataset = Kinetics(
            data_path=root_kinetics + subset,
            video_path_prefix=root_kinetics + subset,
            clip_sampler=RandomClipSampler(clip_duration=args.CLIP_DURATION),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset

    return False

In [12]:
def make_loader(dataset):
    """
    データローダーを作成

    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset

    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    args = Args()
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [13]:
def make_multi_loader(dataset):
    args = Args()
    loader = DataLoader(dataset,
                        batch_size=args.BATCH_SIZE,
                        shuffle=True,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS)
    return loader

In [15]:
ucf_dataset = get_ucf101("test")
# ucf_dataset.video_sampler._num_samples = 100
kinetics_dataset = get_kinetics("test")
print(ucf_dataset.num_videos)
print(kinetics_dataset.num_videos)

3783
35357


https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py



```python:dataset.py
class ChainDataset(IterableDataset):
    r"""Dataset for chaining multiple :class:`IterableDataset` s.
    This class is useful to assemble different existing dataset streams. The
    chaining operation is done on-the-fly, so concatenating large-scale
    datasets with this class will be efficient.
    Args:
        datasets (iterable of IterableDataset): datasets to be chained together
    """
    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ChainDataset, self).__init__()
        self.datasets = datasets

    def __iter__(self):
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            for x in d:
                yield x

    def __len__(self):
        total = 0
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            total += len(d)
        return total
```



```python
class ConcatDataset(Dataset[T_co]):
    r"""Dataset as a concatenation of multiple datasets.
    This class is useful to assemble different existing datasets.
    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    datasets: List[Dataset[T_co]]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super(ConcatDataset, self).__init__()
        self.datasets = list(datasets)
        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes
```

In [15]:
# chain_datasets = ucf_dataset + kinetics_dataset
# print(type(chain_datasets))
# concat_datasets = LimitDataset(ucf_dataset) + LimitDataset(kinetics_dataset)
# print(type(concat_datasets))
limit_datasets = Ucf101LimitDataset(ucf_dataset) + KineticsLimitDataset(kinetics_dataset)
print(type(limit_datasets))
# wrapper_datasets = WrapperDataset(ucf_dataset, labeld_video_paths="a", clip_sampler="b") + WrapperDataset(kinetics_dataset, labeld_video_paths="a", clip_sampler="b")
# print(type(wrapper_datasets))


<class 'torch.utils.data.dataset.ConcatDataset'>


In [16]:
# print(type(chain_datasets.datasets))
# print(type(chain_datasets.datasets[0]))
# print(chain_datasets.__len__())

In [18]:
# print(type(concat_datasets.datasets))
# print(type(concat_datasets.datasets[0]))

In [19]:
# chain_loader = make_multi_loader(chain_datasets)

In [20]:
# concat_loader = make_multi_loader(concat_datasets)

In [21]:
limit_loader = make_multi_loader(limit_datasets)

In [22]:
# chain_batch = iter(chain_loader).__next__()
# print(type(chain_batch))
# chain_inputs = chain_batch["video"]
# # print(type(chain_inputs))
# print(chain_inputs.shape)
# chain_labels = chain_batch["label"]
# # print(type(chain_labels))
# print(chain_labels.shape)

### 下の2つのセルの結果からデータセット名を加えても実行時間は変わらないことが確認できた
- DataLoaderを```Shuffle=False```だともっと時間がはやかっった
- DataLoaderでシャッフルするならVideoSamplerはRandomでなくてもよさそう
  - そっちの方が早いかどうかも後で検証

In [23]:
# concat_batch = iter(concat_loader).__next__()
# print(type(concat_batch))
# print(concat_batch.keys())
# concat_inputs = concat_batch["video"]
# # print(type(concat_inputs))
# print(concat_inputs.shape)
# concat_labels = concat_batch["label"]
# # print(type(concat_labels))
# print(concat_labels.shape)

In [24]:
limit_batch = iter(limit_loader).__next__()
print(type(limit_batch))
print(limit_batch.keys())
limit_inputs = limit_batch["video"]
# print(type(limit_inputs))
print(limit_inputs.shape)
limit_labels = limit_batch["label"]
# print(type(limit_labels))
print(limit_labels.shape)

<class 'dict'>
dict_keys(['video', 'video_name', 'video_index', 'clip_index', 'aug_index', 'label', 'dataset_name'])
torch.Size([32, 3, 8, 256, 256])
torch.Size([32])


In [39]:
# limit_batch, data_name = iter(limit_loader).__next__()
# print(data_name)
# print(type(data_name))
# print(type(limit_batch))
# print(limit_batch.keys())
# limit_inputs = limit_batch["video"]
# # print(type(limit_inputs))
# print(limit_inputs.shape)
# limit_labels = limit_batch["label"]
# # print(type(limit_labels))
# print(limit_labels.shape)

In [25]:
limit_batch["dataset_name"]

['kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'ucf101',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'ucf101',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics',
 'ucf101',
 'kinetics',
 'kinetics',
 'kinetics',
 'kinetics']

In [20]:
# class MyUcf101(pytorchvideo.data.LabeledVideoDataset):

#     def __next__(self):
#         dict = super().__next__()
#         dict["dataset"] = "ucf101"
#         return dict

# class MyKinetics(pytorchvideo.data.LabeledVideoDataset):

#     def __next__(self):
#         dict = super().__next__()
#         dict["dataset"] = "kinetics"
#         return dict

In [None]:
# y_ucf_dataset = MyUcf101(ucf_dataset)
# my_kinetics_dataset = MyKinetics(kinetics_dataset)

## torchvision.dataset.DatasetNameとpytorchvicdo.data.LaveldVideoDatasetの違い
- torchvisionはクラスを呼び出しているのでtorchvisionのクラスを継承したクラスを定義して新たに定義したクラスにtorchvisionと同じ引数を与えて呼び出せる
- pytorchvideoはメソッドを呼び出して返り値がクラスなのでwrapperのようなことをするしかない？
- LimitDatasetの```__getitem__()```でデータセットの名前も同時に返すとか（dictに入らないしデータセットごとにLimitDatasetクラスを用意しなければならない）
  - ```__getitem__()```ではバッチごと持ってくるので1つ1つのデータにデータセット名を与えられないので（バッチ単位で同じデータセット名を与えてしまう）やっぱりdictに入れるのが良さそう

## 以下で学習ループを回しデータによってアダプタを変更するコードを設計

In [32]:
class Adapter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(dim)
        self.conv1 = nn.Conv2d(dim, dim, 1)       
        self.bn2 = nn.BatchNorm2d(dim)


    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.conv1(out)
        
        out += residual
        out = self.bn2(out)        

        return out

    
class VideoToFrame(nn.Module):
    def __init__(self):
        super().__init__()

    def make_new_inputs(self, inputs):
        """
        動画データを画像データに分割

        Args:
            inputs (torch.Tensor): inputs
        Returns:
            new_inputs torch.Tensor: new_inputs
        """

        batch_size = inputs.size(0)
        num_frame = inputs.size(2)
        
        inputs = inputs.permute(0,2,1,3,4)
        outputs = inputs.reshape(batch_size*num_frame, 
                                inputs.size(2), 
                                inputs.size(3), 
                                inputs.size(4))

        return outputs

    def forward(self, x):
        x = self.make_new_inputs(x)
        return x

class FrameAvg(nn.Module):
    def __init__(self):
        super().__init__()


    def frame_out_to_video_out(self, input: torch.Tensor, batch_size, num_frame) -> torch.Tensor:
        """
        フレームごとの出力をビデオとしての出力に変換
        Args:
            input (torch.Tensor): フレームごとの出力
            batch_size (int): バッチサイズ
            num_frame (int): フレーム数

        Returns:
            torch.Tensor: ビデオとしての出力
        """
        input = input.reshape(batch_size, num_frame, -1)
        output = torch.mean(input, dim=1)
        return output

    def forward(self, x, batch_size, num_frame):
        x = self.frame_out_to_video_out(x, batch_size, num_frame)
        return x




## バッチ内に異なるデータセットがある場合のモデルの設計の問題点
- モデル側からはどのデータがどのデータセットからなのかが不明
  - モデルにどのデータがどのデータセットからなのかの情報を入力とともに与えることで解決
- バッチ内で通過するモジュールが異なる（アダプタと出力層）
  - テンソルを複数（データセットの数）に分けてそれぞれモジュールに追加させる
    - バッチごとにあるデータセットからの数は異なるのでモジュールに流れるデータ数も異なる
    - データ数が異なるがモジュールはうまく学習できるのか？（一般的なモデルで言うとバッチサイズが毎回異なるのに学習がうまく進むのかということ）
    - データ数が偏っている場合はあるアダプタだけが学習が進むようなことが起きる？
      - アダプタと出力層の学習だけなら問題ない気がするがドメイン非依存のパラメータも同時に学習する場合はドメイン非依存のパラメータがデータ数の多いアダプタからの出力から主に学習をするのでデータ数の少ないアダプタの学習が遅れる？
  - テンソルをデータセットごとに分ける方法
    - adapter
      - reshapeで(batch*frame, chanel, h, w) -> (batch, frame, chanel, h, w)
      - 理想は(batch, frame, chanel, h, w) -> (num_domain, batch_of_domain, frame, chanel, h, w)
        - データがデータセットごとに順番かつ同数になっていたらreshapeでこれが可能
    - head（出力層）
      - (batch*frame, dim) -> (batch, frame, dim) -> (num_domain, batch_of_domain, frame, dim)
      - テンソルを切り取って (num_domain, batch_of_domain, frame, dim) -> (batch_of_domain, frame, dim)　* num_domain
      - (batch_of_domain, frame, dim) -> (bach_of_domain*frame, dim)
      - あとはdomainに合わせたクラス数を用いてそれぞれのドメインごとに (bach_of_domain*frame, dim) -> (bach_of_domain*frame, num_class) -> (bach_of_domain, frame, dim)
      - フレーム方向（dim=1）で平均　(bach_of_domain, frame, num_class) -> (batch_of_domain, num_class)
      - 最後にドメインごとの結果を連結はできない（クラス数が異なるから）
        - lossはどうやって計算する？
          - ドメインごとにロスを設計？
            - ドメインごとに学習しているのと同じだからダメ
          - ドメインごとにロスの和を最終的なロスとする
            - 中部大のmulti head型のmix loss
  - 疑問点
    - 学習の際は全てのデータセットでデータ数を揃えても問題ない？
      - その方が同じくらい学習したアダプタからの出力をドメイン非依存のモジュールが受け取れるので学習が安定しそう
- 中部大のmulti head型の内容
  - バッチは全て同じデータセットのデータを使う
  - バッチごとに誤差伝播をするのではなくて全てのデータセットのバッチを流して累積した誤差を逆伝播する
    - 学習時間とメモリ使用量の削減のために自動混合精度を用いて学習 (https://arxiv.org/pdf/1710.03740.pdf)
      - おそらくマルチドメインのためのものではなくて単に半精度と単制度の演算を組み合わせて高速化する方法のことだと思う

In [33]:
class ReconstructNet(nn.Module):
    def __init__(self):
        super().__init__()
        model = torchvision.models.resnet152(pretrained=True)
        model_num_features = model.fc.in_features
        ucf_num_class = 101
        kinetics_num_class = 400

        self.video_to_frame = VideoToFrame()
        self.net_bottom = nn.Sequential(
            model.conv1,
            model.bn1,
            model.relu,
            model.maxpool
        )
        
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool

        # self.adapter = Adapter(512)

        self.ucf_head = nn.Sequential(
            nn.Linear(model_num_features, ucf_num_class)
        )

        self.kinetics_head = nn.Sequential(
            nn.Linear(model_num_features, kinetics_num_class)
        )

        self.frame_avg = FrameAvg()
        
        # 学習させるパラメータ名
        self.update_param_names = ["adapter.bn1.weight", "adapter.bn1.bias",
                            "adapter.conv1.weight", "adapter.conv1.bias",
                            "adapter.bn2.weight", "adapter.bn2.bias", 
                            "ucf_head.0.weight", "ucf_head.0.bias"]
        # 学習させるパラメータ以外は勾配計算をなくし、変化しないように設定
        for name, param in self.named_parameters():
            if name in self.update_param_names:
                param.requires_grad = True
            else:
                param.requires_grad = False



    def forward(self, x):
        batch_size = x.size(0)
        num_frame = x.size(2)

        x = self.video_to_frame(x)
        x = self.net_bottom(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # x = self.adapter(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x,1)
        x = self.ucf_head(x)
        x = self.frame_avg(x, batch_size, num_frame)
        return x

In [34]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if type(val) == torch.Tensor:
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def top1(outputs, targets):
    batch_size = outputs.size(0)
    _, predicted = outputs.max(1)
    return predicted.eq(targets).sum().item() / batch_size

In [37]:
dataset_ucf = get_ucf101("test")
dataset_kinetics = get_kinetics("test")
datasets = Ucf101LimitDataset(ucf_dataset) + KineticsLimitDataset(kinetics_dataset)

In [38]:
loader = make_multi_loader(datasets)

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ReconstructNet()
model = model.to(device)

In [36]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 4

acc_list = []
loss_list = []
step = 0

with tqdm(range(num_epochs)) as pbar_epoch:
    for epoch in pbar_epoch:
        pbar_epoch.set_description("[Epoch %d]" % (epoch))


        with tqdm(enumerate(loader),
                  total=len(loader),
                  leave=True) as pbar_batch:

            train_loss = AverageMeter()
            train_acc = AverageMeter()
            model.train()


            for batch_idx, batch in pbar_batch:
                pbar_batch.set_description("[Epoch :{}]".format(epoch))

                inputs = batch['video'].to(device)
                labels = batch['label'].to(device)
                
                bs = inputs.size(0)  # current batch size, may vary at the end of the epoch

                optimizer.zero_grad()
                outputs = model(inputs)
                # print(outputs.device)
                # print(new_labels.device)

                # ここでフレームごとの出力をビデオごとの出力に変換する
                # video_outputs = frame_out_to_video_out(outputs, bs, args.VIDEO_NUM_SUBSAMPLED) 

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                
                preds = torch.squeeze(outputs.max(dim=1)[1])
                # print(video_outputs.shape)
                # print(preds.shape)

                # acc = (preds == labels).float().mean().item()
                # acc_list.append(acc)
                # pbar_batch.set_postfix(OrderedDict(loss=loss.item(),acc=acc))

                train_loss.update(loss, bs)
                train_acc.update(top1(outputs, labels), bs)

                pbar_batch.set_postfix_str(
                    ' | loss={:6.04f} , top1={:6.04f}'
                    ' | loss={:6.04f} , top1={:6.04f}'
                    ''.format(
                    train_loss.avg, train_acc.avg,
                    train_loss.val, train_acc.val,
                ))

                # experiment.log_metric("batch_accuracy", train_acc.val, step=step)
                step += 1

            acc_list.append(train_acc.avg)
            loss_list.append(train_loss.avg)
        pbar_epoch.set_postfix(OrderedDict(
            acc=sum(acc_list)/len(acc_list),
            loss=sum(loss_list)/len(loss_list)
        ))
        # experiment.log_metric("epoch_accuracy", train_acc.val, step=epoch)

