In [162]:
import torch
import numpy as np

import plotly
import plotly.graph_objects as go
import numpy as np

heatmap = go.Heatmap(z=[-1, 0, 0],colorscale='RdBu',zmid=0)
def visualize(pos, faces, intensity=None):
  cpu = torch.device("cpu")
  if type(pos) != np.ndarray:
    pos = pos.to(cpu).clone().detach().numpy()
  if pos.shape[-1] != 3:
    raise ValueError("Vertices positions must have shape [n,3]")
  if type(faces) != np.ndarray:
    faces = faces.to(cpu).clone().detach().numpy()
  if faces.shape[-1] != 3:
    raise ValueError("Face indices must have shape [m,3]") 
  if intensity is None:
    intensity = np.ones([pos.shape[0]])
  elif type(intensity) != np.ndarray:
    intensity = intensity.to(cpu).clone().detach().numpy()

  x, z, y = pos.T
  i, j, k = faces.T

  mesh = go.Mesh3d(x=x, y=y, z=z,
            color='lightpink',
            intensity=intensity,
            opacity=1,
            colorscale=[[0, 'gold'],[0.5, 'mediumturquoise'],[1, 'magenta']],
            i=i, j=j, k=k,
            showscale=True)
  layout = go.Layout(scene=go.layout.Scene(aspectmode="data")) 

  #pio.renderers.default="plotly_mimetype"
  fig = go.Figure(data=[mesh],layout=layout)
  fig.update_layout(
      autosize=True,
      margin=dict(l=20, r=20, t=20, b=20))
  fig.show()

def visualize_pointcloud(pos, color, camera=None,figname="fig.png"):
    axis = dict(backgroundcolor="white",title="",gridcolor="white",showbackground=False,zerolinecolor="white",showticklabels=False)
    layout = go.Layout(scene=go.layout.Scene(aspectmode="data",xaxis=axis, yaxis=axis,zaxis=axis,bgcolor="white"))
    
    if isinstance(pos, torch.Tensor):
        x,z,y = pos.t().cpu().detach().numpy()
    elif isinstance(pos, np.ndarray):
        x,z,y = pos.T
    else:
        raise ValueError()
    if isinstance(color,torch.Tensor):
        color = color.cpu().detach().numpy()

    if camera is None:
        camera = dict(up=dict(x=0, y=0, z=1),
                  center=dict(x=0, y=0, z=0),
                  eye=dict(x=1.25, y=1.25, z=1.25))
    
    fig = go.Figure(data=[go.Scatter3d(
      x=x,
      y=y,
      z=z,
      mode='markers',
      marker=dict(
          size=8,
          color=-color,                # set color to an array/list of desired values
          colorscale='RdBu',   # choose a colorscale
          opacity=1
      ))], layout=layout)

    # tight layout
    fig.update_layout(autosize=True,scene_camera=camera)
    fig.show()
    fig.write_image(figname,width=1600,height=1600)
    

def compare(pos1, faces1, pos2, faces2):
    n,m = pos1.shape[0], pos2.shape[0]
    tmpx = torch.cat([pos1, pos2],dim=0)
    tmpf = torch.cat([faces1, faces2+n], dim=0)
    color = torch.zeros([n+m],dtype=pos1.dtype, device=pos1.device)
    color[n:] = (pos1-pos2).norm(p=2,dim=-1)
    visualize(tmpx, tmpf,color)
    
def read_obj(filename):
    vertices = []
    faces = []
    with open(filename, 'r') as file:
        for line in  file :
            fc =  line[0]
            if fc == "#":
                pass
            elif fc == "v":
                vertices += [[float(s) for s in line[1:].strip().split(' ')]]
            elif fc == "f":
                faces += [[int(s) for s in line[1:].strip().split(' ')]]
            else:
                print(fc)
            
    vertices = np.array(vertices, dtype=np.float64)
    faces = np.array(faces, dtype=np.long)
    return vertices, faces-1




In [171]:
from os.path import join
import utils

prefix = "../model_data/GeoA3_qualitative/"

def mc( x1,x2,f):
    mc1,_,_ = utils.meancurvature(x1,f)
    mc2,_,_ = utils.meancurvature(x2,f)
    return (mc1-mc2).abs().sqrt()

def l2(x1,x2,f):
    return (x1-x2).norm(p=2,dim=-1)

new_camera = lambda  c, r, x: dict(up=dict(x=0, y=0, z=1),
                  center=c,eye=dict(x=c["x"]+r*np.sin(x), y=c["y"]+r*np.cos(x), z=1))

center= dict(x=0, y=0, z=0)
camera = new_camera(center ,2,1)

def GeoA3_ours_saveimg(i,camera:dict,title:str=""):
    pos, faces =  read_obj(join(prefix,"subject_{}_original.obj".format(i)))
    ppos_ours, _ = read_obj(join(prefix,"subject_{}_adv_ours.obj".format(i)))
    ppos_geoa3, _ = read_obj(join(prefix,"subject_{}_adv_GeoA3.obj".format(i)))

    pos, faces = torch.tensor(pos), torch.tensor(faces, dtype=torch.long)
    ppos_ours,ppos_geoa3 = torch.tensor(ppos_ours),torch.tensor(ppos_geoa3)

    filename_ours = join(prefix,title+"_ours.png")
    filename_GeoA3 = join(prefix,title+"_GeoA3.png")
    visualize_pointcloud(ppos_ours, color=l2(pos,ppos_ours,faces),camera=camera,figname=filename_ours)
    visualize_pointcloud(ppos_geoa3, color=l2(pos,ppos_geoa3,faces),camera=camera,figname=filename_GeoA3)



In [174]:
head_78 = dict(x=0.01,y=0.38,z=0.0)
foot_29 = dict(x=-0.05,y=0,z=-0.84)
hand_42 = dict(x=0.32,y=-0.30,z=0.37)
camera = new_camera(foot_29, 1, 0)
camera["eye"]["z"] = -0.1
GeoA3_ours_saveimg(i=29,camera=camera,title="feet_29")

camera = new_camera(head_78, 2, 1)
GeoA3_ours_saveimg(i=78,camera=camera,title="head_78")

camera = new_camera(hand_42, 0.75, 3)
GeoA3_ours_saveimg(i=42,camera=camera,title="hand_42")
