In [None]:
import os, sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.path.append(os.path.abspath(".."))
from cdbs.funcs import *
from cdbs.network import FlattenAnythingModel
data_root = "../data"
expt_root = "../expt"

In [None]:
model_name = "open__human-hand"
cfg_name = "experiment-01"

save_folder = os.path.join(expt_root, model_name + "__" + cfg_name)
if os.path.exists(save_folder):
    shutil.rmtree(save_folder)
os.mkdir(save_folder)
log_file = os.path.join(save_folder, "log.txt")
if os.path.exists(log_file):
    os.remove(log_file)

In [None]:
net = FlattenAnythingModel().train().cuda()
max_lr, min_lr, num_epc, opt_diff_itv = 1e-3, 1e-5, 10000, 5
optimizer = optim.AdamW(net.parameters(), lr=max_lr, weight_decay=1e-8)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epc, eta_min=min_lr)
L1Loss, L2Loss = nn.L1Loss(), nn.MSELoss()
write_log_itv = (num_epc // 100)

In [None]:
mesh_v, mesh_vn, mesh_f = load_mesh_model_vfn(os.path.join(data_root, "input_meshes", model_name + ".obj"))
pre_samplings = np.loadtxt(os.path.join(data_root, "sampled_points", model_name + ".txt"), dtype="float32", delimiter=";") # (num_pts_pres, 6)
num_pts_pres = pre_samplings.shape[0]
pre_samplings[:, 3:6] = rescale_normals(pre_samplings[:, 3:6], scale=1.0)
pre_samplings = torch.tensor(pre_samplings).unsqueeze(0).float().cuda() # [1, num_pts_pres, 6]

In [None]:
N = 10000 # number of input 3D points at each training iteration
grid_height, grid_width = int(np.sqrt(N)), int(np.sqrt(N))
G = torch.tensor(build_2d_grids(grid_height, grid_width).reshape(-1, 2)).unsqueeze(0).float().cuda() # [1, M, 2]
M = G.size(1) # number of grid 2D points

num_col_samp = 20
collected_samplings = []
for i in tqdm(range(num_col_samp)):
    sampling_base = pre_samplings[:, np.random.choice(num_pts_pres, 100000, replace=False), :]
    collected_samplings.append(index_points(sampling_base, get_fps_idx(sampling_base[:, :, 0:3], N)))
collected_samplings = torch.cat(collected_samplings, dim=0) # [num_col_samp, N, 6]

In [None]:
for epc_idx in tqdm(range(1, num_epc+1)):
    net.zero_grad()
    np.random.seed()
    input_pc = collected_samplings[np.random.choice(num_col_samp), ...].unsqueeze(0) # [1, N, 6]
    P = input_pc[:, :, 0:3] # [1, N, 3], point coordinates at this training iteration
    P_gtn = input_pc[:, :, 3:6] # [1, N, 3], ground-truth point normals at this training iteration
    
    #### forward pass
    P_opened, Q, P_cycle, P_cycle_n, Q_hat, P_hat, P_hat_n, P_hat_opened, Q_hat_cycle = net(G, P)
    Q_normalized = uv_bounding_box_normalization(Q)
    Q_hat_normalized = uv_bounding_box_normalization(Q_hat)
    Q_hat_cycle_normalized = uv_bounding_box_normalization(Q_hat_cycle)

    #### wrapping loss
    L_wrap = chamfer_distance_cuda(P_hat, P)

    #### unwrapping loss
    rep_th = (2 / (np.ceil(np.sqrt(M)) - 1)) * 0.25
    L_unwrap = compute_repulsion_loss(Q_normalized, 8, rep_th) + compute_repulsion_loss(Q_hat_normalized, 8, rep_th) + compute_repulsion_loss(Q_hat_cycle_normalized, 8, rep_th)

    #### cycle consistency on points
    L_cc_p = L1Loss(P, P_cycle) + L1Loss(Q_hat, Q_hat_cycle)
    
    #### cycle consistency on normals
    L_cc_n = compute_normal_cos_sim_loss(P_gtn, P_cycle_n)

    '''
    #### cutting offset bound loss
    # concat_P_opened = torch.cat(P_opened, dim=0) # [num_cuts, N, 3]
    # offset_lengths = ((concat_P_opened ** 2).sum(dim=-1) + 1e-8).sqrt() # [num_cuts, N]
    # max_offset_len = 0.25
    # L_cob = F.relu(offset_lengths - max_offset_len).mean()
    '''
    
    #### overall loss function
    if epc_idx==1 or np.mod(epc_idx, opt_diff_itv) == 0:
        _, e1, e2 = compute_differential_properties(P_cycle, Q)
        L_conf = L1Loss(e1, e2)
        L_isom = (e1 - 1).abs().mean() + (e2 - 1).abs().mean()
        loss = L_wrap + L_unwrap*0.01 + L_cc_p*0.01 + L_cc_n*0.005 + L_conf*0.01 # + L_cob*0.005
    else:
        loss = L_wrap + L_unwrap*0.01 + L_cc_p*0.01 + L_cc_n*0.005 # + L_cob*0.005
    
    curr_lr = optimizer.param_groups[0]["lr"]
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    if np.mod(epc_idx, write_log_itv) == 0:
        info = {}
        write_to = ""
        info["epoch"] = align_number(epc_idx, 6)
        info["L_wrap"] = "%.6f" % L_wrap.item()
        info["L_unwrap"] = "%.6f" % L_unwrap.item()
        info["L_cc_p"] = "%.6f" % L_cc_p.item()
        info["L_cc_n"] = "%.6f" % L_cc_n.item()
        # info["L_cob"] = "%.6f" % L_cob.item()
        info["L_conf"] = "%.6f" % L_conf.item()
        info["L_isom"] = "%.6f" % L_isom.item()
        for info_k, info_v in info.items():
            write_to += info_k + ": " + info_v + " "
        with open(log_file, "a") as f:
            f.writelines(write_to + "\n")
        
        plt.figure(figsize=(2.5, 2.5))
        plt.axis('off')
        plt.scatter(ts2np(Q_hat_normalized.squeeze(0))[:, 0], ts2np(Q_hat_normalized.squeeze(0))[:, 1], s=0.2, c="r")
        plt.show()
        
        plt.figure(figsize=(2.5, 2.5))
        plt.axis('off')
        plt.scatter(ts2np(Q_hat_cycle_normalized.squeeze(0))[:, 0], ts2np(Q_hat_cycle_normalized.squeeze(0))[:, 1], s=0.2, c="g")
        plt.show()
    
        plt.figure(figsize=(2.5, 2.5))
        plt.axis('off')
        plt.scatter(ts2np(Q_normalized.squeeze(0))[:, 0], ts2np(Q_normalized.squeeze(0))[:, 1], s=0.2, c=((ts2np(P_gtn.squeeze(0)) + 1) / 2))
        plt.show()

net.zero_grad()
torch.cuda.empty_cache()

#### save network parameters and outputs
torch.save(net.state_dict(), os.path.join(save_folder, "fam.pth"))

P_cycle_n = F.normalize(P_cycle_n, dim=-1)
P_hat_n = F.normalize(P_hat_n, dim=-1)
save_pc_as_ply(os.path.join(save_folder, "last_epoch__input_pc.ply"), input_pc[0, :, 0:3], normals=input_pc[0, :, 3:6]*0.15)
save_pc_as_ply(os.path.join(save_folder, "last_epoch__P_cycle_with_n.ply"), P_cycle[0], normals=P_cycle_n[0]*0.15)
save_pc_as_ply(os.path.join(save_folder, "last_epoch__P_hat_with_n.ply"), P_hat[0], normals=P_hat_n[0]*0.15)
save_pc_as_ply(os.path.join(save_folder, "last_epoch__P_opened.ply"), P_opened.squeeze(0))
save_pc_as_ply(os.path.join(save_folder, "last_epoch__P_hat_opened.ply"), P_hat_opened.squeeze(0))

plt.figure(figsize=(16, 16))
plt.axis('off')
plt.scatter(ts2np(Q_hat_normalized.squeeze(0))[:, 0], ts2np(Q_hat_normalized.squeeze(0))[:, 1], s=2.0, c="r")
plt.savefig(os.path.join(save_folder, "Q_hat_normalized.png"), dpi=400, bbox_inches="tight")
plt.close()

plt.figure(figsize=(16, 16))
plt.axis('off')
plt.scatter(ts2np(Q_hat_cycle_normalized.squeeze(0))[:, 0], ts2np(Q_hat_cycle_normalized.squeeze(0))[:, 1], s=2.0, c="g")
plt.savefig(os.path.join(save_folder, "Q_hat_cycle_normalized.png"), dpi=400, bbox_inches="tight")
plt.close()

plt.figure(figsize=(16, 16))
plt.axis('off')
plt.scatter(ts2np(Q_normalized.squeeze(0))[:, 0], ts2np(Q_normalized.squeeze(0))[:, 1], s=2.0, c=((ts2np(P_gtn.squeeze(0)) + 1) / 2))
plt.savefig(os.path.join(save_folder, "Q_normalized.png"), dpi=400, bbox_inches="tight")
plt.close()

evaluation

In [None]:
input_pc_eval = pre_samplings # [1, num_pts_pres, 6]
P_eval = input_pc_eval[:, :, 0:3] # [1, num_pts_pres, 3]
P_gtn_eval = input_pc_eval[:, :, 3:6] # [1, num_pts_pres, 3]
grid_height_eval, grid_width_eval = int(np.sqrt(num_pts_pres)), int(np.sqrt(num_pts_pres))
G_eval = torch.tensor(build_2d_grids(grid_height_eval, grid_width_eval).reshape(-1, 2)).unsqueeze(0).float().cuda() # [1, M_eval, 2]
M_eval = G_eval.size(1)
torch.cuda.empty_cache()
with torch.no_grad():
    P_opened_eval, Q_eval, P_cycle_eval, P_cycle_n_eval, Q_hat_eval, P_hat_eval, P_hat_n_eval, P_hat_opened_eval, Q_hat_cycle_eval = net(G_eval, P_eval)
Q_eval_normalized = uv_bounding_box_normalization(Q_eval)
Q_hat_eval_normalized = uv_bounding_box_normalization(Q_hat_eval)
Q_hat_cycle_eval_normalized = uv_bounding_box_normalization(Q_hat_cycle_eval)

P_cycle_n_eval = F.normalize(P_cycle_n_eval, dim=-1) # [1, num_pts_pres, 3]
P_hat_n_eval = F.normalize(P_hat_n_eval, dim=-1) # [1, num_pts_pres, 3]
save_pc_as_ply(os.path.join(save_folder, "P_cycle_eval_with_n.ply"), P_cycle_eval[0], normals=P_cycle_n_eval[0]*0.15)
save_pc_as_ply(os.path.join(save_folder, "P_hat_eval_with_n.ply"), P_hat_eval[0], normals=P_hat_n_eval[0]*0.15)

save_pc_as_ply(os.path.join(save_folder, "P_opened_eval.ply"), P_opened_eval.squeeze(0))
save_pc_as_ply(os.path.join(save_folder, "P_hat_opened_eval.ply"), P_hat_opened_eval.squeeze(0))

is_edge_threshold = 0.02 # smaller threshold -> more edge points
edge_mask_eval = extract_edge_points(P_eval, Q_eval_normalized, 1, is_edge_threshold) # [num_pts_pres]
if edge_mask_eval.sum().item() == 0:
    print('no edge found yet.')
else:
    P_eval_edge = P_eval[:, edge_mask_eval, :] # [1, num_pts_eval_edge, 3]
    num_pts_eval_edge = P_eval_edge.size(1)
    print('[{}] points judged to be on edges.'.format(num_pts_eval_edge))
    save_pc_as_ply(os.path.join(save_folder, "P_eval_edge.ply"), P_eval_edge.squeeze(0))

tex_img_path = os.path.join(data_root, "texture_images", "checker_maps", 'v1_r3.png')
tex_img_resolution = 1024
texture_uv, texture_rgb = load_texture_map(tex_img_path, tex_img_resolution, binarize=True)
texture_uv = torch.tensor(texture_uv).unsqueeze(0).float().cuda() # [1, tex_img_resolution**2, 3]
texture_rgb = torch.tensor(texture_rgb).unsqueeze(0).float().cuda() # [1, tex_img_resolution**2, 3]
matched_rgb = index_points(texture_rgb, knn_search(texture_uv, Q_eval_normalized, 1).squeeze(-1)) # [1, num_pts_pres, 3]
P_eval_textured = torch.cat((P_eval, matched_rgb), dim=-1) # [1, num_pts_pres, 6]
save_pc_as_ply(os.path.join(save_folder, "P_eval_textured_with_n.ply"), P_eval_textured[0, :, 0:3], colors=P_eval_textured[0, :, 3:6], normals=P_gtn_eval[0]*0.15)