In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from utils.disentanglement_metric import dlsbd_metric_mixture, apply_inverse_rotation, repeat_angles_n_gaussians

In [None]:
def make_rotation_matrix_2d(angles):
    """
    Create rotation matrices of shape (num_angles, 2, 2)
    Args:
        angles: array of angles shaped as (num_angles,1)

    Returns:

    """
    cos_angle = np.cos(angles)
    sin_angle = np.sin(angles)
    # Stack the cosine and sine of the angles
    matrix = np.stack((cos_angle, -sin_angle,
                       sin_angle, cos_angle),
                      axis=-1)
    # Sha
    output_shape = angles.shape + (2, 2)
    return matrix.reshape(output_shape)

def repeat_angles_n_gaussians(angles, n_gaussians):
    """
    Adds a new dimension to the angles array and repeats the values n_gaussians times
    :param angles: array of angles with shape (n_objects, n_angles)
    :param n_gaussians:
    :return:
    """
    return np.repeat(np.expand_dims(angles, axis=1), n_gaussians, axis=1)


def apply_inverse_rotation(z, angles):
    """
    Applies the inverse rotation to the latent z
    :param z: latent with shape (n_objects, n_angles, num_gaussians, z_dim)
    :param angles: angles with shape (n_objects, n_angles)
    :return:
    """
    # Inverse rotation for torus manifold
    if z.shape[-1] == 4:
        print("Applying inverse rotation for torus manifold")
        inv_z_subspaces = []
        for num_subspace in range(2):
            subspace_angles = angles[..., num_subspace]
            print(z.shape)
            z_subspace = z[..., num_subspace * 2:(num_subspace + 1) * 2]
            inv_z_subspaces.append(apply_inverse_2d_rotation(z_subspace, subspace_angles))
        inv_z = np.concatenate(inv_z_subspaces, axis=-1)
    # Inverse rotation for cylinder manifold
    else:
        print("Applying inverse rotation for cylinder manifold")
        inv_z = apply_inverse_2d_rotation(z, angles)
    return inv_z


def apply_inverse_2d_rotation(z, angles):

    inv_rotations = make_rotation_matrix_2d(-angles)
    print("Z shape",z.shape)
    inv_z = np.expand_dims(z, axis=-1)
    inv_z = np.matmul(inv_rotations, inv_z)
    inv_z = np.squeeze(inv_z, axis=-1)
    print(inv_z.shape)
    return inv_z


def estimate_kmeans_clusters(z_inv, n_stabilizers):
    latent_dim = z_inv.shape[-1]
    if latent_dim == 4:
        z_inv_mean = []
        for num_subspace in range(latent_dim // 2):
            z_inv_subspace = z_inv[..., num_subspace * 2:(num_subspace + 1) * 2]
            k_means = KMeans(n_clusters=n_stabilizers[num_subspace], random_state=0).fit(
                z_inv_subspace.reshape((-1, latent_dim)))
            z_inv_mean_subspace = np.expand_dims(k_means.cluster_centers_, axis=0)
            z_inv_mean_subspace = np.expand_dims(z_inv_mean_subspace, axis=0)
            z_inv_mean.append(z_inv_mean_subspace)

        z_inv_mean = np.concatenate(z_inv_mean, axis=-1)
    else:
        k_means = KMeans(n_clusters=n_stabilizers, random_state=0).fit(
            z_inv.reshape((-1, latent_dim)))
        z_inv_mean = np.expand_dims(k_means.cluster_centers_, axis=0)
        z_inv_mean = np.expand_dims(z_inv_mean, axis=0)
    return z_inv_mean


def estimate_mean_inv(z_inv):
    latent_dim = z_inv.shape[-1]
    if latent_dim == 4:
        z_inv_mean = []
        for num_subspace in range(latent_dim // 2):
            z_inv_subspace = z_inv[..., num_subspace * 2:(num_subspace + 1) * 2]
            print("Subspace shape", z_inv_subspace.shape)
            z_inv_mean.append(np.mean(z_inv_subspace, axis=-3, keepdims=True))
        z_inv_mean = np.concatenate(z_inv_mean, axis=-1)
    else:
        z_inv_mean = np.mean(z_inv, axis=-3, keepdims=True)
    return z_inv_mean


def calculate_dispersion(z_inv, distance_function: str = "euclidean"):
    """
    Calculates the dispersion of the latent z_inv
    :param z_inv: latent with shape (n_objects, n_angles, num_gaussians, z_dim)
    :return:
    """
    z_inv_mean = estimate_mean_inv(z_inv)
    if distance_function == "euclidean":
        dispersion = np.linalg.norm(z_inv - z_inv_mean, axis=-1)
    elif distance_function == "cosine":
        dispersion = 1 - np.sum(z_inv * z_inv_mean, axis=-1) / (
                np.linalg.norm(z_inv, axis=-1) * np.linalg.norm(z_inv_mean, axis=-1))
    elif distance_function == "cross-entropy":
        raise NotImplementedError
    elif distance_function == "chamfer":
        dispersion = matrix_dist_numpy(z_inv, z_inv_mean).min(dim=-1)[0].sum(dim=-1).mean()
    else:
        raise NotImplementedError
    return dispersion


def matrix_dist_numpy(z_mean_next, z_mean_pred):
    latent_dim = z_mean_next.shape[-1]
    if latent_dim != 3:
        return ((np.expand_dims(z_mean_pred, 1) - np.expand_dims(z_mean_next, 2)) ** 2).sum(-1)

    else:
        return ((np.expand_dims(z_mean_pred, 1) - np.expand_dims(z_mean_next, 2)) ** 2).sum(-1).sum(-1)


def dlsbd_metric_mixture(z, angles, average: bool = True, distance_function: str = "euclidean"):
    """
    Calculates the lsbd metric for embeddings z with shape (num_objects, num_angles, num_gaussians, latent_dim)
    corresponding to a mixture of a certain distribution. Latent dim should be 2 in this case.
    Shape of angles is assumed to be (num_objects, num_angles)
    :param z: embeddings in Z_G
    :param angles: angles used to generate the dataset where embeddings are extracted from
    :return:
    """
    num_gaussians = z.shape[-2]
    angles = repeat_angles_n_gaussians(angles, num_gaussians)
    z_inv = apply_inverse_rotation(z, angles)
    dispersion = calculate_dispersion(z_inv, distance_function=distance_function)
    if average:
        dispersion = np.mean(dispersion)

    return dispersion


In [None]:
n_orbits = 1
n_angles = 20
n_omega = 3
n_gaussians = 3
noise = 0.0
true_angles = np.linspace(0, 1, n_angles, endpoint=False) * 2 * np.pi / n_omega
true_angles = np.repeat(np.expand_dims(true_angles, axis=0), n_orbits, axis=0)

# phases = np.expand_dims(phases, axis = 0)
expanded_true_angles = np.repeat(np.expand_dims(true_angles, axis = -1), n_gaussians, axis =-1)
noisy_angles = expanded_true_angles + 2*np.pi * (np.random.rand(*expanded_true_angles.shape)-1/2)*noise
phases = np.expand_dims(np.arange(0, n_gaussians)*2*np.pi/n_gaussians, axis = (0,1))

noisy_angles = noisy_angles+phases
embeddings = np.stack([np.cos(noisy_angles), np.sin(noisy_angles)], axis=-1)


inv_z = apply_inverse_rotation(embeddings, expanded_true_angles)
mean_inv = estimate_mean_inv(inv_z)
fig, ax = plt.subplots(1,1, figsize = (5,5))
cmap = cm.get_cmap('Reds')
for i in range(embeddings.shape[-3]):
    ax.scatter(embeddings[0, i, :, 0], embeddings[0, i, :, 1], color = cmap(i/embeddings.shape[-3]),edgecolors="k")
ax.scatter(inv_z[0, :, :, 0], inv_z[0, :, :, 1], marker="*", s = 100, label = "Inverse")
ax.scatter(mean_inv[0, :, :, 0] * 1.1, mean_inv[0, :, :, 1]* 1.1, marker="*", c = "r", s = 100, label = "Inverse Mean")
ax.legend()


inv_z_mean = np.mean(inv_z, axis = 1, keepdims=True)
dispersion = dlsbd_metric_mixture(embeddings, true_angles)
mean_dispersion = np.mean(dispersion)
print(mean_dispersion)


# plt.scatter(embeddings[0, :, 0], embeddings[0,:, 1])



# dlsbd(embeddings, k_values=[[-1, 1]], factor_manifold="cylinder")
