In [1]:
# Check Pytorch installation
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# Check MMAction2 installation
import mmaction
print(mmaction.__version__)

# Check MMCV installation
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
print(get_compiling_cuda_version())
print(get_compiler_version())

import warnings
warnings.filterwarnings('ignore')
import time
from tqdm.auto import tqdm

1.11.0+cu113 True
0.24.0
11.3
GCC 7.3


### Config

In [2]:
config = '../mmaction2/configs/recognition/tsm/tsm_my_config.py'
checkpoint = 'https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_sthv2_rgb/tsm_r50_256h_1x1x8_50e_sthv2_rgb_20210816-032aa4da.pth'

# config = '../mmaction2/configs/recognition/timesformer/timesformer_divST_8x32x1_15e_kinetics400_rgb.py'
# checkpoint = 'https://download.openmmlab.com/mmaction/recognition/timesformer/timesformer_divST_8x32x1_15e_kinetics400_rgb/timesformer_divST_8x32x1_15e_kinetics400_rgb-3f8e5d03.pth'


In [3]:
from mmcv import Config
import os.path as osp
cfg = Config.fromfile(config)

from mmcv.runner import set_random_seed
DATASET = '/home/umi/projects/WorkoutDetector/data/Binary/'

# Modify dataset type and path
cfg.dataset_type = 'RawframeDataset'
cfg.data_root = DATASET
cfg.data_root_val = DATASET
cfg.ann_file_train = DATASET +'squat-train.txt'
cfg.ann_file_val = DATASET +'squat-val.txt'
cfg.ann_file_test = DATASET +'squat-test.txt'

cfg.data.test.ann_file = cfg.ann_file_test
cfg.data.test.data_prefix = DATASET

cfg.data.train.ann_file = cfg.ann_file_train
cfg.data.train.data_prefix = cfg.data_root

cfg.data.val.ann_file = cfg.ann_file_val
cfg.data.val.data_prefix = cfg.data_root_val

cfg.setdefault('omnisource', False)

cfg.model.cls_head.num_classes = 2

# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.data.videos_per_gpu =  max(1, cfg.data.videos_per_gpu // 8)
cfg.optimizer.lr = cfg.optimizer.lr / 8 / 16
cfg.total_epochs = 10
cfg.load_from = checkpoint
# cfg.resume_from = osp.join(cfg.work_dir, 'latest.pth')
cfg.checkpoint_config.interval = 5

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# Save the best
cfg.evaluation.save_best='auto'

cfg.log_config = dict(
    interval=50,
    hooks=[
        # dict(type='TextLoggerHook'),
        dict(type='TensorboardLoggerHook'),
        dict(type='WandbLoggerHook', init_kwargs=dict(project='Binary-squat', 
                                                      config={**cfg})),
])
cfg.work_dir = f'./work_dirs/tsm_8_binary_squat_{time.strftime("%Y%m%d_%H%M")}'
# print(cfg.pretty_text)


### Train a new recognizer

In [4]:
from mmaction.datasets import build_dataset
from mmaction.models import build_model
from mmaction.apis import train_model

import mmcv

# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the recognizer
model = build_model(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_model(model, datasets, cfg, distributed=False, validate=True)



2022-06-07 19:55:40,166 - mmaction - INFO - These parameters in pretrained checkpoint are not loaded: {'fc.weight', 'fc.bias'}


load checkpoint from torchvision path: torchvision://resnet50


KeyError: 'wandbLoggerHook is not in the hook registry'

### Test the trained recognizer

In [None]:
from mmaction.apis import single_gpu_test
from mmaction.datasets import build_dataloader
from mmcv.parallel import MMDataParallel

# Build a test dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
data_loader = build_dataloader(
        dataset,
        videos_per_gpu=1,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=False,
        shuffle=False)
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)

eval_config = cfg.evaluation
eval_config.pop('interval')
eval_res = dataset.evaluate(outputs, **eval_config)

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 304/304, 12.2 task/s, elapsed: 25s, ETA:     0s
Evaluating top_k_accuracy ...

top1_acc	0.6743
top5_acc	1.0000

Evaluating mean_class_accuracy ...

mean_acc	0.6743
