In [None]:
# This block of code was taken from:
# Farhad Nawaz, Tianyu Li, Nikolai Matni, Nadia Figueroa,
# "Learning Complex Motion Plans using Neural ODEs

!pip install gmr
import gmr
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy import interpolate

import jax
import jax.nn as jnn
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler

# Reading data from text file

flag = None
traj_all = []
traj_c = 0

file_name = '/content/trajs_loop_4.txt'

f = open(file_name, 'r')
data_text = f.readlines()
for i in data_text:
    if i == 'New trajectory\n':
        flag = 'new'
        if traj_c > 0:
            traj_all.append(traj)
        traj_c += 1
        continue
    if flag == 'new':
        traj = np.fromstring(i, dtype=float, sep=' ').reshape((1, 2))
        flag = 'old'
        continue
    if flag == 'old':
        traj = np.concatenate((traj, np.fromstring(i, dtype=float, sep=' ').reshape((1, 2))))
traj_all.append(traj)

traj_all_norm = []
ts_norm = []
scaler = MinMaxScaler(feature_range=(-0.5,0.5))
scaler_all = []
for i in range(len(traj_all)):
    scaler.fit(traj_all[i])
    scaler_all.append(scaler)
    ysti = scaler.transform(traj_all[i])
    traj_all_norm.append(ysti)
    ts_norm.append(jnp.linspace(0,1,num=ysti.shape[0]))

dim = traj.shape[1]

nsamples = 300

data_aug = 0

traj_all_process = jnp.zeros((traj_c + data_aug, nsamples, dim))

seed = 1385

key = jax.random.PRNGKey(seed)

key_trajs = jax.random.split(key, num=traj_c + data_aug)

for i in range(traj_c + data_aug):
  key_dim = jax.random.split(key_trajs[i], num=dim)
  for j in range(dim):
    f = interpolate.interp1d(ts_norm[i], traj_all_norm[i][:, j])
    ts_new = np.linspace(0, 1, nsamples)
    range_traj = max(traj_all_norm[i][:, j]) - min(traj_all_norm[i][:, j])
    scale = 0
    traj_new = f(ts_new) + jax.random.uniform(key_dim[j], shape=ts_new.shape, minval=-scale*range_traj, maxval=scale*range_traj)
    traj_all_process = traj_all_process.at[i, :, j].set(traj_new)

traj_d = jnp.diff(traj_all_process, axis=1)
traj_d = jnp.concatenate((traj_d, jnp.zeros((traj_c + data_aug, 1, dim))), axis=1)
traj_d_all = jnp.concatenate((traj_all_process, traj_d), axis=2)

for i in range(traj_c + data_aug):
    plt.plot(traj_all_process[i][0, 0], traj_all_process[i][0, 1], 'ro')
    plt.plot(traj_all_process[i][:, 0], traj_all_process[i][:, 1])
    plt.plot(traj_all_process[i][-1, 0], traj_all_process[i][-1, 1], 'go')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# generate parameter
t = np.linspace(0, 2*np.pi, 300)

# lemniscate of Gerono (figure-eight)
x = np.sin(t)
y = np.sin(t) * np.cos(t)

# rescale to [-0.5, 0.5]
x = 0.5 * x
y = 0.5 * y

# plot
plt.figure(figsize=(5,5))
plt.plot(x, y, label="Figure-8 Trajectory")
plt.plot(x[0], y[0], "ro", label="Start")
plt.plot(x[-1], y[-1], "go", label="End")
plt.xlim([-0.5, 0.5])
plt.ylim([-0.5, 0.5])
plt.gca().set_aspect('equal', adjustable='box')
plt.legend()
plt.title("Horizontal 8-shaped trajectory")
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
n_points = 300
x_max = 1.0
y_max = 0.25  # reduce vertical amplitude to decrease area
theta = np.pi / 4  # 45 degrees rotation

# Forward path (pick) - unsymmetric
t_up = np.linspace(0, 1, n_points//2)
x_up = x_max * t_up
y_up = y_max * np.sin(np.pi * t_up) * (1 + 0.2 * t_up)  # smaller deviation

# Backward path (place) - unsymmetric
t_down = np.linspace(0, 1, n_points//2)
x_down = x_max * (1 - t_down)
y_down = -y_max * np.sin(np.pi * t_down) * (1 + 0.1 * t_down**2)  # smaller deviation

# Concatenate
x = np.concatenate([x_up, x_down])
y = np.concatenate([y_up, y_down])

# Rotation matrix for 45 degrees
R = np.array([[np.cos(theta), -np.sin(theta)],
              [np.sin(theta),  np.cos(theta)]])

coords = np.stack([x, y], axis=0)
rotated_coords = R @ coords
x_rot, y_rot = rotated_coords[0, :], rotated_coords[1, :]

# Stack into (n_points, 2)
bh_like_traj_rot = np.column_stack((x_rot, y_rot))

# Plot
plt.figure(figsize=(6,6))
plt.plot(x_rot, y_rot, label="Rotated Unsymmetric Pick-and-Place Loop")
plt.plot(x_rot[0], y_rot[0], "ro", label="Start")
plt.plot(x_rot[-1], y_rot[-1], "go", label="End")
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Closed Rotated Unsymmetric Pick-and-Place Trajectory (Reduced Area)")
plt.gca().set_aspect('equal', adjustable='box')
plt.grid(True)
plt.legend()
plt.show()


In [None]:
traj = np.column_stack((x, y))
traj.shape

In [None]:
bh_like_traj_rot.shape
traj = bh_like_traj_rot

In [None]:
from scipy.interpolate import interp1d
#drawing = np.array(traj_all_process[0, :, :])
drawing = traj

def normalize_drawing(d):
    d = d - np.mean(d, axis=0)
    max_range = np.max(np.abs(d))
    return d / max_range

def rotate_and_scale_drawing(d, angle_degrees=0.0, scale_x=1.0, scale_y=1.0):
    theta = np.radians(angle_degrees)
    rotation_matrix = np.array([
        [np.cos(theta), -np.sin(theta)],
        [np.sin(theta),  np.cos(theta)]
    ])
    scaling_matrix = np.diag([scale_x, scale_y])
    transform = rotation_matrix @ scaling_matrix
    return d @ transform.T

def add_noise(drawing, noise_scale=0.13):
    return drawing + np.random.normal(scale=noise_scale, size=drawing.shape)

def simplify_and_interpolate(drawing, keep_ratio=0.05, noise_scale=0.2):
    num_points = int(len(drawing) * keep_ratio)
    idx = np.sort(np.random.choice(len(drawing), num_points, replace=False))
    simplified = drawing[idx]
    simplified += np.random.normal(scale=noise_scale, size=simplified.shape)
    f_interp = interp1d(np.linspace(0, 1, num_points), simplified, axis=0, kind='cubic')
    return f_interp(np.linspace(0, 1, len(drawing)))

def distort_shape(drawing):
    scales = np.random.uniform(0.9, 1, size=(len(drawing), 2))
    return drawing * scales

def exaggerate_parts(drawing, start=100, end=250, factor=1.3):
    modified = drawing.copy()
    modified[start:end] *= factor
    return modified

def create_childlike_versions(drawing):
    drawing = normalize_drawing(drawing)

    version1 = add_noise(drawing)
    version2 = simplify_and_interpolate(drawing)
    version3 = distort_shape(drawing)
    version4 = exaggerate_parts(drawing)

    version1 = (rotate_and_scale_drawing(version1,5,1,1))
    version2 = (rotate_and_scale_drawing(version2,-5,1.1,1.1))
    version3 = (rotate_and_scale_drawing(version3,10,1.2,1.2))
    version4 = (rotate_and_scale_drawing(version4,-10,1.3,1.3))

    return version1, version2, version3, version4


drawing = normalize_drawing(drawing)
child1, child2, child3, child4 = create_childlike_versions(drawing)


fig, axs = plt.subplots(1, 5, figsize=(22, 4))
titles = ['Expert (Original)', 'Child 1: Noisy', 'Child 2: Simplified',
          'Child 3: Distorted', 'Child 4: Exaggerated']

for ax, data, title in zip(axs, [drawing, child1, child2, child3, child4], titles):
    ax.plot(data[:, 0], data[:, 1], lw=2)
    ax.set_title(title)
    ax.axis('equal')
    ax.axis('off')

plt.tight_layout()
plt.show()


In [None]:
noisy_d = np.array([child1, child2, child3, child4])

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for i in range(4):
    x = noisy_d[i, :, 0]
    y = noisy_d[i, :, 1]

    ax = axes[i]
    ax.plot(x, y, lw=2)
    ax.set_title(f"Trajectory {i+1}")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_aspect('equal')
    ax.grid(True)

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.special import digamma

def compute_mi_knn(x, y, k=5):
    """
    Compute mutual information between x and y using KSG estimator (K nearest neighbors).
    x, y: arrays of shape (N, d_x) and (N, d_y)
    k: number of neighbors
    """
    assert x.shape[0] == y.shape[0]
    n = x.shape[0]

    data = np.hstack((x, y))
    tree = NearestNeighbors(metric='chebyshev')
    tree.fit(data)
    dist, _ = tree.kneighbors(data, n_neighbors=k+1)  # include itself
    eps = dist[:, k]  # distance to k-th neighbor

    # Count neighbors in marginal spaces
    tree_x = NearestNeighbors(metric='chebyshev')
    tree_x.fit(x)
    nx = tree_x.radius_neighbors(x, radius=eps - 1e-15, return_distance=False)
    nx = np.array([len(neigh) - 1 for neigh in nx])  # exclude itself

    tree_y = NearestNeighbors(metric='chebyshev')
    tree_y.fit(y)
    ny = tree_y.radius_neighbors(y, radius=eps - 1e-15, return_distance=False)
    ny = np.array([len(neigh) - 1 for neigh in ny])

    mi = digamma(k) + digamma(n) - np.mean(digamma(nx + 1) + digamma(ny + 1))
    return mi

# Example data: 4 trajectories, 300 steps, 2D states
np.random.seed(0)
states = noisy_d

n_traj, T, state_dim = states.shape

def compute_traj_mi_score(traj, k=5):
    """
    For one trajectory, compute total mutual information score between
    states and actions (actions = diff of states).
    """
    actions = np.diff(traj, axis=0)  # shape (T-1, state_dim)
    states_trim = traj[:-1]          # align states (T-1, state_dim)
    # Compute MI between states_trim and actions (both shape (T-1, state_dim))
    mi_score = compute_mi_knn(states_trim, actions, k)
    return mi_score * (T - 1)  # sum MI over all pairs (approximate total)

# Compute MI scores for all trajectories
scores = []
for i in range(n_traj):
    score = compute_traj_mi_score(states[i])
    print(f"Trajectory {i} MI score: {score:.4f}")
    scores.append(score)


In [None]:
import torch
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.special import digamma

# ---- Your existing loss functions ----

def smoothness_losst(pred):
    dx = torch.diff(pred[:, 0], dim=0)
    dy = torch.diff(pred[:, 1], dim=0)
    return torch.sum(dx**2 + dy**2).item()

def closed_shape_loss(pred):
    return torch.norm(pred[0] - pred[-1], dim=-1).item()

def symmetry_loss(pred):
    centroid = torch.mean(pred, dim=0)
    reflected_x = 2 * centroid[0] - pred[:, 0]
    reflected_points = torch.stack((reflected_x, pred[:, 1]), dim=1)
    diff = reflected_points.unsqueeze(1) - pred.unsqueeze(0)
    dists = torch.norm(diff, dim=2)
    min_dists, _ = torch.min(dists, dim=1)
    return torch.mean(min_dists).item()

def compute_scores(pred):
    return {
        'Smoothness': smoothness_losst(pred),
        'Closed Shape': closed_shape_loss(pred),
        'Symmetry': symmetry_loss(pred)
    }

# ---- Mutual Information estimator using sklearn KNN ----

def compute_mi_knn(x, y, k=5):
    """
    Compute mutual information between x and y using KSG estimator (K nearest neighbors).
    x, y: arrays of shape (N, d_x) and (N, d_y)
    """
    n = x.shape[0]
    data = np.hstack((x, y))
    tree = NearestNeighbors(metric='chebyshev')
    tree.fit(data)
    dist, _ = tree.kneighbors(data, n_neighbors=k+1)
    eps = dist[:, k] - 1e-15  # distance to kth neighbor, small epsilon

    tree_x = NearestNeighbors(metric='chebyshev').fit(x)
    nx = np.array([len(tree_x.radius_neighbors([point], radius=eps[i], return_distance=False)[0]) - 1 for i, point in enumerate(x)])

    tree_y = NearestNeighbors(metric='chebyshev').fit(y)
    ny = np.array([len(tree_y.radius_neighbors([point], radius=eps[i], return_distance=False)[0]) - 1 for i, point in enumerate(y)])

    from scipy.special import digamma
    mi = digamma(k) + digamma(n) - np.mean(digamma(nx + 1) + digamma(ny + 1))
    return mi

def compute_mi_score_torch(pred, k=5):
    """
    pred: torch tensor (T, 2) states
    Compute MI between states and actions (diff of states)
    """
    states = pred[:-1].cpu().numpy()      # (T-1, 2)
    actions = (pred[1:] - pred[:-1]).cpu().numpy()  # (T-1, 2)
    return compute_mi_knn(states, actions, k)

# ---- Your input tensors ----
# assuming drawing, child1, child2, child3, child4 are defined as numpy arrays
drawing = torch.tensor(drawing, dtype=torch.float32)
chil1 = torch.tensor(child1, dtype=torch.float32)
chil2 = torch.tensor(child2, dtype=torch.float32)
chil3 = torch.tensor(child3, dtype=torch.float32)
chil4 = torch.tensor(child4, dtype=torch.float32)

drawings = [drawing, chil1, chil2, chil3, chil4]

# Compute existing losses
score_dicts = [compute_scores(d) for d in drawings]

# Compute MI scores for each trajectory
mi_scores = [compute_mi_score_torch(d) for d in drawings]

# Combine all metrics to normalize (including MI)
weights = {
    'Symmetry': 1.0,
    'Closed Shape': 0.5,
    'Smoothness': 0.1,
    'Mutual Information': 1.0  # weight for MI (used only in weighted sum later)
}

# Extract metrics from score_dicts and add MI
all_metrics = {key: [d[key] for d in score_dicts] for key in weights if key != 'Mutual Information'}
all_metrics['Mutual Information'] = mi_scores

# Normalize scores between 0 and 1 per metric
def normalize_list(vals):
    arr = np.array(vals)
    min_val, max_val = arr.min(), arr.max()
    if max_val - min_val < 1e-8:
        return [0.0 for _ in vals]
    return ((arr - min_val) / (max_val - min_val)).tolist()

normalized_scores = []
for i in range(len(drawings)):
    norm_score = {}
    for k in all_metrics:
        norm_vals = normalize_list(all_metrics[k])
        norm_score[k] = norm_vals[i]
    normalized_scores.append(norm_score)

# Now compute final scores as (1-beta)*MI - beta*(weighted sum of other losses)
beta = 0.5  # tune this parameter [0,1]

final_scores = []
for norm_score in normalized_scores:
    weighted_loss = sum(norm_score[k] * weights[k] for k in weights if k != 'Mutual Information')
    combined = (1 - beta) * norm_score['Mutual Information'] - beta * weighted_loss
    final_scores.append(combined)

# Normalize final scores between 0 and 1 for interpretability
final_scores_np = np.array(final_scores)
min_fs, max_fs = final_scores_np.min(), final_scores_np.max()
if max_fs - min_fs < 1e-8:
    final_scores_norm = [0.0 for _ in final_scores]
else:
    final_scores_norm = ((final_scores_np - min_fs) / (max_fs - min_fs)).tolist()

# Titles with scores
titles_with_scores = [
    f'Original Expert | Final Score: {final_scores_norm[0]:.3f}',
    f'Child 1: Noisy | Final Score: {final_scores_norm[1]:.3f}',
    f'Child 2: Simplified | Final Score: {final_scores_norm[2]:.3f}',
    f'Child 3: Distorted | Final Score: {final_scores_norm[3]:.3f}',
    f'Child 4: Exaggerated | Final Score: {final_scores_norm[4]:.3f}'
]

for title in titles_with_scores:
    print(title)


In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 5, figsize=(22, 5))

for ax, data, title, score, mi in zip(axs, drawings, titles_with_scores, score_dicts, mi_scores):
    data_np = data.numpy()
    ax.plot(data_np[:, 0], data_np[:, 1], lw=2)
    ax.axis('equal')
    ax.axis('off')

    # Combine existing losses and add Mutual Information
    score_with_mi = score.copy()
    score_with_mi['Mutual Information'] = mi

    score_text = '\n'.join([f"{k}: {v:.3f}" for k, v in score_with_mi.items()])
    ax.text(0.05, 0.95, score_text, transform=ax.transAxes, fontsize=12,
            verticalalignment='top', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
    ax.set_title(title, fontsize=11)

plt.tight_layout()
plt.show()


In [None]:
class PeriodicDMP:
    def __init__(self, n_dof=2, n_basis=20, tau=1.0, alpha=25.0, beta=10.0):
        self.n_dof = n_dof
        self.n_basis = n_basis
        self.tau = tau
        self.alpha = alpha
        self.beta = beta

        self.time = np.linspace(0, tau, 300)
        self.psi = self._compute_basis_functions(self.time)

        self.weights = np.random.randn(self.n_basis, self.n_dof)

    def _compute_basis_functions(self, time):

        centers = np.linspace(0, 1, self.n_basis)
        psi = np.exp(-self.alpha * (time[:, None] - centers[None, :]) ** 2)
        return psi

    def fit(self, trajectory):
        goal_position = trajectory[-1]

        for i in range(self.n_dof):
            self.weights[:, i] = np.linalg.lstsq(self.psi, trajectory[:, i], rcond=None)[0]

    def generate(self, goal_position):
        y = np.zeros((len(self.time), self.n_dof))
        dy = np.zeros((len(self.time), self.n_dof))

        for i in range(len(self.time)):
            for j in range(self.n_dof):
                y[i, j] = np.dot(self.psi[i, :], self.weights[:, j])

            dy[i] = np.gradient(y[i], self.time[i])

        return y, dy

In [None]:
example_trajectory = child1
dmp = PeriodicDMP(n_dof=2)

dmp.fit(example_trajectory)

goal_position = example_trajectory[-1]

target_trajectory, _ = dmp.generate(goal_position)

plt.plot(target_trajectory[:, 0], target_trajectory[:, 1], label='Generated by DMP', color='r')
plt.plot(example_trajectory[:, 0], example_trajectory[:, 1], label='Original Trajectory', color='b', linestyle='dashed')
plt.title("Original vs. Generated Trajectory")
plt.legend()
plt.axis('equal')
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim


class LSTMModel(nn.Module):
    def __init__(self, input_size=2, hidden_size=128, output_size=2, num_layers=2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, (hn, cn) = self.lstm(x)
        out = self.fc(lstm_out)
        return out

def smoothness_loss(pred):

    dx = torch.diff(pred[:, :, 0], dim=0)
    dy = torch.diff(pred[:, :, 1], dim=0)
    d1 = dx**2 + dy**2
    loss_d1 = torch.sum(d1)

    ddx = torch.diff(dx, dim=0)
    ddy = torch.diff(dy, dim=0)
    d2 = ddx**2 + ddy**2
    loss_d2 = torch.sum(d2)

    dddx = torch.diff(ddx, dim=0)
    dddy = torch.diff(ddy, dim=0)
    d3 = dddx**2 + dddy**2
    loss_d3 = torch.sum(d3)

    loss = 0.2 * loss_d1 + 0.3 * loss_d2 + 0.5 * loss_d3

    return loss

def closed_shape_loss(pred):
    return torch.norm(pred[:,:,0] - pred[:,:,-1])

def symmetry_loss(pred):
    pred = pred.squeeze(0)

    centroid = torch.mean(pred, dim=0)

    reflected_x = 2 * centroid[0] - pred[:, 0]
    reflected_points = torch.stack((reflected_x, pred[:, 1]), dim=1)

    diff = reflected_points[:, None, :] - pred[None, :, :]
    dists = torch.norm(diff, dim=2)  # [T, T]

    min_dists, _ = torch.min(dists, dim=1)

    return torch.mean(min_dists)


def shape_preservation_loss(pred, dmp_trajectory, weight_edge=12):
    dmp_tensor = torch.tensor(dmp_trajectory, dtype=pred.dtype, device=pred.device)

    diffs = torch.norm(pred - dmp_tensor, dim=-1)


    weights = 1*torch.ones_like(diffs)
    weights[0] = weight_edge
    weights[-1] = weight_edge


    loss = torch.sum(weights * diffs)
    return loss

model = LSTMModel(input_size=2, hidden_size=128, output_size=2, num_layers=2)

mse_loss_function = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
expert_trajectory = np.array(target_trajectory)

target1 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=0, scale_x=1, scale_y=1)
target2 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=-10, scale_x=1.1, scale_y=1.1)
target3 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=5, scale_x=1.2, scale_y=1.2)
target4 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=-15, scale_x=1.3, scale_y=1.3)

target_shapes = [target1, target2, target3, target4]

child_drawings_np = np.stack([child1, child2, child3, child4])
inputs = torch.tensor(child_drawings_np, dtype=torch.float32)
targets = inputs.clone()

target_shapes_torch = [torch.tensor(t, dtype=torch.float32) for t in target_shapes]

model = LSTMModel(input_size=2, hidden_size=128, output_size=2, num_layers=2)
mse_loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 500
for epoch in range(epochs):
    model.train()
    total_epoch_loss = 0.0

    for i in range(inputs.shape[0]):
        input_i = inputs[i].unsqueeze(0)
        target_i = targets[i].unsqueeze(0)
        target_shape_i = target_shapes_torch[i]

        output_i = model(input_i)

        mse_loss = mse_loss_function(output_i, target_i)
        sm_loss = smoothness_loss(output_i)
        cs_loss = closed_shape_loss(output_i)
        sy_loss = symmetry_loss(output_i)
        shape_loss = shape_preservation_loss(output_i.squeeze(0), target_shape_i)

        total_loss = 0.9*mse_loss + 0.1 * sm_loss + 0.1 * cs_loss + 0.1 * sy_loss + 0.1 * shape_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        total_epoch_loss += total_loss.item()

    if epoch % 50 == 0:
        avg_loss = total_epoch_loss / inputs.shape[0]
        print(f"Epoch [{epoch}/{epochs}], Avg Loss per Trajectory: {avg_loss:.4f}")

model.eval()
with torch.no_grad():
    cleaned_drawings = model(inputs)  # (4, 300, 2)
    cleaned_drawings_np = cleaned_drawings.numpy()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.special import digamma

# Your existing LSTM model and losses unchanged
class LSTMModel(nn.Module):
    def __init__(self, input_size=2, hidden_size=128, output_size=2, num_layers=2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, (hn, cn) = self.lstm(x)
        out = self.fc(lstm_out)
        return out

def smoothness_loss(pred):
    dx = torch.diff(pred[:, :, 0], dim=0)
    dy = torch.diff(pred[:, :, 1], dim=0)
    d1 = dx**2 + dy**2
    loss_d1 = torch.sum(d1)

    ddx = torch.diff(dx, dim=0)
    ddy = torch.diff(dy, dim=0)
    d2 = ddx**2 + ddy**2
    loss_d2 = torch.sum(d2)

    dddx = torch.diff(ddx, dim=0)
    dddy = torch.diff(ddy, dim=0)
    d3 = dddx**2 + dddy**2
    loss_d3 = torch.sum(d3)

    loss = 0.2 * loss_d1 + 0.3 * loss_d2 + 0.5 * loss_d3
    return loss

def closed_shape_loss(pred):
    return torch.norm(pred[:,:,0] - pred[:,:,-1])

def symmetry_loss(pred):
    pred = pred.squeeze(0)
    centroid = torch.mean(pred, dim=0)
    reflected_x = 2 * centroid[0] - pred[:, 0]
    reflected_points = torch.stack((reflected_x, pred[:, 1]), dim=1)
    diff = reflected_points[:, None, :] - pred[None, :, :]
    dists = torch.norm(diff, dim=2)
    min_dists, _ = torch.min(dists, dim=1)
    return torch.mean(min_dists)

def shape_preservation_loss(pred, dmp_trajectory, weight_edge=12):
    # If dmp_trajectory is a numpy array:
    if isinstance(dmp_trajectory, torch.Tensor):
        dmp_tensor = dmp_trajectory.detach().clone().to(dtype=pred.dtype, device=pred.device)
    else:
        dmp_tensor = torch.tensor(dmp_trajectory, dtype=pred.dtype, device=pred.device)

    diffs = torch.norm(pred - dmp_tensor, dim=-1)

    weights = torch.ones_like(diffs)
    weights[0] = weight_edge
    weights[-1] = weight_edge

    loss = torch.sum(weights * diffs)
    return loss


# Mutual Information using KNN-based estimator (scikit-learn + numpy)
def compute_mi_knn(x, y, k=5):
    n = x.shape[0]
    data = np.hstack((x, y))
    tree = NearestNeighbors(metric='chebyshev').fit(data)
    dist, _ = tree.kneighbors(data, n_neighbors=k+1)
    eps = dist[:, k] - 1e-15

    tree_x = NearestNeighbors(metric='chebyshev').fit(x)
    nx = np.array([len(tree_x.radius_neighbors([point], radius=eps[i], return_distance=False)[0]) - 1 for i, point in enumerate(x)])

    tree_y = NearestNeighbors(metric='chebyshev').fit(y)
    ny = np.array([len(tree_y.radius_neighbors([point], radius=eps[i], return_distance=False)[0]) - 1 for i, point in enumerate(y)])

    mi = digamma(k) + digamma(n) - np.mean(digamma(nx + 1) + digamma(ny + 1))
    return mi

def compute_mi_score_torch(pred, k=5):
    states = pred[:-1].detach().cpu().numpy()
    actions = (pred[1:] - pred[:-1]).detach().cpu().numpy()
    return compute_mi_knn(states, actions, k)


# --- State visitation rate & reward calculation ---
def compute_state_visitation_rate(trajectories, grid_size=50):
    """
    trajectories: list of numpy arrays of shape (T, 2)
    grid_size: number of bins per axis for discretizing state space
    Returns:
      visitation_map: 2D np.array normalized visitation counts
      states_grids: list of (T, 2) tuples of grid indices for each trajectory's states
    """
    # Collect all points for global grid bounds
    all_points = np.vstack(trajectories)
    x_min, y_min = all_points.min(axis=0)
    x_max, y_max = all_points.max(axis=0)

    # Create bins
    x_bins = np.linspace(x_min, x_max, grid_size + 1)
    y_bins = np.linspace(y_min, y_max, grid_size + 1)

    visitation_map = np.zeros((grid_size, grid_size))

    states_grids = []
    for traj in trajectories:
        x_idx = np.digitize(traj[:, 0], bins=x_bins) - 1
        y_idx = np.digitize(traj[:, 1], bins=y_bins) - 1
        x_idx = np.clip(x_idx, 0, grid_size - 1)
        y_idx = np.clip(y_idx, 0, grid_size - 1)
        states_grids.append(np.stack([x_idx, y_idx], axis=1))
        for xi, yi in zip(x_idx, y_idx):
            visitation_map[yi, xi] += 1  # note: y is row, x is col

    visitation_map /= visitation_map.sum()  # Normalize to sum=1 (probability)
    return visitation_map, states_grids

def compute_trajectory_reward(states_grid, visitation_map):
    """
    states_grid: (T,2) numpy array of grid indices for a trajectory
    visitation_map: (grid_size, grid_size) visitation probabilities
    Reward is sum of visitation probabilities of the visited states normalized by length.
    """
    rewards = []
    for (y_idx, x_idx) in states_grid:  # Note: visitation_map is (row=y, col=x)
        rewards.append(visitation_map[y_idx, x_idx])
    return np.sum(rewards) / len(rewards)

# --- Your training setup ---
expert_trajectory = np.array(target_trajectory)

target1 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=0, scale_x=1, scale_y=1)
target2 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=-10, scale_x=1.1, scale_y=1.1)
target3 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=5, scale_x=1.2, scale_y=1.2)
target4 = rotate_and_scale_drawing(expert_trajectory, angle_degrees=-15, scale_x=1.3, scale_y=1.3)

target_shapes = [target1, target2, target3, target4]

child_drawings_np = np.stack([child1, child2, child3, child4])
inputs = torch.tensor(child_drawings_np, dtype=torch.float32)
targets = inputs.clone()
target_shapes_torch = [torch.tensor(t, dtype=torch.float32) for t in target_shapes]

model = LSTMModel(input_size=2, hidden_size=128, output_size=2, num_layers=2)
mse_loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Beta initialization
beta = 0.5
policy_rewards_history = []

epochs = 500
for epoch in range(epochs):
    model.train()
    total_epoch_loss = 0.0
    epoch_policy_rewards = []

    for i in range(inputs.shape[0]):
        input_i = inputs[i].unsqueeze(0)
        target_i = targets[i].unsqueeze(0)
        target_shape_i = target_shapes_torch[i]

        output_i = model(input_i)

        # Compute MI score on output
        mi_score = compute_mi_score_torch(output_i.squeeze(0))

        # Compute losses
        mse_loss = mse_loss_function(output_i, target_i)
        sm_loss = smoothness_loss(output_i)
        cs_loss = closed_shape_loss(output_i)
        sy_loss = symmetry_loss(output_i)
        shape_loss = shape_preservation_loss(output_i.squeeze(0), target_shape_i)

        # Weighted sum of other losses (excluding MI)
        weighted_loss = 0.9 * mse_loss + 0.1 * sm_loss + 0.1 * cs_loss + 0.1 * sy_loss + 0.1 * shape_loss

        # Total loss with beta weighting
        total_loss = (1 - beta) * (-mi_score) + beta * weighted_loss
        # Note: We negate mi_score to maximize MI (since optimizer minimizes)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        total_epoch_loss += total_loss.item()

        # Save trajectory output to compute visitation later
        epoch_policy_rewards.append(output_i.squeeze(0).cpu().detach().numpy())

    # Compute visitation map and reward for this epoch's trajectories
    visitation_map, states_grids = compute_state_visitation_rate(epoch_policy_rewards)
    rewards = [compute_trajectory_reward(sg, visitation_map) for sg in states_grids]
    mean_reward = np.mean(rewards)
    policy_rewards_history.append(mean_reward)

    # Update beta (if not first epoch)
    if epoch > 0:
        prev_reward = policy_rewards_history[-2]
        delta_reward = mean_reward - prev_reward

        # # If reward increased, increase beta -> more weight on losses, less on MI
        # # If reward decreased, decrease beta -> more weight on MI
        # beta += 0.1 * delta_reward
        beta += 0.1 * np.tanh(10 * delta_reward)

        beta = max(0.0, min(1.0, beta))  # clip beta between 0 and 1

    if epoch % 50 == 0:
        print(f"Epoch [{epoch}/{epochs}], Avg Loss: {total_epoch_loss / inputs.shape[0]:.4f}, Beta: {beta:.3f}, Mean Reward: {mean_reward:.4f}")

# Final evaluation
model.eval()
with torch.no_grad():
    cleaned_drawings = model(inputs)
    cleaned_drawings_np = cleaned_drawings.numpy()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

child_titles = ['Teacher 1', 'Teacher 2', 'Teacher 3', 'Teacher 4']
cleaned_titles = ['Cleaned 1', 'Cleaned 2', 'Cleaned 3', 'Cleaned 4']

target_trajectories = [target1, target2, target3, target4]

all_x = np.concatenate([child[:, 0] for child in [child1, child2, child3, child4]] +
                       [target[:, 0] for target in target_trajectories])
all_y = np.concatenate([child[:, 1] for child in [child1, child2, child3, child4]] +
                       [target[:, 1] for target in target_trajectories])

x_min, x_max = all_x.min(), all_x.max()
y_min, y_max = all_y.min(), all_y.max()

x_margin = 0.1 * (x_max - x_min)
y_margin = 0.1 * (y_max - y_min)

x_limits = (x_min - x_margin, x_max + x_margin)
y_limits = (y_min - y_margin, y_max + y_margin)

fig, axs = plt.subplots(1, 4, figsize=(18, 4))
for ax, child, target, title in zip(axs, [child1, child2, child3, child4], target_trajectories, child_titles):
    ax.plot(target[:, 0], target[:, 1], color='red', lw=3, linestyle='--', label='Target Trajectory')
    ax.plot(child[:, 0], child[:, 1], lw=2, label='Child Drawing')
    ax.set_title(title)

    ax.set_xlim(x_limits)
    ax.set_ylim(y_limits)
    ax.set_aspect('equal')

    ax.tick_params(axis='both', which='both', length=5, labelsize=8)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend()

fig, axs = plt.subplots(1, 4, figsize=(18, 4))
for ax, cleaned, target, title in zip(axs, cleaned_drawings_np, target_trajectories, cleaned_titles):
    ax.plot(target[:, 0], target[:, 1], color='red', lw=3, linestyle='--', label='Target Trajectory')
    ax.plot(cleaned[:, 0], cleaned[:, 1], lw=2, color='blue', label='Cleaned Drawing')
    ax.set_title(title)

    ax.set_xlim(x_limits)
    ax.set_ylim(y_limits)
    ax.set_aspect('equal')

    ax.tick_params(axis='both', which='both', length=5, labelsize=8)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend()

plt.tight_layout()
plt.show()


In [None]:
np.random.seed(0)
x = np.linspace(0, 4 * np.pi, 300)
y = np.sin(x)
trajectory = drawing

indices_uniform = np.linspace(0, 299, 60, dtype=int)
trajectory_uniform = trajectory[indices_uniform]

trajectory_avg = trajectory.reshape(60, 5, 2).mean(axis=1)

plt.figure(figsize=(12, 5))

plt.subplot(1, 3, 1)
plt.plot(trajectory[:, 0], trajectory[:, 1], label='Original', color='gray')
plt.title('Original (300 points)')
plt.xlabel('x')
plt.ylabel('y')

plt.subplot(1, 3, 2)
plt.plot(trajectory[:, 0], trajectory[:, 1], color='gray', alpha=0.3)
plt.plot(2*trajectory_uniform[:, 0], 2*trajectory_uniform[:, 1], 'o-', label='Uniform Sampled', color='blue')
plt.title('Uniform Sampling (60 points)')

plt.subplot(1, 3, 3)
plt.plot(trajectory[:, 0], trajectory[:, 1], color='gray', alpha=0.3)
plt.plot(2*trajectory_avg[:, 0], 2*trajectory_avg[:, 1], 'o-', label='Block Averaged', color='green')
plt.title('Block Averaging (60 points)')

plt.tight_layout()
plt.show()

In [None]:
noisy_d.shape

In [None]:
T = 6.0
ts_new = np.linspace(0, T, 60)
dt = ts_new[1] - ts_new[0]

vel_standard = np.gradient(noisy_d, dt, axis=1)

traj_standard = np.concatenate([noisy_d, vel_standard], axis=-1)

In [None]:
import os
curr_path = os.getcwd()
save_dir = os.path.join(curr_path, 'config', 'Data_Trajs')
os.makedirs(save_dir, exist_ok=True)
data_file = os.path.join(save_dir, 'noisy_dataset.npy')
with open(data_file, 'wb') as f:
    np.save(f, traj_standard)

In [None]:
trajectory_uniform.shape

In [None]:
data_file = os.path.join(save_dir, 'draw_ref.npy')
with open(data_file, 'wb') as f:
    np.save(f, 2*trajectory_uniform)