In [1]:
import sys
import os
import glob
import pathlib
import torch

import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from omegaconf import OmegaConf

from equivariant_pose_graph.dataset.point_cloud_data_module import MultiviewDataModule
from equivariant_pose_graph.utils.se3 import random_se3
from equivariant_pose_graph.utils.load_model_utils import load_merged_model
from equivariant_pose_graph.utils.visualizations import plot_all_predictions
from equivariant_pose_graph.training.flow_equivariance_training_module_nocentering_multimodal import EquivarianceTrainingModule, EquivarianceTrainingModule_WithPZCondX
from equivariant_pose_graph.dataset.point_cloud_data_module import MultiviewDataModule
from equivariant_pose_graph.models.transformer_flow import ResidualFlow_DiffEmbTransformer, AlignedFrameDecoder
from equivariant_pose_graph.models.multimodal_transformer_flow import Multimodal_ResidualFlow_DiffEmbTransformer, Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX

torch.cuda.set_device(1)

def toDisplay(x, target_dim = 2):
    while(x.dim() > target_dim):
        x = x[0]
    return x.detach().cpu().numpy()


def plot_multi_np(plist):
    """
    Args: plist, list of numpy arrays of shape, (1,num_points,3)
    """
    colors = [
        '#1f77b4',  # muted blue
        '#ff7f0e',  # safety orange
        '#2ca02c',  # cooked asparagus green
        '#d62728',  # brick red
        '#9467bd',  # muted purple
        '#e377c2',  # raspberry yogurt pink
        '#8c564b',  # chestnut brown
        '#7f7f7f',  # middle gray
        '#bcbd22',  # curry yellow-green
        '#17becf'   # blue-teal
    ]
    skip = 1
    go_data = []
    for i in range(len(plist)):
        p_dp = toDisplay(torch.from_numpy(plist[i]))
        plot = go.Scatter3d(x=p_dp[::skip,0], y=p_dp[::skip,1], z=p_dp[::skip,2], 
                     mode='markers', marker=dict(size=2, color=colors[i],
                     symbol='circle'))
        go_data.append(plot)
 
    layout = go.Layout(
        scene=dict(
            aspectmode='data'
        )
    )

    fig = go.Figure(data=go_data, layout=layout)
    fig.show()
    return fig

def get_dm(cfg):
    dm = MultiviewDataModule(
        dataset_root=cfg.dataset_root,
        test_dataset_root=cfg.test_dataset_root,
        dataset_index=cfg.dataset_index,
        action_class=cfg.action_class,
        anchor_class=cfg.anchor_class,
        dataset_size=cfg.dataset_size,
        rotation_variance=np.pi/180 * cfg.rotation_variance,
        translation_variance=cfg.translation_variance,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        cloud_type=cfg.cloud_type,
        num_points=cfg.num_points,
        overfit=cfg.overfit,
        overfit_distractor_aug=cfg.overfit_distractor_aug,
        num_overfit_transforms=cfg.num_overfit_transforms,
        seed_overfit_transforms=cfg.seed_overfit_transforms,
        set_Y_transform_to_identity=cfg.set_Y_transform_to_identity,
        set_Y_transform_to_overfit=cfg.set_Y_transform_to_overfit,
        num_demo=cfg.num_demo,
        synthetic_occlusion=cfg.synthetic_occlusion,
        ball_radius=cfg.ball_radius,
        plane_standoff=cfg.plane_standoff,
        bottom_surface_z_clipping_height=cfg.bottom_surface_z_clipping_height,
        scale_point_clouds=cfg.scale_point_clouds,
        scale_point_clouds_min=cfg.scale_point_clouds_min,
        scale_point_clouds_max=cfg.scale_point_clouds_max,
        distractor_anchor_aug=cfg.distractor_anchor_aug,
        demo_mod_k_range=[cfg.demo_mod_k_range_min, cfg.demo_mod_k_range_max],
        demo_mod_rot_var=cfg.demo_mod_rot_var * np.pi/180,
        demo_mod_trans_var=cfg.demo_mod_trans_var,
        multimodal_transform_base=cfg.multimodal_transform_base,
        action_rot_sample_method=cfg.action_rot_sample_method,
        anchor_rot_sample_method=cfg.anchor_rot_sample_method,
        distractor_rot_sample_method=cfg.distractor_rot_sample_method,
        skip_failed_occlusion=cfg.skip_failed_occlusion,
        min_num_cameras=cfg.min_num_cameras,
        max_num_cameras=cfg.max_num_cameras,
        use_consistent_validation_set=cfg.use_consistent_validation_set,
        use_all_validation_sets=cfg.use_all_validation_sets,
        conval_rotation_variance=np.pi/180 * cfg.conval_rotation_variance,
        conval_translation_variance=cfg.conval_translation_variance,
        conval_synthetic_occlusion=cfg.conval_synthetic_occlusion,
        conval_scale_point_clouds=cfg.conval_scale_point_clouds,
        conval_action_rot_sample_method=cfg.conval_action_rot_sample_method,
        conval_anchor_rot_sample_method=cfg.conval_anchor_rot_sample_method,
        conval_distractor_rot_sample_method=cfg.conval_distractor_rot_sample_method,
        conval_min_num_cameras=cfg.conval_min_num_cameras,
        conval_max_num_cameras=cfg.conval_max_num_cameras,
        conval_downsample_type=cfg.conval_downsample_type,
        conval_gaussian_noise_mu=cfg.conval_gaussian_noise_mu,
        conval_gaussian_noise_std=cfg.conval_gaussian_noise_std,
        use_class_labels=cfg.use_class_labels,
        action_occlusion_class=cfg.action_occlusion_class,
        action_plane_occlusion=cfg.action_plane_occlusion,
        action_ball_occlusion=cfg.action_ball_occlusion,
        action_bottom_surface_occlusion=cfg.action_bottom_surface_occlusion,
        anchor_occlusion_class=cfg.anchor_occlusion_class,
        anchor_plane_occlusion=cfg.anchor_plane_occlusion,
        anchor_ball_occlusion=cfg.anchor_ball_occlusion,
        anchor_bottom_surface_occlusion=cfg.anchor_bottom_surface_occlusion,
        downsample_type=cfg.downsample_type,
        gaussian_noise_mu=cfg.gaussian_noise_mu,
        gaussian_noise_std=cfg.gaussian_noise_std,
        return_rpdiff_mesh_files=cfg.compute_rpdiff_min_errors,
    )

    dm.setup()
    return dm

def get_model(cfg):

    TP_input_dims = Multimodal_ResidualFlow_DiffEmbTransformer.TP_INPUT_DIMS[cfg.conditioning]
    if cfg.conditioning in ["latent_z_linear", "hybrid_pos_delta_l2norm", "hybrid_pos_delta_l2norm_global"]:
        TP_input_dims += cfg.latent_z_linear_size # Hacky way to add the dynamic latent z to the input dims

    # if cfg.conditioning in ["latent_z_linear"]:
    #     assert not cfg.freeze_embnn and not cfg.freeze_z_embnn and not cfg.freeze_residual_flow, "Probably don't want to freeze the network when training the latent model"
    if cfg.decoder_type == "taxpose":
        inner_network = ResidualFlow_DiffEmbTransformer(
            emb_dims=cfg.emb_dims,
            input_dims=TP_input_dims,
            emb_nn=cfg.emb_nn,
            return_flow_component=cfg.return_flow_component,
            center_feature=cfg.center_feature,
            inital_sampling_ratio=cfg.inital_sampling_ratio,
            pred_weight=cfg.pred_weight,
            freeze_embnn=cfg.freeze_embnn,
            conditioning_size=cfg.latent_z_linear_size if cfg.conditioning in ["latent_z_linear_internalcond", "hybrid_pos_delta_l2norm_internalcond", "hybrid_pos_delta_l2norm_global_internalcond"] else 0,
            multilaterate=cfg.multilaterate,
            sample=cfg.mlat_sample,
            mlat_nkps=cfg.mlat_nkps,
            pred_mlat_weight=cfg.pred_mlat_weight,
            conditioning_type=cfg.taxpose_conditioning_type,
            flow_head_use_weighted_sum=cfg.flow_head_use_weighted_sum,
            flow_head_use_selected_point_feature=cfg.flow_head_use_selected_point_feature,
            post_encoder_input_dims=cfg.post_encoder_input_dims,
            flow_direction=cfg.flow_direction,
            )
    elif cfg.decoder_type in ["flow", "point"]:
        inner_network = AlignedFrameDecoder(
            emb_dims=cfg.emb_dims,
            input_dims=TP_input_dims,
            flow_direction=cfg.flow_direction,
            head_output_type=cfg.decoder_type,   
            flow_frame=cfg.flow_frame,        
        )

    network = Multimodal_ResidualFlow_DiffEmbTransformer(
        residualflow_diffembtransformer=inner_network,
        gumbel_temp=cfg.gumbel_temp,
        freeze_residual_flow=cfg.freeze_residual_flow,
        center_feature=cfg.center_feature,
        freeze_z_embnn=cfg.freeze_z_embnn,
        division_smooth_factor=cfg.division_smooth_factor,
        add_smooth_factor=cfg.add_smooth_factor,
        conditioning=cfg.conditioning,
        latent_z_linear_size=cfg.latent_z_linear_size,
        taxpose_centering=cfg.taxpose_centering,
        use_action_z=cfg.use_action_z,
        pzY_encoder_type=cfg.pzY_encoder_type,
        pzY_dropout_goal_emb=cfg.pzY_dropout_goal_emb,
        pzY_transformer=cfg.pzY_transformer,
        pzY_transformer_embnn_dims=cfg.pzY_transformer_embnn_dims,
        pzY_transformer_emb_dims=cfg.pzY_transformer_emb_dims,
        pzY_input_dims=cfg.pzY_input_dims,
        pzY_embedding_routine=cfg.pzY_embedding_routine,
        pzY_embedding_option=cfg.pzY_embedding_option,
        hybrid_cond_logvar_limit=cfg.hybrid_cond_logvar_limit,
        latent_z_cond_logvar_limit=cfg.latent_z_cond_logvar_limit,
    )

    model = EquivarianceTrainingModule(
        network,
        lr=cfg.lr,
        image_log_period=cfg.image_logging_period,
        flow_supervision=cfg.flow_supervision,
        point_loss_type=cfg.point_loss_type,
        action_weight=cfg.action_weight,
        anchor_weight=cfg.anchor_weight,
        displace_weight=cfg.displace_weight,
        consistency_weight=cfg.consistency_weight,
        smoothness_weight=cfg.smoothness_weight,
        rotation_weight=cfg.rotation_weight,
        #latent_weight=cfg.latent_weight,
        weight_normalize=cfg.weight_normalize,
        softmax_temperature=cfg.softmax_temperature,
        vae_reg_loss_weight=cfg.vae_reg_loss_weight,
        sigmoid_on=cfg.sigmoid_on,
        min_err_across_racks_debug=cfg.min_err_across_racks_debug,
        error_mode_2rack=cfg.error_mode_2rack,
        n_samples=cfg.pzY_n_samples,
        get_errors_across_samples=cfg.pzY_get_errors_across_samples,
        use_debug_sampling_methods=cfg.pzY_use_debug_sampling_methods,
        return_flow_component=cfg.return_flow_component,
        plot_encoder_distribution=cfg.plot_encoder_distribution,
        joint_infonce_loss_weight=cfg.pzY_joint_infonce_loss_weight,
        spatial_distance_regularization_type=cfg.spatial_distance_regularization_type,
        spatial_distance_regularization_weight=cfg.spatial_distance_regularization_weight,
        hybrid_cond_regularize_all=cfg.hybrid_cond_regularize_all,
        pzY_taxpose_infonce_loss_weight=cfg.pzY_taxpose_infonce_loss_weight,
        pzY_taxpose_occ_infonce_loss_weight=cfg.pzY_taxpose_occ_infonce_loss_weight,
        decoder_type=cfg.decoder_type,
        flow_frame=cfg.flow_frame,
        compute_rpdiff_min_errors=cfg.compute_rpdiff_min_errors,
        rpdiff_descriptions_path=cfg.rpdiff_descriptions_path,
    )

    if not cfg.pzX_adversarial and not cfg.joint_train_prior and cfg.init_cond_x and (not cfg.freeze_embnn or not cfg.freeze_residual_flow):
        raise ValueError("YOU PROBABLY DIDN'T MEAN TO DO JOINT TRAINING")
    if not cfg.joint_train_prior and cfg.init_cond_x and cfg.checkpoint_file is None:
        raise ValueError("YOU PROBABLY DIDN'T MEAN TO TRAIN BOTH P(Z|X) AND P(Z|Y) FROM SCRATCH")
    
    if cfg.init_cond_x:
        network_cond_x = Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX(
            residualflow_embnn=network,
            encoder_type=cfg.pzcondx_encoder_type,
            shuffle_for_pzX=cfg.shuffle_for_pzX,
            use_action_z=cfg.use_action_z,
            pzX_transformer=cfg.pzX_transformer,
            pzX_transformer_embnn_dims=cfg.pzX_transformer_embnn_dims,
            pzX_transformer_emb_dims=cfg.pzX_transformer_emb_dims,
            pzX_input_dims=cfg.pzX_input_dims,
            pzX_dropout_goal_emb=cfg.pzX_dropout_goal_emb,
            hybrid_cond_pzX_sample_latent=cfg.hybrid_cond_pzX_sample_latent,
            pzX_embedding_routine=cfg.pzX_embedding_routine,
            pzX_embedding_option=cfg.pzX_embedding_option,
        )

        model_cond_x = EquivarianceTrainingModule_WithPZCondX(
            network_cond_x,
            model,
            goal_emb_cond_x_loss_weight=cfg.goal_emb_cond_x_loss_weight,
            joint_train_prior=cfg.joint_train_prior,
            freeze_residual_flow=cfg.freeze_residual_flow,
            freeze_z_embnn=cfg.freeze_z_embnn,
            freeze_embnn=cfg.freeze_embnn,
            n_samples=cfg.pzX_n_samples,
            get_errors_across_samples=cfg.pzX_get_errors_across_samples,
            use_debug_sampling_methods=cfg.pzX_use_debug_sampling_methods,
            plot_encoder_distribution=cfg.plot_encoder_distribution,
            pzX_use_pzY_z_samples=cfg.pzX_use_pzY_z_samples,
            goal_emb_cond_x_loss_type=cfg.goal_emb_cond_x_loss_type,
            joint_infonce_loss_weight=cfg.pzX_joint_infonce_loss_weight,
            spatial_distance_regularization_type=cfg.spatial_distance_regularization_type,
            spatial_distance_regularization_weight=cfg.spatial_distance_regularization_weight,
            overwrite_loss=cfg.pzX_overwrite_loss,
            pzX_adversarial=cfg.pzX_adversarial,
            hybrid_cond_pzX_regularize_type=cfg.hybrid_cond_pzX_regularize_type,
            hybrid_cond_pzX_sample_latent=cfg.hybrid_cond_pzX_sample_latent,
        )

        model_cond_x.cuda()
        model_cond_x.train()        
    else:
        model.cuda()
        model.train()
        

    if(cfg.checkpoint_file is not None):
        print("loaded checkpoint from")
        print(cfg.checkpoint_file)
        if not cfg.load_cond_x:
            model.load_state_dict(torch.load(cfg.checkpoint_file)['state_dict'])
            
            if cfg.init_cond_x and cfg. load_pretraining_for_conditioning:
                if cfg.checkpoint_file_action is not None:
                    if model_cond_x.model_with_cond_x.encoder_type == "1_dgcnn":
                        raise NotImplementedError()
                    elif model_cond_x.model_with_cond_x.encoder_type == "2_dgcnn":
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.bn5 = nn.BatchNorm2d(512)
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.load_state_dict(
                            torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                        model_cond_x.model_with_cond_x.p_z_cond_x_embnn_action.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                        print("----Action Pretraining for p(z|X) Loaded!----")
                    else:
                        raise ValueError()
                if cfg.checkpoint_file_anchor is not None:
                    if model_cond_x.model_with_cond_x.encoder_type == "1_dgcnn":
                        raise NotImplementedError()
                    elif model_cond_x.model_with_cond_x.encoder_type == "2_dgcnn":
                        print("--Not loading p(z|X) for anchor for now--")
                        pass
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.bn5 = nn.BatchNorm2d(512)
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.load_state_dict(
                        #     torch.load(cfg.checkpoint_file_anchor)['embnn_state_dict'])
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                        # model_cond_x.model_with_cond_x.p_z_cond_x_embnn_anchor.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                        # print("--Anchor Pretraining for p(z|X) Loaded!--")
                    else:
                        raise ValueError()
        else:
            model_cond_x.load_state_dict(torch.load(cfg.checkpoint_file)['state_dict'])

    else:
        if cfg.checkpoint_file_action is not None:
            if cfg.load_pretraining_for_taxpose:
                model.model.tax_pose.emb_nn_action.conv1 = nn.Conv2d(3*2, 64, kernel_size=1, bias=False)
                model.model.tax_pose.emb_nn_action.load_state_dict(
                    torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                model.model.tax_pose.emb_nn_action.conv1 = nn.Conv2d(TP_input_dims*2, 64, kernel_size=1, bias=False)
                print(
                '-----------------------Pretrained EmbNN Action Model Loaded!-----------------------')
            if cfg.load_pretraining_for_conditioning:
                if not cfg.init_cond_x:
                    print("---Not Loading p(z|Y) Pretraining For Now---")
                    pass
                    # model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                    # model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(512)
                    # model.model.emb_nn_objs_at_goal.load_state_dict(
                    #         torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                    # model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                    # model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                    # print("----Action Pretraining for p(z|Y) Loaded!----")
            
        if cfg.checkpoint_file_anchor is not None:
            if cfg.load_pretraining_for_taxpose:
                model.model.tax_pose.emb_nn_anchor.conv1 = nn.Conv2d(3*2, 64, kernel_size=1, bias=False)
                model.model.tax_pose.emb_nn_anchor.load_state_dict(
                    torch.load(cfg.checkpoint_file_anchor)['embnn_state_dict'])
                model.model.tax_pose.emb_nn_anchor.conv1 = nn.Conv2d(TP_input_dims*2, 64, kernel_size=1, bias=False)
                print(
                '-----------------------Pretrained EmbNN Anchor Model Loaded!-----------------------')
            if cfg.load_pretraining_for_conditioning:
                if not cfg.init_cond_x:
                    print("---Not Loading p(z|Y) Pretraining For Now---")
                    pass
                    # if cfg.checkpoint_file_action is None:
                    #     model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, 512, kernel_size=1, bias=False)
                    #     model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(512)
                    #     model.model.emb_nn_objs_at_goal.load_state_dict(
                    #             torch.load(cfg.checkpoint_file_action)['embnn_state_dict'])
                    #     model.model.emb_nn_objs_at_goal.conv5 = nn.Conv2d(512, TP_input_dims-3, kernel_size=1, bias=False)
                    #     model.model.emb_nn_objs_at_goal.bn5 = nn.BatchNorm2d(TP_input_dims-3)
                    #     print("----Anchor Pretraining for p(z|Y) Loaded! (because action pretraining is not present)----")

    cfg.return_debug = True
    # model = load_merged_model(pzY_checkpoint_path, pzX_checkpoint_path, cfg.conditioning, True, cfg)
    # model.cuda()

    if cfg.init_cond_x:
        model = model_cond_x
    return model

###########################################################################################################

# cfg = OmegaConf.load("/home/odonca/workspace/rpad/equivariant_pose_graph/configs/experimental/train_pzX-dgcnn-transformer-gradclip1-se3-action-upright-anchor_pzY-pn2_gradclip1e-3_uniformz-action_uniformz-anchor_noreg.yaml")
# cfg.checkpoint_file = '/home/odonca/workspace/rpad/data/equivariant_pose_graph/logs/residual_flow_occlusion/2024-01-22_223338/residual_flow_occlusion/rsacw3qn/checkpoints/epoch_2000_global_step_250000.ckpt'
cfg = OmegaConf.load("/home/odonca/workspace/rpad/equivariant_pose_graph/configs/icra2024/train_2rackvariety_densez_learned_prior.yaml")
cfg.checkpoint_file = '/home/odonca/workspace/rpad/data/equivariant_pose_graph/logs/icra2024/2024-02-28_131608/icra2024/8tv2ecpr/checkpoints/epoch_1040_global_step_130000.ckpt'
cfg.load_cond_x = True
cfg.batch_size = 1
cfg.use_class_labels = False
cfg.pzY_input_dims = 3

cfg.dataset_root = '/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders'
cfg.test_dataset_root = '/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/test_data_duprack_bothmugrack/renders'
cfg.distractor_anchor_aug = False

dm = get_dm(cfg)

val_dataloader = dm.val_dataloader()[0]
val_iter = iter(val_dataloader)

model = get_model(cfg)
model.eval()

TRAIN Dataset
/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders
/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders/3_teleport_obj_points.npz
Using 10 files: 
[PosixPath('/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders/3_teleport_obj_points.npz'), PosixPath('/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders/9996_teleport_obj_points.npz'), PosixPath('/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders/9993_teleport_obj_points.npz'), PosixPath('/home/odonca/workspace/rpad/data/equivariant_pose_graph/data/duprack_bothmugrack/train_data_duprack_bothmugrack/renders/11_teleport_obj_points.npz'), PosixPath('/home/odonca/workspace/rpad/data/equivariant

EquivarianceTrainingModule_WithPZCondX(
  (model): Multimodal_ResidualFlow_DiffEmbTransformer(
    (tax_pose): ResidualFlow_DiffEmbTransformer(
      (emb_nn_action): DGCNN(
        (conv1): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv4): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv5): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn4): BatchNorm2d(256, eps=1e-05, mo

In [8]:
from pytorch3d.transforms import Transform3d, Translate, matrix_to_axis_angle, Rotate, random_rotations
from equivariant_pose_graph.utils.se3 import dualflow2pose, flow2pose, get_translation, get_degree_angle, dense_flow_loss, pure_translation_se3

# Get the data
batch = next(val_iter)
points_action = batch['points_action'].to(model.device)
points_anchor = batch['points_anchor'].to(model.device)
points_trans_action = batch['points_action_trans'].to(model.device)
points_trans_anchor = batch['points_anchor_trans'].to(model.device)
points_onetrans_action = batch['points_action_onetrans'].to(model.device) if 'points_action_onetrans' in batch else batch['points_action'].to(model.device)
points_onetrans_anchor = batch['points_anchor_onetrans'].to(model.device) if 'points_anchor_onetrans' in batch else batch['points_anchor'].to(model.device)

T0 = Transform3d(matrix=batch['T0']).to(model.device)
T1 = Transform3d(matrix=batch['T1']).to(model.device)

# Run the model
model_outputs = model.model(points_trans_action, 
                            points_trans_anchor, 
                            points_onetrans_action, 
                            points_onetrans_anchor, 
                            n_samples=2, 
                            sampling_method='gumbel')
# flow_fix_model_outputs = flow_fix_model.model(points_trans_action, 
#                             points_trans_anchor, 
#                             points_onetrans_action, 
#                             points_onetrans_anchor, 
#                             n_samples=1, 
#                             sampling_method='gumbel')

all_predicted_points = [T1.inverse().transform_points(points_trans_anchor)[0].detach().cpu().numpy()]
for model_output in model_outputs:
    x_action = model_output['flow_action']
    x_anchor = model_output['flow_anchor']
    goal_emb = model_output['goal_emb']

    # Get the prediction from the model forward pass
    points_action = points_action[:, :, :3]
    points_anchor = points_anchor[:, :, :3]
    points_trans_action = points_trans_action[:, :, :3]
    points_trans_anchor = points_trans_anchor[:, :, :3]

    if "sampled_ixs_action" in model_outputs[0]:
        ixs_action = model_outputs[0]["sampled_ixs_action"].unsqueeze(-1)
        sampled_points_action = torch.take_along_dim(
            points_action, ixs_action, dim=1
        )
        sampled_points_trans_action = torch.take_along_dim(
            points_trans_action, ixs_action, dim=1
        )
    else:
        sampled_points_action = points_action
        sampled_points_trans_action = points_trans_action

    if "sampled_ixs_anchor" in model_outputs[0]:
        ixs_anchor = model_outputs[0]["sampled_ixs_anchor"].unsqueeze(-1)
        sampled_points_anchor = torch.take_along_dim(
            points_anchor, ixs_anchor, dim=1
        )
        sampled_points_trans_anchor = torch.take_along_dim(
            points_trans_anchor, ixs_anchor, dim=1
        )
    else:
        sampled_points_anchor = points_anchor
        sampled_points_trans_anchor = points_trans_anchor

    pred_flow_action, pred_w_action = model.extract_flow_and_weight(x_action)
    pred_flow_anchor, pred_w_anchor = model.extract_flow_and_weight(x_anchor)

    pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, 
                                xyz_tgt=sampled_points_trans_anchor,
                                flow_src=pred_flow_action, 
                                flow_tgt=pred_flow_anchor,
                                weights_src=pred_w_action, 
                                weights_tgt=pred_w_anchor,
                                return_transform3d=True, 
                                normalization_scehme="softmax",
                                temperature=1)

    pred_points_action = pred_T_action.transform_points(points_trans_action)

    all_predicted_points.append(T1.inverse().transform_points(pred_points_action)[0].detach().cpu().numpy())

# flow_fix_all_predicted_points = [points_action[0].detach().cpu().numpy(), 
#                         T1.inverse().transform_points(points_trans_anchor)[0].detach().cpu().numpy()]
# for model_output in flow_fix_model_outputs:
#     x_action = model_output['flow_action']
#     x_anchor = model_output['flow_anchor']
#     goal_emb = model_output['goal_emb']

#     # Get the prediction from the model forward pass
#     points_action = points_action[:, :, :3]
#     points_anchor = points_anchor[:, :, :3]
#     points_trans_action = points_trans_action[:, :, :3]
#     points_trans_anchor = points_trans_anchor[:, :, :3]

#     if "sampled_ixs_action" in model_outputs[0]:
#         ixs_action = model_outputs[0]["sampled_ixs_action"].unsqueeze(-1)
#         sampled_points_action = torch.take_along_dim(
#             points_action, ixs_action, dim=1
#         )
#         sampled_points_trans_action = torch.take_along_dim(
#             points_trans_action, ixs_action, dim=1
#         )
#     else:
#         sampled_points_action = points_action
#         sampled_points_trans_action = points_trans_action

#     if "sampled_ixs_anchor" in model_outputs[0]:
#         ixs_anchor = model_outputs[0]["sampled_ixs_anchor"].unsqueeze(-1)
#         sampled_points_anchor = torch.take_along_dim(
#             points_anchor, ixs_anchor, dim=1
#         )
#         sampled_points_trans_anchor = torch.take_along_dim(
#             points_trans_anchor, ixs_anchor, dim=1
#         )
#     else:
#         sampled_points_anchor = points_anchor
#         sampled_points_trans_anchor = points_trans_anchor

#     pred_flow_action, pred_w_action = model.extract_flow_and_weight(x_action)
#     pred_flow_anchor, pred_w_anchor = model.extract_flow_and_weight(x_anchor)

#     pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, 
#                                 xyz_tgt=sampled_points_trans_anchor,
#                                 flow_src=pred_flow_action, 
#                                 flow_tgt=pred_flow_anchor,
#                                 weights_src=pred_w_action, 
#                                 weights_tgt=pred_w_anchor,
#                                 return_transform3d=True, 
#                                 normalization_scehme="softmax",
#                                 temperature=1)

#     pred_points_action = pred_T_action.transform_points(points_trans_action)

#     flow_fix_all_predicted_points.append(T1.inverse().transform_points(pred_points_action)[0].detach().cpu().numpy())


plot_multi_np(all_predicted_points)
# plot_multi_np(flow_fix_all_predicted_points)

# predictions = []
# for model_output in model_outputs:
#     prediction = model.predict(model_output, points_trans_action, points_trans_anchor)
#     predictions.append({**prediction, **model_output})
# print(prediction.keys())

# plot_all_predictions(points_onetrans_action, points_onetrans_anchor, predictions, cfg)

In [3]:
import torch
import torch.nn.functional as F

def plot_flow_debug(points_action_list, points_anchor_list, predicted_points_list, model_output_list, flow_fix=False):
    plotting_offset = [0, 0.7, 0]
    sample_offset = [1, 0, 0]
    
    all_traces = []
    
    for sample_idx, (points_action, points_anchor, predicted_points, model_output) in enumerate(zip(points_action_list, points_anchor_list, predicted_points_list, model_output_list)):
        
        # Get the points for action, transformed action, and anchor
        points_action_data = np.array(points_action.squeeze(0).cpu())
        points_trans_action_data = np.array(predicted_points.squeeze(0).cpu())
        points_anchor_data = np.array(points_anchor.squeeze(0).cpu())

        # Get the one-hot vectors for the selected action and anchor objects
        goal_emb_norm_action_onehot = model_output["trans_sample_action"].detach().cpu()
        goal_emb_norm_anchor_onehot = model_output["trans_sample_anchor"].detach().cpu()
        
        # Get the goal_emb_cond_x from p(z|X)
        goal_emb = model_output["goal_emb"]
        
        # Get the distribution over the action points and the anchor points
        goal_emb_norm_action = F.softmax(goal_emb[0, :, :points_action.shape[1]], dim=-1).detach().cpu()
        goal_emb_norm_anchor = F.softmax(goal_emb[0, :, points_anchor.shape[1]:], dim=-1).detach().cpu()

        # Get the point selected by the one-hot vector for action, transformed action, and anchor
        action_selected_point = points_action_data[goal_emb_norm_action_onehot[0].argmax()]
        trans_action_selected_point = points_trans_action_data[goal_emb_norm_action_onehot[0].argmax()]
        anchor_selected_point = points_anchor_data[goal_emb_norm_anchor_onehot[0].argmax()]

        # Create plotly traces for the action, transformed action, and anchor point clouds
        traces = []

        if True:
            # Add the whole anchor point cloud
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 6, "color": goal_emb_norm_anchor[0], "colorscale": "plasma", "line": {"width": 0}},
                    x=points_anchor_data[:, 0] + sample_idx*sample_offset[0],
                    y=points_anchor_data[:, 1] + sample_idx*sample_offset[1],
                    z=points_anchor_data[:, 2] + sample_idx*sample_offset[2],
                    name="goal embedding anchor",
                    scene="scene1"
                )
            )
            # Denote the selected anchor point
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 10, "color": "red", "line": {"width": 1}, "symbol": "x"},
                    x=[anchor_selected_point[0] + sample_idx*sample_offset[0]],
                    y=[anchor_selected_point[1] + sample_idx*sample_offset[1]],
                    z=[anchor_selected_point[2] + sample_idx*sample_offset[2]],
                    name="anchor selected point",
                    scene="scene1"
                )
            )
            
            # Add the whole transformed action point cloud
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 6, "color": goal_emb_norm_action[0], "colorscale": "plasma", "line": {"width": 0}},
                    x=points_trans_action_data[:, 0] + sample_idx*sample_offset[0],
                    y=points_trans_action_data[:, 1] + sample_idx*sample_offset[1],
                    z=points_trans_action_data[:, 2] + sample_idx*sample_offset[2],
                    name="goal embedding action",
                    scene="scene1"
                )
            )
            # Denote the selected transformed action point
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 10, "color": "red", "line": {"width": 1}, "symbol": "x"},
                    x=[trans_action_selected_point[0] + sample_idx*sample_offset[0]],
                    y=[trans_action_selected_point[1] + sample_idx*sample_offset[1]],
                    z=[trans_action_selected_point[2] + sample_idx*sample_offset[2]],
                    name="action selected point",
                    scene="scene1"
                )
            )
        
        
        # flow_action = model_output["flow_action"].permute(0, 2, 1)
        # flow_anchor = model_output["flow_anchor"].permute(0, 2, 1)
        
        # flow_head_weights_action = flow_action[:, -1, :]
        # flow_head_weights_anchor = flow_anchor[:, -1, :]
        
        # flow_head_weights_action_norm = F.softmax(flow_head_weights_action, dim=-1).detach().cpu()
        # flow_head_weights_anchor_norm = F.softmax(flow_head_weights_anchor, dim=-1).detach().cpu()
        
        # if True:
        #     # Add the whole anchor point cloud
        #     traces.append(
        #         go.Scatter3d(
        #             mode="markers",
        #             marker={"size": 4, "color": flow_head_weights_anchor_norm[0], "colorscale": "viridis", "line": {"width": 0}},
        #             x=points_anchor_data[:, 0] + plotting_offset[0],
        #             y=points_anchor_data[:, 1] + plotting_offset[1],
        #             z=points_anchor_data[:, 2] + plotting_offset[2],
        #             name="flow head weights anchor",
        #             scene="scene1"
        #         )
        #     )
        #     # Denote the selected anchor point
        #     traces.append(
        #         go.Scatter3d(
        #             mode="markers",
        #             marker={"size": 6, "color": "red", "line": {"width": 0}},
        #             x=[anchor_selected_point[0] + plotting_offset[0]],
        #             y=[anchor_selected_point[1] + plotting_offset[1]],
        #             z=[anchor_selected_point[2] + plotting_offset[2]],
        #             name="anchor selected point",
        #             scene="scene1"
        #         )
        #     )
            
        #     # Add the whole transformed action point cloud
        #     traces.append(
        #         go.Scatter3d(
        #             mode="markers",
        #             marker={"size": 4, "color": flow_head_weights_action_norm[0], "colorscale": "viridis", "line": {"width": 0}},
        #             x=points_trans_action_data[:, 0] + plotting_offset[0],
        #             y=points_trans_action_data[:, 1] + plotting_offset[1],
        #             z=points_trans_action_data[:, 2] + plotting_offset[2],
        #             name="flow head weights action",
        #             scene="scene1"
        #         )
        #     )
        #     # Denote the selected transformed action point
        #     traces.append(
        #         go.Scatter3d(
        #             mode="markers",
        #             marker={"size": 6, "color": "red", "line": {"width": 0}},
        #             x=[trans_action_selected_point[0] + plotting_offset[0]],
        #             y=[trans_action_selected_point[1] + plotting_offset[1]],
        #             z=[trans_action_selected_point[2] + plotting_offset[2]],
        #             name="action selected point",
        #             scene="scene1"
        #         )
        #     )
        
        dense_pt_action = model_output["dense_trans_pt_action"][:, -1, :]
        dense_pt_anchor = model_output["dense_trans_pt_anchor"][:, -1, :]
        
        dense_pt_action_norm = F.softmax(dense_pt_action, dim=-1).detach().cpu()
        dense_pt_anchor_norm = F.softmax(dense_pt_anchor, dim=-1).detach().cpu()
        
        if True:
            # Add the whole anchor point cloud
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 6, "color": dense_pt_anchor_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                    x=points_anchor_data[:, 0] + 2*plotting_offset[0] + sample_idx*sample_offset[0],
                    y=points_anchor_data[:, 1] + 2*plotting_offset[1] + sample_idx*sample_offset[1],
                    z=points_anchor_data[:, 2] + 2*plotting_offset[2] + sample_idx*sample_offset[2],
                    name="conditioning value anchor",
                    scene="scene1"
                )
            )
            # Denote the selected anchor point
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 10, "color": "red", "line": {"width": 0}},
                    x=[anchor_selected_point[0] + 2*plotting_offset[0] + sample_idx*sample_offset[0]],
                    y=[anchor_selected_point[1] + 2*plotting_offset[1] + sample_idx*sample_offset[1]],
                    z=[anchor_selected_point[2] + 2*plotting_offset[2] + sample_idx*sample_offset[2]],
                    name="anchor selected point",
                    scene="scene1"
                )
            )
            
            # Add the whole transformed action point cloud
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 6, "color": dense_pt_action_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                    x=points_trans_action_data[:, 0] + 2*plotting_offset[0] + sample_idx*sample_offset[0],
                    y=points_trans_action_data[:, 1] + 2*plotting_offset[1] + sample_idx*sample_offset[1],
                    z=points_trans_action_data[:, 2] + 2*plotting_offset[2] + sample_idx*sample_offset[2],
                    name="conditioning value action",
                    scene="scene1"
                )
            )
            # Denote the selected transformed action point
            traces.append(
                go.Scatter3d(
                    mode="markers",
                    marker={"size": 10, "color": "red", "line": {"width": 0}},
                    x=[trans_action_selected_point[0] + 2*plotting_offset[0] + sample_idx*sample_offset[0]],
                    y=[trans_action_selected_point[1] + 2*plotting_offset[1] + sample_idx*sample_offset[1]],
                    z=[trans_action_selected_point[2] + 2*plotting_offset[2] + sample_idx*sample_offset[2]],
                    name="action selected point",
                    scene="scene1"
                )
            )
        
        if not flow_fix:
            used_weights_action = flow_action[:, -2, :]
            used_weights_anchor = flow_anchor[:, -2, :]
            
            used_weights_action_norm = F.softmax(used_weights_action, dim=-1).detach().cpu()
            used_weights_anchor_norm = F.softmax(used_weights_anchor, dim=-1).detach().cpu()
            
            if True:
                # Add the whole anchor point cloud
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 4, "color": used_weights_anchor_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                        x=points_anchor_data[:, 0] + 3*plotting_offset[0],
                        y=points_anchor_data[:, 1] + 3*plotting_offset[1],
                        z=points_anchor_data[:, 2] + 3*plotting_offset[2],
                        name="4th dimension weights anchor",
                        scene="scene1"
                    )
                )
                # Denote the selected anchor point
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 6, "color": "red", "line": {"width": 0}},
                        x=[anchor_selected_point[0] + 3*plotting_offset[0]],
                        y=[anchor_selected_point[1] + 3*plotting_offset[1]],
                        z=[anchor_selected_point[2] + 3*plotting_offset[2]],
                        name="anchor selected point",
                        scene="scene1"
                    )
                )
                
                # Add the whole transformed action point cloud
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 4, "color": used_weights_action_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                        x=points_trans_action_data[:, 0] + 3*plotting_offset[0],
                        y=points_trans_action_data[:, 1] + 3*plotting_offset[1],
                        z=points_trans_action_data[:, 2] + 3*plotting_offset[2],
                        name="4th dimension weights action",
                        scene="scene1"
                    )
                )
                # Denote the selected transformed action point
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 6, "color": "red", "line": {"width": 0}},
                        x=[trans_action_selected_point[0] + 3*plotting_offset[0]],
                        y=[trans_action_selected_point[1] + 3*plotting_offset[1]],
                        z=[trans_action_selected_point[2] + 3*plotting_offset[2]],
                        name="action selected point",
                        scene="scene1"
                    )
                )
        
        
        if not flow_fix:
            weight_sum_action = model_output["corr_points_action"][:, -1, :]
            weight_sum_anchor = model_output["corr_points_anchor"][:, -1, :]
            
            weight_sum_action_norm = F.softmax(weight_sum_action, dim=-1).detach().cpu()
            weight_sum_anchor_norm = F.softmax(weight_sum_anchor, dim=-1).detach().cpu()
            
            if True:
                # Add the whole anchor point cloud
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 4, "color": weight_sum_anchor_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                        x=points_anchor_data[:, 0] + 4*plotting_offset[0],
                        y=points_anchor_data[:, 1] + 4*plotting_offset[1],
                        z=points_anchor_data[:, 2] + 4*plotting_offset[2],
                        name="",
                        scene="scene1"
                    )
                )
                # Denote the selected anchor point
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 6, "color": "red", "line": {"width": 0}},
                        x=[anchor_selected_point[0] + 4*plotting_offset[0]],
                        y=[anchor_selected_point[1] + 4*plotting_offset[1]],
                        z=[anchor_selected_point[2] + 4*plotting_offset[2]],
                        name="anchor selected point",
                        scene="scene1"
                    )
                )
                
                # Add the whole transformed action point cloud
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 4, "color": weight_sum_action_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                        x=points_trans_action_data[:, 0] + 4*plotting_offset[0],
                        y=points_trans_action_data[:, 1] + 4*plotting_offset[1],
                        z=points_trans_action_data[:, 2] + 4*plotting_offset[2],
                        name="transformed action",
                        scene="scene1"
                    )
                )
                # Denote the selected transformed action point
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 6, "color": "red", "line": {"width": 0}},
                        x=[trans_action_selected_point[0] + 4*plotting_offset[0]],
                        y=[trans_action_selected_point[1] + 4*plotting_offset[1]],
                        z=[trans_action_selected_point[2] + 4*plotting_offset[2]],
                        name="action selected point",
                        scene="scene1"
                    )
                )
            
        if not flow_fix:
            residual_flow_action = model_output["residual_flow_action"].permute(0, 2, 1)[:, -1, :]
            residual_flow_anchor = model_output["residual_flow_anchor"].permute(0, 2, 1)[:, -1, :]
            
            residual_flow_action_norm = F.softmax(residual_flow_action, dim=-1).detach().cpu()
            residual_flow_anchor_norm = F.softmax(residual_flow_anchor, dim=-1).detach().cpu()
            
            if True:
                # Add the whole anchor point cloud
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 4, "color": residual_flow_anchor_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                        x=points_anchor_data[:, 0] + 5*plotting_offset[0],
                        y=points_anchor_data[:, 1] + 5*plotting_offset[1],
                        z=points_anchor_data[:, 2] + 5*plotting_offset[2],
                        name="",
                        scene="scene1"
                    )
                )
                # Denote the selected anchor point
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 6, "color": "red", "line": {"width": 0}},
                        x=[anchor_selected_point[0] + 5*plotting_offset[0]],
                        y=[anchor_selected_point[1] + 5*plotting_offset[1]],
                        z=[anchor_selected_point[2] + 5*plotting_offset[2]],
                        name="anchor selected point",
                        scene="scene1"
                    )
                )
                
                # Add the whole transformed action point cloud
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 4, "color": residual_flow_action_norm[0], "colorscale": "viridis", "line": {"width": 0}},
                        x=points_trans_action_data[:, 0] + 5*plotting_offset[0],
                        y=points_trans_action_data[:, 1] + 5*plotting_offset[1],
                        z=points_trans_action_data[:, 2] + 5*plotting_offset[2],
                        name="transformed action",
                        scene="scene1"
                    )
                )
                # Denote the selected transformed action point
                traces.append(
                    go.Scatter3d(
                        mode="markers",
                        marker={"size": 6, "color": "red", "line": {"width": 0}},
                        x=[trans_action_selected_point[0] + 5*plotting_offset[0]],
                        y=[trans_action_selected_point[1] + 5*plotting_offset[1]],
                        z=[trans_action_selected_point[2] + 5*plotting_offset[2]],
                        name="action selected point",
                        scene="scene1"
                    )
                )
            
            
        # Add the whole anchor point cloud
        traces.append(
            go.Scatter3d(
                mode="markers",
                marker={"size": 6, "color": 'red', "line": {"width": 0}},
                x=points_anchor_data[:, 0] + 5*plotting_offset[0] + sample_idx*sample_offset[0],
                y=points_anchor_data[:, 1] + 5*plotting_offset[1] + sample_idx*sample_offset[1],
                z=points_anchor_data[:, 2] + 5*plotting_offset[2] + sample_idx*sample_offset[2],
                name="anchor",
                scene="scene1"
            )
        )
        # Add the whole transformed action point cloud
        traces.append(
            go.Scatter3d(
                mode="markers",
                marker={"size": 6, "color": 'blue', "line": {"width": 0}},
                x=points_trans_action_data[:, 0] + 5*plotting_offset[0] + sample_idx*sample_offset[0],
                y=points_trans_action_data[:, 1] + 5*plotting_offset[1] + sample_idx*sample_offset[1],
                z=points_trans_action_data[:, 2] + 5*plotting_offset[2] + sample_idx*sample_offset[2],
                name="action",
                scene="scene1"
            )
        )
        
        # Add the whole anchor point cloud
        traces.append(
            go.Scatter3d(
                mode="markers",
                marker={"size": 6, "color": 'black', "line": {"width": 0}},
                x=points_anchor_data[:, 0] + 6*plotting_offset[0] + sample_idx*sample_offset[0],
                y=points_anchor_data[:, 1] + 6*plotting_offset[1] + sample_idx*sample_offset[1],
                z=points_anchor_data[:, 2] + 6*plotting_offset[2] + sample_idx*sample_offset[2],
                name="anchor selected vis",
                scene="scene1"
            )
        )
        # Denote the selected anchor point
        traces.append(
            go.Scatter3d(
                mode="markers",
                marker={"size": 14, "color": "red", "line": {"width": 1}},
                x=[anchor_selected_point[0] + 6*plotting_offset[0] + sample_idx*sample_offset[0]],
                y=[anchor_selected_point[1] + 6*plotting_offset[1] + sample_idx*sample_offset[1]],
                z=[anchor_selected_point[2] + 6*plotting_offset[2] + sample_idx*sample_offset[2]],
                name="anchor selected vis point",
                scene="scene1"
            )
        )
        # Add the whole transformed action point cloud
        traces.append(
            go.Scatter3d(
                mode="markers",
                marker={"size": 6, "color": 'black', "line": {"width": 0}},
                x=points_trans_action_data[:, 0] + 6*plotting_offset[0] + sample_idx*sample_offset[0],
                y=points_trans_action_data[:, 1] + 6*plotting_offset[1] + sample_idx*sample_offset[1],
                z=points_trans_action_data[:, 2] + 6*plotting_offset[2] + sample_idx*sample_offset[2],
                name="action selected vis",
                scene="scene1"
            )
        )
        # Denote the selected transformed action point
        traces.append(
            go.Scatter3d(
                mode="markers",
                marker={"size": 14, "color": "red", "line": {"width": 1}},
                x=[trans_action_selected_point[0] + 6*plotting_offset[0] + sample_idx*sample_offset[0]],
                y=[trans_action_selected_point[1] + 6*plotting_offset[1] + sample_idx*sample_offset[1]],
                z=[trans_action_selected_point[2] + 6*plotting_offset[2] + sample_idx*sample_offset[2]],
                name="action selected vis point",
                scene="scene1"
            )
        )

        all_traces.extend(traces)
        
    if True:
    
        # Add traces to fig
        fig = go.Figure()
        fig.add_traces(all_traces)
        
        # Update layout following _3d_scene
        all_data = np.concatenate([points_action_data, points_anchor_data, points_anchor_data + np.array(plotting_offset)*6 + np.array(sample_offset)*sample_idx], axis=0)
        all_data_mean = np.mean(all_data, axis=0)
        all_data_max_x = np.abs(all_data[:, 0] - all_data_mean[0]).max()
        all_data_max_y = np.abs(all_data[:, 1] - all_data_mean[1]).max()
        all_data_max_z = np.abs(all_data[:, 2] - all_data_mean[2]).max()
        all_max = max(all_data_max_x, all_data_max_y, all_data_max_z)
        scene1 = dict(
            xaxis=dict(nticks=10, range=[all_data_mean[0] - all_max, all_data_mean[0] + all_max]),
            yaxis=dict(nticks=10, range=[all_data_mean[1] - all_max, all_data_mean[1] + all_max]),
            zaxis=dict(nticks=10, range=[all_data_mean[2] - all_max, all_data_mean[2] + all_max]),
            aspectratio=dict(x=1, y=1, z=1),
        )
        fig.update_layout(
            scene1=scene1,
            showlegend=True,
            margin=dict(l=0, r=0, b=0, t=40),
            legend=dict(x=1.0, y=1),
            width=2480,
            height=1440,
            template="simple_white",
        )

        # Display figure    
        fig.show()
        
        
    if True:
        pass
        # fig = go.Figure()
        
        # # Create box and whisker plots for the flow head weights, used weights, and residual flow original values (before softmax), make action/anchor of each type a similar color
        # flow_head_weights_action_vals = flow_head_weights_action[0].detach().cpu().numpy()
        # flow_head_weights_anchor_vals = flow_head_weights_anchor[0].detach().cpu().numpy()
        
        # if not flow_fix:
        #     used_weights_action_vals = used_weights_action[0].detach().cpu().numpy()
        #     used_weights_anchor_vals = used_weights_anchor[0].detach().cpu().numpy()
        #     weighted_sum_action_vals = weight_sum_action[0].detach().cpu().numpy()
        #     weighted_sum_anchor_vals = weight_sum_anchor[0].detach().cpu().numpy()
        #     residual_flow_action_vals = residual_flow_action[0].detach().cpu().numpy()
        #     residual_flow_anchor_vals = residual_flow_anchor[0].detach().cpu().numpy()
        
        # fig.add_trace(go.Box(y=flow_head_weights_action_vals, name="flow head weights action", marker_color="blue"))
        # fig.add_trace(go.Box(y=flow_head_weights_anchor_vals, name="flow head weights anchor", marker_color="lightblue"))
        # if not flow_fix:
        #     fig.add_trace(go.Box(y=used_weights_action_vals, name="used weights action", marker_color="red"))
        #     fig.add_trace(go.Box(y=used_weights_anchor_vals, name="used weights anchor", marker_color="pink"))
        #     fig.add_trace(go.Box(y=weighted_sum_action_vals, name="weighted sum action", marker_color="orange"))
        #     fig.add_trace(go.Box(y=weighted_sum_anchor_vals, name="weighted sum anchor", marker_color="yellow"))
        #     fig.add_trace(go.Box(y=residual_flow_action_vals, name="residual flow action", marker_color="green"))
        #     fig.add_trace(go.Box(y=residual_flow_anchor_vals, name="residual flow anchor", marker_color="lightgreen"))
        
        # # Make the plot taller
        # fig.update_layout(height=2000)
        
        # fig.show()
        
    
        
points_action_list = []
points_anchor_list = []
pred_points_list = []
model_output_list = []
for i in range(len(model_outputs)):
    model_output = model_outputs[i]
    print(model_output.keys())

    x_action = model_output['flow_action']
    x_anchor = model_output['flow_anchor']
    goal_emb = model_output['goal_emb']

    # Get the prediction from the model forward pass
    points_action = points_action[:, :, :3]
    points_anchor = points_anchor[:, :, :3]
    points_trans_action = points_trans_action[:, :, :3]
    points_trans_anchor = points_trans_anchor[:, :, :3]

    if "sampled_ixs_action" in model_outputs[0]:
        ixs_action = model_outputs[0]["sampled_ixs_action"].unsqueeze(-1)
        sampled_points_action = torch.take_along_dim(
            points_action, ixs_action, dim=1
        )
        sampled_points_trans_action = torch.take_along_dim(
            points_trans_action, ixs_action, dim=1
        )
    else:
        sampled_points_action = points_action
        sampled_points_trans_action = points_trans_action

    if "sampled_ixs_anchor" in model_outputs[0]:
        ixs_anchor = model_outputs[0]["sampled_ixs_anchor"].unsqueeze(-1)
        sampled_points_anchor = torch.take_along_dim(
            points_anchor, ixs_anchor, dim=1
        )
        sampled_points_trans_anchor = torch.take_along_dim(
            points_trans_anchor, ixs_anchor, dim=1
        )
    else:
        sampled_points_anchor = points_anchor
        sampled_points_trans_anchor = points_trans_anchor

    pred_flow_action, pred_w_action = model.extract_flow_and_weight(x_action)
    pred_flow_anchor, pred_w_anchor = model.extract_flow_and_weight(x_anchor)

    pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, 
                                xyz_tgt=sampled_points_trans_anchor,
                                flow_src=pred_flow_action, 
                                flow_tgt=pred_flow_anchor,
                                weights_src=pred_w_action, 
                                weights_tgt=pred_w_anchor,
                                return_transform3d=True, 
                                normalization_scehme="softmax",
                                temperature=1)

    pred_points_action = pred_T_action.transform_points(points_trans_action)
    pred_points = T1.inverse().transform_points(pred_points_action)[0].detach()
    
    points_action_list.append(points_action)
    points_anchor_list.append(points_anchor)
    pred_points_list.append(pred_points)
    model_output_list.append(model_output)

plot_flow_debug(points_action_list, points_anchor_list, pred_points_list, model_output_list, flow_fix=True)
    
if True:
    pass
    # model_output = flow_fix_model_outputs[0]
    # print(model_output.keys())

    # x_action = model_output['flow_action']
    # x_anchor = model_output['flow_anchor']
    # goal_emb = model_output['goal_emb']

    # # Get the prediction from the model forward pass
    # points_action = points_action[:, :, :3]
    # points_anchor = points_anchor[:, :, :3]
    # points_trans_action = points_trans_action[:, :, :3]
    # points_trans_anchor = points_trans_anchor[:, :, :3]

    # if "sampled_ixs_action" in model_outputs[0]:
    #     ixs_action = model_outputs[0]["sampled_ixs_action"].unsqueeze(-1)
    #     sampled_points_action = torch.take_along_dim(
    #         points_action, ixs_action, dim=1
    #     )
    #     sampled_points_trans_action = torch.take_along_dim(
    #         points_trans_action, ixs_action, dim=1
    #     )
    # else:
    #     sampled_points_action = points_action
    #     sampled_points_trans_action = points_trans_action

    # if "sampled_ixs_anchor" in model_outputs[0]:
    #     ixs_anchor = model_outputs[0]["sampled_ixs_anchor"].unsqueeze(-1)
    #     sampled_points_anchor = torch.take_along_dim(
    #         points_anchor, ixs_anchor, dim=1
    #     )
    #     sampled_points_trans_anchor = torch.take_along_dim(
    #         points_trans_anchor, ixs_anchor, dim=1
    #     )
    # else:
    #     sampled_points_anchor = points_anchor
    #     sampled_points_trans_anchor = points_trans_anchor

    # pred_flow_action, pred_w_action = model.extract_flow_and_weight(x_action)
    # pred_flow_anchor, pred_w_anchor = model.extract_flow_and_weight(x_anchor)

    # pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, 
    #                             xyz_tgt=sampled_points_trans_anchor,
    #                             flow_src=pred_flow_action, 
    #                             flow_tgt=pred_flow_anchor,
    #                             weights_src=pred_w_action, 
    #                             weights_tgt=pred_w_anchor,
    #                             return_transform3d=True, 
    #                             normalization_scehme="softmax",
    #                             temperature=1)

    # pred_points_action = pred_T_action.transform_points(points_trans_action)
    # pred_points = T1.inverse().transform_points(pred_points_action)[0].detach()

    # plot_flow_debug(points_action, points_anchor, pred_points, model_output, flow_fix=True)    


dict_keys(['flow_action', 'residual_flow_action', 'corr_flow_action', 'corr_points_action', 'scores_action', 'flow_anchor', 'residual_flow_anchor', 'corr_flow_anchor', 'corr_points_anchor', 'scores_anchor', 'goal_emb', 'dense_trans_pt_action', 'dense_trans_pt_anchor', 'trans_pt_action', 'trans_pt_anchor', 'trans_sample_action', 'trans_sample_anchor', 'action_points_and_cond', 'anchor_points_and_cond'])
dict_keys(['flow_action', 'residual_flow_action', 'corr_flow_action', 'corr_points_action', 'scores_action', 'flow_anchor', 'residual_flow_anchor', 'corr_flow_anchor', 'corr_points_anchor', 'scores_anchor', 'goal_emb', 'dense_trans_pt_action', 'dense_trans_pt_anchor', 'trans_pt_action', 'trans_pt_anchor', 'trans_sample_action', 'trans_sample_anchor', 'action_points_and_cond', 'anchor_points_and_cond'])
