# MorseGraph Example 5: Learned Dynamics

This notebook demonstrates the complete machine learning workflow for analyzing a dynamical system. The process involves:
1. Generating a time-series dataset from a system (in this case, the Henon map).
2. Training an autoencoder to learn a low-dimensional latent representation of the state space.
3. Training a model to approximate the system's dynamics in this latent space.
4. Using the `LearnedDynamics` class to compute a Morse graph of the original system by analyzing the learned latent dynamics.

In [None]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

# Import the MorseGraph components
from morsegraph.grids import UniformGrid
from morsegraph.dynamics import LearnedDynamics
from morsegraph.core import Model
from morsegraph.analysis import compute_morse_graph
from morsegraph.plot import plot_morse_graph, plot_morse_sets
from morsegraph.models import Encoder, Decoder, LatentDynamics
from morsegraph.training import Training

## 1. Generate and Prepare Data

We need a time-series dataset of the form `(x_t, x_{t+1})`. We generate this by iterating the Henon map. This data is then used to create PyTorch `DataLoader`s for training.

In [None]:
def henon_map(x, a=1.4, b=0.3):
    """ Standard Henon map. """
    x_next = 1 - a * x[:, 0]**2 + x[:, 1]
    y_next = b * x[:, 0]
    return np.column_stack([x_next, y_next])

# Generate a trajectory
num_points = 10000
x0 = np.array([[0.1, 0.1]])
trajectory = np.zeros((num_points, 2))
trajectory[0] = x0
for i in range(num_points - 1):
    trajectory[i+1] = henon_map(trajectory[i:i+1])

# Create (x_t, x_{t+1}) pairs
X_t = trajectory[:-1]
X_tau = trajectory[1:]

# Convert to torch tensors
X_t_tensor = torch.from_numpy(X_t).float()
X_tau_tensor = torch.from_numpy(X_tau).float()

# Create DataLoader
dataset = TensorDataset(X_t_tensor, X_tau_tensor)
# Using 80% for training, 20% for validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

print(f"Generated {len(dataset)} data points.")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

## 2. Train the ML Models

Now we instantiate the `Encoder`, `Decoder`, and `LatentDynamics` models, and use the `Training` class to train them on our dataset.

In [None]:
# Model parameters
input_dim = 2
latent_dim = 2 # For simplicity, we use a latent space of the same dimension
output_dim = 2

# Instantiate models
encoder = Encoder(input_dim, latent_dim)
decoder = Decoder(latent_dim, output_dim)
latent_dynamics = LatentDynamics(latent_dim)

# Instantiate training manager
trainer = Training(encoder, decoder, latent_dynamics, learning_rate=0.001)

# Run training
trainer.train(train_loader, epochs=50, val_loader=val_loader, dynamics_weight=0.5)

# Save models (optional)
# trainer.save_models("henon_models")

## 3. Analyze the Learned Dynamics

With the trained models, we can now create a `LearnedDynamics` object. This object will act as the interface between the Morse graph computation and our neural network models. We then proceed with the analysis as in the other examples.

In [None]:
# 1. Create the LearnedDynamics object
dynamics = LearnedDynamics(
    encoder=encoder,
    dynamics_model=latent_dynamics,
    decoder=decoder,
    bloat_factor=0.2 # May need a larger bloat factor for learned models
    )

# 2. Define the grid on the original state space
subdivisions = [32, 32]
domain = np.array([[-1.5, 1.5], [-0.4, 0.4]])
grid = UniformGrid(bounds=domain, subdivisions=subdivisions)

# 3. Create the model
model = Model(grid, dynamics)

# 4. Compute the map and Morse graphs
print("\nComputing map graph using learned dynamics...")
map_graph = model.compute_map_graph()
morse_graph, morse_sets = compute_morse_graph(map_graph)
print("Computation complete.")

## 4. Visualize the Results

Finally, we plot the results. The Morse sets are computed on the original grid, but the transitions were determined by the learned dynamics in the latent space. If the training was successful, this should give us a good approximation of the real system's Morse graph.

In [None]:
# Plot the Morse graph
fig, ax = plt.subplots()
plot_morse_graph(morse_graph, ax=ax)
ax.set_title("Morse Graph from Learned Dynamics")
plt.show()

# Plot the Morse sets
fig, ax = plt.subplots(figsize=(10, 5))
plot_morse_sets(morse_graph, grid, ax=ax)
ax.set_title("Morse Sets from Learned Dynamics")
# Plot the original data attractor for comparison
ax.scatter(trajectory[:, 0], trajectory[:, 1], s=1, c='red', alpha=0.1)
plt.show()