In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, Dataset, DataLoader
import functorch
import matplotlib.pyplot as plt

import os
from datetime import datetime
import time

import Double_Pendulum.Lumped_Mass.robot_parameters as robot_parameters
import Double_Pendulum.Lumped_Mass.transforms as transforms
import Double_Pendulum.Lumped_Mass.dynamics as dynamics
import Learning.loss_terms as loss_terms
#import Plotting.plotters_h1h2 as plotters_h1h2
import Learning.training_data as training_data
import Plotting.theta_visualiser as theta_visualizer

import Models.autoencoders as autoencoders

from functools import partial

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
rp = robot_parameters.LUMPED_PARAMETERS
rp['m1'] = 0.0
print(rp)

train_clockwise = False

In [None]:
def mask_points(q1_split, clockwise = False):

    """
    Returns a set of [q1, q2] points based on "q1_split" limits on q1 and q2.
    The limits on q2 depend on whether a clockwise or counterclockwise dataset is selected.
    """   

    # Retrieve training points
    points = training_data.points.to(device)
    
    # Mask to retrieve only the counterclockwise points
    width_mask = (points[:,0] >= q1_split[0]) & (points[:,0] <= q1_split[1])
    ccw_mask = ((points[:,1] >= points[:,0]) & 
                  (points[:,1] <= points[:,0] + torch.pi)) #| ((points[:,1] >= points[:,0] - 2*torch.pi) &
                  #(points[:,1] <= points[:,0] - torch.pi))
    
    # Mask to retrieve only the clockwise points
    cw_mask = ((points[:,1] >= points[:,0] - torch.pi) & (points[:,1] <= points[:,0]))

    if clockwise:
        final_mask = width_mask & cw_mask
    else:
        final_mask = width_mask & ccw_mask
    
    points = points[final_mask]
    points = points[0:3000]

    if points.size(0) < 3000:
        print("Warning: Only", points.size(0), "points in dataset.")

    return(points)

In [None]:
def points_plotter(points, extend = None):

   """ 
   Simple plotter function which visualizes the points used for training of the Autoencoder. 
   """
    
   plt.figure(figsize=(4, 4))
   plt.scatter(points[:, 0].cpu().numpy(), points[:, 1].cpu().numpy(), alpha=0.6, edgecolors='k', s=20)
   plt.title('Scatter Plot of q1 vs q2')
   plt.xlabel('q1')
   plt.ylabel('q2')
   plt.xlim(-2*torch.pi, 2*torch.pi)
   plt.ylim(-2*torch.pi, 2*torch.pi)
   plt.grid(True)
   plt.show()

In [None]:
masked_points = mask_points((-torch.pi, torch.pi), clockwise = train_clockwise)
deshifted_points = transforms.wrap_to_pi(masked_points.clone())
points_plotter(masked_points, extend="ccw")
points_plotter(deshifted_points, extend="ccw")

In [None]:


def make_dataset(points):

    """
    Compute mass- and input matrix of all training points to reduce load in training.
    Returns TensorDataset of (q, M_q, A_q). 
    """

    data_pairs = []
    for point in points:
        Mq_point, _, _ = dynamics.dynamical_matrices(rp, point, point)
        Aq_point = dynamics.input_matrix(rp, point)
        data_pairs.append((point, Mq_point, Aq_point))

    points_tensor = torch.stack([pair[0] for pair in data_pairs])           # Tensor of all points
    mass_matrices_tensor = torch.stack([pair[1] for pair in data_pairs])   # Tensor of all mass matrices
    input_matrices_tensor = torch.stack([pair[2] for pair in data_pairs])  # Tensor of all input matrices

    # Create TensorDataset
    dataset = TensorDataset(points_tensor, mass_matrices_tensor, input_matrices_tensor)
    return(dataset)


In [None]:
def make_dataloaders(dataset, batch_size = 512, train_part = 0.7):

    """
    Creates a training and validation dataloader from an input dataset, based on 
    batch size and the ratio train_part. 
    """

    train_size = int(train_part * len(dataset))
    val_size = len(dataset) - train_size

    # Create TensorDataset for both training and testing sets
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Create the DataLoader for both training and testing sets
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    return(train_dataloader, val_dataloader)

In [None]:
import torch.nn.functional as F

def loss_fun(q, theta, q_hat, M_q, A_q, J_h_enc, J_h_dec):

    """
    Loss function for training the Autoencoder. Loss terms are the following:
    l_recon:    Loss between input- and reconstructed variable. (MSE)
    l_off_dia:  Loss of off-diagonal terms of mass matrix in theta-space. (MSE)
    l_dia:      Loss on diagonal terms of mass matrix in theta-space. 
                (mean of normalized negative log-loss)
    l_input:    Loss to drive input matrix in theta-space to [1, 0]^T (MSE)

    """

    l_recon = F.mse_loss(q, q_hat, reduction="mean")

    # Calculate forward and inverse Jacobians
    J_h = J_h_enc
    J_h_trans = torch.transpose(J_h, 1, 2)
    J_h_inv = J_h_dec
    J_h_inv_trans = torch.transpose(J_h_inv, 1, 2)

    M_th = J_h_inv_trans @ M_q @ J_h_inv
    A_th = J_h_inv_trans @ A_q

    # Loss inspired by Pietro Pustina's paper on input decoupling:
    # https://arxiv.org/pdf/2306.07258
    l_input_jac = F.mse_loss(J_h[:, 0, :], A_q[:, :, 0], reduction="mean")



    # Loss on the first coordinate theta, again from Pietro Pustina's analytic formulation
    theta_ana = torch.vmap(transforms.analytic_theta, in_dims = (None, 0))(rp, q)
    l_theta = F.mse_loss(theta[:, 0], theta_ana[:, 0], reduction="mean")

    # Enforce inertial decoupling
    l_off_diag = torch.mean((M_th[:, 0, 1])**2)
    diag_values = torch.diagonal(M_th, dim1=1, dim2=2)
    l_diag = torch.mean((-1 + torch.exp(-(diag_values - 1))) * (diag_values < 1).float())  # Shape: (batch_size, 2)

    ## input decoupling loss
    #l_input = torch.mean((A_th[:, 1]**2)) + torch.mean(((A_th[:, 0]-1)**2))
    l_input = torch.mean((A_th[:, 1]**2))
    
    #print(torch.eye(M_th.size(-1)))]).
    M_eye = torch.eye(M_th.size(-1), device=M_th.device, dtype=M_th.dtype).expand(M_th.size())
    l_inertia = F.mse_loss(M_th, M_eye, reduction="mean")

    f_recon        = 10.
    f_diag         = 0. #TODO: CHECK OUT THESE LOSS WEIGHTS
    f_off_diag     = 0. #TODO: CHECK OUT THESE LOSS WEIGHTS
    f_input        = 1.
    f_input_jac    = 0.
    f_theta        = 0.
    f_inertia      = 10.
    
    #loss_terms = torch.tensor([l_recon * f_recon, l_diag * f_diag, l_off_diag * f_off_diag, 
    #                           l_input * f_input, l_input_jac * f_input_jac, l_theta * f_theta])
    #loss_sum = l_recon * f_recon + l_diag * f_diag + l_off_diag * f_off_diag + l_input * f_input + l_input_jac * f_input_jac + l_theta * f_theta

    loss_terms = torch.tensor([l_recon * f_recon, l_inertia * f_inertia, 
                               l_input * f_input, l_input_jac * f_input_jac, l_theta * f_theta])
    loss_sum = l_recon * f_recon + l_inertia * f_inertia + l_input * f_input + l_input_jac * f_input_jac + l_theta * f_theta


    lamda = 100

    return lamda * loss_sum, lamda * loss_terms

In [None]:

#model.load_state_dict(torch.load(load_path, weights_only=True))


def train_AE_model(rp, device, lr, num_epochs, q1_split, train_dataloader, val_dataloader, current_time):

    """
    Executes training loop for Autoencoder
    """
   
    model = autoencoders.Autoencoder_double(rp).to(device)  # Move model to GPU
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)#,  weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5 ** (1 / num_epochs))


    train_losses = []
    val_losses = []
    start_time = time.time()
    save_directory = os.path.join(os.getcwd(), "Models/Split_AEs")
    os.makedirs(save_directory, exist_ok=True)
    file_name = f"Lumped_Mass_{current_time}.pth"
    file_path = os.path.join(save_directory, file_name)

    JSON = {"q1_low" : q1_split[0],
            "q1_high" : q1_split[1],
            "lr" : lr,
            "epochs" : num_epochs,
            "file_name" : file_name}

    for epoch in range(num_epochs):

        # Training phase
        model.train()
        train_loss = 0
        train_loss_terms = torch.zeros(5)
        for index, (q, M_q, A_q) in enumerate(train_dataloader):
            q = q.to(device)

            M_q = M_q.to(device)
            A_q = A_q.to(device)
            
            theta, J_h, q_hat, J_h_dec, J_h_ana = model(q)  
            theta_ana = model.theta_ana(q)
                    
            loss, loss_terms = loss_fun(q, theta, q_hat, M_q, A_q, J_h, J_h_dec)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_loss_terms += loss_terms
        train_loss /= len(train_dataloader.dataset)
        train_loss_terms /= len(train_dataloader.dataset)
        train_losses.append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0
        val_loss_terms = torch.zeros(5)
        with torch.no_grad():
            for index, (q, M_q, A_q) in enumerate(val_dataloader):
                q = q.to(device)
                M_q = M_q.to(device)
                A_q = A_q.to(device)

                theta, J_h, q_hat, J_h_dec, J_h_ana = model(q)
                theta_ana = model.theta_ana(q)

                loss, loss_terms = loss_fun(q, theta, q_hat, M_q, A_q, J_h, J_h_dec)

                J_h_inv = torch.linalg.inv(J_h)
                J_h_inv_trans = torch.transpose(J_h_inv, 1, 2)
                M_th = J_h_inv_trans @ M_q @ J_h_inv 

                val_loss += loss.item()
                val_loss_terms += loss_terms
        val_loss /= len(val_dataloader.dataset)
        val_loss_terms /= len(val_dataloader.dataset)
        val_losses.append(val_loss)
        epoch_duration = time.time() - start_time
        scheduler.step()
        tlt = train_loss_terms
        print(
            f'Epoch [{epoch + 1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Duration: {epoch_duration:.2f} seconds')
        #print(
        #    f'l_recon: {tlt[0]:.4f}, l_diag: {tlt[1]:.4f}, l_off_diag: {tlt[2]:.4f}, l_input: {tlt[3]:.4f}, l_input_jac: {tlt[4]:.4f}, l_theta: {tlt[5]:.4f}'
        #)
        print(
            f'l_recon: {tlt[0]:.4f}, l_inertia: {tlt[1]:.4f}, l_input: {tlt[2]:.4f}, l_input_jac: {tlt[3]:.4f}, l_theta: {tlt[4]:.4f}'
        )
        
    return(model, train_losses, val_losses, file_path)


In [None]:
def save_model(model, file_path):
    torch.save(model.state_dict(), file_path)
    print(f"Model parameters saved to {file_path}")


In [None]:
def plot_loss(train_losses, val_losses, log = False):

    """
    Plots training and validation loss. 
    ylim" and "yscale" should be enabled depending on the loss function.
    """

    plt.figure(figsize=(5, 3))
    plt.plot(train_losses, label="Training Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    #plt.ylim((-1, 40))
    plt.legend()
    plt.title("Training and Validation Loss over Epochs")
    plt.grid(True)
    if log:
        plt.yscale("log")
    plt.show()


In [None]:
import Plotting.plotters_simple as plotters_simple

def make_plot_dataloader(dataset):

    """
    Takes the training dataset and returns a dataloader of every 10th point
    to reduce visual clutter. 
    """

    points_tensor, mass_matrices_tensor, input_matrices_tensor = dataset.tensors
    
    plot_sampled = points_tensor[::10]
    mass_sampled = mass_matrices_tensor[::10]
    input_sampled = input_matrices_tensor[::10]

    plot_dataset = TensorDataset(plot_sampled, mass_sampled, input_sampled)
    plot_dataloader = DataLoader(plot_dataset, batch_size=len(plot_dataset), shuffle=False, num_workers=0)

    return(plot_dataloader)

In [None]:
def plot_model_performance(model, plot_dataloader):

    model_ana = autoencoders.Analytic_transformer(rp)

    model.eval()
    with torch.no_grad():
        for (q, M_q, A_q) in plot_dataloader:
            q = q.to(device)
            M_q = M_q.to(device)
            A_q = A_q.to(device)

            theta, J_h, q_hat, J_h_dec, J_h_ana = model(q)
            theta_ana = model.theta_ana(q)
            J_h_trans = torch.transpose(J_h, 1, 2)
            J_h_inv = J_h_dec
            J_h_inv_trans = torch.transpose(J_h_inv, 1, 2)

            J_h_inv_ana = torch.linalg.inv(J_h_ana)
            J_h_inv_trans_ana = torch.transpose(J_h_inv_ana, 1, 2)

            M_th = J_h_inv_trans @ M_q @ J_h_inv
            A_th = (J_h_inv_trans @ A_q).squeeze(-1)

            M_th_ana = J_h_inv_trans_ana @ M_q @ J_h_inv_ana
            A_th_ana = (J_h_inv_trans_ana @ A_q).squeeze(-1)

            
            plotters_simple.plot_3d_double(q, theta_ana[:, 0], theta[:, 0], "th0", "analytical", "learned", "q_0", "q_1", "th0", device)
            plotters_simple.plot_3d_double(q, theta_ana[:, 1], theta[:, 1], "th1", "analytical", "learned", "q_0", "q_1", "th1", device)


            plotters_simple.plot_3d_double(q, A_th[:, 0], A_th[:, 1], "Input decoupling", "A0", "A1", "q_0", "q_1", "A", device)
            plotters_simple.plot_3d_double(q, A_th_ana[:, 0], A_th_ana[:, 1], "Input decoupling ana", "A0", "A1", "q_0", "q_1", "A", device)
            A_th = A_th.cpu().detach().numpy()
            print("Percentage of abs(A_0) > 0.6:", 100 * np.sum(np.abs(A_th[:, 0]) > 0.6)/A_th[:, 0].size, "%")
            print("Percentage of abs(A_1) < 0.3:", 100 * np.sum(np.abs(A_th[:, 1]) < 0.3)/A_th[:, 1].size, "%")
            
            #plotters_simple.plot_3d_double(q, M_th[:, 0, 0], M_th[:, 0, 1], "M_th", "00", "01", "q_0", "q_1", "M_th", device)
            #plotters_simple.plot_3d_double(q, M_th[:, 1, 0], M_th[:, 1, 1], "M_th", "10", "11", "q_0", "q_1", "M_th", device)
            M_th_cpu = M_th.cpu().detach().numpy()
            #print("Percentage of abs(M_00) > 1.0:", 100 * np.sum(np.abs(M_th_cpu[:, 0, 0]) > 1.0)/M_th_cpu[:, 0, 0].size, "%")
            #print("Percentage of abs(M_01) < 0.2:", 100 * np.sum(np.abs(M_th_cpu[:, 0, 1]) < 0.2)/M_th_cpu[:, 0, 1].size, "%")
            #print("Percentage of abs(M_11) > 1.0:", 100 * np.sum(np.abs(M_th_cpu[:, 1, 1]) > 1.0)/M_th_cpu[:, 1, 1].size, "%")

            #plotters_simple.plot_3d_double(theta, M_th[:, 0, 0], M_th[:, 0, 1], "M_th vs. th", "00", "01", "th_0", "th_1", "M_th", device)
            #plotters_simple.plot_3d_double(theta, M_th[:, 1, 0], M_th[:, 1, 1], "M_th vs. th", "10", "11", "th_0", "th_1", "M_th", device)
            
            plotters_simple.plot_3d_double(q, M_th_ana[:, 0, 0], M_th_ana[:, 0, 1], "M_th_ana", "00", "01", "q_0", "q_1", "M_th", device)
            plotters_simple.plot_3d_double(q, M_th_ana[:, 1, 0], M_th_ana[:, 1, 1], "M_th_ana", "10", "11", "q_0", "q_1", "M_th", device)

            plotters_simple.plot_3d_double(theta, M_th_ana[:, 0, 0], M_th_ana[:, 0, 1], "M_th_ana vs. th", "00", "01", "th_0", "th_1", "M_th", device)
            plotters_simple.plot_3d_double(theta, M_th_ana[:, 1, 0], M_th_ana[:, 1, 1], "M_th_ana vs. th", "10", "11", "th_0", "th_1", "M_th", device)

            #plotters_simple.plot_3d_double(q, J_h_ana[:, 0, 0], J_h[:, 0, 1], "J_h", "00", "01", "q_0", "q_1", "J_h", device)
            #plotters_simple.plot_3d_double(q, J_h_ana[:, 1, 0], J_h[:, 1, 1], "J_h", "10", "11", "q_0", "q_1", "J_h", device)
        

In [None]:
def theta_1_single(model, q):
    theta = model.encoder_nn(q)[0]
    return theta[:, 0].detach()
    
def theta_2_single(model, q):
    theta = model.encoder_nn(q)[0]
    return theta[:, 1].detach()
    
def q_hat_1_single(model, theta):
    q_hat = model.decoder_nn(theta)[0]
    return q_hat[:, 0].detach()
    
def q_hat_2_single(model, theta):
    q_hat = model.decoder_nn(theta)[0]
    return q_hat[:, 1].detach()

In [None]:
q1_split = (-torch.pi, torch.pi)
batch_size = 512
train_part = 0.7

rp = robot_parameters.LUMPED_PARAMETERS
num_epochs = 1001
lr = 1e-3


plt.ion()

current_time = datetime.now().strftime("%Y%m%d%H%M")

shifted_points = mask_points(q1_split, clockwise=train_clockwise)
points_plotter(shifted_points)
dataset = make_dataset(shifted_points)
(train_dataloader, val_dataloader) = make_dataloaders(dataset=dataset, batch_size=batch_size, train_part=train_part)

outputs = []
model, train_losses, val_losses, file_path = train_AE_model(rp, device, lr, num_epochs, q1_split, train_dataloader, 
                                                        val_dataloader, current_time)

plot_loss(train_losses, val_losses, log = True)
mapping_functions = (partial(theta_1_single,model), 
                        partial(theta_2_single,model), 
                        partial(q_hat_1_single,model), 
                        partial(q_hat_2_single,model))

save_model(model, file_path)


In [None]:
th_plotter = theta_visualizer.theta_plotter(rp=rp, n_lines=50, device=device, 
                                            mapping_functions=mapping_functions, mask_split=q1_split)
#th_plotter.make_figure("theta_learned_full.png")
#th_plotter.make_animation("theta_learned_full.mp4", duration = 4, fps = 20)

In [None]:
%matplotlib widget


q1_split = (-torch.pi, torch.pi)
plt.ion()
model = autoencoders.Autoencoder_double(rp).to(device)
model_location = 'Models/Split_AEs/Lumped_Mass_202503051257.pth'
model.load_state_dict(torch.load(model_location, weights_only=True))
plot_points = mask_points(q1_split, clockwise = train_clockwise)
points_plotter(plot_points[::5])
plot_dataset = make_dataset(plot_points)
plot_dataloader = make_plot_dataloader(plot_dataset)
plot_model_performance(model, plot_dataloader)

In [None]:
import warnings
warnings.filterwarnings("ignore")


model_ana = autoencoders.Analytic_transformer(rp)


models = [model_ana, model]
model_names = ["Analytic", "Neural Network"]

def check_clockwise_vectorized(q):
    """
    Expects q to be a tensor of shape (N,2) where each row is [q1, q2].
    Returns two boolean masks: (cw_mask, ccw_mask), where:
      - cw_mask[i] is True if the i-th configuration is elbow clockwise.
      - ccw_mask[i] is True if the i-th configuration is elbow counterclockwise.
    
    The logic is as follows (from your original function):
      If q2 lies between q1 and q1+π, or between q1-2π and q1-π, then the configuration
      is considered counterclockwise. Otherwise it is clockwise.
    """
    q1 = q[:, 0]
    q2 = q[:, 1]
    cond_ccw = ((q2 >= q1) & (q2 <= q1 + torch.pi))
    cw_mask = ~cond_ccw
    ccw_mask = cond_ccw
    return cw_mask, ccw_mask


# Define the number of grid points along each dimension.
n_points = 200

# Create 1D tensors for q1 and q2 in the range [-pi, 0]
q1_vals = torch.linspace(-np.pi, np.pi, n_points)
q2_vals = torch.linspace(-np.pi, 2*np.pi, n_points)

# Create a 2D grid (meshgrid) of q values.
# (Note: using indexing='ij' so that the first axis corresponds to q1 and the second to q2)
q1_grid, q2_grid = torch.meshgrid(q1_vals, q2_vals, indexing='ij')

# Stack the grid to get a tensor of shape (n_points*n_points, 2)
q_grid = torch.stack([q1_grid.flatten(), q2_grid.flatten()], dim=1).to(device)

# === Compute theta1 and theta2 using the analytic encoder functions ===
# We use torch.vmap to evaluate the functions over the batch of q values.
# Note: encoder_theta_1_ana and encoder_theta_2_ana each return a tuple (theta, theta).

print(q_grid.size())
print(q_grid)


for model, model_name in zip(models, model_names):
  print(model_name)
  theta_out = model.encoder(q_grid)
  #theta_out = torch.vmap(model.encoder)(q_grid)

  theta1 = theta_out[:, 0]
  theta2 = theta_out[:, 1]

  # Since q1_grid and q2_grid are already on a mesh, we can compute x_end and y_end elementwise.
  x_end = rp["l1"] * torch.cos(q_grid[:, 0]) + rp["l2"] * torch.cos(q_grid[:, 1])
  y_end = rp["l1"] * torch.sin(q_grid[:, 0]) + rp["l2"] * torch.sin(q_grid[:, 1])

  # --- Determine configuration (clockwise vs. counterclockwise) for each q ---
  cw_mask, ccw_mask = check_clockwise_vectorized(q_grid)

  # Counterclockwise points
  x_end_ccw   = x_end[ccw_mask].detach().cpu().numpy()
  y_end_ccw   = y_end[ccw_mask].detach().cpu().numpy()
  theta1_ccw  = theta1[ccw_mask].detach().cpu().numpy()
  theta2_ccw  = theta2[ccw_mask].detach().cpu().numpy()

  fig, axes = plt.subplots(1, 2, figsize=(12, 5))

  # --- Bottom row: Counterclockwise ---
  sc3 = axes[0].scatter(x_end_ccw, y_end_ccw, c=theta1_ccw, cmap='viridis', s=5)
  axes[0].set_title("Theta0 - Counterclockwise")
  axes[0].set_xlabel("x")
  axes[0].set_ylabel("y")
  plt.colorbar(sc3, ax=axes[0])

  sc4 = axes[1].scatter(x_end_ccw, y_end_ccw, c=theta2_ccw, cmap='viridis', s=5)
  axes[1].set_title("Theta1 - Counterclockwise")
  axes[1].set_xlabel("x")
  axes[1].set_ylabel("y")
  plt.colorbar(sc4, ax=axes[1])

  plt.tight_layout()
  plt.show()
