In [2]:
import polars as pl
import numpy as np

# Load the data
df = pl.read_csv('./connectome_graph.csv')

# Extract arrays
source_nodes = df[df.columns[0]].to_numpy().astype(np.int64)
target_nodes = df[df.columns[1]].to_numpy().astype(np.int64)
edge_weights = df[df.columns[2]].to_numpy().astype(np.int64)


In [3]:
# Get unique node IDs and map to indices
unique_nodes = np.unique(np.concatenate((source_nodes, target_nodes)))
node_id_to_index = {node_id: idx for idx, node_id in enumerate(unique_nodes)}
index_to_node_id = {idx: node_id for node_id, idx in node_id_to_index.items()}

# Map node IDs to indices in edge lists
source_indices = np.array([node_id_to_index[node_id] for node_id in source_nodes])
target_indices = np.array([node_id_to_index[node_id] for node_id in target_nodes])

In [4]:
import jax.numpy as jnp
import jax
from jax import random

# Convert to JAX arrays
source_indices = jnp.array(source_indices)
target_indices = jnp.array(target_indices)
edge_weights = jnp.array(edge_weights)

# Compute maximum edge weight
total_edge_weight = jnp.max(edge_weights)

# Normalize edge weights
edge_weights = edge_weights / total_edge_weight


2024-10-29 10:47:37.016780: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [15]:
import numpy as np
import pandas as pd
import jax.numpy as jnp
def calculate_node_forward(source_orders, target_orders, edge_weights):
    forward_edges = source_orders < target_orders

    # Calculate total forward edge weight
    forward_edge_weight = jnp.sum(edge_weights * forward_edges)
    print(forward_edge_weight)
    # Calculate total edge weight (for normalization)
    total_edge_weight = jnp.sum(edge_weights)

    # Calculate percentage of forward edge weight
    percentage_forward = 100 * (forward_edge_weight / total_edge_weight)
    return percentage_forward
    # print(f"Baseline Percentage of forward edge weight baseline: {percentage_forward:.2f}%")
import pandas as pd
baseline_ordering = pd.read_csv('./ordered_nodes_84.54927062988281_brute.csv')

node_ids = baseline_ordering['Node ID'].to_numpy()
orders = jnp.arange(len(node_ids))

# Create a sorting index
sort_idx = np.argsort(node_ids)

# Sort both arrays using this index
sorted_node_ids = node_ids[sort_idx]
sorted_orders = orders[sort_idx]

In [16]:
node_order = jnp.zeros(orders.shape[0])
node_order = node_order.at[sort_idx].set(jnp.arange(orders.shape[0]))

In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
num_nodes = len(unique_nodes)
key = random.PRNGKey(0)
# base_positions = np.load('./positions_84.09818529671391_999.npy')
base_positions = (sort_idx / jnp.max(sort_idx))
# base_positions = random.uniform(key, shape=(num_nodes,))
positions = base_positions
sorted_indices = jnp.argsort(positions)

# Create a mapping from node index to order in the sequence
node_order = jnp.zeros(num_nodes, dtype=int)
node_order = node_order.at[sorted_indices].set(jnp.arange(num_nodes))

edge_directions = node_order[target_indices] - node_order[source_indices]

forward_edges = edge_directions > 0

total_forward_weight_initial = jnp.sum(edge_weights * forward_edges)

total_edge_weight = jnp.sum(edge_weights)
original_total_edge_weights = total_edge_weight
# Compute the percentage of forward edge weight
percentage_forward_initial = 100 * float(total_forward_weight_initial) / (total_edge_weight)

# Convert to JAX arrays
source_indices = jnp.array(source_indices)
target_indices = jnp.array(target_indices)
edge_weights = jnp.array(edge_weights)

# Compute maximum edge weight
mean_edge_weight = jnp.max(edge_weights)

# Normalize edge weights
edge_weights = edge_weights / mean_edge_weight

# Print the results
print(f"Total Forward Edge Weight (Initial): {total_forward_weight_initial}")
print(f"Percentage of Forward Edge Weight (Initial): {percentage_forward_initial:.2f}%")
best_metric=0

Total Forward Edge Weight (Initial): 14734.474609375
Percentage of Forward Edge Weight (Initial): 84.55%


In [20]:
import jax
import jax.numpy as jnp

def monte_carlo_node_ordering(source_indices, target_indices, node_order, edge_weights, num_iterations=200000000, temp=1.0):
    num_nodes = node_order.shape[0]
    num_edges = source_indices.shape[0]

    source_indices = jnp.array(source_indices)
    target_indices = jnp.array(target_indices)
    node_order = node_order.astype(float)

    # Function to compute the forward score
    def calculate_forward_score(node_order):
        source_order = node_order[source_indices]
        target_order = node_order[target_indices]
        forward_edges = source_order < target_order
        return jnp.sum(edge_weights * forward_edges)

    # Initial score
    current_score = calculate_forward_score(node_order)

    # Initial PRNGKey
    key = jax.random.PRNGKey(0)

    def monte_carlo_step(state):
        node_order, current_score, iteration, temp, key = state

        # Split the key for reproducibility
        key, subkey_i, subkey_j, subkey_accept = jax.random.split(key, 4)

        # Sample two random nodes to swap
        i = jax.random.randint(subkey_i, (), 0, num_nodes)
        j = jax.random.randint(subkey_j, (), 0, num_nodes)

        # Ensure that i != j using jax.lax.while_loop
        def cond_fun(val):
            _, j = val
            return j == i

        def body_fun(val):
            key_j, _ = val
            key_j, subkey_new_j = jax.random.split(key_j)
            new_j = jax.random.randint(subkey_new_j, (), 0, num_nodes)
            return key_j, new_j

        key_j, j = jax.lax.while_loop(cond_fun, body_fun, (key, j))

        # Update key with the latest key_j
        key = key_j

        # Swap positions of node i and node j
        new_node_order = node_order.at[i].set(node_order[j])
        new_node_order = new_node_order.at[j].set(node_order[i])

        # Compute new score
        new_score = calculate_forward_score(new_node_order)

        # Acceptance probability (simulated annealing)
        delta = new_score - current_score
        accept_prob = jnp.exp(delta / temp)  # Probabilistically accept worse solutions based on temperature

        random_value = jax.random.uniform(subkey_accept, ())

        # Update state based on acceptance criterion
        accept_swap = (delta > 0) #| (random_value < accept_prob)

        node_order = jax.lax.select(accept_swap, new_node_order, node_order)
        current_score = jax.lax.select(accept_swap, new_score, current_score)

        return node_order, current_score, iteration + 1, temp, key

    # Initial state for the Monte Carlo loop
    state = (node_order, current_score, 0, temp, key)

    # Loop for a given number of iterations
    def cond_fun(state):
        node_order, _, iteration, _, _ = state
        def check_metric_fn(_):
            source_order = node_order[source_indices]
            target_order = node_order[target_indices]
            metric = calculate_node_forward(source_order, target_order, edge_weights)
            jax.debug.print("Iteration {iteration}, metric {metric}, score {current_score}", iteration=iteration, metric=metric, current_score=current_score)
            return None

        # Conditionally run the metric check
        jax.lax.cond(iteration % 100000 == 0, check_metric_fn, lambda _: None, operand=None)

        return iteration < num_iterations

    state = jax.lax.while_loop(cond_fun, monte_carlo_step, state)
    final_node_order, final_score, _, _, _ = state

    return final_node_order, final_score


In [22]:
print("START")
node_order, best_metric = monte_carlo_node_ordering(source_indices, target_indices, node_order, edge_weights, temp=1e-3)

source_order = node_order[source_indices]
target_order = node_order[target_indices]
metric = calculate_node_forward(source_order, target_order, edge_weights)
ordered_node_ids = [index_to_node_id[int(idx)] for idx in node_order]

ordered_nodes_df = pd.DataFrame({"Node ID": ordered_node_ids, "Order": jnp.arange(node_order.shape[0])})
ordered_nodes_df.to_csv(f"./ordered_nodes_{metric}_brute.csv", index=False)
print("FINISHED", metric)

START
Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
Iteration 0, metric 84.54927062988281, score 14734.474609375
Iteration 100000, metric 84.54927062988281, score 14734.474609375
Iteration 200000, metric 84.54927062988281, score 14734.474609375
Iteration 300000, metric 84.54927062988281, score 14734.474609375
Iteration 400000, metric 84.54927062988281, score 14734.474609375
Iteration 500000, metric 84.54927062988281, score 14734.474609375
Iteration 600000, metric 84.54927062988281, score 14734.474609375
Iteration 700000, metric 84.54927062988281, score 14734.474609375
Iteration 800000, metric 84.54927062988281, score 14734.474609375
Iteration 900000, metric 84.54927062988281, score 14734.474609375
Iteration 1000000, metric 84.54928588867188, score 14734.474609375
Iteration 1100000, metric 84.54928588867188, score 14734.474609375
Iteration 1200000, metric 84.54928588867188, score 14734.474609375
Iteration 1300000, metric 84.54928588867188, score 14734.474609375
Iterat

In [29]:
node_order_8514 = jnp.array(node_order)
source_order = node_order_8514[source_indices]
target_order = node_order_8514[target_indices]
metric = calculate_node_forward(source_order, target_order, edge_weights)
print(metric)

14734.475
84.54927


In [31]:
# sorted_indices = jnp.argsort(positions)
node_order = jnp.zeros(num_nodes)
node_order = node_order.at[node_order_8514.astype(int)].set(jnp.arange(num_nodes))

In [32]:
# Save the ordering to a CSV file
import pandas as pd
ordered_node_ids = [index_to_node_id[int(idx)] for idx in node_order]

ordered_nodes_df = pd.DataFrame({"Node ID": ordered_node_ids, "Order": jnp.arange(node_order.shape[0])})
ordered_nodes_df.to_csv(f"./ordered_nodes_{metric}_brute.csv", index=False)