# Coordinate Shifts Parameter Estimation

This notebook demonstrates a challenging test case for parameter estimation using the `LorenzParameterEstimation` package: learning coordinate shifts in the Lorenz 63 system.

## Problem Description

We modify the standard Lorenz 63 system by introducing coordinate shifts:
- x → x - x_s
- y → y - y_s  
- z → z - z_s

The goal is to learn the shift parameters (x_s, y_s, z_s) using the package's Enzyme.jl integration when starting from unstable initial conditions `randn(3) .+ (0, 0, 10)`. This transforms the problem from an initial condition problem to a climate problem.

**Key insight**: The shifts provide smooth gradients (constants in parameter space), making this an ideal test case for gradient-based optimization and showcasing the package's capabilities.

## Setup and Imports

In [None]:
using DifferentialEquations
using Plots
using Random
using LinearAlgebra
using Statistics
using Enzyme

# Import our LorenzParameterEstimation package
using Pkg
Pkg.activate("../../..")  # Activate the main project
using LorenzParameterEstimation

# Set random seed for reproducibility
Random.seed!(42)

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m DifferentialEquations ─────────────── v7.16.1
[32m[1m   Installed[22m[39m OrdinaryDiffEqRosenbrock ──────────── v1.18.1
[32m[1m   Installed[22m[39m OrdinaryDiffEqRKN ─────────────────── v1.5.0
[32m[1m   Installed[22m[39m OrdinaryDiffEqSDIRK ───────────────── v1.7.0
[32m[1m   Installed[22m[39m AlmostBlockDiagonals ──────────────── v0.1.10
[32m[1m   Installed[22m[39m BoundaryValueDiffEqFIRK ───────────── v1.9.0
[32m[1m   Installed[22m[39m TimerOutputs ──────────────────────── v0.5.29
[32m[1m   Installed[22m[39m SciMLJacobianOperators ────────────── v0.1.11
[32m[1m   Installed[22m[39m OrdinaryDiffEqStabilizedRK ────────── v1.4.0
[32m[1m   Installed[22m[39m NonlinearSolve ────────────────────── v4.11.0
[32m[1m   Installed[22m[39m BoundaryValueDiffEqMIRK ───────────── v1.9.0
[32m[1m   Inst

## Base System Setup

First, let's use the package's standard parameters and utilities.

In [None]:
# Use package standard parameters
base_params = classic_params()  # Standard Lorenz 63 parameters from package
println("Base Lorenz 63 parameters: σ=$(base_params.σ), ρ=$(base_params.ρ), β=$(base_params.β)")

# We'll extend the package to handle coordinate shifts
# This showcases how to extend the package for new test cases

## Extended System: Coordinate Shifts

Now we implement the shifted Lorenz system where each variable is shifted:
- x → x - x_s
- y → y - y_s  
- z → z - z_s

This extends the package functionality for this specific test case while maintaining compatibility with the package's design patterns.

In [None]:
# Extended Lorenz 63 system with coordinate shifts
# This extends the package functionality for this specific test case
function lorenz63_shifted!(du, u, shifts, t, base_params::L63Parameters)
    x, y, z = u
    x_s, y_s, z_s = shifts  # shift parameters to be learned
    
    # Apply shifts: x -> x - x_s, y -> y - y_s, z -> z - z_s
    x_shifted = x - x_s
    y_shifted = y - y_s
    z_shifted = z - z_s
    
    # Use package parameters in shifted equations
    du[1] = base_params.σ * (y_shifted - x_shifted)
    du[2] = x_shifted * (base_params.ρ - z_shifted) - y_shifted
    du[3] = x_shifted * y_shifted - base_params.β * z_shifted
end

# Create a wrapper that's compatible with DifferentialEquations.jl
function make_shifted_system(base_params::L63Parameters)
    return (du, u, shifts, t) -> lorenz63_shifted!(du, u, shifts, t, base_params)
end

# True shift parameters (what we want to learn)
true_shifts = [2.0, -1.5, 3.0]  # [x_s, y_s, z_s]
println("True shift parameters: x_s=$(true_shifts[1]), y_s=$(true_shifts[2]), z_s=$(true_shifts[3])")

## Generate Reference Data

Generate reference trajectory from unstable initial conditions using package-compatible data structures.

In [None]:
# Initial conditions: randn(3) .+ (0, 0, 10)
# This puts us on/close to the attractor but in unstable region
u0 = randn(3) .+ [0, 0, 10]
println("Initial conditions: $u0")

# Time parameters
tspan = (0.0, 20.0)
dt = 0.01

# Create the shifted system using our base parameters
shifted_system = make_shifted_system(base_params)

# Generate reference data with true shift parameters
prob_ref = ODEProblem(shifted_system, u0, tspan, true_shifts)
sol_ref = solve(prob_ref, Tsit5(), saveat=dt)

# Convert to package-compatible solution format
reference_data = hcat([u for u in sol_ref.u]...)  # 3×N matrix
reference_solution = L63Solution(
    reference_data,
    sol_ref.t,
    base_params,  # Use base parameters for the solution
    u0
)

println("Generated reference trajectory with $(length(sol_ref.t)) time points")
println("Final state: $(sol_ref.u[end])")

## Parameter Estimation with Enzyme

Now we'll set up parameter estimation using Enzyme.jl for gradient computation, following the package's design patterns.

In [None]:
# Custom loss function for shift parameter estimation
# This extends the package's loss functions for this specific case
function shift_loss_function(shifts, reference_sol, base_params, u0, tspan, dt)
    # Solve ODE with current shift guess
    shifted_sys = make_shifted_system(base_params)
    prob = ODEProblem(shifted_sys, u0, tspan, shifts)
    
    try
        sol = solve(prob, Tsit5(), saveat=dt)
        
        # Return early if solver failed
        if sol.retcode != :Success
            return Inf
        end
        
        # Convert to comparable format
        predicted_data = hcat([u for u in sol.u]...)
        
        # Use package-style loss computation (similar to window_rmse)
        return sum((predicted_data .- reference_sol.trajectory).^2) / length(predicted_data)
    catch
        return Inf
    end
end

# Initial parameter guess for shifts
shift_guess = [0.0, 0.0, 0.0]  # Start with no shifts
println("Initial shift guess: $shift_guess")
println("True shifts: $true_shifts")

# Test initial loss
initial_loss = shift_loss_function(shift_guess, reference_solution, base_params, u0, tspan, dt)
println("Initial loss: $initial_loss")

## Optimization with Enzyme Gradients

Use Enzyme for automatic differentiation, following the package's approach.

In [None]:
# Wrapper function for optimization
function loss_wrapper(shifts)
    return shift_loss_function(shifts, reference_solution, base_params, u0, tspan, dt)
end

# Gradient computation using Enzyme (following package patterns)
function compute_gradient!(grad, shifts)
    try
        autodiff(Reverse, loss_wrapper, Active, Duplicated(shifts, grad))
    catch e
        println("Gradient computation failed: $e")
        fill!(grad, 0.0)
    end
    return nothing
end

# Package-style optimization function (similar to modular_train! internals)
function optimize_shifts(initial_shifts; max_iters=200, lr=1e-4, verbose=true)
    shifts = copy(initial_shifts)
    grad = zeros(length(shifts))
    
    best_loss = Inf
    best_shifts = copy(shifts)
    loss_history = Float64[]
    
    verbose && println("Starting optimization with Enzyme gradients...")
    
    for iter in 1:max_iters
        # Compute loss and gradient
        current_loss = loss_wrapper(shifts)
        compute_gradient!(grad, shifts)
        
        # Update parameters (gradient descent)
        shifts .-= lr .* grad
        
        # Track best result
        if current_loss < best_loss
            best_loss = current_loss
            best_shifts = copy(shifts)
        end
        
        push!(loss_history, current_loss)
        
        # Print progress
        if verbose && (iter % 20 == 0 || iter == 1)
            println("Iter $iter: Loss = $(round(current_loss, digits=8)), Shifts = $(round.(shifts, digits=4))")
        end
        
        # Early stopping if loss is very small
        if current_loss < 1e-12
            verbose && println("Converged at iteration $iter")
            break
        end
    end
    
    return best_shifts, best_loss, loss_history
end

# Run optimization
optimized_shifts, final_loss, loss_history = optimize_shifts(shift_guess)

println("\n" * "="^50)
println("OPTIMIZATION RESULTS")
println("="^50)
println("Final shifts: $(round.(optimized_shifts, digits=6))")
println("True shifts:  $(round.(true_shifts, digits=6))")
println("Errors:       $(round.(optimized_shifts .- true_shifts, digits=6))")
println("Final loss:   $(round(final_loss, digits=10))")
println("Relative error: $(round(norm(optimized_shifts .- true_shifts) / norm(true_shifts) * 100, digits=4))%")

## Visualization and Analysis

Compare the original and reconstructed trajectories to validate our parameter estimation.

In [None]:
# Generate trajectory with learned parameters
learned_system = make_shifted_system(base_params)
prob_learned = ODEProblem(learned_system, u0, tspan, optimized_shifts)
sol_learned = solve(prob_learned, Tsit5(), saveat=dt)

# Create comparison plots
# 1. 3D phase space plot
x_ref = [u[1] for u in sol_ref.u]
y_ref = [u[2] for u in sol_ref.u]
z_ref = [u[3] for u in sol_ref.u]

x_learned = [u[1] for u in sol_learned.u]
y_learned = [u[2] for u in sol_learned.u]
z_learned = [u[3] for u in sol_learned.u]

p1 = plot(x_ref, y_ref, z_ref, 
          label="Reference (true shifts)", 
          color=:blue, alpha=0.8, linewidth=2,
          title="Lorenz 63 with Coordinate Shifts",
          xlabel="x", ylabel="y", zlabel="z")

plot!(p1, x_learned, y_learned, z_learned, 
      label="Learned shifts", 
      color=:red, alpha=0.8, linewidth=1, linestyle=:dash)

# 2. Time series comparison
p2 = plot(sol_ref.t, x_ref, label="Reference x", color=:blue, linewidth=2,
          title="Time Series Comparison", xlabel="Time", ylabel="x")
plot!(p2, sol_learned.t, x_learned, label="Learned x", color=:red, linestyle=:dash)

# 3. Loss convergence
p3 = plot(1:length(loss_history), log10.(loss_history .+ 1e-16), 
          label="Loss", color=:green, linewidth=2,
          title="Optimization Convergence", xlabel="Iteration", ylabel="log₁₀(Loss)")

# Combine plots
combined_plot = plot(p1, p2, p3, layout=(1,3), size=(1200, 400))
display(combined_plot)

## Gradient Analysis

Let's analyze the gradient properties that make this an interesting test case. The shifts provide smooth gradients (constants in parameter space).

In [None]:
# Analyze gradient behavior around the optimum
param_range = 1.0
n_points = 21
x_s_range = range(true_shifts[1] - param_range, true_shifts[1] + param_range, length=n_points)
y_s_range = range(true_shifts[2] - param_range, true_shifts[2] + param_range, length=n_points)

# Compute loss landscape for x_s and y_s (keeping z_s fixed at true value)
loss_landscape = zeros(n_points, n_points)

println("Computing loss landscape...")
for (i, x_s) in enumerate(x_s_range)
    for (j, y_s) in enumerate(y_s_range)
        params = [x_s, y_s, true_shifts[3]]
        loss_landscape[i, j] = loss_wrapper(params)
    end
end

# Plot loss landscape
p_loss = heatmap(y_s_range, x_s_range, log10.(loss_landscape .+ 1e-16),
                 xlabel="y_s", ylabel="x_s", 
                 title="Loss Landscape (z_s fixed)\nlog₁₀(Loss)",
                 color=:viridis)
scatter!(p_loss, [true_shifts[2]], [true_shifts[1]], 
         marker=:star, markersize=10, color=:red, label="True optimum")
scatter!(p_loss, [optimized_shifts[2]], [optimized_shifts[1]], 
         marker=:circle, markersize=8, color=:yellow, label="Found optimum")

display(p_loss)

println("\nObserve the smooth, quadratic loss landscape - ideal for gradient-based optimization!")