#### Create synthetic plots from Canonical Correspondence Analysis (CCA) results

This notebook goes through the steps needed to create synthetic plots from CCA results using sample data from the `sknnr` package.  We demonstrate the use of different synthetic plot network types including:
- `ReferenceNetwork`: a point network where synthetic points are at the same locations as the original points
- `FuzzedNetwork`: a point network where synthetic points are randomly fuzzed a certain distance from the original points in each CCA dimension
- `QuantileMesh`: a mesh network where synthetic points are placed at the quantiles of the CCA scores in each dimension
- `EqualIntervalMesh`: a mesh network where synthetic points are placed at equal intervals of the CCA scores in each dimension

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sknnr.datasets import load_moscow_stjoes
from sknnr.transformers import CCATransformer

from synthetic_knn.synthetic_plots import SyntheticPlots
from synthetic_knn.point_networks import (
    ReferenceNetwork,
    FuzzedNetwork,
    QuantileMesh,
    EqualIntervalMesh
)

Load in the sample Moscow Mountain / St. Joes data from `sknnr` as separate `X` and `y` dataframes.

In [None]:
X, y = load_moscow_stjoes(return_X_y=True, as_frame=True)

Use the `CCATransformer` to transform these data into a 5-dimensional space and created the transformed X scores to serve as the reference coordinates in the CCA space.

In [None]:
estimator = CCATransformer(n_components=5).fit(X, y)
X_transformed = estimator.transform(X)

Create a network of synthetic plots using the reference coordinates to set the CCA space.  Initially, create a `ReferenceNetwork` where the synthetic points are at the same locations as the original points.  The `network` argument must be an instance of a class that inherits from the `PointNetwork` superclass.  The `k` argument specifies how many (reference) neighbors to return for each synthetic plot.

In [None]:
reference_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=ReferenceNetwork(),
    k=10
)

Get the _k_ neighbors and distances associated with each synthetic plot

In [None]:
d, n_idx = reference_network_plots.distances(), reference_network_plots.neighbors()
print(d[:5, :5])
print(n_idx[:5, :5])

We can also crosswalk this back to the IDs associated with the neighbors

In [None]:
n = reference_network_plots.neighbors(id_arr=np.array(y.index))
print(n[:5, :5])

Notice how the first neighbor distances are zero for the first 5 records.  For the `ReferenceNetwork`, the first neighbor is always the original point itself.  Verify that the first neighbor distance is always zero and that the first neighbor ID is the same as the original point ID.

In [None]:
print("All zero distances:", np.all(d[:, 0] == 0))
print("All neighbors match:", np.all(n[:, 0] == y.index))

We can also retrieve synthetic plot neighbors and distances as pandas `DataFrame`s which can be written out to CSV files.  Note that the synthetic plots will receive dummy IDs, sequentially numbered from 1 to the number of synthetic plots.  The index of the datafarame will be called `SYNTHETIC_PLOT_ID`.

In [None]:
distances_df = reference_network_plots.distances(as_frame=True)
neighbors_df = reference_network_plots.neighbors(id_arr=np.array(y.index), as_frame=True)
print(distances_df.head())
print(neighbors_df.head())

distances_df.to_csv("../data/networks/reference_k25_distances.csv", float_format="%.6f")
neighbors_df.to_csv("../data/networks/reference_k25_neighbors.csv")


If desired, we can also capture the synthetic plot coordinates to a `DataFrame` and write out to a CSV file.  Because this network may be used in a second-stage imputation, it is useful to retain these coordinates such that they can be used to train a `sklearn.neighbors.NearestNeighbors` estimator in the second stage.

In [None]:
coordinates_df = reference_network_plots.synthetic_coordinates(as_frame=True, prefix="CCA")
print(coordinates_df.head())
coordinates_df.to_csv("../data/networks/reference_k25_coordinates.csv", float_format="%.6f")

#### Using other network types

Now, use a `FuzzedNetwork` to add some noise to the reference coordinates.  This will ensure that the fuzzed coordinate is between the minimum and maximum distances of the coupled reference point, but the fuzzed coordinate may be smaller than the minimum distance to *another* reference point.

In [None]:
fuzzed_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=FuzzedNetwork(minimum_distance=0.1, maximum_distance=0.5),
    k=10
)

Finally, create `QuantileMesh` and `EqualIntervalMesh` networks.  These networks are mesh networks where the synthetic points are placed at the midpoints of the quantiles or equal intervals of the CCA scores in each dimension.  The `n_bins` argument associated with these network's intializers specifies how many bins to use in each dimension.

In [None]:
quantile_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=QuantileMesh(n_bins=10),
    k=10
)

equal_interval_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=EqualIntervalMesh(n_bins=10),
    k=10
)

With these mesh types, beware the curse of dimensionality.  If a user has a CCA ordination with many components and you request even a small number of bins per component (axis), they may end up with a very large number of synthetic plots (`n_bins` ** `n_dimensions`).  At present, the maximum number of synthetic plots is capped at 1,000,000 and an error will be raised if the user tries to exceed this.

In [None]:
# Create a fake set of reference coordinates with 100 plots and 10 dimensions
n_components = 10
rng = np.random.default_rng(42)
reference_coordinates = rng.normal(size=(100, n_components))

# Request five bins per component
n_bins = 5

# An error is raised when creating the plots
quantile_network_plots_error = SyntheticPlots(
    reference_coordinates=reference_coordinates,
    network=QuantileMesh(n_bins=n_bins),
)

#### Visualizing synthetic networks

It can be difficult to conceptualize what these plot networks look like in high-dimensional space.  We can use `matplotlib` to visualize (at least) the first three axes of the different synthetic networks.  This allows the user to understand how the reference plots and different synthetic plot networks relate to one another.

Again, using the Moscow Mountain / St. Joes dataset, we'll create a CCA transformation and capture the plot scores for the first three CCA axes.

In [None]:
X, y = load_moscow_stjoes(return_X_y=True, as_frame=True)
estimator = CCATransformer(n_components=3).fit(X, y)
X_transformed = estimator.transform(X)

Now create the `FuzzedNetwork`, `QuantileMesh`, and `EqualIntervalMesh` networks as before, but using eight bins per axis for the mesh networks (there is no need to visualize the `ReferenceNetwork` as it is the same as the original points).

In [None]:
fuzzed_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=FuzzedNetwork(minimum_distance=0.1, maximum_distance=0.5),
    k=10
)

quantile_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=QuantileMesh(n_bins=8),
    k=10
)

equal_interval_network_plots = SyntheticPlots(
    reference_coordinates=X_transformed,
    network=EqualIntervalMesh(n_bins=8),
    k=10
)

Concatentate the original plot scores and the three synthetic networks into a `pandas.DataFrame` to facilitate visualization.

In [None]:
plots = pd.concat([
    pd.DataFrame(X_transformed, columns=["x", "y", "z"]).assign(group="reference"),
    pd.DataFrame(fuzzed_network_plots.synthetic_coordinates(), columns=["x", "y", "z"]).assign(group="fuzzed"),
    pd.DataFrame(quantile_network_plots.synthetic_coordinates(), columns=["x", "y", "z"]).assign(group="quantile"),
    pd.DataFrame(equal_interval_network_plots.synthetic_coordinates(), columns=["x", "y", "z"]).assign(group="equal_interval"),
])

plots.head()

Finally, create paired 3D scatter plots of the original plots against each of the three synthetic networks.

In [None]:
fig, ax = plt.subplots(ncols=3, subplot_kw={"projection": "3d"}, figsize=(16, 16))

ref = plots[plots["group"] == "reference"]

for axis, group in zip(ax, ["fuzzed", "equal_interval", "quantile"]):
    data = plots[plots["group"] == group]

    axis.scatter(ref["x"], ref["y"], ref["z"], label="reference", alpha=0.5)
    axis.scatter(data["x"], data["y"], data["z"], label=group, alpha=0.7)

    axis.set_proj_type("ortho")
    axis.view_init(elev=30, azim=30)
    axis.legend()