Skip to content

Commit

Permalink
Update single image demo
Browse files Browse the repository at this point in the history
  • Loading branch information
hassony2 committed Jun 10, 2019
1 parent fb70baa commit fcb83c9
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 92 deletions.
102 changes: 47 additions & 55 deletions image_demo.py
Expand Up @@ -12,105 +12,97 @@

from mano_train.netscripts.reload import reload_model
from mano_train.visualize import displaymano
from mano_train.demo.attention import AttentionHook
from mano_train.demo.preprocess import prepare_input, preprocess_frame


def forward_pass_3d(model, input_image, pred_obj=True):
sample = {}
sample[TransQueries.images] = input_image
sample[BaseQueries.sides] = [args.hand_side]
sample[TransQueries.joints3d] = input_image.new_ones((1, 21, 3)).float()
sample['root'] = 'wrist'
sample["root"] = "wrist"
if pred_obj:
sample[TransQueries.objpoints3d] = input_image.new_ones((1, 600,
3)).float()
sample[TransQueries.objpoints3d] = input_image.new_ones(
(1, 600, 3)
).float()
_, results, _ = model.forward(sample, no_loss=True)

return results


if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--resume',
"--resume",
type=str,
help='Path to checkpoint',
default='release_models/obman/checkpoint.pth.tar')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--hand_side', default='left')
parser.add_argument('--pred_obj', action='store_true')
parser.add_argument('--image_path', help='Path to video')
help="Path to checkpoint",
default="release_models/obman/checkpoint.pth.tar",
)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--hand_side", default="left")
parser.add_argument("--pred_obj", action="store_true")
parser.add_argument(
"--image_path",
help="Path to image",
default="readme_assets/images/can.jpg",
)
parser.add_argument(
'--no_beta', action='store_true', help='Force shape to average')
"--no_beta", action="store_true", help="Force shape to average"
)
parser.add_argument(
'--flip_left_right',
action='store_true',
help='Force shape to average')
"--flip_left_right", action="store_true", help="Force shape to average"
)
args = parser.parse_args()
argutils.print_args(args)

checkpoint = os.path.dirname(args.resume)
with open(os.path.join(checkpoint, 'opt.pkl'), 'rb') as opt_f:
with open(os.path.join(checkpoint, "opt.pkl"), "rb") as opt_f:
opts = pickle.load(opt_f)

# Initialize network
model = reload_model(args.resume, opts, no_beta=args.no_beta)

model.eval()

print('Please input image of {} hand !'.format(args.hand_side))
print(
"Input image is processed flipped and unflipped "
"(as left and right hand), both outputs are displayed"
)

# load faces of hand
with open('misc/mano/MANO_RIGHT.pkl', 'rb') as p_f:
mano_right_data = pickle.load(p_f, encoding='latin1')
faces = mano_right_data['f']

# Add attention map
attention_hand = AttentionHook(model.module.base_net)
if hasattr(model.module, 'atlas_base_net'):
attention_atlas = AttentionHook(model.module.atlas_base_net)
has_atlas_encoder = True
else:
has_atlas_encoder = False
with open("misc/mano/MANO_RIGHT.pkl", "rb") as p_f:
mano_right_data = pickle.load(p_f, encoding="latin1")
faces = mano_right_data["f"]

fig = plt.figure(figsize=(4, 4))
fig.clf()
frame = cv2.imread(args.image_path)
frame = preprocess_frame(frame)
input_image = prepare_input(frame)
cv2.imshow('input', frame)
blend_img_hand = attention_hand.blend_map(frame)
cv2.imshow('attention hand', blend_img_hand)
if has_atlas_encoder:
blend_img_atlas = attention_atlas.blend_map(frame)
cv2.imshow('attention atlas', blend_img_atlas)
cv2.imshow("input", frame)
img = Image.fromarray(frame.copy())
hand_crop = cv2.resize(np.array(img), (256, 256))

noflip_hand_image = prepare_input(
hand_crop, flip_left_right=False)
flip_hand_image = prepare_input(
hand_crop, flip_left_right=True)
noflip_hand_image = prepare_input(hand_crop, flip_left_right=False)
flip_hand_image = prepare_input(hand_crop, flip_left_right=True)
noflip_output = forward_pass_3d(model, noflip_hand_image)
flip_output = forward_pass_3d(model, flip_hand_image)
flip_verts = flip_output['verts'].cpu().detach().numpy()[0]
noflip_verts = noflip_output['verts'].cpu().detach().numpy()[0]
ax = fig.add_subplot(1, 2, 1, projection='3d')
flip_verts = flip_output["verts"].cpu().detach().numpy()[0]
noflip_verts = noflip_output["verts"].cpu().detach().numpy()[0]
ax = fig.add_subplot(1, 2, 1, projection="3d")
ax.title.set_text("unflipped input")
displaymano.add_mesh(ax, flip_verts, faces, flip_x=True)
if 'objpoints3d' in flip_output:
objverts = flip_output['objpoints3d'].cpu().detach().numpy()[0]
if "objpoints3d" in flip_output:
objverts = flip_output["objpoints3d"].cpu().detach().numpy()[0]
displaymano.add_mesh(
ax, objverts, flip_output['objfaces'], flip_x=True, c='r')
ax = fig.add_subplot(1, 2, 2, projection='3d')
ax, objverts, flip_output["objfaces"], flip_x=True, c="r"
)
ax = fig.add_subplot(1, 2, 2, projection="3d")
ax.title.set_text("flipped input")
displaymano.add_mesh(ax, noflip_verts, faces, flip_x=True)
if 'objpoints3d' in noflip_output:
objverts = noflip_output['objpoints3d'].cpu().detach().numpy()[0]
if "objpoints3d" in noflip_output:
objverts = noflip_output["objpoints3d"].cpu().detach().numpy()[0]
displaymano.add_mesh(
ax, objverts, noflip_output['objfaces'], flip_x=True, c='r')
ax, objverts, noflip_output["objfaces"], flip_x=True, c="r"
)
plt.show()
buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
# Captured right hand of user is seen as right (mirror effect)
cv2.imshow('pose estimation', cv2.flip(frame, 1))
cv2.imshow('mesh', buf)
cv2.waitKey(0)
68 changes: 41 additions & 27 deletions mano_train/netscripts/reload.py
Expand Up @@ -18,11 +18,7 @@
from mano_train.netscripts.get_datasets import get_dataset
from mano_train.modelutils import modelio
from mano_train.visualize import visualizemeshes
from mano_train.networks.branches.emutils import batch_emd
from mano_train.networks.branches.contactloss import (
compute_contact_loss,
batch_pairwise_dist,
)
from mano_train.networks.branches.contactloss import compute_contact_loss

from handobjectdatasets.queries import BaseQueries, TransQueries

Expand All @@ -37,15 +33,21 @@ def save_obj(filename, verticies, faces):

def get_opts(resume_checkpoint):
if resume_checkpoint.endswith("tar"):
resume_checkpoint = os.path.join("/", *resume_checkpoint.split("/")[:-1])
resume_checkpoint = os.path.join(
"/", *resume_checkpoint.split("/")[:-1]
)
opt_path = os.path.join(resume_checkpoint, "opt.pkl")
with open(opt_path, "rb") as p_f:
opts = pickle.load(p_f)
return opts


def reload_model(
model_path, checkpoint_opts, mano_root="misc/mano", ico_divisions=3, no_beta=False
model_path,
checkpoint_opts,
mano_root="misc/mano",
ico_divisions=3,
no_beta=False,
):
if "absolute_lambda" not in checkpoint_opts:
checkpoint_opts["absolute_lambda"] = 0
Expand Down Expand Up @@ -109,7 +111,9 @@ def reload_model(
modelio.load_checkpoint(model, resume_path=model_path, strict=True)
except RuntimeError:
traceback.print_exc()
warnings.warn("Couldn' load model in strict mode, trying without strict")
warnings.warn(
"Couldn' load model in strict mode, trying without strict"
)
modelio.load_checkpoint(model, resume_path=model_path, strict=False)
return model

Expand Down Expand Up @@ -187,10 +191,14 @@ def get_samples_score_interval(
):
assert (
interval[0] >= 0 and interval[0] <= 1
), "bounds of interval should be in [0, 1], got lower bound {}".format(interval[0])
), "bounds of interval should be in [0, 1], got lower bound {}".format(
interval[0]
)
assert (
interval[1] >= 0 and interval[1] <= 1
), "bounds of interval should be in [0, 1], got upper bound {}".format(interval[1])
), "bounds of interval should be in [0, 1], got upper bound {}".format(
interval[1]
)
assert (
interval[0] < interval[1]
), "Lower bound {} should be lower then upper bound {}".format(
Expand All @@ -200,7 +208,9 @@ def get_samples_score_interval(
upper_idx = math.ceil(interval[1] * len(sorted_losses))
selected_losses = sorted_losses[lower_idx:upper_idx]

selected_samples = [sorted_samples[idx] for idx in range(lower_idx, upper_idx)]
selected_samples = [
sorted_samples[idx] for idx in range(lower_idx, upper_idx)
]
if reverse:
selected_losses = list(reversed(selected_losses))
selected_samples = list(reversed(selected_samples))
Expand Down Expand Up @@ -279,7 +289,6 @@ def show_meshes(
save_root="/sequoia/data2/yhasson/code/mano_train/data/results",
show_contacts=False,
show_gt=True,
get_emd=False,
show_losses=[
"mano_verts3d",
"mano_joints3d",
Expand All @@ -294,8 +303,6 @@ def show_meshes(
loader(ConcatDataloader): dataloader
model: trained neural network
"""
if get_emd:
show_losses.append("atlas_emd")
renderers = []
if isinstance(model, (list, tuple)):
models = model
Expand All @@ -315,11 +322,6 @@ def show_meshes(
_, results, losses = model.forward(
sample, no_loss=False, force_objects=force_objects
)
if get_emd:
emd_vals = batch_emd(
sample[TransQueries.objpoints3d], results["objpoints3d"]
)
losses["atlas_emd"] = np.mean(emd_vals)
if model_idx == 0:
filter_losses = OrderedDict(
(loss_name, [loss_val.item()])
Expand Down Expand Up @@ -437,7 +439,9 @@ def render_mesh(
)
save_img_folder = os.path.join(save_model_root, "images")
os.makedirs(save_img_folder, exist_ok=True)
save_img_path = os.path.join(save_img_folder, "{:08d}.png".format(sample_idx))
save_img_path = os.path.join(
save_img_folder, "{:08d}.png".format(sample_idx)
)
img = (sample[TransQueries.images][0] + 0.5).permute(1, 2, 0)
scipy.misc.toimage(img, cmin=0, cmax=1).save(save_img_path)
scale = 0.001
Expand All @@ -447,11 +451,15 @@ def render_mesh(
save_path = os.path.join(
save_pkl_folder, "mesh_penetr_{:04d}.pkl".format(sample_idx)
)
save_meshes_dict(save_path, pred_points, pred_faces, pred_verts, mano_faces)
save_meshes_dict(
save_path, pred_points, pred_faces, pred_verts, mano_faces
)
save_obj_path = os.path.join(
save_obj_folder, "{:08d}_obj.obj".format(sample_idx)
)
obj_mesh = trimesh.load({"vertices": pred_points, "faces": pred_faces})
obj_mesh = trimesh.load(
{"vertices": pred_points, "faces": pred_faces}
)
trimesh.repair.fix_normals(obj_mesh)
obj_verts = np.array(obj_mesh.vertices)
obj_faces = np.array(obj_mesh.faces)
Expand All @@ -477,7 +485,7 @@ def render_mesh(
scene_children_base = [c, p3js.AmbientLight(intensity=0.4)]
scene_children = scene_children_base + hand_obj_children
if show_contacts:
missed_loss, penetr_loss, contact_infos, metrics = compute_contact_loss(
miss_loss, pen_loss, contact_infos, metrics = compute_contact_loss(
torch.Tensor(pred_verts).unsqueeze(0).cuda(),
mano_faces,
torch.Tensor(pred_points).unsqueeze(0).cuda(),
Expand All @@ -494,8 +502,12 @@ def render_mesh(
penetrating_close_verts = (
all_close_matches[0][all_penetr_masks[0]].cpu().numpy()
)
missed_verts = torch.Tensor(pred_verts)[all_missed_masks[0]].cpu().numpy()
missed_close_verts = all_close_matches[0][all_missed_masks[0]].cpu().numpy()
missed_verts = (
torch.Tensor(pred_verts)[all_missed_masks[0]].cpu().numpy()
)
missed_close_verts = (
all_close_matches[0][all_missed_masks[0]].cpu().numpy()
)

attraction_lines_children = visualizemeshes.lines_children(
missed_verts, missed_close_verts, color="green"
Expand All @@ -504,7 +516,9 @@ def render_mesh(
penetrating_verts, penetrating_close_verts, color="orange"
)
scene_children = (
scene_children + attraction_lines_children + repulsion_lines_children
scene_children
+ attraction_lines_children
+ repulsion_lines_children
)
if show_gt:
if TransQueries.joints3d in sample:
Expand All @@ -522,7 +536,7 @@ def render_mesh(
renderer = p3js.Renderer(
camera=c, scene=scene, controls=[controls], width=400, height=400
)
display(renderer)
# display(renderer)
else:
renderer = None
return renderer
12 changes: 4 additions & 8 deletions mano_train/netscripts/simulate.py
Expand Up @@ -18,16 +18,12 @@ def full_simul(
sample_step=1,
workers=8,
cluster=False,
vhacd_exe=None,
):
if cluster:
vhacd_exe = "/sequoia/data1/yhasson/tools/"
"v-hacd/build/linux/test/testVHACD"
else:
vhacd_exe = (
"/sequoia/data1/yhasson/code/pose_3d/mano_train/thirdparty/"
)
"v-hacd/build/linux/test/testVHACD"
assert os.path.exists(exp_id), "{} does not exists!".format(exp_id)
assert os.path.exists(vhacd_exe), (
f"VHACD executable {vhacd_exe}" "does not exists!"
)
save_pickles = sorted(
[
os.path.join(exp_id, filename)
Expand Down
20 changes: 18 additions & 2 deletions mano_train/options/simulopts.py
@@ -1,4 +1,7 @@
def add_simul_opts(parser):
"""
Options for the physical simulation of object-in-hand stability
"""
parser.add_argument(
"--wait_time", default=0, type=float, help="Wait time for simulation"
)
Expand All @@ -7,12 +10,25 @@ def add_simul_opts(parser):
"--batch_step", default=1, type=int, help="Step between batches"
)
parser.add_argument(
"--sample_step", default=1, type=int, help="Step between samples in batch"
"--sample_step",
default=1,
type=int,
help="Step between samples in batch",
)
parser.add_argument(
"--workers", default=8, type=int, help="Step between samples in batch"
)
parser.add_argument(
"--sample_vis_freq", default=100, type=int, help="Step between samples in batch"
"--sample_vis_freq",
default=100,
type=int,
help="Step between samples in batch",
)
parser.add_argument("--cluster", action="store_true")
parser.add_argument(
"--vhacd_exe",
default=(
"/sequoia/data1/yhasson/tools/v-hacd/" "build/linux/test/testVHACD"
),
help="Path to VHACD executable",
)
2 changes: 2 additions & 0 deletions mano_train/simulation/simulate.py
Expand Up @@ -206,6 +206,8 @@ def run_simulation(
"Cannot compute convex hull "
"decomposition for {}".format(obj_tmp_fname)
)
else:
print(f"Succeeded vhacd decomp of {obj_tmp_fname}")

obj_collision_id = p.createCollisionShape(
p.GEOM_MESH, fileName=obj_tmp_fname, physicsClientId=conn_id
Expand Down
Binary file added readme_assets/images/can.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit fcb83c9

Please sign in to comment.