In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, Dataset, DataLoader
import functorch
import matplotlib.pyplot as plt

import os
from datetime import datetime
import time

import Double_Pendulum.Lumped_Mass.robot_parameters as robot_parameters
import Double_Pendulum.Lumped_Mass.transforms as transforms
import Double_Pendulum.Lumped_Mass.dynamics as dynamics
import Learning.loss_terms as loss_terms
#import Plotting.plotters_h1h2 as plotters_h1h2
import Learning.training_data as training_data
import Plotting.theta_visualiser as theta_visualizer

import Models.autoencoders as autoencoders

from functools import partial

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
rp = robot_parameters.LUMPED_PARAMETERS
print(rp)

mapping_functions = (transforms.analytic_theta_1, transforms.analytic_theta_2)
th_plotter = theta_visualizer.theta_plotter(rp=rp, n_lines=50, mapping_functions=mapping_functions)
#th_plotter.make_figure("theta_subset_5_2_image.png")
#th_plotter.make_animation("theta_subset_5_2.mp4", duration = 4, fps = 20, stride = 1)

In [None]:
def mask_points(q1_split, clockwise = False):

    """
    Returns a set of [q1, q2] points based on "q1_split" limits on q1 and q2.
    The limits on q2 depend on whether a clockwise or counterclockwise dataset is selected.
    """   

    # Retrieve training points
    points = training_data.points.to(device)
    
    # Mask to retrieve only the counterclockwise points
    width_mask = (points[:,0] >= q1_split[0]) & (points[:,0] <= q1_split[1])
    ccw_mask = ((points[:,1] >= points[:,0]) & 
                  (points[:,1] <= points[:,0] + torch.pi)) #| ((points[:,1] >= points[:,0] - 2*torch.pi) &
                  #(points[:,1] <= points[:,0] - torch.pi))
    
    # Mask to retrieve only the clockwise points
    cw_mask = ((points[:,1] >= points[:,0] - torch.pi) & (points[:,1] <= points[:,0]))

    if clockwise:
        final_mask = width_mask & cw_mask
    else:
        final_mask = width_mask & ccw_mask
    
    points = points[final_mask]
    points = points[0:3000]

    if points.size(0) < 3000:
        print("Warning: Only", points.size(0), "points in dataset.")

    return(points)

In [None]:
def points_plotter(points, extend = None):

   """ 
   Simple plotter function which visualizes the points used for training of the Autoencoder. 
   """
    
   plt.figure(figsize=(4, 4))
   plt.scatter(points[:, 0].cpu().numpy(), points[:, 1].cpu().numpy(), alpha=0.6, edgecolors='k', s=20)
   plt.title('Scatter Plot of q1 vs q2')
   plt.xlabel('q1')
   plt.ylabel('q2')
   plt.xlim(-2*torch.pi, 2*torch.pi)
   plt.ylim(-2*torch.pi, 2*torch.pi)
   plt.grid(True)
   plt.show()

In [None]:
masked_points = mask_points((-torch.pi, torch.pi), clockwise = False)
deshifted_points = transforms.wrap_to_pi(masked_points.clone())
points_plotter(masked_points, extend="ccw")
points_plotter(deshifted_points, extend="ccw")

In [None]:


def make_dataset(points):

    """
    Compute mass- and input matrix of all training points to reduce load in training.
    Returns TensorDataset of (q, M_q, A_q). 
    """

    data_pairs = []
    for point in points:
        Mq_point, _, _ = dynamics.dynamical_matrices(rp, point, point)
        Aq_point = dynamics.input_matrix(rp, point)
        data_pairs.append((point, Mq_point, Aq_point))

    points_tensor = torch.stack([pair[0] for pair in data_pairs])           # Tensor of all points
    mass_matrices_tensor = torch.stack([pair[1] for pair in data_pairs])   # Tensor of all mass matrices
    input_matrices_tensor = torch.stack([pair[2] for pair in data_pairs])  # Tensor of all input matrices

    # Create TensorDataset
    dataset = TensorDataset(points_tensor, mass_matrices_tensor, input_matrices_tensor)
    return(dataset)


In [None]:
def make_dataloaders(dataset, batch_size = 512, train_part = 0.7):

    """
    Creates a training and validation dataloader from an input dataset, based on 
    batch size and the ratio train_part. 
    """

    train_size = int(train_part * len(dataset))
    val_size = len(dataset) - train_size

    # Create TensorDataset for both training and testing sets
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Create the DataLoader for both training and testing sets
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    return(train_dataloader, val_dataloader)

In [None]:
import torch.nn.functional as F

def loss_fun(q, theta, q_hat, M_q, A_q, J_h_enc, J_h_dec):

    """
    Loss function for training the Autoencoder. Loss terms are the following:
    l_recon:    Loss between input- and reconstructed variable. (MSE)
    l_off_dia:  Loss of off-diagonal terms of mass matrix in theta-space. (MSE)
    l_dia:      Loss on diagonal terms of mass matrix in theta-space. 
                (mean of normalized negative log-loss)
    l_input:    Loss to drive input matrix in theta-space to [1, 0]^T (MSE)

    """

    l_recon = F.mse_loss(q, q_hat, reduction="mean")

    # Calculate forward and inverse Jacobians
    J_h = J_h_enc
    J_h_trans = torch.transpose(J_h, 1, 2)
    J_h_inv = J_h_dec
    J_h_inv_trans = torch.transpose(J_h_inv, 1, 2)

    M_th = J_h_inv_trans @ M_q @ J_h_inv
    A_th = J_h_inv_trans @ A_q

    # Loss inspired by Pietro Pustina's paper on input decoupling:
    # https://arxiv.org/pdf/2306.07258
    l_input_jac = F.mse_loss(J_h[:, 0, :], A_q[:, :, 0], reduction="mean")



    # Loss on the first coordinate theta, again from Pietro Pustina's analytic formulation
    theta_ana = torch.vmap(transforms.analytic_theta, in_dims = (None, 0))(rp, q)
    l_theta = F.mse_loss(theta[:, 0], theta_ana[:, 0], reduction="mean")

    # Enforce inertial decoupling
    l_off_diag = torch.mean((M_th[:, 0, 1])**2)
    diag_values = torch.diagonal(M_th, dim1=1, dim2=2)
    l_diag = torch.mean((-1 + torch.exp(-(diag_values - 1))) * (diag_values < 1).float())  # Shape: (batch_size, 2)

    ## input decoupling loss
    #l_input = torch.mean((A_th[:, 1]**2)) + torch.sum(((A_th[:, 0]-1)**2))
    l_input = torch.mean((A_th[:, 1]**2))

    f_recon        = 10.
    f_diag         = 1.
    f_off_diag     = 1.
    f_input        = 1.
    f_input_jac    = 1.
    f_theta        = 1.

    
    loss_terms = torch.tensor([l_recon * f_recon, l_diag * f_diag, l_off_diag * f_off_diag, 
                               l_input * f_input, l_input_jac * f_input_jac, l_theta * f_theta])
    loss_sum = l_recon * f_recon + l_diag * f_diag + l_off_diag * f_off_diag + l_input * f_input + l_input_jac * f_input_jac + l_theta * f_theta

    lamda = 200

    return lamda * loss_sum, lamda * loss_terms

In [None]:

#model.load_state_dict(torch.load(load_path, weights_only=True))


def train_AE_model(rp, device, lr, num_epochs, q1_split, train_dataloader, val_dataloader, current_time):

    """
    Executes training loop for Autoencoder
    """
   
    model = autoencoders.Autoencoder_double(rp).to(device)  # Move model to GPU
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)#,  weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5 ** (1 / num_epochs))


    train_losses = []
    val_losses = []
    start_time = time.time()
    save_directory = os.path.join(os.getcwd(), "Models/Split_AEs")
    os.makedirs(save_directory, exist_ok=True)
    file_name = f"Lumped_Mass_{current_time}.pth"
    file_path = os.path.join(save_directory, file_name)

    JSON = {"q1_low" : q1_split[0],
            "q1_high" : q1_split[1],
            "lr" : lr,
            "epochs" : num_epochs,
            "file_name" : file_name}

    for epoch in range(num_epochs):

        # Training phase
        model.train()
        train_loss = 0
        train_loss_terms = torch.zeros(6)
        for index, (q, M_q, A_q) in enumerate(train_dataloader):
            q = q.to(device)

            M_q = M_q.to(device)
            A_q = A_q.to(device)
            
            theta, J_h, q_hat, J_h_dec, J_h_ana = model(q)  
            theta_ana = model.theta_ana(q)
                    
            loss, loss_terms = loss_fun(q, theta, q_hat, M_q, A_q, J_h, J_h_dec)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_loss_terms += loss_terms
        train_loss /= len(train_dataloader.dataset)
        train_loss_terms /= len(train_dataloader.dataset)
        train_losses.append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0
        val_loss_terms = torch.zeros(6)
        with torch.no_grad():
            for index, (q, M_q, A_q) in enumerate(val_dataloader):
                q = q.to(device)
                M_q = M_q.to(device)
                A_q = A_q.to(device)

                theta, J_h, q_hat, J_h_dec, J_h_ana = model(q)
                theta_ana = model.theta_ana(q)

                loss, loss_terms = loss_fun(q, theta, q_hat, M_q, A_q, J_h, J_h_dec)

                J_h_inv = torch.linalg.inv(J_h)
                J_h_inv_trans = torch.transpose(J_h_inv, 1, 2)
                M_th = J_h_inv_trans @ M_q @ J_h_inv 

                val_loss += loss.item()
                val_loss_terms += loss_terms
        val_loss /= len(val_dataloader.dataset)
        val_loss_terms /= len(val_dataloader.dataset)
        val_losses.append(val_loss)
        epoch_duration = time.time() - start_time
        scheduler.step()
        tlt = train_loss_terms
        print(
            f'Epoch [{epoch + 1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Duration: {epoch_duration:.2f} seconds')
        print(
            f'l_recon: {tlt[0]:.4f}, l_diag: {tlt[1]:.4f}, l_off_diag: {tlt[2]:.4f}, l_input: {tlt[3]:.4f}, l_input_jac: {tlt[4]:.4f}, l_theta: {tlt[5]:.4f}'
        )
            
    return(model, train_losses, val_losses, file_path)


In [None]:
def save_model(model, file_path):
    torch.save(model.state_dict(), file_path)
    print(f"Model parameters saved to {file_path}")


In [None]:
q_test = torch.tensor([[2, -0.]]).to(device)
theta, J_h, q_hat, J_h_ana = model(q_test)

J_h_inv = torch.linalg.pinv(J_h)
J_h_inv_trans = J_h_inv.transpose(1, 2)




print("J_h:\n", J_h.detach().cpu().numpy()[0])
print("J_h_inv:\n", J_h_inv.detach().cpu().numpy()[0])

M_q, C_q, G_q = dynamics.dynamical_matrices(rp, q_test[0], q_test[0])
M_q = M_q.unsqueeze(0)
M_th, C_th, G_th = transforms.transform_dynamical_matrices(M_q, C_q, G_q, J_h, device)
print("M_q:", M_q)
print("M_th:\n", M_th)


off_dia = M_th[:, 1,0]
diag_elements = M_th[:, [0, 1], [0, 1]]
diag_product = torch.sqrt(diag_elements[:, 0] * diag_elements[:, 1] + 1e-6)
M_th_ratio = off_dia/diag_product
print("M_th_ratio:", M_th_ratio)

print("M_q:\n", M_q)

print("J_h_ana:\n", J_h_ana.detach().cpu().numpy()[0], "\n")
J_h_ana_inv = torch.linalg.pinv(J_h_ana)
J_h_ana_inv_trans = J_h_ana_inv.transpose(1,2)

print("J_h_ana_inv:\n", J_h_ana_inv.detach().cpu().numpy()[0], "\n")


M_th_ana, C_th_ana, G_th_ana = dynamics.dynamical_matrices_th(rp, q_test[0], q_test[0]) 
print("M_th_ana:", M_th_ana.detach().cpu().numpy())

In [None]:
# Extract epochs and losses
epochs = [entry[0] for entry in outputs]
losses = [entry[1].item() for entry in outputs]

# Plot loss as a function of epoch
plt.figure(figsize=(8, 6))
plt.plot(epochs, losses, marker='', label='Training Loss')
plt.title('Training Loss vs. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
import random
import string

max_neurons = 4
blank_layer = [None for _ in range(max_neurons)]

table_layers = []
for idx, param in enumerate(model.parameters()):
    layer = param.data
    num_parallel = layer.shape[0]
    side_padding = int((max_neurons - num_parallel)/2)
    
    if idx % 2 == 0:
        
        table_layer = blank_layer.copy()
        table_layer[0] = "weights" + str(idx//2+1)
        table_layers.append(table_layer)
        for i in range(layer.shape[1]):
            table_layer = blank_layer.copy()
            for j in range(num_parallel):
                table_layer[j+side_padding] = '{:.2e}'.format(layer[j][i].item())
            table_layers.append(table_layer)
        table_layers.append(blank_layer)
            
    else:  
        
        table_layer = blank_layer.copy()
        table_layer[0] = "bias" + str(idx//2+1)
        table_layers.append(table_layer)
        table_layer = blank_layer.copy()
        for j in range(num_parallel):
            table_layer[j+side_padding] = '{:.2e}'.format(layer[j].item())
        table_layers.append(table_layer)
        table_layers.append(blank_layer)

numeric_values = np.zeros((len(table_layers), max_neurons))
for i, row in enumerate(table_layers):
    for j, item in enumerate(row):
        if item not in (None, "weights1", "weights2", "bias1", "bias2"):  # Replace with relevant layer names
            try:
                numeric_values[i, j] = (float(item))
            except ValueError:
                pass
        
min_val, max_val = numeric_values.min(), numeric_values.max()


# Step 2: Apply a logarithmic transformation, setting a small threshold to avoid log(0)
threshold = 1e-5
log_values = np.log10(np.clip(np.abs(numeric_values), threshold, None))

# Normalize the log-scaled values to range between 0 and 1
normalized_values = (log_values - log_values.min()) / (log_values.max() - log_values.min())
colors = plt.cm.Blues(normalized_values)

        
# Plot the table
fig, ax = plt.subplots(figsize=(10, 6))
ax.axis('tight')
ax.axis('off')

# Create table
table = plt.table(cellText=table_layers, cellColours=colors, loc='center', cellLoc='center')

plt.show()

In [None]:
class Autoencoder2(nn.Module):
    def __init__(self, rp):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(2, 3),
            nn.Sigmoid(),
            nn.Linear(3, 4),
            #nn.Sigmoid(),
            #nn.Linear(4, 4)
        )
        
        
        self.decoder = nn.Sequential(
            #nn.Linear(4, 4),
            #nn.Sigmoid(),
            nn.Linear(4, 3),
            nn.Sigmoid(),
            nn.Linear(3, 2)
        )
        
        self.rp = rp

    def encoder_nn(self, q):
        latent = self.encoder(q)
        J_h_inv_1 = latent[:,0:2]
        J_h_inv_2 = latent[:,2:4]
        J_h_inv = torch.stack((J_h_inv_1, J_h_inv_2), dim=1)
        return J_h_inv, latent
    
    def forward(self, q):
        J_h_inv, latent = self.encoder_nn(q)

        q_hat = self.decoder(latent)
        return J_h_inv, q_hat

In [None]:
%%time
%matplotlib widget

current_time = datetime.now().strftime("%Y%m%d%H%M")
save_directory = os.path.join(os.getcwd(), "Models")
os.makedirs(save_directory, exist_ok=True)
file_name = f"Lumped_Mass_{current_time}.pth"
file_path = os.path.join(save_directory, file_name)

load_path = os.path.normpath("/home/kian/Documents/Thesis/ICS_fork/ics-pa-sv/Kian_code/Models/Lumped_Mass_202411271337.pth")


rp = robot_parameters.LUMPED_PARAMETERS
num_epochs = 2401

print(file_path)
model = Autoencoder2(rp).to(device)  # Move model to GPU
#model.load_state_dict(torch.load(load_path, weights_only=True))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)#, weight_decay=1e-6)
#optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-2, alpha=0.99, eps=1e-08)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1 ** (1 / num_epochs))



l_weights = [1,
             1,
             1e-1,
             1e-2]
outputs = []



plt.ion()

for epoch in range(num_epochs):
    for index, batch in enumerate(train_dataloader):
        batch_size = batch[0].shape[0]
        q = batch[0][:, 0:2].to(device)
        q.requires_grad = True
        q_d = batch[0][:, 2:4].to(device)
        
        J_h_inv, q_hat = model(q)  
        J_h_inv_trans = J_h_inv.transpose(1,2)
                
        matrices_vmap = torch.vmap(dynamics.dynamical_matrices, 
                                   in_dims=(None, 0, 0))

        M_q, C_q, G_q = matrices_vmap(rp, q, q_d)
        
        M_th, C_th, G_th = transforms.transform_dynamical_from_inverse(M_q, C_q, G_q, J_h_inv, J_h_inv_trans)      
        
        #loss_reconstruction = loss_terms.loss_reconstruction(q, q_hat)
        loss_diagonality_geo_mean = loss_terms.loss_diagonality_geo_mean(M_th, batch_size, device)
        loss_diagonality_trace = loss_terms.loss_diagonality_trace(M_th, batch_size, device)
        loss_diagonality_smallest = loss_terms.loss_diagonality_smallest(M_th, batch_size, device)
        ### Use J@J^T = eye to avoid needing to calculate the Jacobian inverse for efficiency. 
        #loss_J_h_unitary = loss_terms.loss_J_h_unitary(J_h, batch_size, device)
        #loss_J_h_cheat = loss_terms.loss_J_h_cheat(J_h, J_h_ana)
        #loss_M_th_cheat = loss_terms.loss_M_th_cheat(M_th, rp, q, q_d, batch_size)
        #l1_norm = loss_terms.loss_l1(model)
        
        #loss_diagonality = 10 * loss_diagonality_geo_mean + loss_diagonality_smallest + 100 * loss_diagonality_trace
        loss_diagonality = loss_diagonality_geo_mean


        loss = loss_diagonality #+ 0.2 * loss_J_h_unitary

        
        #loss = loss_J_h_cheat


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

 

    if epoch % 400 == 0:
        print(f'Epoch:{epoch+1}, Loss:{loss.item():.9f}')#, LR:{scheduler.get_last_lr():.7f}')
        print("Weighted loss_diagonality_geo_mean:", loss_diagonality_geo_mean)
        #print("Weighted loss_diagonality_smallest:", loss_diagonality_smallest)
        #print("Weighted loss_diagonality_trace:", 100 * loss_diagonality_trace)
        #print("Weighted loss Jh unitary:", 0.2 * loss_J_h_unitary)
    if epoch % 1200 == 0 and epoch > 0:
        #plotters.plot_h2(model, device, rp, epoch)
        #plotters.plot_J_h(model, device, rp, epoch, plot_index = 0)
        #plotters.plot_J_h(model, device, rp, epoch, plot_index = 1)
        plotters.plot_decoupling_inv(model, device, rp, epoch)
    scheduler.step()

    outputs.append((epoch, loss, q, q_hat, J_h_inv, M_th))


torch.save(model.state_dict(), file_path)
print(f"Model parameters saved to {file_path}")