In [1]:
import functools
import os
import shutil
from typing import Any, Dict, List, Optional

import clrs
import jax
import numpy as np
import requests
import tensorflow as tf
import networkx as nx
import matplotlib.pyplot as plt
import random

def visualize_graph_from_adjacency_matrix(adjacency_matrix, weight_matrix=None):
    """
    Visualizes a graph with explicit arrows and labeled edge weights (adjacent).

    Args:
        adjacency_matrix: Adjacency matrix (NumPy array).
        weight_matrix: Optional weight matrix (NumPy array).
    """

    adjacency_matrix = np.array(adjacency_matrix)
    if adjacency_matrix.shape[0] != adjacency_matrix.shape[1]:
        raise ValueError("Adjacency matrix must be square.")
    num_nodes = adjacency_matrix.shape[0]

    if weight_matrix is None:
        weight_matrix = np.ones_like(adjacency_matrix)
    else:
        weight_matrix = np.array(weight_matrix)
        if weight_matrix.shape != adjacency_matrix.shape:
            raise ValueError("Weight matrix must have the same dimensions.")

    directed_graph = nx.DiGraph()
    undirected_graph = nx.Graph()

    for i in range(num_nodes):
        directed_graph.add_node(i)
        undirected_graph.add_node(i)

    for i in range(num_nodes):
        for j in range(num_nodes):
            if i != j:
                if adjacency_matrix[i, j] != 0:
                    weight = round(weight_matrix[i, j], 2)
                    if adjacency_matrix[j, i] != 0:
                        if i < j:
                            undirected_graph.add_edge(i, j, weight=weight)
                    else:
                        directed_graph.add_edge(i, j, weight=weight)

    pos = nx.spring_layout(undirected_graph)  # Layout based on undirected

    plt.figure(figsize=(8, 6))

    # Draw undirected edges (no arrows)
    nx.draw_networkx_edges(undirected_graph, pos, edge_color='gray', width=2, arrows=False)
    edge_labels_undirected = nx.get_edge_attributes(undirected_graph, 'weight')
    # Use label_pos and rotate for adjacent labels
    nx.draw_networkx_edge_labels(undirected_graph, pos, edge_labels=edge_labels_undirected,
                                 label_pos=0.3, rotate=True)

    # Draw directed edges with explicit arrows
    nx.draw_networkx_edges(directed_graph, pos, edge_color='black', width=1,
                           arrowstyle='->', arrowsize=15)
    edge_labels_directed = nx.get_edge_attributes(directed_graph, 'weight')
    # Use label_pos and rotate for adjacent labels
    nx.draw_networkx_edge_labels(directed_graph, pos, edge_labels=edge_labels_directed,
                                 label_pos=0.3, rotate=True)

    nx.draw_networkx_nodes(directed_graph, pos, node_color='skyblue', node_size=500)
    nx.draw_networkx_labels(directed_graph, pos)

    plt.title("Graph Visualization")
    plt.axis('off')
    plt.show()


NUM_SAMPLES = 1000
encode_hints = True
decode_hints = True

rng = np.random.RandomState(42)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

processor_factory = clrs.get_processor_factory(
    'triplet_gmpnn',
    use_ln=True,
    nb_triplet_fts=8,
    nb_heads=1,
    )
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=128,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init='xavier_on_scalars',
    use_lstm=False,
    learning_rate=0.001,
    grad_clip_max_norm=1.0,
    checkpoint_path='checkpoints/CLRS30',
    freeze_processor=False,
    dropout_prob=0.0,
    hint_teacher_forcing=0.0,
    hint_repred_mode='soft',
    nb_msg_passing_steps=1,
    )


2025-02-24 16:51:04.733937: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740415864.786883   35068 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740415864.801194   35068 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# --- Saving Data ---
def save_to_hdf5(data, filename):
    with h5py.File(filename, 'w') as f:
        for i, datapoint in enumerate(data):
            group = f.create_group(f'datapoint_{i}')  # Create a group for each datapoint
            for key, array in datapoint.items():
                # Store each array as a dataset within the group
                group.create_dataset(key, data=array, compression="gzip") # Optional compression

# --- PyTorch Dataset ---
class HDF5Dataset(Dataset):
    def __init__(self, filename):
        self.filename = filename
        self.file = None  # Open the file lazily
        with h5py.File(self.filename, 'r') as f: # Open temp to get len
            self.length = len(f.keys())

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if self.file is None:
            self.file = h5py.File(self.filename, 'r')  # Open on first access
        group = self.file[f'datapoint_{idx}']
        datapoint = {}
        for key in group.keys():
            datapoint[key] = torch.from_numpy(np.array(group[key])) # Convert to tensor

        return datapoint

    def close(self):
        if self.file is not None:
            self.file.close()
            self.file = None


In [3]:
from tqdm import tqdm

# Create samplers
LENGTH = 7
NUM_SAMPLES = 200
sampler, spec = clrs.build_sampler(
        "bellman_ford",
        seed=rng.randint(2**32),
        num_samples=NUM_SAMPLES,
        length=LENGTH,
        )

# initialise model
dummy_traj = [sampler.next()]
model = clrs.models.BaselineModel(
  spec=[spec],
  dummy_trajectory=dummy_traj,
    get_inter=True,
  **model_params
)

all_features = [f.features for f in dummy_traj]
model.init(all_features, 43)

model.restore_model('best.pkl', only_load_processor=False)

# get model predictions
feedback = sampler.next()
# batch_size = feedback.outputs[0].data.shape[0]
new_rng_key, rng_key = jax.random.split(rng_key)

preds, _, hist = model.predict(new_rng_key, feedback.features)

# creating data
data = []
for item in tqdm(range(NUM_SAMPLES)):
    hidden_states = np.stack([hist[i].hiddens[item] for i in range(LENGTH)]).transpose((0,2,1))
    graph_adj = feedback.features.inputs[3].data[item]
    edge_weights = feedback.features.inputs[2].data[item]
    upd_pi = feedback.features.hints[3].data[:,item,:]
    upd_d = feedback.features.hints[4].data[:,item,:]
    gt_pi = feedback.outputs[0].data[item]
    datapoint = {
        'hidden_states': np.copy(hidden_states),
        'graph_adj': np.copy(graph_adj),
        'edge_weights': np.copy(edge_weights),
        'upd_pi': np.copy(upd_pi),
        'upd_d': np.copy(upd_d),
        'gt_pi': np.copy(gt_pi),
    }
    data.append(datapoint)


100%|███████████████████████████████████████| 200/200 [00:00<00:00, 1550.80it/s]


In [4]:
data[0]["upd_pi"]

array([[0., 1., 2., 3., 4., 5., 6.],
       [0., 1., 2., 3., 4., 5., 3.],
       [0., 1., 2., 3., 4., 6., 6.],
       [0., 1., 2., 3., 4., 5., 6.],
       [0., 1., 2., 3., 4., 5., 6.],
       [0., 1., 2., 3., 4., 5., 6.],
       [0., 1., 2., 3., 4., 5., 6.]])

In [5]:
save_to_hdf5(data, 'data/interp_data_7_eval.h5')

dataset = HDF5Dataset('data/interp_data_7_eval.h5')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for batch in dataloader:
    # Now 'batch' is a dictionary of tensors
    print({key: value.shape for key, value in batch.items()})

dataset.close() # Important to close open file!

{'edge_weights': torch.Size([4, 7, 7]), 'graph_adj': torch.Size([4, 7, 7]), 'gt_pi': torch.Size([4, 7]), 'hidden_states': torch.Size([4, 7, 128, 7]), 'upd_d': torch.Size([4, 7, 7]), 'upd_pi': torch.Size([4, 7, 7])}
{'edge_weights': torch.Size([4, 7, 7]), 'graph_adj': torch.Size([4, 7, 7]), 'gt_pi': torch.Size([4, 7]), 'hidden_states': torch.Size([4, 7, 128, 7]), 'upd_d': torch.Size([4, 7, 7]), 'upd_pi': torch.Size([4, 7, 7])}
{'edge_weights': torch.Size([4, 7, 7]), 'graph_adj': torch.Size([4, 7, 7]), 'gt_pi': torch.Size([4, 7]), 'hidden_states': torch.Size([4, 7, 128, 7]), 'upd_d': torch.Size([4, 7, 7]), 'upd_pi': torch.Size([4, 7, 7])}
{'edge_weights': torch.Size([4, 7, 7]), 'graph_adj': torch.Size([4, 7, 7]), 'gt_pi': torch.Size([4, 7]), 'hidden_states': torch.Size([4, 7, 128, 7]), 'upd_d': torch.Size([4, 7, 7]), 'upd_pi': torch.Size([4, 7, 7])}
{'edge_weights': torch.Size([4, 7, 7]), 'graph_adj': torch.Size([4, 7, 7]), 'gt_pi': torch.Size([4, 7]), 'hidden_states': torch.Size([4, 7, 