In [None]:
from manifold_flow.flows import ManifoldFlow
from manifold_flow import transforms
from manifold_flow.architectures.vector_transforms import create_vector_transform

import numpy as np
import torch
from tqdm import tqdm
from math import sqrt
import matplotlib.pyplot as plt

In [None]:
# Define M-Flow model
params = {
    "batch_size"         : 512,
    "n_flow_steps"       : 16,   # Depth (#layers)  of the "outer transform"
    "hidden_features"    : 100,  # Width (#neurons) of the "outer transform"
    "n_transform_blocks" : 2     # ???
}
mflow = ManifoldFlow(
    data_dim=2,
    latent_dim=1,
    inner_transform=transforms.ConditionalAffineScalarTransform(features=1),
    outer_transform=create_vector_transform(
        dim                  = 2,
        flow_steps           = params["n_flow_steps"],
        hidden_features      = params["hidden_features"],
        num_transform_blocks = params["n_transform_blocks"]
    )
)

mflow.load_state_dict(torch.load("../data/models/spiral_mflow_nrotations_1.2_successful_A.pt"))

In [None]:
# Sample many points from the model to visualise the learned manifold
mflow.eval()
manifold_points = mflow.sample(n=10_000).detach().numpy()

In [None]:
def generate_grid_tensor(
        n_points_per_axis: int,
        min_val:           float,
        max_val:           float
    ) -> torch.Tensor:
    linspace    = torch.linspace(min_val, max_val, n_points_per_axis)
    x, y        = torch.meshgrid(linspace, linspace, indexing="ij")
    grid_points = torch.stack([x.flatten(), y.flatten()], dim=-1)
    return grid_points

In [None]:
# Generate random colours for the sample points
n_samples     = 10000
random_values = torch.rand(n_samples)
colours       = plt.colormaps["hsv"](random_values)[:, :3]
grid_colours  = plt.colormaps["plasma"](np.linspace(0, 1, n_samples))[:, :3]

# Generate poins on a regular grid
grid_points = generate_grid_tensor(int(sqrt(n_samples)), -10, 10)

# Transform grid points from altent space to ambient data space
points_proj = mflow.outer_transform.inverse(grid_points)[0].detach().numpy()

# Set colours for grid points with y = 0 to pink
zero_levelset_mask = (grid_points[:, 1].abs() < 0.0000001)
zero_levelset_idxs = torch.nonzero(zero_levelset_mask, as_tuple=True)[0]
grid_colours[zero_levelset_idxs] = torch.tensor([1, 0, 1])

# Define arrays used for plotting the projection vectors via `plt.quiver` 
# x_start, y_start = grid_points.T.numpy()
# x_end,   y_end   = points_proj.T
# dx               = x_end - x_start
# dy               = y_end - y_start

# Normalise vectors for visual clarity (optional)
# norm = np.linalg.norm(np.array([dx, dy]), axis=0) * 4
# dx  /= np.where(norm > 1, norm, 1)  # Normalise vectors with norm > 1
# dy  /= np.where(norm > 1, norm, 1)  # Normalise vectors with norm > 1

# Plotting
grid_colours = grid_colours[:grid_points.shape[0]]
plt.figure(figsize=(5, 5), dpi=200)
plt.scatter(*manifold_points.T, s=1, alpha=0.5, c="darkmagenta")
plt.scatter(*grid_points.T,     s=5, alpha=0.2,   c=grid_colours, lw=0)
plt.scatter(*points_proj.T,     s=5, alpha=1,   c=grid_colours, lw=0)

# Plot projection vectors
# plt.quiver(x_start, y_start, dx, dy, scale_units="xy", scale=1, color=grid_colours)

plt.gca().set_aspect("equal", adjustable="box")
plt.axis("off")
plt.xlim(-2.3, 2.3)
plt.ylim(-2.3, 2.3)
plt.tight_layout()

# Transform a horizontal line from latent space to ambient space

In [None]:
# Generate colours for the sample points
n_samples     = 10000
grid_colours  = plt.colormaps["plasma"](np.linspace(0, 1, n_samples))[:, :3]

# Generate samples and project them to the manifold
x_range = 5
y_val   = 1
xs = torch.linspace(-1, 1, n_samples) * x_range
ys = torch.ones_like(xs) * y_val
line_points = torch.column_stack([xs, ys])

# Transform points from latent space to ambiet space
points_proj = mflow.outer_transform.inverse(line_points)[0].detach().numpy()

# Plotting
grid_colours = grid_colours[:grid_points.shape[0]]
plt.figure(figsize=(5, 5), dpi=200)
plt.scatter(*manifold_points.T, s=1,  alpha=0.02, c="darkmagenta")
plt.scatter(*line_points.T, s=2, alpha=1,   c=grid_colours, lw=0)
plt.scatter(*points_proj.T, s=5, alpha=1,   c=grid_colours, lw=0)

plt.gca().set_aspect("equal", adjustable="box")
plt.axis("off")
plt.xlim(-2.3, 2.3)
plt.ylim(-2.3, 2.3)
plt.tight_layout()

## Generate animation frames of horizontal line transform

In [None]:
# # Generate colours for the sample points
# n_samples    = 10_000
# grid_colours = plt.colormaps["plasma"](np.linspace(0, 1, n_samples))[:, :3]

# # Generate x values for the line
# x_range = 5
# xs      = torch.linspace(-1, 1, n_samples) * x_range

# # Create 120 frames
# num_frames = 120
# y_vals     = np.linspace(-2.3, 2.3, num_frames)

# for frame, y_val in enumerate(tqdm(y_vals)):
#     # Generate samples and project them to the manifold
#     ys = torch.ones_like(xs) * y_val
#     line_points = torch.column_stack([xs, ys])

#     # Transform points from latent space to ambient space
#     points_proj = mflow.outer_transform.inverse(line_points)[0].detach().numpy()

#     # Plotting
#     plt.figure(figsize=(5, 5), dpi=200)
#     plt.scatter(*manifold_points.T, s=1, alpha=0.02, c="darkmagenta")
#     plt.scatter(*line_points.T,     s=2, alpha=1,    c=grid_colours, lw=0)
#     plt.scatter(*points_proj.T,     s=5, alpha=1,    c=grid_colours, lw=0)

#     plt.gca().set_aspect("equal", adjustable="box")
#     plt.axis("off")
#     plt.xlim(-2.3, 2.3)
#     plt.ylim(-2.3, 2.3)
#     plt.tight_layout()

#     # Save the frame
#     plt.savefig(f"../figures/spiral_horizontal_line_transform/frame_{frame:03d}.png", bbox_inches="tight", pad_inches=0)
#     plt.close()