# Setup

The code for the LLM Agent data generation can be found [here](https://codeshare.io/Q8qyBL)

## Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import os
import json
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import random_split

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import kneighbors_graph
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from scipy.spatial.distance import pdist, squareform

! pip install torch_geometric
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv, BatchNorm, LayerNorm, DenseGCNConv, global_mean_pool, dense_diff_pool
from torch_geometric.utils import to_dense_adj, to_dense_batch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Setting Random Seeds for Reproducibility

In [None]:
np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Reading in Data

## Data Downloaded Directly from ARC

In [None]:
root = "/content/drive/MyDrive/CS224W_Project/"
train_path = root + "ARC-800-tasks/training/"
val_path = root + "ARC-800-tasks/evaluation/" # not used for now

In [None]:
# length 400, each element is a list of varying length, depending on how many training examples there are
# for one pattern (usually 3-5)
inputs_all = []

# same shape as inputs_all, but stores the ground-truths (the solutions to the puzzles)
outputs_all = []

# length 400, each element is an input for the LLM to solve and
# FOLLOWS THE SAME PATTERN as the inputs in the corresponding element of inputs_all
X_all = []

# length 400, each element stores the corresponding ground-truths in X_all that the LLM won't see
y_all = []

In [None]:
for i in os.listdir(train_path):
  # print(i)
  with open(train_path + i, "r") as file:
    data = json.load(file)
    inputs = [sample["input"] for sample in data["train"]]
    outputs = [sample["output"] for sample in data["train"]]
    X = data["test"][0]["input"]
    y = data["test"][0]["output"]
    inputs_all.append(inputs)
    outputs_all.append(outputs)
    X_all.append(X)
    y_all.append(y)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i, ax in enumerate(axes[0]):
  if i < len(axes[0]) - 1:
    ax.imshow(inputs_all[1][i], cmap="Greys")
    ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
    ax.set_title(f"Input {i+1}")

for i, ax in enumerate(axes[1]):
  if i < len(axes[1]) - 1:
    ax.imshow(outputs_all[1][i], cmap="Greys")
    ax.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
    ax.set_title(f"Output {i+1}")

axes[0,3].imshow(X_all[1], cmap="Greys")
axes[0,3].tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
axes[0,3].set_title("X")

grid = np.zeros((10, 10))
pattern = [(1, 3), (2, 2), (1, 4), (1, 5), (2, 6), (3, 6), (4, 5), (5, 4), (6, 4), (8, 4)]
for x, y in pattern:
    grid[x, y] = 1

axes[1,3].imshow(grid, cmap="Reds")
axes[1,3].tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
axes[1,3].set_title("Y")
plt.tight_layout()
# plt.savefig("input_output_example.png", dpi=300)
plt.show()

In [None]:
fig,
plt.imshow(inputs_all[0][0], cmap="Greys")
plt.axis("off")
plt.show()


## Data Outputs from LLM Agents

Note: Code for LLM Agent data generation is in another notebook.



In [None]:
all_graphs = []
all_files = os.listdir(root + "AgentIdeas/")
all_files = sorted(all_files, key=lambda x: float(x[x.index("_")+1:x.rindex("_")]))

In [None]:
num_iterations = 6 # From Krish's data generation
root = "/content/drive/MyDrive/CS224W_Project/"
for i in all_files:
  with open(root + "AgentIdeas/" + i, "r") as file:
    temp = []
    data = json.load(file)
    for j in range(len(data["embeddings_data"])): # number of agents * num_iterations
      print(data["embeddings_data"][j]["metadata"]["agent_persona"], "idea", j % num_iterations, "for problem",
            data["embeddings_data"][j]["metadata"]["problem_idx"])
      temp.append(data["embeddings_data"][j]["embedding"])
    all_graphs.append(np.array(temp))
    print()

In [None]:
data_array = np.array(all_graphs)
num_graphs, num_nodes, embedding_dim = np.array(all_graphs).shape

# Dimensions: [number of ARC problems, number of ideas (nodes), node feature dimension]
data_array.shape

# Graph Machine Learning

## Graph Construction

In [None]:
def create_kmeans_edges(node_features, n_clusters=5):
  kmeans = KMeans(n_clusters=n_clusters)
  clusters = kmeans.fit_predict(node_features)
  edges = []
  for i in range(len(node_features)):
    same_cluster_nodes = np.where(clusters == clusters[i])[0]
    edges.extend([(i, j) for j in same_cluster_nodes if i != j])
  return edges

def create_similarity_edges(node_features, threshold=0.7):
  sim_matrix = cosine_similarity(node_features)
  idxs = np.triu_indices_from(sim_matrix, k=1)
  upper_tri_values = sim_matrix[idxs]
  valid_pairs = np.where(upper_tri_values > threshold)
  edges = list(zip(idxs[0][valid_pairs], idxs[1][valid_pairs]))
  return edges

def create_diffusion_edges(node_features, temp=0.1, random_prob=0.1):
  distances = pdist(node_features)
  dist_matrix = squareform(distances)
  diff_matrix = np.exp(-dist_matrix / temp)
  mean_diff = np.mean(diff_matrix)
  idxs = np.triu_indices_from(diff_matrix, k=1)
  valid_pairs = np.where(diff_matrix[idxs] > mean_diff)
  edges = set(zip(idxs[0][valid_pairs], idxs[1][valid_pairs]))
  n_nodes = len(node_features)
  n_random = int(random_prob * n_nodes * (n_nodes - 1) / 2)
  all_possible = np.array(list(zip(*idxs)))
  random_indices = np.random.choice(all_possible.shape[0], size=n_random, replace=False)
  random_edges = all_possible[random_indices]
  edges.update(map(tuple, random_edges))
  return list(edges)

def combine_edges(node_features, k=3, sim_threshold=0.82, temp=0.1, random_prob=0.1):
  edges = set()
  edges.update(create_kmeans_edges(node_features, k))
  edges.update(create_similarity_edges(node_features, sim_threshold))
  edges.update(create_diffusion_edges(node_features, temp, random_prob))
  return list(edges)

In [None]:
def create_graph_data_from_numpy(numpy_array, labels):
  dataset = []
  num_graphs, num_nodes, embedding_dim = numpy_array.shape
  for i in range(num_graphs):
    node_features = torch.tensor(numpy_array[i], dtype=torch.float)
    edges = combine_edges(numpy_array[i])
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    graph_label = torch.tensor(labels[i], dtype=torch.float)
    data = Data(x=node_features, edge_index=edge_index, y=graph_label)
    dataset.append(data)
  return dataset

In [None]:
graph_dataset = create_graph_data_from_numpy(data_array, y_all)

In [None]:
# We can either construct the edges using a threshold cosine similarity

def create_graph_data_from_numpy(numpy_array, labels, threshold=0.5):
    dataset = []
    num_graphs, num_nodes, embedding_dim = numpy_array.shape
    for i in range(num_graphs):
        node_features = torch.tensor(numpy_array[i], dtype=torch.float)
        similarity_matrix = cosine_similarity(numpy_array[i])
        adjacency_matrix = (similarity_matrix > threshold).astype(int)
        np.fill_diagonal(adjacency_matrix, 0)
        edge_index = np.array(np.nonzero(adjacency_matrix))
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        graph_data = Data(x=node_features, edge_index=edge_index)
        graph_label = torch.tensor(labels[i], dtype=torch.float)
        graph_data.y = graph_label
        dataset.append(graph_data)
    return dataset

graph_dataset = create_graph_data_from_numpy(data_array, y_all, threshold=0.82)

In [None]:
# Or we can construct the edges using k-nearest neighbors

def create_graph_data_from_numpy(numpy_array, labels, k=3):
  dataset = []
  num_graphs, num_nodes, embedding_dim = numpy_array.shape
  for i in range(num_graphs):
    node_features = torch.tensor(numpy_array[i], dtype=torch.float)
    adjacency_matrix = kneighbors_graph(numpy_array[i], n_neighbors=k, mode="connectivity", include_self=False)
    edge_index = torch.tensor(np.array(adjacency_matrix.nonzero()), dtype=torch.long)
    graph_data = Data(x=node_features, edge_index=edge_index)
    graph_label = torch.tensor(labels[i], dtype=torch.float)
    graph_data.y = graph_label
    dataset.append(graph_data)
  return dataset

graph_dataset = create_graph_data_from_numpy(data_array, y_all, k=2)

In [None]:
# Visualize one graph
idx = np.random.randint(len(graph_dataset))
graph_data = graph_dataset[idx]
num_nodes = graph_data.x.shape[0]

G = nx.Graph()
for i in range(num_nodes):
  G.add_node(i, feature=graph_data.x[i].numpy())

edge_index = graph_data.edge_index.numpy()
for i in range(edge_index.shape[1]):
    u, v = edge_index[:,i]
    G.add_edge(u, v)

pos = nx.spring_layout(G, seed=0)
nx.draw(G, pos, with_labels=True)
# plt.savefig("example_graph", dpi=200)
plt.show()

In [None]:
def pad_example(matrix, desired_shape=(10,10)):
  r = matrix.shape[0]
  c = matrix.shape[1]

  r_needed, c_needed = desired_shape
  top_pad = (r_needed - r) // 2
  left_pad = (c_needed - c) // 2

  data = np.zeros(desired_shape)
  data[top_pad:top_pad+r, left_pad:left_pad+c] = matrix
  return data

# LLM output embeddings are 10 x 10, but some ARC answers are larger in size
# so for simplicity's sake we just select the ones that are equal or smaller
# in size and pad them if needed
graph_dataset = [i for i in graph_dataset if i.y.shape[0] <= 10 and i.y.shape[1] <= 10]

# Normalize labels because LLM embeddings are normalized
for i in graph_dataset:
  i.y = torch.tensor(pad_example(i.y.numpy()), dtype=torch.float32).view(1,-1) / torch.max(i.y)

In [None]:
train_ratio = 0.8
train_size = int(train_ratio * len(graph_dataset))
val_size = len(graph_dataset) - train_size
train_dataset, val_dataset = random_split(graph_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

## Graph Convolutional Network and Graph Attention Network

In [None]:
class ResGCNBlock(nn.Module):
  def __init__(self, in_c, out_c):
    super(ResGCNBlock, self).__init__()
    self.conv = GCNConv(in_c, out_c)
    self.norm = BatchNorm(out_c)
    self.shortcut = (nn.Linear(in_c, out_c) if in_c != out_c else nn.Identity())

  def forward(self, x, edge_index):
    identity = self.shortcut(x)
    out = self.conv(x, edge_index)
    out = self.norm(out)
    return F.relu(out + identity)

class GraphConvolutionalModel(nn.Module):
  def __init__(self, num_features, hidden_dim, output_dim, num_layers=3):
    super(GraphConvolutionalModel, self).__init__()
    self.layers = nn.ModuleList()

    # ResGCNBlock modules
    self.layers.append(ResGCNBlock(num_features, hidden_dim))
    for i in range(num_layers - 2):
      self.layers.append(ResGCNBlock(hidden_dim, hidden_dim))
    self.layers.append(ResGCNBlock(hidden_dim, hidden_dim))

    # Fully connected tail
    self.fc = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, output_dim)
    )

  def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch
    for layer in self.layers:
      x = layer(x, edge_index)
    x = global_mean_pool(x, batch)
    graph_embedding = self.fc(x)
    return graph_embedding

num_node_features = data_array.shape[2] # 100
hidden_dim = 64
embedding_dim = data_array.shape[2] # 100

model = GraphConvolutionalModel(num_node_features, hidden_dim, embedding_dim)
print(model)

In [None]:
class MultiHeadGAT(nn.Module):
  def __init__(self, num_features, hidden_dim, output_dim, num_heads=4):
    super().__init__()
    self.gat1 = GATConv(num_features, hidden_dim, heads=num_heads)
    self.gat2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1)
    self.fc = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, output_dim)
    )
  def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch
    x = F.relu(self.gat1(x, edge_index))
    x = F.relu(self.gat2(x, edge_index))
    x = global_mean_pool(x, batch)
    return self.fc(x)

num_features = data_array.shape[2]
hidden_dim = 64
output_dim = data_array.shape[2]

model = MultiHeadGAT(num_features, hidden_dim, output_dim)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train_epoch(loader, model, optimizer):
  model.train()
  epoch_loss = 0
  for data in loader:
    optimizer.zero_grad()
    embeddings = model(data)
    loss = F.mse_loss(embeddings, data.y)
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  return epoch_loss / len(loader)

def validate(loader, model):
  model.eval()
  epoch_loss = 0
  with torch.no_grad():
    for data in loader:
      embeddings = model(data)
      loss = F.mse_loss(embeddings, data.y)
      epoch_loss += loss.item()
  return epoch_loss / len(loader)

num_epochs = 400
for epoch in range(1, num_epochs + 1):
  train_loss = train_epoch(train_loader, model, optimizer)
  val_loss = validate(val_loader, model)
  print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [None]:
# Collect all predictions

all_preds = []
all_labels = []
for i in train_loader:
  preds = model(i).detach().numpy()
  labels = i.y
  all_preds.append(preds)
  all_labels.append(labels)
all_preds = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

In [None]:
# Visualizing Performance
def apply_fourier(img):
  f = np.fft.fft2(img.reshape(10, 10))
  fshift = np.fft.fftshift(f)
  magnitude_spectrum = np.log(np.abs(fshift))

  rows, cols = img.reshape(10, 10).shape
  crow, ccol = rows // 2, cols // 2
  mask = np.zeros((rows, cols))
  mask[crow-30:crow+30, ccol-30:ccol+30] = 1

  fshift = fshift * mask
  f_ishift = np.fft.ifftshift(fshift)
  return np.abs(np.fft.ifft2(f_ishift))

num_examples = 10
columns = 5
rows = (num_examples + columns - 1) // columns
fig, axes = plt.subplots(2 * rows, columns, figsize=(columns * 2, rows * 3))
for i in range(num_examples):
  row, col = divmod(i, columns)
  axes[2*row, col].imshow(apply_fourier(preds[i]), cmap="gray")
  axes[2*row, col].set_title(f"Predicted #{i+1}")
  axes[2*row, col].axis("off")
  axes[2*row+1, col].imshow(labels[i].reshape(10, 10), cmap="gray")
  axes[2*row+1, col].set_title(f"True #{i+1}")
  axes[2*row+1, col].axis("off")

for ax in axes.flat[num_examples:]:
  ax.axis("off")

plt.subplots_adjust(wspace=0.0, hspace=0.3)
# plt.savefig("test.png", dpi=300)
plt.show()

## Use PyTorch Hooks to Analyze GCN Activations

In [None]:
class GCNVisualizer(GraphConvolutionalModel):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.activations = []
    for conv in self.convs:
      conv.register_forward_hook(self.save_activations)
  def save_activations(self, module, input, output):
    self.activations.append(output)
  def reset_activations(self):
    self.activations = []

visualizer = GCNVisualizer(num_node_features=100, hidden_dim=64, embedding_dim=100, num_layers=3)

### For one example, test that GCNVisualizer works, and plot node activations, collapsing by hidden dimension.

In [None]:
idx = 1
visualizer.reset_activations()
visualizer(graph_dataset[idx])

In [None]:
mean_activations = np.mean(visualizer.activations[0].detach().numpy(), axis=1)
data = pd.DataFrame({"Node Index": range(len(mean_activations)), "Mean Activations": mean_activations})
data_melted = data.melt(id_vars=["Node Index"], value_vars=["Mean Activations"], var_name="Metric", value_name="Value")
plt.figure(figsize=(10, 5))
sns.barplot(data=data_melted, x="Node Index", y="Value", hue="Metric", palette="muted")
plt.title("Mean of Node Activations")
plt.xlabel("Node Index")
plt.ylabel("Activation Value")
# plt.savefig("mean_activation_per_node.png", dpi=200)
plt.tight_layout()
plt.show()

In [None]:
data = pd.DataFrame(visualizer.activations[0].detach().numpy())
data_melted = data.melt(var_name="Node Index", value_name="Activation Value")
plt.figure(figsize=(10, 6))
sns.boxplot(data=data.T, orient="h", palette="muted", showmeans=True)

plt.title("Activation Distributions per Node")
plt.xlabel("Activation Value")
plt.ylabel("Node Index")
# plt.savefig("activation_per_node.png", dpi=200)
plt.show()

In [None]:
def plot_node_activations(activations, labels=None):
  tsne = TSNE(n_components=2, perplexity=29)
  reduced_activations = tsne.fit_transform(activations)
  plt.figure(figsize=(8, 6))
  scatter = plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=labels, cmap='Spectral', alpha=0.7)
  if labels is not None:
    plt.colorbar(scatter)
  plt.title("Node Activations")
  plt.show()

node_labels = [0] * 6 + [1] * 6 + [2] * 6 + [3] * 6 + [4] * 6
plot_node_activations(visualizer.activations[2].detach().cpu().numpy(), labels=node_labels)

# assign clusters as labels
kmeans = KMeans(n_clusters=5)
clusters = kmeans.fit_predict(visualizer.activations[2].detach().cpu().numpy())
plot_node_activations(visualizer.activations[2].detach().cpu().numpy(), labels=clusters)

## Reasoning Framework
We can interpret the activation values from GCN and determine node importance (LLM agent importance), ARC problem similarity, etc.

In [None]:
def calculate_mean_activations(dataset):
  # one entry for each layer of the model (three total)
  all_mean_activations = {"1": [], "2": [], "3": []}
  for i in range(len(dataset)):
    visualizer.reset_activations()
    _ = visualizer(dataset[i])
    for j in range(len(visualizer.activations)):
      mean_activations = np.mean(visualizer.activations[0].detach().numpy(), axis=1)
      all_mean_activations[str(j+1)].append(mean_activations)
  return all_mean_activations

In [None]:
train_mean_activations = calculate_mean_activations(train_dataset)
val_mean_activations = calculate_mean_activations(val_dataset)

In [None]:
def plot_activations(activations, normalized=True, save=False, title=None):
  sns.set(font_scale=0.6)
  if normalized:
    fig = sns.clustermap(np.array(activations["3"]) == np.max(np.array(activations["3"]), axis=1, keepdims=True),
                        cmap="Blues", figsize=(20,20))
  else:
    fig = sns.clustermap(np.array(activations["3"]))
  fig.ax_heatmap.tick_params(axis='x', labelsize=10, width=0.5)  # Smaller x-axis labels
  fig.ax_heatmap.set_xlabel("Node")
  fig.ax_heatmap.set_ylabel("ARC Problem Number")
  if save:
    fig.fig.savefig(title, dpi=300, bbox_inches="tight")
  plt.show()

In [None]:
plot_activations(train_mean_activations, normalized=True, save=True, title="train_max_activations.png")

In [None]:
plot_activations(val_mean_activations, normalized=True, save=True, title="val_max_activations.png")

In [None]:
plot_activations(train_mean_activations, normalized=False, save=True, title="train_activations.png")

In [None]:
plot_activations(val_mean_activations, normalized=False, save=True, title="val_activations.png")

## DiffPool

In [None]:
# DiffPool needs a different data set structure and data loader
# Adjacency matrices need to be dense and node features need to have an extra dimension

class GraphDataset(torch.utils.data.Dataset):
  def __init__(self, graph_dataset):
    self.graphs_and_labels = graph_dataset
  def __len__(self):
    return len(self.graphs_and_labels)
  def __getitem__(self, idx):
    graph = self.graphs_and_labels[idx]
    # Convert to dense format
    x = graph.x
    y = graph.y
    adj = to_dense_adj(graph.edge_index, batch=None, max_num_nodes=30).squeeze(0)
    return x, adj, y

def collate(batch):
    x_list, adj_list, label_list = zip(*batch)
    x_batch = torch.stack(x_list)  # Shape:[batch_size, num_nodes, feature_dim]
    adj_batch = torch.stack(adj_list)  # Shape: [batch_size, num_nodes, num_nodes]
    labels = torch.tensor(label_list, dtype=torch.long)
    return x_batch, adj_batch, labels

diffpool_dataset = GraphDataset(graph_dataset)
diffpool_loader = DataLoader(diffpool_dataset, batch_size=32, shuffle=True, collate_fn=collate)

In [None]:
class DiffPoolGCN(nn.Module):
  def __init__(self, num_node_features, hidden_dim, embedding_dim, num_clusters_list):
    super(DiffPoolGCN, self).__init__()
    self.num_layers = len(num_clusters_list)
    self.hidden_dim = hidden_dim

    # GCNConv layers for message passing
    self.gcns_embed = nn.ModuleList([DenseGCNConv(num_node_features if i == 0 else hidden_dim, hidden_dim)
                                     for i in range(self.num_layers)])

    # GCNConv layers for cluster assignment
    self.gcns_assign = nn.ModuleList([DenseGCNConv(hidden_dim, num_clusters_list[i]) for i in range(self.num_layers)])
    self.fc = nn.Sequential(
        nn.Linear(hidden_dim * self.num_layers, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, embedding_dim)
    )

  def forward(self, x_batch, adj_batch):
    all_embeds = []
    s_matrices = []
    total_loss = 0
    for i in range(self.num_layers):
      x_batch = self.gcns_embed[i](x_batch, adj_batch).relu()

      # Cluster assignment
      s = self.gcns_assign[i](x_batch, adj_batch)
      s = torch.softmax(s, dim=-1)
      s_matrices.append(s)

      x_batch, adj_batch, _, entropy_reg = dense_diff_pool(x_batch, adj_batch, s)
      total_loss += entropy_reg

      # Graph-level representation
      all_embeds.append(torch.mean(x_batch, dim=1))
    graph_embedding = torch.cat(all_embeds, dim=-1)
    graph_embedding = self.fc(graph_embedding)
    return graph_embedding, total_loss, s_matrices

num_node_features = data_array.shape[2]
hidden_dim = 64
embedding_dim = data_array.shape[2]
num_clusters_list = [15, 5]

model = DiffPoolGCN(num_node_features=num_node_features, hidden_dim=hidden_dim, embedding_dim=embedding_dim, num_clusters_list=num_clusters_list)

In [None]:
train_ratio = 0.8
train_size = int(train_ratio * len(diffpool_dataset))
val_size = len(diffpool_dataset) - train_size
train_dataset, val_dataset = random_split(diffpool_dataset, [train_size, val_size])

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

for epoch in range(300):
  model.train()
  total_loss = 0
  for x_batch, adj_batch, labels in train_dataset:
    optimizer.zero_grad()
    graph_embeddings, diffpool_loss, _ = model(x_batch, adj_batch)
    objective_loss = F.mse_loss(graph_embeddings, labels.squeeze(1))
    loss = objective_loss + 0.1 * diffpool_loss
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

  model.eval()
  with torch.no_grad():
    val_loss = 0
    for x_batch, adj_batch, labels in val_dataset:
      graph_embeddings, diffpool_loss, _ = model(x_batch, adj_batch)
      objective_loss = F.mse_loss(graph_embeddings, labels.squeeze(1))
      loss = objective_loss + 0.1 * diffpool_loss
      val_loss += diffpool_loss.item()

  print(f"Epoch {epoch + 1}, Train Loss: {total_loss:.4f}, Val Loss: {val_loss:.4f}")

In [None]:
x_batch, adj_batch, label_batch = next(iter(diffpool_loader))
model.eval()

graph_embedding, total_loss, s_matrices = model(x_batch, adj_batch)

In [None]:
s_matrix = s_matrices[0]

graph_idx = 5
s_single = s_matrix[graph_idx].detach().cpu().numpy() # [num_nodes, num_clusters]
plt.figure(figsize=(10, 8))
sns.heatmap(s_single, annot=False)
plt.show()