/
eval_single_pair.py
101 lines (83 loc) · 3.53 KB
/
eval_single_pair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import torch
from models import create_model_mixamo
from data_preprocess.Mixamo import create_dataset
import parser.parser_mixamo as option_parser
def eval_prepare(args):
character = []
file_id = []
character_names = []
character_names.append(args.input_bvh.split('/')[-2])
character_names.append(args.target_bvh.split('/')[-2])
if args.test_type == 'intra':
if character_names[0].endswith('_m'):
character = [['BigVegas', 'BigVegas'], character_names]
file_id = [[0, 0], [args.input_bvh, args.input_bvh]]
src_id = 1
else:
character = [character_names, ['Goblin_m', 'Goblin_m']]
file_id = [[args.input_bvh, args.input_bvh], [0, 0]]
src_id = 0
elif args.test_type == 'cross':
if character_names[0].endswith('_m'):
character = [[character_names[1]], [character_names[0]]]
file_id = [[0], [args.input_bvh]]
src_id = 1
else:
character = [[character_names[0]], [character_names[1]]]
file_id = [[args.input_bvh], [0]]
src_id = 0
else:
raise Exception('Unknown test type')
return character, file_id, src_id
def recover_space(file):
l = file.split('/')
l[-1] = l[-1].replace('_', ' ')
return '/'.join(l)
def main():
parser = option_parser.get_parser()
parser.add_argument('--input_bvh', type=str, required=True)
parser.add_argument('--target_bvh', type=str, required=True)
parser.add_argument('--test_type', type=str, required=True)
parser.add_argument('--output_filename', type=str, required=True)
parser.add_argument('--model_dir', type=str, required=True)
parser.add_argument('--epoch', type=int, required=True)
args = parser.parse_args()
# argsparse can't take space character as part of the argument
args.input_bvh = recover_space(args.input_bvh)
args.target_bvh = recover_space(args.target_bvh)
args.output_filename = recover_space(args.output_filename)
character_names, file_id, src_id = eval_prepare(args)
input_character_name = args.input_bvh.split('/')[-2]
output_character_name = args.target_bvh.split('/')[-2]
output_filename = args.output_filename
test_device = args.cuda_device
eval_seq = args.eval_seq
epoch = args.epoch
para_path = os.path.join(args.model_dir, 'para.txt')
with open(para_path, 'r') as para_file:
argv_ = para_file.readline().split()[1:]
args = option_parser.get_parser().parse_args(argv_)
args.model = 'pan'
args.cuda_device = test_device if torch.cuda.is_available() else 'cpu'
args.is_train = False
args.rotation = 'quaternion'
args.eval_seq = eval_seq
dataset = create_dataset(args, character_names)
model = create_model_mixamo(args, character_names, dataset)
model.load(epoch=epoch)
input_motion = []
for i, character_group in enumerate(character_names):
input_group = []
for j in range(len(character_group)):
new_motion = dataset.get_item(i, j, file_id[i][j])
new_motion.unsqueeze_(0)
new_motion = (new_motion - dataset.mean[i][j]) / dataset.var[i][j]
input_group.append(new_motion)
input_group = torch.cat(input_group, dim=0)
input_motion.append([input_group, list(range(len(character_group)))])
model.set_input(input_motion)
model.test()
os.system('cp "{}/{}/0_{}.bvh" "./{}"'.format(model.bvh_path, output_character_name, src_id, output_filename))
if __name__ == '__main__':
main()