## Notes

- Left multiplying by a permutation matrix permutes the rows, right multiplying permutes the columns.

## Imports

In [None]:
import numpy as np

import matplotlib.pyplot as plt
import plotly.express as px
import torch
from scipy.spatial.distance import cdist
import numpy as np
import time
from tqdm import tqdm
import pygmtools as pygm
import functools

pygm.set_backend("numpy")


def time_decorator(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"Function {func.__name__} took {end_time - start_time} seconds to run.")
        return result

    return wrapper

## Input

In [None]:
@time_decorator
def create_permuted_matrices(n, d, plot=False):
    Wa = torch.rand(n * d).reshape(n, d)

    gt_perm_ba_indices = torch.randperm(n)
    P_BA_gt = torch.eye(n)[gt_perm_ba_indices]

    if plot:
        fig = px.imshow(P_BA_gt)
        fig.show()

    Wb = P_BA_gt @ Wa

    P_AB_gt = P_BA_gt.T

    assert torch.all(P_AB_gt @ Wb == Wa)

    return Wa, Wb, P_AB_gt


Wa, Wb, P_AB_gt = create_permuted_matrices(n=32, d=128)

## Affinity matrix

In [None]:
@time_decorator
def build_affinity_matrix_inefficient(Wa, Wb):
    Wa = torch.tensor(Wa)
    Wb = torch.tensor(Wb)

    num_neurons = len(Wa)
    num_matchings = num_neurons**2

    S = torch.zeros((num_matchings, num_matchings))

    for xi in tqdm(range(num_neurons)):
        p_xi = Wa[xi, :]
        for yi in range(num_neurons):
            p_yi = Wb[yi, :]
            for xj in range(num_neurons):
                p_xj = Wa[xj, :]
                dx = torch.norm(p_xi - p_xj, p=2)
                for yj in range(num_neurons):
                    p_yj = Wb[yj, :]
                    dy = torch.norm(p_yi - p_yj, p=2)
                    S[xi * num_neurons + yi, xj * num_neurons + yj] = torch.exp(-torch.abs(dx - dy) / 1e-2)

    return S.numpy()

In [None]:
affinity = build_affinity_matrix_inefficient(Wa, Wb)

In [None]:
assert np.all(affinity == affinity.T)
fig = plt.imshow(affinity)

plt.show()

In [None]:
from backports.strenum import StrEnum
from enum import auto


class DiagContent(StrEnum):
    """Enum for diagonal content of affinity matrix"""

    ONES = auto()
    SIMILARITIES = auto()


@time_decorator
def build_affinity_matrix_vectorized(Wa, Wb, diag: DiagContent):
    Wa = torch.tensor(Wa)
    Wb = torch.tensor(Wb)

    num_neurons = Wa.size(0)
    num_matchings = num_neurons**2

    # Compute all pairwise Euclidean distances for Wa and Wb
    Wa_distances = torch.cdist(Wa, Wa, p=2)
    Wb_distances = torch.cdist(Wb, Wb, p=2)

    # Prepare the distance matrices for broadcasting
    Wa_distances = Wa_distances.view(num_neurons, 1, num_neurons, 1).expand(-1, num_neurons, -1, num_neurons)
    Wb_distances = Wb_distances.view(1, num_neurons, 1, num_neurons).expand(num_neurons, -1, num_neurons, -1)

    S = torch.exp(-torch.abs(Wa_distances - Wb_distances) / 1e-2)

    S = S.reshape(num_matchings, num_matchings)

    if diag == DiagContent.ONES:
        diag_matrix = torch.eye(num_matchings)

    elif diag == DiagContent.SIMILARITIES:
        Wa_Wb_distances = torch.cdist(Wa, Wb, p=2)
        Wa_Wb_sim = torch.exp(-torch.abs(Wa_Wb_distances) / 1e-2)

        Wa_Wb_sim = Wa_Wb_sim.reshape(num_matchings, 1)

        diag_matrix = torch.diag(Wa_Wb_sim.squeeze())

    mask = torch.eye(num_matchings, dtype=torch.bool)
    S[mask] = 0

    S = S + diag_matrix

    return S.numpy()

In [None]:
diag_content = DiagContent.ONES
affinity_vectorized = build_affinity_matrix_vectorized(Wa, Wb, diag=diag_content)

In [None]:
fig = plt.imshow(affinity_vectorized)

plt.show()

assert np.all(np.abs(affinity_vectorized.T - affinity_vectorized) < 5e-3)

In [None]:
if diag_content == DiagContent.ONES:
    assert np.all(np.abs(np.diag(affinity_vectorized) - 1) < 5e-3)

## Matching

In [None]:
@time_decorator
def get_principal_eigenvector(M):

    num_neurons = torch.sqrt(torch.tensor(M.size(0))).int()

    values, vectors = torch.linalg.eigh(M)
    principal_eigenvector = vectors[:, torch.argmax(values)]

    principal_eigenvector = principal_eigenvector.reshape(num_neurons, num_neurons)
    principal_eigenvector = torch.abs(principal_eigenvector)

    return principal_eigenvector

In [None]:
principal_eigenvector = get_principal_eigenvector(torch.tensor(affinity_vectorized))

In [None]:
dist_aa = torch.cdist(Wa, Wa, p=2)
dist_bb = torch.cdist(Wb, Wb, p=2)

In [None]:
num_neurons = len(Wa)

dist_aa_batched = np.expand_dims(dist_aa, axis=0)
dist_bb_batched = np.expand_dims(dist_bb, axis=0)
num_neurons_batched = np.expand_dims(num_neurons, axis=0)

conn1, edge1, ne1 = pygm.utils.dense_to_sparse(dist_aa_batched)
conn2, edge2, ne2 = pygm.utils.dense_to_sparse(dist_bb_batched)

gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1)
inner_prod_aff_fn = pygm.utils.inner_prod_aff_fn

K = pygm.utils.build_aff_mat(
    node_feat1=None,
    edge_feat1=edge1,
    connectivity1=conn1,
    node_feat2=None,
    edge_feat2=edge2,
    connectivity2=conn2,
    n1=num_neurons_batched,
    ne1=None,
    n2=num_neurons_batched,
    ne2=None,
    edge_aff_fn=inner_prod_aff_fn,
)

In [None]:
X = pygm.sm(K, num_neurons_batched, num_neurons_batched).squeeze()

X = pygm.hungarian(X)

In [None]:
fig = plt.imshow(X)
plt.show()

In [None]:
matching_accuracy = (P_AB_gt * X).sum() / num_neurons
matching_accuracy

In [None]:
fig = plt.imshow(principal_eigenvector)
plt.show()

In [None]:
fig = plt.imshow(P_AB_gt)
plt.show()

In [None]:
from torch import Tensor


@time_decorator
def extract_matching_leordeanu(principal_eigenvector: Tensor):
    """
    principal_eigenvector: shape (num_neurons, num_neurons)
    """

    num_neurons = principal_eigenvector.shape[0]

    # Initialize the solution vector
    x = torch.zeros((num_neurons, num_neurons)).type_as(principal_eigenvector).long()

    # Initialize masks for rows and columns
    row_mask = torch.ones(num_neurons, dtype=torch.bool)
    col_mask = torch.ones(num_neurons, dtype=torch.bool)

    while True:
        # Apply masks to principal eigenvector
        masked_principal_eigenvector = principal_eigenvector.clone()
        masked_principal_eigenvector[~row_mask, :] = 0
        masked_principal_eigenvector[:, ~col_mask] = 0

        # Find the maximum value and its index
        flat_index = masked_principal_eigenvector.argmax()

        i, j = np.unravel_index(flat_index.item(), (num_neurons, num_neurons))

        assignment_value = masked_principal_eigenvector[(i, j)]
        if assignment_value == 0:
            break

        # Update the solution vector
        x[i, j] = 1

        # Update the masks to exclude row i and column j
        row_mask[i] = False
        col_mask[j] = False

    return x

In [None]:
P_AB_leordeanu = extract_matching_leordeanu(principal_eigenvector)

In [None]:
from scipy.optimize import linear_sum_assignment
import scipy


def extract_matching_lap(principal_eigenvector):
    num_neurons = principal_eigenvector.shape[0]

    principal_eigenvector = principal_eigenvector.cpu().numpy()

    row_ind, col_ind = linear_sum_assignment(principal_eigenvector.max() - principal_eigenvector)
    P_AB = scipy.sparse.coo_matrix(
        (np.ones(num_neurons), (row_ind, col_ind)), shape=(num_neurons, num_neurons)
    ).toarray()

    return torch.tensor(P_AB)

In [None]:
P_AB_lap = extract_matching_lap(principal_eigenvector)

In [None]:
assert torch.all(P_AB_leordeanu == P_AB_lap)

In [None]:
assert torch.all(P_AB_leordeanu == P_AB_gt)

In [None]:
class EigenvectorPostprocess(StrEnum):

    LEORDEANU = auto()
    LAP = auto()


def compare_matching_algorithms(num_neurons, dim, eigenvec_postprocess: EigenvectorPostprocess):

    Wa, Wb, P_AB_gt = create_permuted_matrices(num_neurons, dim)

    affinity = build_affinity_matrix_vectorized(Wa, Wb, diag=DiagContent.ONES)
    affinity = torch.tensor(affinity).cuda()

    principal_eigenvector = get_principal_eigenvector(affinity)

    if eigenvec_postprocess == EigenvectorPostprocess.LEORDEANU:
        P_AB = extract_matching_leordeanu(principal_eigenvector)
    elif eigenvec_postprocess == EigenvectorPostprocess.LAP:
        P_AB = extract_matching_lap(principal_eigenvector)

    matching_accuracy = (P_AB_gt.cuda() * P_AB.cuda()).sum() / num_neurons

    assert torch.all(P_AB.cpu().float() @ Wb.cpu() == Wa.cpu())

    return matching_accuracy

In [None]:
all_accuracies = []
for i in tqdm(range(20)):
    acc = compare_matching_algorithms(128, 256, EigenvectorPostprocess.LEORDEANU)
    print(acc)

    all_accuracies.append(acc)

In [None]:
# plot all the accuracies
fig = px.histogram([acc.cpu().numpy() for acc in all_accuracies])
fig.show()