In [None]:
import matplotlib.pyplot as plt
import numpy as np
import random

In [None]:
# Load the data from the file
matrix = np.loadtxt('correlations.txt', delimiter=',', dtype=int)

# Verify the shape of the matrix
rows, cols = matrix.shape
assert rows == cols

In [None]:
self_frequencies = np.diag(matrix)
sort_indices = np.argsort(self_frequencies)[::-1]
sorted_matrix = matrix[sort_indices][:, sort_indices]
sorted_self_frequencies = np.diag(sorted_matrix)

print(f"{sort_indices = }")
print(f"diagonal = {np.diag(sorted_matrix)}")

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(sorted_matrix, interpolation="none")
plt.colorbar(label="occurences")
plt.title("non-zero co-activations")
plt.grid(False)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(sorted_matrix[:256, :256], interpolation='none')
plt.colorbar(label="occurences")
plt.title("non-zero co-activations")
plt.grid(False)
plt.tight_layout()
plt.show()

In [None]:
# scale each co-activation by the self-frequency of the corresponding neuron
scaled_matrix = np.zeros_like(sorted_matrix, dtype=float)
for i in range(rows):
    for j in range(cols):
        scaled_matrix[i, j] = sorted_matrix[i, j] / np.sqrt(
            sorted_self_frequencies[i] * sorted_self_frequencies[j]
        )

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(scaled_matrix, interpolation="none")
plt.colorbar(label="scaled co-activation")
plt.title("scaled co-activations")
plt.grid(False)
plt.tight_layout()
plt.show()

In [None]:
def cost_function(order: np.ndarray) -> np.int64:
    """Calculate the cost of a visit order."""
    correlation = np.int64(0)
    for i in range(0, len(order) - 8):
        correlation += (
            matrix[order[i], order[i + 1 : i + 8]].sum(dtype=np.int64) * order[i]
        )
    return -correlation


def swap_two(order: np.ndarray) -> np.ndarray:
    """Swap two random indices."""
    a, b = random.sample(range(len(order)), 2)
    order[a], order[b] = order[b], order[a]
    return order


def reverse_segment(order: np.ndarray) -> np.ndarray:
    """Reverse a random segment."""
    a, b = sorted(random.sample(range(len(order)), 2))
    order[a : b + 1] = order[a : b + 1][::-1]
    return order


def simulated_annealing(order: np.ndarray, initial_temp=1, cooling_rate=0.9995, min_temp=1e-2) -> tuple[np.ndarray, np.int64]:
    """Simulated annealing to optimise the visit order."""
    current_order = order[:]
    best_order = order[:]
    current_cost = cost_function(current_order)
    best_cost = current_cost
    temp = initial_temp

    while temp > min_temp:
        new_order = random.choice([swap_two, reverse_segment])(current_order[:])
        new_cost = cost_function(new_order)

        if (
            new_cost < current_cost
            or np.exp((current_cost - new_cost) / temp) > random.random()
        ):
            current_order, current_cost = new_order, new_cost
            if new_cost < best_cost:
                best_order, best_cost = new_order[:], new_cost

        temp *= cooling_rate
        print(f"{temp = }, {best_cost = }")

    return best_order, best_cost


In [None]:
# first, try by starting with the 0..2048 indices
# initial_order = np.arange(len(matrix))

final_best_order, final_best_cost = simulated_annealing(np.argsort(np.diag(matrix)))
# final_best_order, final_best_cost = simulated_annealing(initial_order)

print(f"best cost = {final_best_cost}")
print(f"best order = {final_best_order}")

In [None]:
factored_matrix = matrix[final_best_order][:, final_best_order]
factored_self_freqs = np.diag(factored_matrix)
scaled_factored_matrix = np.zeros_like(factored_matrix, dtype=float)
for i in range(len(factored_matrix)):
    for j in range(len(factored_matrix)):
        scaled_factored_matrix[i, j] = factored_matrix[i, j] / np.sqrt(
            factored_self_freqs[i] * factored_self_freqs[j]
        )

In [None]:
plt.figure(figsize=(12, 10))
plt.imshow(scaled_factored_matrix, interpolation="none")
plt.colorbar(label="number of co-activations")
plt.title("factored co-activations")
plt.grid(False)
plt.tight_layout()
plt.show()

In [None]:
",".join([f"{int(x)}" for x in final_best_order])

In [None]:
",".join([f"{int(x)}" for x in np.argsort(np.diag(matrix))])