In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import sys
import os
import torch
import torchvision.transforms as transforms

from visualize import update_config, add_path

lib_path = osp.join('lib')
add_path(lib_path)

import dataset as dataset
from config import cfg
import models
import os
import torchvision.transforms as T

os.environ['CUDA_VISIBLE_DEVICES'] ='3'

file_name = 'experiments/TP_H_w48_256x192_stage3_1_4_d96_h192_relu_enc6_mh1.yaml' # choose a yaml file
f = open(file_name, 'r')
update_config(cfg, file_name)

model_name = 'T-H-A6'
assert model_name in ['T-R', 'T-H','T-H-L','T-R-A4', 'T-H-A6', 'T-H-A5', 'T-H-A4' ,'T-R-A4-DirectAttention']

normalize = T.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

dataset = eval('dataset.' + cfg.DATASET.DATASET)(
    cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
    transforms.Compose([
      transforms.ToTensor(),
      normalize,
    ])
  )


device = torch.device('cuda')
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
    cfg, is_train=True
)

if cfg.TEST.MODEL_FILE:
    print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
    model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
else:
    raise ValueError("please choose one ckpt in cfg.TEST.MODEL_FILE")

model.to(device)
print("model params:{:.3f}M".format(sum([p.numel() for p in model.parameters()])/1000**2))

In [None]:
import numpy as np
from core.inference import get_final_preds
from utils import transforms, vis
import cv2


with torch.no_grad():
    model.eval()
    tmp = []
    tmp2 = []

    img = dataset[0][0]

    inputs = torch.cat([img.to(device)]).unsqueeze(0)
    outputs = model(inputs)
    if isinstance(outputs, list):
        output = outputs[-1]
    else:
        output = outputs
        
    preds, maxvals = get_final_preds(
            cfg, output.clone().cpu().numpy(), None, None, transform_back=False)

# from heatmap_coord to original_image_coord
query_locations = np.array([p*4+0.5 for p in preds[0]])

In [None]:
import math
import torchvision

def save_batch_image_with_joints(batch_image, batch_joints, file_name, nrow=8, padding=2):
    '''
    batch_image: [batch_size, channel, height, width]
    batch_joints: [batch_size, num_joints, 3],
    batch_joints_vis: [batch_size, num_joints, 1],
    }
    '''

    grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)

    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    ndarr = ndarr.copy()

    nmaps = batch_image.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))

    height = int(batch_image.size(1) + padding)
    width = int(batch_image.size(2) + padding)

    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
            joints = batch_joints[k]

            print(batch_joints)

            for idx, item in enumerate(batch_joints):
                joint = batch_joints[idx]
                joint[0] = x * width + padding + joint[0] # 加上 x * width + padding 的作用是：每个 gt.jpg 里的连续四张图片都能够展示对应的 gt 坐标
                joint[1] = y * height + padding + joint[1]

                color = [28, 1, 255]

                if idx >= 0 and idx <= 4:
                    color = [28, 1, 255]
                elif idx >=5 and idx <= 8:
                    color = [22, 128, 0]
                elif idx >= 9 and idx <= 12:
                    color = [62, 255, 255]
                elif idx >= 13 and idx <= 16:
                    color = [254, 255, 0]
                elif idx >= 17 and idx <= 20:
                    color = [203, 192, 255]
                
                cv2.circle(ndarr, (int(joint[0]), int(
                    joint[1])), 1, color, 2)
            k = k + 1
    cv2.imwrite(file_name, ndarr)

In [None]:
filename = '{}_gt.jpg'.format('demo')

save_batch_image_with_joints(img, query_locations, filename)