In [None]:
import sys
import os
import glob
import pathlib
import torch

import numpy as np
import matplotlib.pyplot as plt
import plotly.io as pio
pio.renderers.default = "vscode"
import plotly.graph_objects as go
from omegaconf import OmegaConf

from equivariant_pose_graph.dataset.point_cloud_data_module import MultiviewDataModule
from equivariant_pose_graph.utils.se3 import random_se3
from equivariant_pose_graph.utils.load_model_utils import load_merged_model
from equivariant_pose_graph.utils.visualizations import plot_all_predictions
from equivariant_pose_graph.training.flow_equivariance_training_module_nocentering_multimodal import EquivarianceTrainingModule, EquivarianceTrainingModule_WithPZCondX
from equivariant_pose_graph.dataset.point_cloud_data_module import MultiviewDataModule
from equivariant_pose_graph.models.transformer_flow import ResidualFlow_DiffEmbTransformer
from equivariant_pose_graph.models.multimodal_transformer_flow import Multimodal_ResidualFlow_DiffEmbTransformer, Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX

torch.cuda.set_device(0)

def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()


def plot_multi_np(plist):
    """
    Args: plist, list of numpy arrays of shape, (1,num_points,3)
    """
    colors = [
        '#1f77b4',  # muted blue
        '#ff7f0e',  # safety orange
        '#2ca02c',  # cooked asparagus green
        '#d62728',  # brick red
        '#9467bd',  # muted purple
        '#e377c2',  # raspberry yogurt pink
        '#8c564b',  # chestnut brown
        '#7f7f7f',  # middle gray
        '#bcbd22',  # curry yellow-green
        '#17becf'   # blue-teal
    ]
    skip = 1
    go_data = []
    for i in range(len(plist)):
        p_dp = toDisplay(torch.from_numpy(plist[i]))
        plot = go.Scatter3d(x=p_dp[::skip,0], y=p_dp[::skip,1], z=p_dp[::skip,2], 
                     mode='markers', marker=dict(size=2, color=colors[i],
                     symbol='circle'))
        go_data.append(plot)
 
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()
    return fig

def get_dm(cfg):
    dm = MultiviewDataModule(
        dataset_root=cfg.dataset_root,
        test_dataset_root=cfg.test_dataset_root,
        dataset_index=cfg.dataset_index,
        action_class=cfg.action_class,
        anchor_class=cfg.anchor_class,
        dataset_size=cfg.dataset_size,
        rotation_variance=np.pi/180 * cfg.rotation_variance,
        translation_variance=cfg.translation_variance,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        cloud_type=cfg.cloud_type,
        num_points=cfg.num_points,
        overfit=cfg.overfit,
        overfit_distractor_aug=cfg.overfit_distractor_aug,
        num_overfit_transforms=cfg.num_overfit_transforms,
        seed_overfit_transforms=cfg.seed_overfit_transforms,
        set_Y_transform_to_identity=cfg.set_Y_transform_to_identity,
        set_Y_transform_to_overfit=cfg.set_Y_transform_to_overfit,
        num_demo=cfg.num_demo,
        synthetic_occlusion=cfg.synthetic_occlusion,
        ball_radius=cfg.ball_radius,
        plane_standoff=cfg.plane_standoff,
        bottom_surface_z_clipping_height=cfg.bottom_surface_z_clipping_height,
        scale_point_clouds=cfg.scale_point_clouds,
        scale_point_clouds_min=cfg.scale_point_clouds_min,
        scale_point_clouds_max=cfg.scale_point_clouds_max,
        distractor_anchor_aug=cfg.distractor_anchor_aug,
        demo_mod_k_range=[cfg.demo_mod_k_range_min, cfg.demo_mod_k_range_max],
        demo_mod_rot_var=cfg.demo_mod_rot_var * np.pi/180,
        demo_mod_trans_var=cfg.demo_mod_trans_var,
        multimodal_transform_base=cfg.multimodal_transform_base,
        action_rot_sample_method=cfg.action_rot_sample_method,
        anchor_rot_sample_method=cfg.anchor_rot_sample_method,
        distractor_rot_sample_method=cfg.distractor_rot_sample_method,
        skip_failed_occlusion=cfg.skip_failed_occlusion,
        min_num_cameras=cfg.min_num_cameras,
        max_num_cameras=cfg.max_num_cameras,
        use_consistent_validation_set=cfg.use_consistent_validation_set,
        use_all_validation_sets=cfg.use_all_validation_sets,
        conval_rotation_variance=np.pi/180 * cfg.conval_rotation_variance,
        conval_translation_variance=cfg.conval_translation_variance,
        conval_synthetic_occlusion=cfg.conval_synthetic_occlusion,
        conval_scale_point_clouds=cfg.conval_scale_point_clouds,
        conval_action_rot_sample_method=cfg.conval_action_rot_sample_method,
        conval_anchor_rot_sample_method=cfg.conval_anchor_rot_sample_method,
        conval_distractor_rot_sample_method=cfg.conval_distractor_rot_sample_method,
        conval_min_num_cameras=cfg.conval_min_num_cameras,
        conval_max_num_cameras=cfg.conval_max_num_cameras,
        conval_downsample_type=cfg.conval_downsample_type,
        conval_gaussian_noise_mu=cfg.conval_gaussian_noise_mu,
        conval_gaussian_noise_std=cfg.conval_gaussian_noise_std,
        use_class_labels=cfg.use_class_labels,
        action_occlusion_class=cfg.action_occlusion_class,
        action_plane_occlusion=cfg.action_plane_occlusion,
        action_ball_occlusion=cfg.action_ball_occlusion,
        action_bottom_surface_occlusion=cfg.action_bottom_surface_occlusion,
        anchor_occlusion_class=cfg.anchor_occlusion_class,
        anchor_plane_occlusion=cfg.anchor_plane_occlusion,
        anchor_ball_occlusion=cfg.anchor_ball_occlusion,
        anchor_bottom_surface_occlusion=cfg.anchor_bottom_surface_occlusion,
        downsample_type=cfg.downsample_type,
        gaussian_noise_mu=cfg.gaussian_noise_mu,
        gaussian_noise_std=cfg.gaussian_noise_std,
    )

    dm.setup()
    return dm

def get_model(cfg):

    TP_input_dims = Multimodal_ResidualFlow_DiffEmbTransformer.TP_INPUT_DIMS[cfg.conditioning]
    if cfg.conditioning in ["latent_z_linear", "hybrid_pos_delta_l2norm"]:
        TP_input_dims += cfg.latent_z_linear_size # Hacky way to add the dynamic latent z to the input dims

    # if cfg.conditioning in ["latent_z_linear"]:
    #     assert not cfg.freeze_embnn and not cfg.freeze_z_embnn and not cfg.freeze_residual_flow, "Probably don't want to freeze the network when training the latent model"
    inner_network = ResidualFlow_DiffEmbTransformer(
        emb_dims=cfg.emb_dims,
        input_dims=TP_input_dims,
        emb_nn=cfg.emb_nn,
        return_flow_component=cfg.return_flow_component,
        center_feature=cfg.center_feature,
        inital_sampling_ratio=cfg.inital_sampling_ratio,
        pred_weight=cfg.pred_weight,
        freeze_embnn=cfg.freeze_embnn,
        conditioning_size=cfg.latent_z_linear_size if cfg.conditioning in ["latent_z_linear_internalcond", "hybrid_pos_delta_l2norm_internalcond", "hybrid_pos_delta_l2norm_global_internalcond"] else 0,
        multilaterate=cfg.multilaterate,
        sample=cfg.mlat_sample,
        mlat_nkps=cfg.mlat_nkps,
        pred_mlat_weight=cfg.pred_mlat_weight,
        conditioning_type=cfg.taxpose_conditioning_type,
        flow_head_use_weighted_sum=cfg.flow_head_use_weighted_sum,
        flow_head_use_selected_point_feature=cfg.flow_head_use_selected_point_feature,
        post_encoder_input_dims=cfg.post_encoder_input_dims,
        flow_direction=cfg.flow_direction,
        )

    network = Multimodal_ResidualFlow_DiffEmbTransformer(
        residualflow_diffembtransformer=inner_network,
        gumbel_temp=cfg.gumbel_temp,
        freeze_residual_flow=cfg.freeze_residual_flow,
        center_feature=cfg.center_feature,
        freeze_z_embnn=cfg.freeze_z_embnn,
        division_smooth_factor=cfg.division_smooth_factor,
        add_smooth_factor=cfg.add_smooth_factor,
        conditioning=cfg.conditioning,
        latent_z_linear_size=cfg.latent_z_linear_size,
        taxpose_centering=cfg.taxpose_centering,
        use_action_z=cfg.use_action_z,
        pzY_encoder_type=cfg.pzY_encoder_type,
        pzY_dropout_goal_emb=cfg.pzY_dropout_goal_emb,
        pzY_transformer=cfg.pzY_transformer,
        pzY_transformer_embnn_dims=cfg.pzY_transformer_embnn_dims,
        pzY_transformer_emb_dims=cfg.pzY_transformer_emb_dims,
        pzY_input_dims=cfg.pzY_input_dims,
        hybrid_cond_logvar_limit=cfg.hybrid_cond_logvar_limit,
    )

    model = EquivarianceTrainingModule(
        network,
        lr=cfg.lr,
        image_log_period=cfg.image_logging_period,
        flow_supervision=cfg.flow_supervision,
        point_loss_type=cfg.point_loss_type,
        action_weight=cfg.action_weight,
        anchor_weight=cfg.anchor_weight,
        displace_weight=cfg.displace_weight,
        consistency_weight=cfg.consistency_weight,
        smoothness_weight=cfg.smoothness_weight,
        rotation_weight=cfg.rotation_weight,
        #latent_weight=cfg.latent_weight,
        weight_normalize=cfg.weight_normalize,
        softmax_temperature=cfg.softmax_temperature,
        vae_reg_loss_weight=cfg.vae_reg_loss_weight,
        sigmoid_on=cfg.sigmoid_on,
        min_err_across_racks_debug=cfg.min_err_across_racks_debug,
        error_mode_2rack=cfg.error_mode_2rack,
        n_samples=cfg.pzY_n_samples,
        get_errors_across_samples=cfg.pzY_get_errors_across_samples,
        use_debug_sampling_methods=cfg.pzY_use_debug_sampling_methods,
        return_flow_component=cfg.return_flow_component,
        plot_encoder_distribution=cfg.plot_encoder_distribution,
        joint_infonce_loss_weight=cfg.pzY_joint_infonce_loss_weight,
        spatial_distance_regularization_type=cfg.spatial_distance_regularization_type,
        spatial_distance_regularization_weight=cfg.spatial_distance_regularization_weight,
        hybrid_cond_regularize_all=cfg.hybrid_cond_regularize_all,
        pzY_taxpose_infonce_loss_weight=cfg.pzY_taxpose_infonce_loss_weight,
        pzY_taxpose_occ_infonce_loss_weight=cfg.pzY_taxpose_occ_infonce_loss_weight,
    )

    if not cfg.pzX_adversarial and not cfg.joint_train_prior and cfg.init_cond_x and (not cfg.freeze_embnn or not cfg.freeze_residual_flow):
        raise ValueError("YOU PROBABLY DIDN'T MEAN TO DO JOINT TRAINING")
    if not cfg.joint_train_prior and cfg.init_cond_x and cfg.checkpoint_file is None:
        raise ValueError("YOU PROBABLY DIDN'T MEAN TO TRAIN BOTH P(Z|X) AND P(Z|Y) FROM SCRATCH")
    
    if cfg.init_cond_x:
        network_cond_x = Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX(
            residualflow_embnn=network,
            encoder_type=cfg.pzcondx_encoder_type,
            shuffle_for_pzX=cfg.shuffle_for_pzX,
            use_action_z=cfg.use_action_z,
            pzX_transformer=cfg.pzX_transformer,
            pzX_transformer_embnn_dims=cfg.pzX_transformer_embnn_dims,
            pzX_transformer_emb_dims=cfg.pzX_transformer_emb_dims,
            pzX_input_dims=cfg.pzX_input_dims,
            pzX_dropout_goal_emb=cfg.pzX_dropout_goal_emb,
            hybrid_cond_pzX_sample_latent=cfg.hybrid_cond_pzX_sample_latent,
        )

        model_cond_x = EquivarianceTrainingModule_WithPZCondX(
            network_cond_x,
            model,
            goal_emb_cond_x_loss_weight=cfg.goal_emb_cond_x_loss_weight,
            joint_train_prior=cfg.joint_train_prior,
            freeze_residual_flow=cfg.freeze_residual_flow,
            freeze_z_embnn=cfg.freeze_z_embnn,
            freeze_embnn=cfg.freeze_embnn,
            n_samples=cfg.pzX_n_samples,
            get_errors_across_samples=cfg.pzX_get_errors_across_samples,
            use_debug_sampling_methods=cfg.pzX_use_debug_sampling_methods,
            plot_encoder_distribution=cfg.plot_encoder_distribution,
            pzX_use_pzY_z_samples=cfg.pzX_use_pzY_z_samples,
            goal_emb_cond_x_loss_type=cfg.goal_emb_cond_x_loss_type,
            joint_infonce_loss_weight=cfg.pzX_joint_infonce_loss_weight,
            spatial_distance_regularization_type=cfg.spatial_distance_regularization_type,
            spatial_distance_regularization_weight=cfg.spatial_distance_regularization_weight,
            overwrite_loss=cfg.pzX_overwrite_loss,
            pzX_adversarial=cfg.pzX_adversarial,
            hybrid_cond_pzX_regularize_type=cfg.hybrid_cond_pzX_regularize_type,
            hybrid_cond_pzX_sample_latent=cfg.hybrid_cond_pzX_sample_latent,
        )

        model_cond_x.cuda()
        model_cond_x.eval()        
    else:
        model.cuda()
        model.eval()
        

    if(cfg.checkpoint_file is not None):
        print("loaded checkpoint from")
        print(cfg.checkpoint_file)
        if not cfg.load_cond_x:
            model.load_state_dict(torch.load(cfg.checkpoint_file)['state_dict'])
            
            if cfg.init_cond_x and cfg. load_pretraining_for_conditioning:
                if cfg.checkpoint_file_action is not None:
                    if model_cond_x.model_with_cond_x.encoder_type == "1_dgcnn":
                        raise NotImplementedError()
                    elif model_cond_x.model_with_cond_x.encoder_type == "2_dgcnn":
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.bn5 = nn.BatchNorm2d(512)
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.load_state_dict(
                            torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                        print("----Action Pretraining for p(z|X) Loaded!----")
                    else:
                        raise ValueError()
                if cfg.checkpoint_file_anchor is not None:
                    if model_cond_x.model_with_cond_x.encoder_type == "1_dgcnn":
                        raise NotImplementedError()
                    elif model_cond_x.model_with_cond_x.encoder_type == "2_dgcnn":
                        print("--Not loading p(z|X) for anchor for now--")
                        pass
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.bn5 = nn.BatchNorm2d(512)
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.load_state_dict(
                        #     torch.load(cfg.checkpoint_file_anchor)['embnn_state_dict'])
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                        # print("--Anchor Pretraining for p(z|X) Loaded!--")
                    else:
                        raise ValueError()
        else:
            model_cond_x.load_state_dict(torch.load(cfg.checkpoint_file)['state_dict'])

    else:
        if cfg.checkpoint_file_action is not None:
            if cfg.load_pretraining_for_taxpose:
                model.model.tax_pose.emb_nn_action.conv1 = nn.Conv2d(3*2, 64, kernel_size=1, bias=False)
                model.model.tax_pose.emb_nn_action.load_state_dict(
                    torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                model.model.tax_pose.emb_nn_action.conv1 = nn.Conv2d(TP_input_dims*2, 64, kernel_size=1, bias=False)
                print(
                '-----------------------Pretrained EmbNN Action Model Loaded!-----------------------')
            if cfg.load_pretraining_for_conditioning:
                if not cfg.init_cond_x:
                    print("---Not Loading p(z|Y) Pretraining For Now---")
                    pass
                    # model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                    # model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(512)
                    # model.model.emb_nn_objs_at_goal.load_state_dict(
                    #         torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                    # model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                    # model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                    # print("----Action Pretraining for p(z|Y) Loaded!----")
            
        if cfg.checkpoint_file_anchor is not None:
            if cfg.load_pretraining_for_taxpose:
                model.model.tax_pose.emb_nn_anchor.conv1 = nn.Conv2d(3*2, 64, kernel_size=1, bias=False)
                model.model.tax_pose.emb_nn_anchor.load_state_dict(
                    torch.load(cfg.checkpoint_file_anchor)['embnn_state_dict'])
                model.model.tax_pose.emb_nn_anchor.conv1 = nn.Conv2d(TP_input_dims*2, 64, kernel_size=1, bias=False)
                print(
                '-----------------------Pretrained EmbNN Anchor Model Loaded!-----------------------')
            if cfg.load_pretraining_for_conditioning:
                if not cfg.init_cond_x:
                    print("---Not Loading p(z|Y) Pretraining For Now---")
                    pass
                    # if cfg.checkpoint_file_action is None:
                    #     model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                    #     model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(512)
                    #     model.model.emb_nn_objs_at_goal.load_state_dict(
                    #             torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                    #     model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                    #     model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                    #     print("----Anchor Pretraining for p(z|Y) Loaded! (because action pretraining is not present)----")

    cfg.return_debug = True
    # model = load_merged_model(pzY_checkpoint_path, pzX_checkpoint_path, cfg.conditioning, True, cfg)
    # model.cuda()

    if cfg.init_cond_x:
        model = model_cond_x
    return model

###########################################################################################################

# cfg = OmegaConf.load("/home/odonca/workspace/rpad/equivariant_pose_graph/configs/experimental/train_pzX-dgcnn-transformer-gradclip1-se3-action-upright-anchor_pzY-pn2_gradclip1e-3_uniformz-action_uniformz-anchor_noreg.yaml")
# cfg.checkpoint_file = '/home/odonca/workspace/rpad/data/equivariant_pose_graph/logs/residual_flow_occlusion/2024-01-22_223338/residual_flow_occlusion/rsacw3qn/checkpoints/epoch_2000_global_step_250000.ckpt'
cfg = OmegaConf.load("/home/odonca/workspace/rpad/equivariant_pose_graph/configs/evalgap/joint_train_pzX-dgcnn-transformer_pzY-pn2_gc1e-3_se3-upright_noreg_flow-fix-one-head_with-cl_synthocc0.8_2rackvariety.yaml")
cfg.checkpoint_file = '/home/odonca/workspace/rpad/data/equivariant_pose_graph/logs/residual_flow_occlusion/2024-02-14_001809/residual_flow_occlusion/n9a6o3sz/checkpoints/epoch_838_global_step_140000.ckpt'
cfg.load_cond_x = True
cfg.batch_size = 1

dm = get_dm(cfg)

val_dataloader = dm.val_dataloader()[0]
val_iter = iter(val_dataloader)

model = get_model(cfg)
model.eval()


In [None]:
from pytorch3d.transforms import Transform3d, Translate, matrix_to_axis_angle, Rotate, random_rotations
from equivariant_pose_graph.utils.se3 import dualflow2pose, flow2pose, get_translation, get_degree_angle, dense_flow_loss, pure_translation_se3

print(val_iter)

# Get the data
batch = next(val_iter)
points_action = batch['points_action'].to(model.device)
points_anchor = batch['points_anchor'].to(model.device)
points_trans_action = batch['points_action_trans'].to(model.device)
points_trans_anchor = batch['points_anchor_trans'].to(model.device)
points_onetrans_action = batch['points_action_onetrans'].to(model.device) if 'points_action_onetrans' in batch else batch['points_action'].to(model.device)
points_onetrans_anchor = batch['points_anchor_onetrans'].to(model.device) if 'points_anchor_onetrans' in batch else batch['points_anchor'].to(model.device)

T0 = Transform3d(matrix=batch['T0']).to(model.device)
T1 = Transform3d(matrix=batch['T1']).to(model.device)

print(f'points_trans_action.shape: {points_trans_action.shape}')
print(f'points_trans_anchor.shape: {points_trans_anchor.shape}')

# Run the model
model_outputs = model.model_with_cond_x(points_trans_action, 
                            points_trans_anchor, 
                            None, 
                            None, 
                            n_samples=1, 
                            sampling_method='gumbel')

all_predicted_points = [points_action[0][:, :3].detach().cpu().numpy(), 
                        T1.inverse().transform_points(points_trans_anchor[:, :, :3])[0].detach().cpu().numpy()]
for model_output in model_outputs:
    x_action = model_output['flow_action']
    x_anchor = model_output['flow_anchor']
    goal_emb = model_output['goal_emb']

    # Get the prediction from the model forward pass
    points_action = points_action[:, :, :3]
    points_anchor = points_anchor[:, :, :3]
    points_trans_action = points_trans_action[:, :, :3]
    points_trans_anchor = points_trans_anchor[:, :, :3]

    if "sampled_ixs_action" in model_outputs[0]:
        ixs_action = model_outputs[0]["sampled_ixs_action"].unsqueeze(-1)
        sampled_points_action = torch.take_along_dim(
            points_action, ixs_action, dim=1
        )
        sampled_points_trans_action = torch.take_along_dim(
            points_trans_action, ixs_action, dim=1
        )
    else:
        sampled_points_action = points_action
        sampled_points_trans_action = points_trans_action

    if "sampled_ixs_anchor" in model_outputs[0]:
        ixs_anchor = model_outputs[0]["sampled_ixs_anchor"].unsqueeze(-1)
        sampled_points_anchor = torch.take_along_dim(
            points_anchor, ixs_anchor, dim=1
        )
        sampled_points_trans_anchor = torch.take_along_dim(
            points_trans_anchor, ixs_anchor, dim=1
        )
    else:
        sampled_points_anchor = points_anchor
        sampled_points_trans_anchor = points_trans_anchor

    pred_flow_action, pred_w_action = model.extract_flow_and_weight(x_action)
    pred_flow_anchor, pred_w_anchor = model.extract_flow_and_weight(x_anchor)

    pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, 
                                xyz_tgt=sampled_points_trans_anchor,
                                flow_src=pred_flow_action, 
                                flow_tgt=pred_flow_anchor,
                                weights_src=pred_w_action, 
                                weights_tgt=pred_w_anchor,
                                return_transform3d=True, 
                                normalization_scehme="softmax",
                                temperature=1)

    pred_points_action = pred_T_action.transform_points(points_trans_action)

    all_predicted_points.append(T1.inverse().transform_points(pred_points_action)[0].detach().cpu().numpy())
    # all_predicted_points.append(pred_points_action[0].detach().cpu().numpy())

plot_multi_np(all_predicted_points)


In [None]:
import glob
# 2 rack arbitrary
# eval_dir = '/home/odonca/workspace/rpad/taxpose_j/logs/eval_mug_arbitrary_10/2024-02-07/11-49-59/pointclouds'
# 2 rack arbitrary
eval_dir = '/home/odonca/workspace/rpad/taxpose_j/logs/eval_mug_arbitrary_10/2024-02-14/13-20-16/pointclouds'

def load_eval_data(path):
    filenames = glob.glob(os.path.join(path, '*_init_obj_points.npz'))
    
    action_points_list = []
    anchor_points_list = []
    
    for filename in filenames:
        data = np.load(filename)
        # action_points.append(data['points_mug_raw'])
        # anchor_points.append(data['points_rack_raw'])
        
        point_data = np.load(filename, allow_pickle=True)
        points_raw_np = point_data['clouds']
        classes_raw_np = point_data['classes']
        points_action_np = points_raw_np[classes_raw_np == 0].copy()
        points_action_mean_np = points_action_np.mean(axis=0)
        points_action_np = points_action_np - points_action_mean_np

        points_anchor_np = points_raw_np[classes_raw_np == 1].copy()
        points_anchor_np = points_anchor_np - points_action_mean_np
        points_anchor_mean_np = points_anchor_np.mean(axis=0)

        points_action = torch.from_numpy(points_action_np).float().unsqueeze(0)
        points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0)
        
        action_points_list.append(points_action)
        anchor_points_list.append(points_anchor)
    
    return action_points_list, anchor_points_list
    
action_points_list, anchor_points_list = load_eval_data(eval_dir)
print(f'len(action_points_list): {len(action_points_list)}')

action_points_iter = iter(action_points_list)
anchor_points_iter = iter(anchor_points_list)

In [None]:
from equivariant_pose_graph.utils.occlusion_utils import ball_occlusion, plane_occlusion, bottom_surface_occlusion
from pytorch3d.ops import sample_farthest_points
from equivariant_pose_graph.utils.visualizations import plot_taxposed_embeddings

action_points = next(action_points_iter)
anchor_points = next(anchor_points_iter)

print(f'action_points.shape: {action_points.shape}')
print(f'anchor_points.shape: {anchor_points.shape}')

if False:
    # action_points = ball_occlusion(action_points[0], radius=cfg.ball_radius).unsqueeze(0)
    # anchor_points = ball_occlusion(anchor_points[0], radius=cfg.ball_radius).unsqueeze(0)
    
    action_points = plane_occlusion(action_points[0], stand_off=cfg.plane_standoff).unsqueeze(0)
    # anchor_points = plane_occlusion(anchor_points[0], stand_off=cfg.plane_standoff).unsqueeze(0)
    
action_points, _ = sample_farthest_points(action_points, K=cfg.num_points, random_start_point=True)
anchor_points, _ = sample_farthest_points(anchor_points, K=cfg.num_points, random_start_point=True)

action_points = action_points.numpy()
anchor_points = anchor_points.numpy()

print(f'action_points.shape: {action_points.shape}')
print(f'anchor_points.shape: {anchor_points.shape}')

action_points = torch.from_numpy(action_points).to(model.device)
anchor_points = torch.from_numpy(anchor_points).to(model.device)

print(action_points.shape)
print(anchor_points.shape)

# Run the model
model_outputs = model.model_with_cond_x(action_points, 
                            anchor_points, 
                            None, 
                            None, 
                            n_samples=1, 
                            sampling_method='gumbel')

all_predicted_points = [action_points[0].detach().cpu().numpy(), 
                        anchor_points[0].detach().cpu().numpy()]
for model_output in model_outputs:
    x_action = model_output['flow_action']
    x_anchor = model_output['flow_anchor']
    goal_emb = model_output['goal_emb']
    goal_emb_cond_x = model_output['goal_emb_cond_x']

    # Get the prediction from the model forward pass
    points_action = points_action[:, :, :3]
    points_anchor = points_anchor[:, :, :3]
    action_points = action_points[:, :, :3]
    anchor_points = anchor_points[:, :, :3]

    if "sampled_ixs_action" in model_outputs[0]:
        ixs_action = model_outputs[0]["sampled_ixs_action"].unsqueeze(-1)
        sampled_points_action = torch.take_along_dim(
            points_action, ixs_action, dim=1
        )
        sampled_action_points = torch.take_along_dim(
            action_points, ixs_action, dim=1
        )
    else:
        sampled_points_action = points_action
        sampled_action_points = action_points

    if "sampled_ixs_anchor" in model_outputs[0]:
        ixs_anchor = model_outputs[0]["sampled_ixs_anchor"].unsqueeze(-1)
        sampled_points_anchor = torch.take_along_dim(
            points_anchor, ixs_anchor, dim=1
        )
        sampled_anchor_points = torch.take_along_dim(
            anchor_points, ixs_anchor, dim=1
        )
    else:
        sampled_points_anchor = points_anchor
        sampled_anchor_points = anchor_points

    pred_flow_action, pred_w_action = model.extract_flow_and_weight(x_action)
    pred_flow_anchor, pred_w_anchor = model.extract_flow_and_weight(x_anchor)

    pred_T_action = dualflow2pose(xyz_src=sampled_action_points, 
                                xyz_tgt=sampled_anchor_points,
                                flow_src=pred_flow_action, 
                                flow_tgt=pred_flow_anchor,
                                weights_src=pred_w_action, 
                                weights_tgt=pred_w_anchor,
                                return_transform3d=True, 
                                normalization_scehme="softmax",
                                temperature=1)

    pred_points_action = pred_T_action.transform_points(action_points)

    # all_predicted_points.append(T1.inverse().transform_points(pred_points_action)[0].detach().cpu().numpy())
    all_predicted_points.append(pred_points_action[0].detach().cpu().numpy())
    
# plot_multi_np(all_predicted_points)
print(f'model_outputs[0].keys(): {model_outputs[0].keys()}')
model_outputs[0]['pred_points_action'] = torch.Tensor(all_predicted_points[-1]).unsqueeze(0)
print(f'pred_points_action.shape: {all_predicted_points[-1].shape}')
plot_taxposed_embeddings(points_action=action_points[:1],
                         points_anchor=anchor_points[:1],
                         ans=model_outputs[0],
                         hydra_cfg=cfg,)
