# Visualize fitted beta distributions

In [None]:
import os
import plotly.graph_objects as go
import torch

from semalign3d.core import data_classes
from semalign3d.core.data import (
    setup_paths,
    raw_data_utils,
    sem3d_data_utils,
	keypoint_processing,
    augmentations
)
from semalign3d.core.geometry import geom_filter, geom_features

In [None]:
run_config_path = f"{os.getcwd()}/../scripts/run_config.yaml"
paths = setup_paths.setup_paths(run_config_path)
all_categories = raw_data_utils.DATASET_NAME_TO_CATEGORIES[paths.dataset_name]

In [None]:
category = "aeroplane"
use_vggt = False
sem_align_3d_data = sem3d_data_utils.load_data(
	category=category,
	paths=paths,
	load_geom_stats=False,
	do_load_sparse_pc=False,
	do_load_dense_pc=False,
	use_vggt=use_vggt,
)

In [None]:
gt_kpts_data = keypoint_processing.correct_gt_kpts_xyz(
	sem_align_3d_data.full_data.processed_data_train
)
# the params below were used for beta distribution fitting
n_samples_per_img = 10
noise_rate=0.2
gt_kpts_data_aug = augmentations.generate_noisy_gt_kpts_data(
	gt_kpts_data,
	n_samples_per_img=n_samples_per_img,
	noise_rate=noise_rate,
)

In [None]:
load_geom_stats_suffix = ""
geom_stats_partial = torch.load(
	f"{paths.geom_stats_dir}/{category}_geom_stats{load_geom_stats_suffix}{paths.suffix}.pt",
	weights_only=False,
)
geom_relation_combinations_partial = torch.load(
	f"{paths.geom_stats_dir}/{category}_geom_combinations{load_geom_stats_suffix}{paths.suffix}.pt",
	weights_only=False,
)
geom_stats_partial_filtered, geom_relation_combinations_partial_filtered = (
	geom_filter.filter_valid_stats(
		geom_stats_partial, geom_relation_combinations_partial
	)
)
geom_relation_combinations_partial_filtered.e1.shape

In [None]:
geom_feature_idx = 400

geom_relation_combinations_batch = data_classes.GeomRelationCombinations(
	e1=geom_relation_combinations_partial.e1[geom_feature_idx : geom_feature_idx + 1],
	e2=geom_relation_combinations_partial.e2[geom_feature_idx : geom_feature_idx + 1],
	t=geom_relation_combinations_partial.t[:1],
	v=geom_relation_combinations_partial.v[:1],
)
geom_fts = geom_features.compute_geom_features_batched(
	vertices=gt_kpts_data_aug.xyz,
	geom_relation_combinations=geom_relation_combinations_batch,
)

from scipy.stats import beta
def create_beta_dist_line(a, b):
	x = torch.linspace(0, 1, 100)
	y = beta.pdf(x.numpy(), a=a, b=b)
	return x, y

fts_m = geom_stats_partial.m_edge_pair_ratios
print("n samples:", torch.sum(fts_m[:,geom_feature_idx]).int().item())
fts = geom_fts.edge_pair_ratios[fts_m[:,geom_feature_idx],0]
x, y = create_beta_dist_line(
	geom_stats_partial.edge_pair_ratios_alpha[geom_feature_idx],
	geom_stats_partial.edge_pair_ratios_beta[geom_feature_idx]
)

go.Figure([
    go.Histogram(x=fts, nbinsx=50, name='Edge Pair Ratios', histnorm='probability density'),
	go.Scatter(
		x=x,
		y=y,
		mode='lines',
		name='Fitted Beta Distribution',
		line=dict(color='red')
	)
])