In [None]:
import matplotlib.pyplot as plt
import numpy as np

from frechet_fda.data_generation_tools import (
    gen_params_scenario_one,
    make_truncnorm_pdf,
)
from frechet_fda.distribution_tools import (
    frechet_mean,
    get_optimal_range,
    inverse_log_qd_transform,
    log_qd_transform,
    make_distribution_objects,
    mean_func,
)
from frechet_fda.fda_tools import (
    compute_centered_data,
    compute_cov_function,
    compute_principal_components,
    gen_qdtransformation_pcs,
    k_optimal,
    karhunen_loeve,
    mode_of_variation,
    total_frechet_variance,
)

In [None]:
# Set parameters
n = 200
grid_size = 10000
trunc = 3
mus, sigmas = gen_params_scenario_one(n)
# Sort sigmas, because when summing Distribution instances something goes wrong otherwise
sigmas.sort()

In [None]:
# Generate pdfs within truncation points
pdfs = make_truncnorm_pdf(-trunc, trunc, mus, sigmas, grid_size=grid_size)
# Make Distribution class objects
my_pdfs = make_distribution_objects(pdfs)
my_cdfs = [pdf.integrate() for pdf in my_pdfs]
my_qfs = [cdf.invert() for cdf in my_cdfs]
my_qdfs = [qf.differentiate() for qf in my_qfs]
# For numerical correction: shorten the range for smaller sigmas to get rid of
# numerical artifacts when computing integrals, derivatives and means later
new_ranges = get_optimal_range(my_pdfs)
# Generate pdfs again, this time within individual ranges
pdfs2 = [
    make_truncnorm_pdf(
        new_ranges[i][0],
        new_ranges[i][1],
        mus[i],
        sigmas[i],
        grid_size=grid_size,
    )[0]
    for i in range(n)
]

In [None]:
# Generate all the distribution objects
new_pdfs = make_distribution_objects(pdfs2)
new_cdfs = [pdf.integrate() for pdf in new_pdfs]
new_qfs = [cdf.invert() for cdf in new_cdfs]
new_qdfs = [qf.differentiate() for qf in new_qfs]

In [None]:
# Compute centered data, just to see whether it works
# One can clearly see how inappropriate it is to apply fda methods on densities
mean_pdf, centered_pdfs = compute_centered_data(new_pdfs)
centered_pdfs[sigmas.argmin()].compare(new_pdfs[sigmas.argmin()])

In [None]:
covariance_function = compute_cov_function(centered_pdfs)

In [None]:
eigenvalues, eigenfunctions = compute_principal_components(
    centered_pdfs[0].x,
    covariance_function,
)

## Transformation FPCA

In [None]:
# Transform pdf sample, and test whether inverse works
log_qdfs = log_qd_transform(new_pdfs)
inverse_log_qdfs = inverse_log_qd_transform(log_qdfs)
inverse_log_qdfs[0].compare(new_pdfs[0] + 0.05)

In [None]:
# Compute transformation FPCA objects
(
    mean_log_qdfs,
    eigenvalues_log_qdfs,
    eigenfunctions_log_qdfs,
    fpc_scores_log_qdfs,
) = gen_qdtransformation_pcs(log_qdfs)

In [None]:
# Karhunen-Loève decomposition of transforms
truncated_representations_transforms = karhunen_loeve(
    mean_log_qdfs,
    eigenfunctions_log_qdfs,
    fpc_scores_log_qdfs,
    K=1,
)

In [None]:
# Transform to density space
truncated_representations = inverse_log_qd_transform(
    truncated_representations_transforms,
)
truncated_representations[0].compare(new_pdfs[0])

In [None]:
# Look at modes of variance of transformed functions
variation_modes_transforms = [
    mode_of_variation(mean_log_qdfs, eigval, eigfunc, alpha=5e-3)
    for eigval, eigfunc in zip(eigenvalues_log_qdfs, eigenfunctions_log_qdfs)
]
variation_modes_transforms[0].compare(variation_modes_transforms[1])

In [None]:
# Translate modes of variation to density space, compare first two modes
variation_modes = inverse_log_qd_transform(variation_modes_transforms)
variation_modes[0].compare(variation_modes[1])

In [None]:
# Compute Fréchet mean
f_mean = frechet_mean(new_pdfs)

In [None]:
# Compute Fréchet variance
total_variance = total_frechet_variance(f_mean, new_pdfs)

In [None]:
# Try function that finds optimal trunc representation
optimal_k, fraction_explained, truncated_representations = k_optimal(
    0.5,
    total_variance,
    new_pdfs,
    mean_log_qdfs,
    eigenfunctions_log_qdfs,
    fpc_scores_log_qdfs,
)
optimal_k, fraction_explained

# Simulation

In [None]:
# Number of simulations
m = 3
sample_sizes = [50, 100, 200]
# Initialize arrays to store Fréchet and cross sectional means
stored_f_means = np.empty((m, len(sample_sizes)), dtype="object")
stored_cs_means = np.empty((m, len(sample_sizes)), dtype="object")

In [None]:
print("Simulating...", end="\r")
for i in range(m):
    # Try different sample sizes
    for j, n in enumerate(sample_sizes):
        # j index only used for storing Fréchet means below!
        # Set parameters
        grid_size = 10000
        trunc = 3
        seed_num = int(str(i) + str(n))
        mus, sigmas = gen_params_scenario_one(n, seed=seed_num)
        # Sort sigmas, because when summing Distribution instances something goes wrong otherwise
        sigmas.sort()

        # Generate pdfs within truncation points
        pdfs = make_truncnorm_pdf(-trunc, trunc, mus, sigmas, grid_size=grid_size)
        # Make Distribution class objects
        my_pdfs = make_distribution_objects(pdfs)
        my_cdfs = [pdf.integrate() for pdf in my_pdfs]
        my_qfs = [cdf.invert() for cdf in my_cdfs]
        my_qdfs = [qf.differentiate() for qf in my_qfs]
        # For numerical correction: shorten the range for smaller sigmas to get rid of
        # numerical artifacts when computing integrals, derivatives and means later
        new_ranges = get_optimal_range(my_pdfs)
        # Generate pdfs again, this time within individual ranges
        pdfs2 = [
            make_truncnorm_pdf(
                new_ranges[i][0],
                new_ranges[i][1],
                mus[i],
                sigmas[i],
                grid_size=grid_size,
            )[0]
            for i in range(n)
        ]

        # Generate all the distribution objects
        new_pdfs = make_distribution_objects(pdfs2)
        new_cdfs = [pdf.integrate() for pdf in new_pdfs]
        new_qfs = [cdf.invert() for cdf in new_cdfs]
        new_qdfs = [qf.differentiate() for qf in new_qfs]

        # Transform pdf sample
        log_qdfs = log_qd_transform(new_pdfs)

        (
            mean_log_qdfs,
            eigenvalues_log_qdfs,
            eigenfunctions_log_qdfs,
            fpc_scores_log_qdfs,
        ) = gen_qdtransformation_pcs(log_qdfs)

        # Compute Fréchet mean
        stored_f_means[i, j] = frechet_mean(new_pdfs)
        stored_cs_means[i, j] = mean_func(new_pdfs)

        # Compute Fréchet variance
        total_variance = total_frechet_variance(f_mean, new_pdfs)
        # # Try function that finds optimal trunc representation
        optimal_k, fraction_explained, truncated_representations = k_optimal(
            0.5,
            total_variance,
            new_pdfs,
            mean_log_qdfs,
            eigenfunctions_log_qdfs,
            fpc_scores_log_qdfs,
        )
    perc = int(100 * (i + 1) / m)
    print(f"Simulating...{perc}%", end="\r")

In [None]:
# Calculate true center of distribution for plotting against estimates
std_normal = make_truncnorm_pdf(-trunc, trunc, 0, 1, grid_size=grid_size)[0]

In [None]:
# Calculate Fréchet means of Fréchet means over all simulations
mean_of_f_means50 = frechet_mean(stored_f_means[:, 0])
mean_of_f_means100 = frechet_mean(stored_f_means[:, 1])
mean_of_f_means200 = frechet_mean(stored_f_means[:, 2])
mean_of_cs_means50 = frechet_mean(stored_cs_means[:, 0])
mean_of_cs_means100 = frechet_mean(stored_cs_means[:, 1])
mean_of_cs_means200 = frechet_mean(stored_cs_means[:, 2])

In [None]:
# Plot Fréchet means of different sample sizes against true center
fig, ax = plt.subplots()
ax.plot(mean_of_f_means50.x, mean_of_f_means50.y, label="f-mean 50", linestyle="--")
ax.plot(mean_of_f_means100.x, mean_of_f_means100.y, label="f-mean 100", linestyle="--")
ax.plot(mean_of_f_means200.x, mean_of_f_means200.y, label="f-mean 200", linestyle="--")
ax.plot(std_normal[0], std_normal[1], label="Standard Normal", color="Black")
plt.legend()
plt.show()

In [None]:
# Plot Fréchet means of different sample sizes against true center
fig, ax = plt.subplots()
ax.plot(mean_of_f_means200.x, mean_of_f_means200.y, label="f-mean 200", linestyle="--")
ax.plot(
    mean_of_cs_means200.x,
    mean_of_cs_means200.y,
    label="cs-mean 200",
    linestyle="--",
)
ax.plot(std_normal[0], std_normal[1], label="Standard Normal", color="black")
plt.legend()
plt.show()