Skip to content

Commit

Permalink
Updata to Pytorch0.4
Browse files Browse the repository at this point in the history
Add new model AAGCN
  • Loading branch information
lshiwjx committed Dec 17, 2019
1 parent f05d2c0 commit a3795f8
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 56 deletions.
23 changes: 18 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# 2s-AGCN
Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19

# Note

~~PyTorch version should be 0.3! For PyTorch0.4 or higher, the codes need to be modified.~~ \
Now we have updated the code to >=Pytorch0.4. \
A new model named AAGCN is added, which can achieve better performance.

# Data Preparation

- Download the raw data from [NTU-RGB+D][https://github.com/shahroudy/NTURGB-D] and [Skeleton-Kinetics][https://github.com/yysijie/st-gcn]. Then put them under the data directory:
- Download the raw data from [NTU-RGB+D](https://github.com/shahroudy/NTURGB-D) and [Skeleton-Kinetics](https://github.com/yysijie/st-gcn). Then put them under the data directory:

-data\
-kinetics_raw\
Expand Down Expand Up @@ -54,11 +60,18 @@ Then combine the generated scores with:
Please cite the following paper if you use this repository in your reseach.

@inproceedings{2sagcn2019cvpr,
title = {Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition},
author = {Lei Shi and Yifan Zhang and Jian Cheng and Hanqing Lu},
booktitle = {CVPR},
year = {2019},
title = {Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition},
author = {Lei Shi and Yifan Zhang and Jian Cheng and Hanqing Lu},
booktitle = {CVPR},
year = {2019},
}

@article{shi_skeleton-based_2019,
title = {Skeleton-{Based} {Action} {Recognition} with {Multi}-{Stream} {Adaptive} {Graph} {Convolutional} {Networks}},
journal = {arXiv:1912.06971 [cs]},
author = {Shi, Lei and Zhang, Yifan and Cheng, Jian and LU, Hanqing},
month = dec,
year = {2019},
}
# Contact
For any questions, feel free to contact: `lei.shi@nlpr.ia.ac.cn`
107 changes: 68 additions & 39 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@
#!/usr/bin/env python
from __future__ import print_function

import argparse
import inspect
import os
import time
import numpy as np
import yaml
import pickle
import random
import shutil
import time
from collections import OrderedDict

import numpy as np
# torch
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm
from tensorboardX import SummaryWriter
import shutil
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
import random
import inspect
import torch.backends.cudnn as cudnn


class GradualWarmupScheduler(_LRScheduler):
def __init__(self, optimizer, total_epoch, after_scheduler=None):
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
self.last_epoch = -1
super().__init__(optimizer)

def get_lr(self):
return [base_lr * (self.last_epoch + 1) / self.total_epoch for base_lr in self.base_lrs]

def step(self, epoch=None, metric=None):
if self.last_epoch >= self.total_epoch - 1:
if metric is None:
return self.after_scheduler.step(epoch)
else:
return self.after_scheduler.step(metric, epoch)
else:
return super(GradualWarmupScheduler, self).step(epoch)


def init_seed(_):
Expand Down Expand Up @@ -235,11 +258,14 @@ def load_model(self):
[[k.split('module.')[-1],
v.cuda(output_device)] for k, v in weights.items()])

keys = list(weights.keys())
for w in self.arg.ignore_weights:
if weights.pop(w, None) is not None:
self.print_log('Sucessfully Remove Weights: {}.'.format(w))
else:
self.print_log('Can Not Remove Weights: {}.'.format(w))
for key in keys:
if w in key:
if weights.pop(key, None) is not None:
self.print_log('Sucessfully Remove Weights: {}.'.format(key))
else:
self.print_log('Can Not Remove Weights: {}.'.format(key))

try:
self.model.load_state_dict(weights)
Expand Down Expand Up @@ -275,10 +301,12 @@ def load_optimizer(self):
else:
raise ValueError()

self.lr_scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1,
patience=10, verbose=True,
threshold=1e-4, threshold_mode='rel',
cooldown=0)
lr_scheduler_pre = optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=self.arg.step, gamma=0.1)

self.lr_scheduler = GradualWarmupScheduler(self.optimizer, total_epoch=self.arg.warm_up_epoch,
after_scheduler=lr_scheduler_pre)
self.print_log('using warm up, epoch: {}'.format(self.arg.warm_up_epoch))

def save_arg(self):
# save arg
Expand Down Expand Up @@ -370,13 +398,13 @@ def train(self, epoch, save_model=False):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
loss_value.append(loss.data[0])
loss_value.append(loss.data.item())
timer['model'] += self.split_time()

value, predict_label = torch.max(output.data, 1)
acc = torch.mean((predict_label == label.data).float())
self.train_writer.add_scalar('acc', acc, self.global_step)
self.train_writer.add_scalar('loss', loss.data[0], self.global_step)
self.train_writer.add_scalar('loss', loss.data.item(), self.global_step)
self.train_writer.add_scalar('loss_l1', l1, self.global_step)
# self.train_writer.add_scalar('batch_time', process.iterable.last_duration, self.global_step)

Expand Down Expand Up @@ -423,26 +451,27 @@ def eval(self, epoch, save_score=False, loader_name=['test'], wrong_file=None, r
step = 0
process = tqdm(self.data_loader[ln])
for batch_idx, (data, label, index) in enumerate(process):
data = Variable(
data.float().cuda(self.output_device),
requires_grad=False,
volatile=True)
label = Variable(
label.long().cuda(self.output_device),
requires_grad=False,
volatile=True)
output = self.model(data)
if isinstance(output, tuple):
output, l1 = output
l1 = l1.mean()
else:
l1 = 0
loss = self.loss(output, label)
score_frag.append(output.data.cpu().numpy())
loss_value.append(loss.data[0])

_, predict_label = torch.max(output.data, 1)
step += 1
with torch.no_grad():
data = Variable(
data.float().cuda(self.output_device),
requires_grad=False,
volatile=True)
label = Variable(
label.long().cuda(self.output_device),
requires_grad=False,
volatile=True)
output = self.model(data)
if isinstance(output, tuple):
output, l1 = output
l1 = l1.mean()
else:
l1 = 0
loss = self.loss(output, label)
score_frag.append(output.data.cpu().numpy())
loss_value.append(loss.data.item())

_, predict_label = torch.max(output.data, 1)
step += 1

if wrong_file is not None or result_file is not None:
predict = list(predict_label.cpu().numpy())
Expand Down
2 changes: 1 addition & 1 deletion model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import agcn
from . import agcn, aagcn
Loading

0 comments on commit a3795f8

Please sign in to comment.