In [1]:
from comet_ml import Experiment
import torch
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_m', pretrained=True)

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_main


In [2]:
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


import torchinfo

import numpy as np
from tqdm.notebook import tqdm
import itertools
import os
import pickle
import random
import matplotlib.pyplot as plt
import shutil
from sklearn import mixture
from sklearn import svm
from sklearn import decomposition
import os.path as osp
import argparse

In [3]:
class Args:
    def __init__(self):
        self.NUM_EPOCH = 3
        self.FRAMES_PER_CLIP = 16
        self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 32
        self.NUM_WORKERS = 32
        # self.CLIP_DURATION = 16 / 25
        # (num_frames * sampling_rate)/fps
        self.kinetics_clip_duration = (8 * 8) / 30
        self.ucf101_clip_duration = 16 / 25
        self.VIDEO_NUM_SUBSAMPLED = 16
        self.UCF101_NUM_CLASSES = 101
        self.KINETIC400_NUM_CLASSES = 400

In [4]:
class Adapter2D(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(dim)
        self.conv1 = nn.Conv2d(dim, dim, 1)       
        self.bn2 = nn.BatchNorm2d(dim)
    
    def video_to_frame(self, inputs):
        batch_size = inputs.size(0)
        num_frame = inputs.size(2)

        inputs = inputs.permute(0, 2, 1, 3, 4)
        outputs = inputs.reshape(batch_size * num_frame,
                                 inputs.size(2),
                                 inputs.size(3),
                                 inputs.size(4))

        return outputs

    def frame_to_video(
            self, input: torch.Tensor, batch_size, num_frame, channel, height, width) -> torch.Tensor:
        output = input.reshape(batch_size, num_frame, channel, height, width)
        output = output.permute(0,2,1,3,4)
        return output


    def forward(self, x):
        batch_size = x.size(0)
        num_frame = x.size(2)
        channel= x.size(1)
        height = x.size(3)

        x = self.video_to_frame(x)

        residual = x
        out = self.bn1(x)
        out = self.conv1(out)
        out += residual
        out = self.bn2(out)

        out = self.frame_to_video(out, batch_size, num_frame, channel, height, height)

        return out

In [5]:
class ReconstructNet(nn.Module):
    def __init__(self):
        super().__init__()
        model = torch.hub.load(
        'facebookresearch/pytorchvideo', "x3d_m", pretrained=True)
        self.model_num_features = model.blocks[5].proj.in_features
        self.ucf_num_class = 101
        self.kinetics_num_class = 400

        self.net_bottom = nn.Sequential(
            model.blocks[0],
            model.blocks[1],
            model.blocks[2],
            model.blocks[3],
        )
        
        self.ucf_adapter0 = Adapter2D(96)
        self.kinetics_adapter0 = Adapter2D(96)

        self.blocks4 = model.blocks[4]

        self.adapter1_ucf = Adapter2D(192)
        self.adapter1_kinetics = Adapter2D(192)

        self.net_top = nn.Sequential(
            model.blocks[5].pool,
            model.blocks[5].dropout
        )

        self.linear_ucf = nn.Linear(self.model_num_features, self.ucf_num_class)
        self.linear_kinetics = model.blocks[5].proj
        # self.linear_kinetics = nn.Linear(self.model_num_features, self.kinetics_num_class)


    def forward(self, x: torch.Tensor, domain) -> torch.Tensor:
        x = self.net_bottom(x)
        # x = self.ucf_adapter0(x)
        x = self.blocks4(x)

        if domain == "UCF101":
            x = self.adapter1_ucf(x)
        elif domain == "Kinetics":
            x = self.adapter1_kinetics(x)
        
        x = self.net_top(x)
        x = x.permute(0,2,3,4,1)

        if domain == "UCF101":
            x = self.linear_ucf(x)
            x = x.view(-1,self.ucf_num_class)
        elif domain == "Kinetics":
            x = self.linear_kinetics(x)
            x = x.view(-1, self.kinetics_num_class)
            
        return x




In [6]:
class Adapter3D(nn.Module):

    def __init__(self, channel_dim, frame_dim):
        super().__init__()
        self.bn1 = nn.BatchNorm3d(channel_dim)
        self.conv1 = nn.Conv2d(frame_dim, frame_dim, 1)       
        self.bn2 = nn.BatchNorm3d(channel_dim)
    
    def video_to_frame_swap(self, inputs):
        batch_size = inputs.size(0)
        channel = inputs.size(1)
        num_frame = inputs.size(2)

        outputs = inputs.reshape(batch_size * channel,
                                 num_frame,
                                 inputs.size(3),
                                 inputs.size(4))

        return outputs

    def frame_to_video(
            self, input: torch.Tensor, batch_size, num_frame, channel, height, width) -> torch.Tensor:
        output = input.reshape(batch_size, channel, num_frame, height, width)
        return output


    def forward(self, x):
        batch_size = x.size(0)
        channel= x.size(1)
        num_frame = x.size(2)
        height = x.size(3)
        
        residual = x

        out = self.bn1(x)
        out = self.video_to_frame_swap(out)
        out = self.conv1(out)
        out = self.frame_to_video(out, batch_size, num_frame, channel, height, height)
        out += residual
        out = self.bn2(out)

        return out

In [7]:
class ReconstructNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        model = torch.hub.load(
        'facebookresearch/pytorchvideo', "x3d_m", pretrained=True)
        self.model_num_features = model.blocks[5].proj.in_features
        self.ucf_num_class = 101
        self.kinetics_num_class = 400
        self.num_frame = 16

        self.net_bottom = nn.Sequential(
            model.blocks[0],
            model.blocks[1],
            model.blocks[2],
            model.blocks[3],
        )
        
        # self.adapter0 = Adapter3D(96, self.num_frame)

        self.blocks4 = model.blocks[4]

        self.adapter1_ucf = Adapter3D(192, self.num_frame)
        self.adapter1_kinetics = Adapter3D(192, self.num_frame)

        self.net_top = nn.Sequential(
            model.blocks[5].pool,
            model.blocks[5].dropout
        )

        self.linear_kinetics = model.blocks[5].proj
        self.linear_ucf = nn.Linear(self.model_num_features, self.ucf_num_class)
        # self.linear_kineitcs = nn.Linear(self.model_num_features, self.kinetics_num_class)


    def forward(self, x: torch.Tensor, domain) -> torch.Tensor:
        x = self.net_bottom(x)
        # x = self.adapter0(x)
        x = self.blocks4(x)

        if domain == "UCF101":
            x = self.adapter1_ucf(x)
        elif domain == "Kinetics":
            x = self.adapter1_kinetics(x)
        
        x = self.net_top(x)
        x = x.permute(0,2,3,4,1)

        if domain == "UCF101":
            x = self.linear_ucf(x)
            x = x.view(-1,self.ucf_num_class)
        elif domain == "Kinetics":
            x = self.linear_kinetics(x)
            x = x.view(-1,self.kinetics_num_class)

        return x




## 学習させてみる

In [8]:
def get_kinetics(subset):
    """
    Kinetics400のデータセットを取得

    Args:
        subset (str): "train" or "val" or "test"

    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset: 取得したデータセット
    """
    args = Args()
    train_transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                RandomShortSideScale(min_size=256, max_size=320,),
                RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        RemoveKey("audio"),
    ])

    val_transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(256),
                CenterCrop(224),
            ]),
        ),
        RemoveKey("audio"),
    ])

    transform = val_transform if subset == "val" else train_transform

    root_kinetics = '/mnt/dataset/Kinetics400/'

    if subset == "test":
        dataset = Kinetics(
            data_path=root_kinetics + "test_list.txt",
            video_path_prefix=root_kinetics + 'test/',
            clip_sampler=RandomClipSampler(
                clip_duration=args.kinetics_clip_duration),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset
    else:
        dataset = Kinetics(
            data_path=root_kinetics + subset,
            video_path_prefix=root_kinetics + subset,
            clip_sampler=RandomClipSampler(
                clip_duration=args.kinetics_clip_duration),
            video_sampler=RandomSampler,
            decode_audio=False,
            transform=transform,
        )
        return dataset

    return False


In [9]:
def get_ucf101(subset):
    """
    ucf101のデータセットを取得

    Args:
        subset (str): "train" or "test"

    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset: 取得したデータセット
    """
    subset_root_Ucf101 = 'ucfTrainTestlist/trainlist01.txt' if subset == "train" else 'ucfTrainTestlist/testlist.txt'
    # if subset == "test":
    #     subset_root_Ucf101 = 'ucfTrainTestlist/testlist.txt'

    args = Args()
    train_transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                RandomShortSideScale(min_size=256, max_size=320,),
                RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x - 1),
        ),
        RemoveKey("audio"),
    ])

    val_transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(args.VIDEO_NUM_SUBSAMPLED),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                ShortSideScale(256),
                CenterCrop(224),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x - 1),
        ),
        RemoveKey("audio"),
    ])

    transform = train_transform if subset == "train" else val_transform

    root_ucf101 = '/mnt/dataset/UCF101/'
    # root_ucf101 = '/mnt/NAS-TVS872XT/dataset/UCF101/'

    dataset = Ucf101(
        data_path=root_ucf101 + subset_root_Ucf101,
        video_path_prefix=root_ucf101 + 'video/',
        clip_sampler=RandomClipSampler(
            clip_duration=args.ucf101_clip_duration),
        video_sampler=RandomSampler,
        decode_audio=False,
        transform=transform,
    )

    return dataset


In [10]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [11]:
def make_loader(dataset):
    """
    データローダーを作成

    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset

    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    args = Args()
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS,
                        shuffle=True)
    return loader

In [12]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if isinstance(val, torch.Tensor):
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def top1(outputs, targets):
    batch_size = outputs.size(0)
    _, predicted = outputs.max(1)
    return predicted.eq(targets).sum().item() / batch_size



In [16]:
import os.path as osp
import shutil

def save_checkpoint(state, is_best, filename, best_model_file, dir_data_name):
    file_path = osp.join(dir_data_name, filename)
    if not os.path.exists(dir_data_name):
        os.makedirs(dir_data_name)
    torch.save(state.state_dict(), file_path)
    if is_best:
        shutil.copyfile(file_path, osp.join(dir_data_name, best_model_file))

In [17]:
def train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    train_dataset_ucf = get_ucf101("train")
    val_dataset_ucf = get_ucf101("val")
    train_loader_ucf = make_loader(train_dataset_ucf)
    val_loader_ucf = make_loader(val_dataset_ucf)


    train_dataset_kinetics = get_kinetics("train")
    val_dataset_kinetics = get_kinetics("val")
    train_loader_kinetics = make_loader(train_dataset_kinetics)
    val_loader_kinetics = make_loader(val_dataset_kinetics)

    # model = ReconstructNet()
    model = ReconstructNet3D()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

    lr = 0.001
    weight_decay = 5e-5
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=5e-4)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    hyper_params = {
        "Dataset": "UCF101, Kinetics",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
        "adapter mode": "adapter3d",
        "Adapter": "adp:1",
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH
    # kinetics_iteration = len(train_loader_kinetics)
    iteration_ucf = len(train_loader_ucf) - 1
    train_ucf_enm = enumerate(train_loader_ucf)

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""

            train_loss_kinetics = AverageMeter()
            train_acc_kinetics = AverageMeter()
            train_loss_ucf = AverageMeter()
            train_acc_ucf = AverageMeter()

            with tqdm(enumerate(train_loader_kinetics),
                      total=len(train_loader_kinetics),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, batch in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = batch['video'].to(device)
                    labels = batch['label'].to(device)

                    optimizer.zero_grad()

                    bs_kinetics = inputs.size(0)

                    # optimizer.zero_grad()
                    outputs = model(inputs, "Kinetics")
                    loss_kinetics = criterion(outputs, labels)
                    loss_kinetics.backward()
                    train_loss_kinetics.update(loss_kinetics, bs_kinetics)
                    train_acc_kinetics.update(top1(outputs, labels), bs_kinetics)

                    
                    i, batch = train_ucf_enm.__next__()
                    if i+1 > iteration_ucf:
                        train_ucf_enm = enumerate(train_loader_ucf)
                    inputs = batch['video'].to(device)
                    labels = batch['label'].to(device)
                    bs_ucf = inputs.size(0)
                    # optimizer.zero_grad()
                    outputs = model(inputs, "UCF101")
                    loss_ucf = criterion(outputs, labels)
                    loss_ucf.backward()
                    
                    # loss = loss_kinetics + loss_ucf
                    # bs = bs_kinetics + bs_ucf

                    # loss.backward()
                    optimizer.step()

                    train_loss_ucf.update(loss_ucf, bs_ucf)
                    train_acc_ucf.update(top1(outputs, labels), bs_ucf)

                    pbar_train_batch.set_postfix_str(
                        ' | batch_loss_kinetics={:6.04f} , batch_top1_kinetics={:6.04f}'
                        ' | batch_loss_ucf={:6.04f} , batch_top1_ucf={:6.04f}'
                        ''.format(
                            train_loss_kinetics.val, train_acc_kinetics.val,
                            train_loss_ucf.val, train_acc_ucf.val,
                        ))

                    experiment.log_metric(
                        "batch_accuracy_kinetics", train_acc_kinetics.val, step=step)
                    experiment.log_metric(
                        "batch_loss_kinetics", train_loss_kinetics.val, step=step)
                    experiment.log_metric(
                        "batch_accuracy_ucf", train_acc_ucf.val, step=step)
                    experiment.log_metric(
                        "batch_loss_ucf", train_loss_ucf.val, step=step)
                    experiment.log_metric(
                        "total_acc", (train_acc_kinetics.val+train_acc_ucf.val)/2, step=step)
                    experiment.log_metric(
                        "total_loss", (train_loss_kinetics.val+train_loss_ucf.val)/2, step=step)
                    step += 1

                    if step % 300 == 0:
                        """Val mode"""
                        model.eval()
                        val_loss_kinetics = AverageMeter()
                        val_acc_kinetics = AverageMeter()
                        val_loss_ucf = AverageMeter()
                        val_acc_ucf = AverageMeter()

                        with torch.no_grad():
                            for batch_idx, val_batch in enumerate(val_loader_kinetics):
                                inputs = val_batch['video'].to(device)
                                labels = val_batch['label'].to(device)

                                bs = inputs.size(0)

                                val_outputs = model(inputs, "Kinetics")
                                loss = criterion(val_outputs, labels)

                                val_loss_kinetics.update(loss, bs)
                                val_acc_kinetics.update(top1(val_outputs, labels), bs)
                            
                            for batch_idx, val_batch in enumerate(val_loader_ucf):
                                inputs = val_batch['video'].to(device)
                                labels = val_batch['label'].to(device)

                                bs = inputs.size(0)

                                val_outputs = model(inputs, "UCF101")
                                loss = criterion(val_outputs, labels)

                                val_loss_ucf.update(loss, bs)
                                val_acc_ucf.update(top1(val_outputs, labels), bs)

                        experiment.log_metric(
                            "val_accuracy_kinetics", val_acc_kinetics.avg, step=step)
                        experiment.log_metric(
                            "val_loss_kinetics", val_loss_kinetics.avg, step=step)
                        experiment.log_metric(
                            "val_accuracy_ucf", val_acc_ucf.avg, step=step)
                        experiment.log_metric(
                            "val_loss_ucf", val_loss_ucf.avg, step=step)
                        experiment.log_metric(
                            "val_total_acc", (val_acc_kinetics.avg+val_acc_ucf.avg)/2, step=step)
                        experiment.log_metric(
                            "val_total_loss", (val_loss_kinetics.avg+val_loss_ucf.avg)/2, step=step)

                        """finish Val mode"""
                        model.train()


            # """Val mode"""
            # model.eval()
            # val_loss = AverageMeter()
            # val_acc = AverageMeter()

            # with torch.no_grad():
            #     for batch_idx, val_batch in enumerate(val_loader_kinetics):
            #         inputs = val_batch['video'].to(device)
            #         labels = val_batch['label'].to(device)

            #         bs = inputs.size(0)

            #         val_outputs = model(inputs, "Kinetics")
            #         loss = criterion(val_outputs, labels)

            #         val_loss.update(loss, bs)
            #         val_acc.update(top1(val_outputs, labels), bs)
            # """Finish Val mode"""

            # """save model"""
            # if best_acc < val_acc.avg:
            #     best_acc = val_acc.avg
            #     is_best = True
            # else:
            #     is_best = False
                
            # save_checkpoint(model, is_best, filename="checkpoint.pth", best_model_file="best.pth", dir_data_name="model_path")
            

            # pbar_epoch.set_postfix_str(
            #     ' train_loss={:6.04f} , val_loss={:6.04f}, train_acc={:6.04f}, val_acc={:6.04f}'
            #     ''.format(
            #         train_loss.avg,
            #         val_loss.avg,
            #         train_acc.avg,
            #         val_acc.avg)
            # )

            # experiment.log_metric("epoch_train_accuracy",
            #                       train_acc.avg,
            #                       step=epoch + 1)
            # experiment.log_metric("epoch_train_loss",
            #                       train_loss.avg,
            #                       step=epoch + 1)
            # experiment.log_metric("val_accuracy",
            #                       val_acc.avg,
            #                       step=epoch + 1)
            # experiment.log_metric("val_loss",
            #                       val_loss.avg,
            #                       step=epoch + 1)
    experiment.end()


In [18]:
train()

Using cache found in /home/omi/.cache/torch/hub/facebookresearch_pytorchvideo_main
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/kazukiomi/feeature-extract/350193b0795545368c9ee5e75ecdd68d



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7464.0), HTML(value='')))

In [10]:
test_dataset = get_ucf101("val")
test_loader = make_loader(test_dataset)

In [48]:
print(type(test_loader))
print(len(test_loader))

# test_itr = iter(test_loader)
test_enm = enumerate(test_loader)

<class 'torch.utils.data.dataloader.DataLoader'>
29


In [51]:
# test_itr.next()
test_enm = enumerate(test_loader)

i, batch = test_enm.__next__()
print(i)

0


In [52]:
for j in range(28):
    i, batch = test_enm.__next__()
    print(i)

# for j in range(28):
#     batch = test_itr.next()
#     print(j)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28


In [29]:
test_enm = enumerate(test_loader)
i, batch = test_enm.__next__()
print(i)

# test_itr.reset()


In [None]:
def test_loader():
    args = Args()
    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

    train_dataset_ucf = get_ucf101("train")
    val_dataset_ucf = get_ucf101("val")
    train_loader_ucf = make_loader(train_dataset_ucf)
    val_loader_ucf = make_loader(val_dataset_ucf)

    train_dataset_kinetics = get_kinetics("train")
    val_dataset_kinetics = get_kinetics("val")
    train_loader_kinetics = make_loader(train_dataset_kinetics)
    val_loader_kinetics = make_loader(val_dataset_kinetics)

    model = ReconstructNet()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

    lr = 0.01
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=lr,
        momentum=0.9,
        weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()

    num_epochs = args.NUM_EPOCH

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""

            train_loss = AverageMeter()
            train_acc = AverageMeter()

            with tqdm(enumerate(train_loader_ucf),
                      total=len(train_loader_ucf),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, batch in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = batch['video'].to(device)
                    labels = batch['label'].to(device)