In [None]:
import torch
import sys
import os
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
import matplotlib.cm as cm
import scipy
from datetime import datetime
import json

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
	sys.path.insert(0, module_path)
print(sys.path)

import Double_Pendulum.Learning.autoencoders as autoencoders
import Double_Pendulum.robot_parameters as robot_parameters
import Double_Pendulum.transforms as transforms
import Double_Pendulum.dynamics as dynamics
import Plotting.pendulum_plot as pendulum_plot

from functools import partial



%load_ext autoreload
%autoreload 2


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
rp = robot_parameters.LUMPED_PARAMETERS
plotter = pendulum_plot.Anim_plotter(rp)


model_cw = False
model_ana = autoencoders.Analytic_transformer(rp)

In [None]:
import torch

def find_q_des_equilibrium(rp, q_des, max_iter = 100, tol = 1e-6):


    q = q_des.clone().detach().requires_grad_(True)
    opt = torch.optim.LBFGS([q],
        max_iter=max_iter,
        tolerance_grad=tol,
        tolerance_change=tol
    )

    def closure():
        opt.zero_grad()
        # Residual
        M_q, C_q, G_q = dynamics.dynamical_matrices(rp, q, torch.zeros_like(q))
        A_q = dynamics.input_matrix(rp, q)
        u_ff = (A_q.T @ G_q) / (A_q.T @ A_q)
        r = G_q - A_q * u_ff
        res_loss = (r**2).sum()

        loss = res_loss
        loss.backward()
        return loss

    opt.step(closure)
    return q.detach()


In [None]:

def forward_kinematics_np(rp, q_np):
    """numpy forward kinematics of end effector only."""
    x = rp['l0']*np.cos(q_np[0]) + rp['l1']*np.cos(q_np[1])
    y = rp['l0']*np.sin(q_np[0]) + rp['l1']*np.sin(q_np[1])
    return np.array([x, y])

def find_q_eq_np(rp, q_des_np):
    """
    Wrap your Torch-based equilibrium solver into numpy.
    Replace body with your actual call to find_q_des_equilibrium().
    """
    q_des_t = torch.from_numpy(q_des_np).float().to(device)
    #q_eq_t  = find_q_des_equilibrium(rp, q_des_t).detach()
    q_eq_t = find_q_des_equilibrium(rp, q_des_t)
    return q_eq_t.cpu().numpy()

# ------------------------------------------------------------------
# 2) Sample q_des and compute q_eq
# ------------------------------------------------------------------
N = 100
# e.g. sample each joint uniformly in [-pi/2, +pi/2]
qs_des = np.random.uniform(-np.pi, np.pi, size=(N,2))
qs_eq  = np.array([find_q_eq_np(rp, qd) for qd in qs_des])

# compute Cartesian start/end
pos_des = np.array([forward_kinematics_np(rp, q) for q in qs_des])  # shape (N,2)
pos_eq  = np.array([forward_kinematics_np(rp, q) for q in qs_eq])   # shape (N,2)

# ------------------------------------------------------------------
# 3) Set up figure
# ------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(6,6))
R = rp['l0'] + rp['l1']
ax.set_aspect('equal', 'box')
ax.set_xlim(-R*1.2, R*1.2)
ax.set_ylim(-R*1.2, R*1.2)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title('q_des → q_eq mapping')

# grid
ax.grid(True, linestyle='--', alpha=0.5)

# dashed reach circle
circle = plt.Circle((0,0), R, fill=False, linestyle='--', color='gray')
ax.add_patch(circle)

# vertical line at x_a
ax.axvline(rp['xa'], linestyle='--', color='gray')

# scatter objects (start in pos_des)
scat = ax.scatter(pos_des[:,0], pos_des[:,1], c='C0', label='end-effector')

ax.legend(loc='upper right')
# ------------------------------------------------------------------
# 4) Animation update: interpolate positions
# ------------------------------------------------------------------
num_frames = 100
def update(frame):
    t = frame / (num_frames-1)
    pts = (1-t)*pos_des + t*pos_eq
    scat.set_offsets(pts)
    return scat,

anim = FuncAnimation(fig, update,
                     frames=num_frames,
                     interval=50, blit=True)

# ------------------------------------------------------------------
# 5) Show or save
# ------------------------------------------------------------------
# To display in a notebook:
plt.show()

# Or to save to file:
# anim.save('q_des_to_q_eq.mp4', fps=30, dpi=200)
anim.save('q_des_to_q_eq.mp4', fps=30, dpi=200)