In [None]:
import functools
import os
import shutil
from typing import Any, Dict, List, Optional

import clrs
import jax
import numpy as np
import requests
import tensorflow as tf
import networkx as nx
import matplotlib.pyplot as plt
import random

def visualize_graph_from_adjacency_matrix(adjacency_matrix, weight_matrix=None):
    """
    Visualizes a graph with explicit arrows and labeled edge weights (adjacent).

    Args:
        adjacency_matrix: Adjacency matrix (NumPy array).
        weight_matrix: Optional weight matrix (NumPy array).
    """

    adjacency_matrix = np.array(adjacency_matrix)
    if adjacency_matrix.shape[0] != adjacency_matrix.shape[1]:
        raise ValueError("Adjacency matrix must be square.")
    num_nodes = adjacency_matrix.shape[0]

    if weight_matrix is None:
        weight_matrix = np.ones_like(adjacency_matrix)
    else:
        weight_matrix = np.array(weight_matrix)
        if weight_matrix.shape != adjacency_matrix.shape:
            raise ValueError("Weight matrix must have the same dimensions.")

    directed_graph = nx.DiGraph()
    undirected_graph = nx.Graph()

    for i in range(num_nodes):
        directed_graph.add_node(i)
        undirected_graph.add_node(i)

    for i in range(num_nodes):
        for j in range(num_nodes):
            if i != j:
                if adjacency_matrix[i, j] != 0:
                    weight = round(weight_matrix[i, j], 2)
                    if adjacency_matrix[j, i] != 0:
                        if i < j:
                            undirected_graph.add_edge(i, j, weight=weight)
                    else:
                        directed_graph.add_edge(i, j, weight=weight)

    pos = nx.spring_layout(undirected_graph)  # Layout based on undirected

    plt.figure(figsize=(8, 6))

    # Draw undirected edges (no arrows)
    nx.draw_networkx_edges(undirected_graph, pos, edge_color='gray', width=2, arrows=False)
    edge_labels_undirected = nx.get_edge_attributes(undirected_graph, 'weight')
    # Use label_pos and rotate for adjacent labels
    nx.draw_networkx_edge_labels(undirected_graph, pos, edge_labels=edge_labels_undirected,
                                 label_pos=0.3, rotate=True)

    # Draw directed edges with explicit arrows
    nx.draw_networkx_edges(directed_graph, pos, edge_color='black', width=1,
                           arrowstyle='->', arrowsize=15)
    edge_labels_directed = nx.get_edge_attributes(directed_graph, 'weight')
    # Use label_pos and rotate for adjacent labels
    nx.draw_networkx_edge_labels(directed_graph, pos, edge_labels=edge_labels_directed,
                                 label_pos=0.3, rotate=True)

    nx.draw_networkx_nodes(directed_graph, pos, node_color='skyblue', node_size=500)
    nx.draw_networkx_labels(directed_graph, pos)

    plt.title("Graph Visualization")
    plt.axis('off')
    plt.show()


NUM_SAMPLES = 1000
encode_hints = True
decode_hints = True

rng = np.random.RandomState(42)
rng_key = jax.random.PRNGKey(rng.randint(2**32, dtype=np.int64))

processor_factory = clrs.get_processor_factory(
    'triplet_gmpnn',
    use_ln=True,
    nb_triplet_fts=8,
    nb_heads=1,
    )
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=128,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init='xavier_on_scalars',
    use_lstm=False,
    learning_rate=0.001,
    grad_clip_max_norm=1.0,
    checkpoint_path='checkpoints/CLRS30',
    freeze_processor=False,
    dropout_prob=0.0,
    hint_teacher_forcing=0.0,
    hint_repred_mode='soft',
    nb_msg_passing_steps=1,
    )


2025-03-05 16:37:56.856422: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741192676.907784  146064 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741192676.921515  146064 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from interp.dataset import HDF5Dataset

# --- Saving Data ---
def save_to_hdf5(data, filename):
    with h5py.File(filename, 'w') as f:
        for i, datapoint in enumerate(data):
            group = f.create_group(f'datapoint_{i}')  # Create a group for each datapoint
            for key, array in datapoint.items():
                # Store each array as a dataset within the group
                group.create_dataset(key, data=array, compression="gzip") # Optional compression



In [14]:
from tqdm import tqdm
# Set up multiple lengths and samples per length
LENGTHS = [8]
# LENGTHS = [4, 7, 11, 13, 16]
# LENGTHS = [20, 25, 30, 35, 40, 45, 50, 55, 60, 64]
SAMPLES_PER_LENGTH = 500
alg = "bellman_ford"

data = []
for length in LENGTHS:
    # Create sampler for this length
    sampler, spec = clrs.build_sampler(
        alg,
        seed=rng.randint(2**32, dtype=np.int64),
        num_samples=SAMPLES_PER_LENGTH,
        length=length,
    )

    # Get dummy trajectory and initialize model
    dummy_traj = [sampler.next()]
    model = clrs.models.BaselineModel(
        spec=[spec],
        dummy_trajectory=dummy_traj,
        get_inter=True,
        **model_params
    )

    all_features = [f.features for f in dummy_traj]
    model.init(all_features, 42)
    model.restore_model(f'best_{alg}.pkl', only_load_processor=False)

    # Get predictions for this length
    feedback = sampler.next()
    new_rng_key, rng_key = jax.random.split(rng_key)
    preds, _, hist = model.predict(new_rng_key, feedback.features)

    # Create data for all samples of this length
    for item in tqdm(range(SAMPLES_PER_LENGTH)):
        hidden_states = np.stack([hist[i].hiddens[item] for i in range(length)]).transpose((0,2,1))
        graph_adj = feedback.features.inputs[3].data[item] # (D, D)
        edge_weights = feedback.features.inputs[2].data[item] # (D, D)
        upd_pi = feedback.features.hints[3].data[:,item,:] # (T, D)
        upd_d = feedback.features.hints[4].data[:,item,:] # (T, D)
        gt_pi = feedback.outputs[0].data[item] # (D)
        start_node = feedback.features.inputs[1].data[item] # (D)
        datapoint = {
            'hidden_states': np.copy(hidden_states),
            'graph_adj': np.copy(graph_adj),
            'edge_weights': np.copy(edge_weights), 
            'upd_pi': np.copy(upd_pi),
            'upd_d': np.copy(upd_d),
            'gt_pi': np.copy(gt_pi),
            'start_node': np.copy(start_node),
        }
        data.append(datapoint)


100%|██████████| 500/500 [00:00<00:00, 725.67it/s]


In [15]:
data[-1]["hidden_states"].shape

(8, 128, 8)

In [16]:
from interp.dataset import custom_collate
save_to_hdf5(data, 'data/interp_data_8_eval.h5')

dataset = HDF5Dataset('data/interp_data_8_eval.h5')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=custom_collate)

for batch in dataloader:
    # Now 'batch' is a dictionary of tensors
    print({key: value.shape for key, value in batch.items()})

dataset.close() # Important to close open file!

{'batch': torch.Size([32]), 'num_graphs': torch.Size([]), 'num_nodes_per_graph': torch.Size([4]), 'all_cumsum': torch.Size([5]), 'edge_weights': torch.Size([32, 32]), 'graph_adj': torch.Size([32, 32]), 'gt_pi': torch.Size([32]), 'hidden_states': torch.Size([32, 128, 32]), 'timesteps_per_graph': torch.Size([4]), 'all_cumsum_timesteps': torch.Size([5]), 'start_node': torch.Size([32]), 'upd_d': torch.Size([32, 32]), 'upd_pi': torch.Size([32, 32])}
{'batch': torch.Size([32]), 'num_graphs': torch.Size([]), 'num_nodes_per_graph': torch.Size([4]), 'all_cumsum': torch.Size([5]), 'edge_weights': torch.Size([32, 32]), 'graph_adj': torch.Size([32, 32]), 'gt_pi': torch.Size([32]), 'hidden_states': torch.Size([32, 128, 32]), 'timesteps_per_graph': torch.Size([4]), 'all_cumsum_timesteps': torch.Size([5]), 'start_node': torch.Size([32]), 'upd_d': torch.Size([32, 32]), 'upd_pi': torch.Size([32, 32])}
{'batch': torch.Size([32]), 'num_graphs': torch.Size([]), 'num_nodes_per_graph': torch.Size([4]), 'all