In [None]:
from __future__ import annotations

import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv

import numpy as np
import pandas as pd

torch.set_printoptions(sci_mode=False)
print('torch:', torch.__version__)


In [None]:
# Paths (relative to Test-2/)
PROCESSED_DIR = Path('../Test-2/data/processed/9f751fe859f07b5c')
CKPT_PATH = Path('../Test-2/data/models/iter4_gat.pth')

assert PROCESSED_DIR.exists(), f'Missing processed dir: {PROCESSED_DIR.resolve()}'
print('processed dir:', PROCESSED_DIR.resolve())
print('checkpoint exists:', CKPT_PATH.exists(), '|', CKPT_PATH.resolve())

In [None]:
# Load processed tensors/loaders
edge_index = torch.load(PROCESSED_DIR / 'edge_index.pt', weights_only=False)
edge_attr = torch.load(PROCESSED_DIR / 'edge_attr.pt', weights_only=False)
static_features = torch.load(PROCESSED_DIR / 'static_features.pt', weights_only=False)
sensor_mask = torch.load(PROCESSED_DIR / 'sensor_mask.pt', weights_only=False)
train_loader = torch.load(PROCESSED_DIR / 'train_loader.pt', weights_only=False)
val_loader = torch.load(PROCESSED_DIR / 'val_loader.pt', weights_only=False)
test_loader = torch.load(PROCESSED_DIR / 'test_loader.pt', weights_only=False)

with open(PROCESSED_DIR / 'sensor_name_to_id_map.json', 'r') as f:
    name_to_id = json.load(f)
id_to_name = {int(v): k for k, v in name_to_id.items()}

print('edge_index:', tuple(edge_index.shape))
print('edge_attr :', tuple(edge_attr.shape))
print('static_features:', tuple(static_features.shape))
print('sensor_mask:', tuple(sensor_mask.shape), '| dtype:', sensor_mask.dtype)
print('train batches:', len(train_loader), 'val:', len(val_loader), 'test:', len(test_loader))


In [None]:
# The training notebook uses these scalers for unnormalizing targets before ZINB loss.
# Keep them here so forward() can be called the same way if you want loss values.
SCALER_MU = 14.323774337768555
SCALER_SIGMA = 34.9963493347168


In [None]:
# Select 3 random days from test set for temporal analysis
# Assuming data is sampled at 15-minute intervals, 96 samples per day
SAMPLES_PER_DAY = 96
NUM_DAYS = 3
BATCH_SIZE = 16

# Get total samples in test set
total_test_samples = len(test_loader) * BATCH_SIZE
total_days = total_test_samples // SAMPLES_PER_DAY

print(f"Total test samples: {total_test_samples}")
print(f"Total days in test set: {total_days}")
print(f"Samples per day: {SAMPLES_PER_DAY}")

# Randomly select 3 days
import random
random.seed(42)  # For reproducibility
selected_days = sorted(random.sample(range(total_days), min(NUM_DAYS, total_days)))
selected_day_ranges = [(day * SAMPLES_PER_DAY, (day + 1) * SAMPLES_PER_DAY) for day in selected_days]

print(f"\nSelected days: {selected_days}")
print(f"Sample ranges: {selected_day_ranges}")

## Sample Selection for Temporal Analysis

Select a random subset of days from the test set for temporal attention analysis.


In [None]:
# Same helper as in iter2_gat.ipynb (needed to match checkpoint state_dict structure)
def prepare_hybrid_loader(loader, batch_size: int):
    all_batches = [(X, y) for X, y in loader]

    # Temporal component (unused by this Iter2 model, but kept for dataset structure)
    X_temporal_list = [(X[:, :, :, -1:]) for X, _ in all_batches]
    X_temporal = torch.cat(X_temporal_list, dim=0)

    # Spatial component for GAT
    X_agg_list = [(X[:, 0:1, :, :-1]) for X, _ in all_batches]  # 9 aggregated stats
    X_raw_list = [(X[:, :, :, -1:].permute(0, 3, 2, 1)) for X, _ in all_batches]  # 12 raw timesteps
    X_agg = torch.cat(X_agg_list, dim=0)
    X_raw = torch.cat(X_raw_list, dim=0)
    X_spatial = torch.cat([X_agg, X_raw], dim=3)  # 9 + 12 = 21

    y_list = [y[:, 0:1, :, :] for _, y in all_batches]
    y_target = torch.cat(y_list, dim=0)

    dataset = torch.utils.data.TensorDataset(X_spatial, X_temporal, y_target)
    hybrid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    return hybrid_loader

train_h = prepare_hybrid_loader(train_loader, batch_size=16)
val_h = prepare_hybrid_loader(val_loader, batch_size=16)
test_h = prepare_hybrid_loader(test_loader, batch_size=16)

X_spatial, X_temporal, y = next(iter(test_h))
print('X_spatial:', tuple(X_spatial.shape), 'X_temporal:', tuple(X_temporal.shape), 'y:', tuple(y.shape))

In [None]:
# Model (state_dict-compatible with Test-2/iter2_gat.ipynb)
def _finite_stats(name, t: torch.Tensor | None) -> bool:
    if t is None:
        print(f'[DEBUG] {name}: None')
        return False
    is_finite = torch.isfinite(t)
    if not is_finite.all():
        n_nan = torch.isnan(t).sum().item()
        n_inf = torch.isinf(t).sum().item()
        print(f'[NON-FINITE] {name}: nan={n_nan}, inf={n_inf}, shape={tuple(t.shape)}')
        return False
    return True

def _check(name, t: torch.Tensor | None):
    _finite_stats(name, t)

class DynamicNodeGATZINB(nn.Module):
    def __init__(self, dynamic_node_dim, static_node_dim, edge_dim, n_embd, n_heads, dropout_rate):
        super().__init__()

        dynamic_input_dim = dynamic_node_dim * 2
        gat1_input_channels = dynamic_input_dim + static_node_dim

        self.gat1 = GATv2Conv(
            in_channels=gat1_input_channels,
            out_channels=n_embd,
            edge_dim=edge_dim,
            heads=n_heads,
            concat=False,
            dropout=dropout_rate,
        )
        self.gat2 = GATv2Conv(
            in_channels=n_embd,
            out_channels=n_embd,
            edge_dim=edge_dim,
            heads=n_heads,
            concat=False,
            dropout=dropout_rate,
        )
        self.norm = nn.LayerNorm(n_embd)

        self.mu_head = nn.Linear(n_embd, 1)
        self.theta_head = nn.Linear(n_embd, 1)
        self.pi_head = nn.Linear(n_embd, 1)

        self.elu = nn.ELU()
        self.dropout = nn.Dropout(dropout_rate)

        # NO register_buffer calls!

    def forward(self, X_batch, targets, node_mask, edge_index, edge_attr, static_node_features, return_attention=False):
        mu_list, theta_list, pi_list = [], [], []
        attn = {}

        missing_mask = torch.isnan(X_batch)
        imputed_X = torch.nan_to_num(X_batch, nan=0.0)
        mask_features = missing_mask.float()
        combined_input = torch.cat([imputed_X, mask_features], dim=-1)

        B = combined_input.shape[0]
        Xb = combined_input[:, 0, :, :]

        for b in range(B):
            combined_features = torch.cat([Xb[b], static_node_features], dim=-1)
            xb = self.dropout(combined_features)

            if return_attention:
                xb, (ei1, alpha1) = self.gat1(xb, edge_index, edge_attr, return_attention_weights=True)
                attn.setdefault('layer1', []).append((ei1, alpha1))
            else:
                xb = self.gat1(xb, self.edge_index, self.edge_attr)

            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            if return_attention:
                xb, (ei2, alpha2) = self.gat2(xb, edge_index,
                edge_attr, return_attention_weights=True)
                attn.setdefault('layer2', []).append((ei2, alpha2))
            else:
                xb = self.gat2(xb, self.edge_index, self.edge_attr)

            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            mu_b = F.softplus(self.mu_head(xb)) + 1e-6
            theta_b = F.softplus(self.theta_head(xb)) + 1e-6
            pi_b = torch.sigmoid(self.pi_head(xb))
            pi_b = torch.clamp(pi_b, min=1e-6, max=1 - 1e-6)

            mu_list.append(mu_b)
            theta_list.append(theta_b)
            pi_list.append(pi_b)

        mu = torch.stack(mu_list, dim=0).unsqueeze(1)
        theta = torch.stack(theta_list, dim=0).unsqueeze(1)
        pi = torch.stack(pi_list, dim=0).unsqueeze(1)

        preds = mu * (1 - pi)

        if targets is None:
            zinb_nll_loss = None
            mse_loss = None
            huber_loss = None
            valid_sum = torch.tensor(0.0, device=preds.device)
        else:
            zinb_nll_loss, valid_sum = self.zinb_nll_loss(mu, theta, pi, targets, node_mask)
            mse_loss, _ = self.mse_loss(preds, targets, node_mask)
            huber_loss, _ = self.mse_loss(preds, targets, node_mask)

        extra = {'mu': mu, 'theta': theta, 'pi': pi, 'valid_sum': valid_sum}
        if return_attention:
            extra['attn'] = attn

        return preds, zinb_nll_loss, mse_loss, huber_loss, extra

    def zinb_nll_loss(self, mu, theta, pi, targets, node_mask):
        eps = 1e-8
        nan_mask = ~torch.isnan(targets)
        valid_mask = nan_mask & node_mask if node_mask is not None else nan_mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=mu.device, requires_grad=True), torch.tensor(0.0, device=mu.device)

        mu_valid = mu[valid_mask]
        theta_valid = theta[valid_mask]
        pi_valid = pi[valid_mask]
        targets_valid = targets[valid_mask]

        theta_mu = theta_valid + mu_valid
        nb_log_prob = (
            torch.lgamma(theta_valid + targets_valid + eps)
            - torch.lgamma(theta_valid + eps)
            - torch.lgamma(targets_valid + 1)
            + theta_valid * torch.log(theta_valid + eps)
            - theta_valid * torch.log(theta_mu + eps)
            + targets_valid * torch.log(mu_valid + eps)
            - targets_valid * torch.log(theta_mu + eps)
        )

        zero_mask = (targets_valid < eps).float()
        nb_zero_prob = theta_valid * torch.log(theta_valid / (theta_mu + eps))
        zero_log_prob = torch.log(pi_valid + (1 - pi_valid) * torch.exp(nb_zero_prob) + eps)
        non_zero_log_prob = torch.log(1 - pi_valid + eps) + nb_log_prob
        log_prob = zero_mask * zero_log_prob + (1 - zero_mask) * non_zero_log_prob
        nll = -log_prob.mean()
        return nll, valid_mask.sum()

    def mse_loss(self, predictions, targets, node_mask):
        nan_mask = ~torch.isnan(targets)
        valid_mask = nan_mask & node_mask if node_mask is not None else nan_mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0
        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]
        mse_loss = ((targets_valid - preds_valid) ** 2).mean()
        return mse_loss, valid_mask.sum()

    def huber_loss(self, predictions, targets, node_mask, delta=1.0):
        """
        Huber loss (smooth L1 loss) - less sensitive to outliers than MSE.

        Args:
            predictions: Model predictions
            targets: Ground truth values
            node_mask: Mask for valid nodes
            delta: Threshold at which to switch from quadratic to linear loss
        """
        nan_mask = ~torch.isnan(targets)
        valid_mask = nan_mask & node_mask if node_mask is not None else nan_mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0

        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]

        # Huber loss formula
        diff = torch.abs(targets_valid - preds_valid)
        huber = torch.where(
            diff < delta,
            0.5 * diff ** 2,
            delta * (diff - 0.5 * delta)
        )
        huber_loss = huber.mean()

        return huber_loss, valid_mask.sum()

In [None]:
print("Training expectations:")
print("  Dynamic features: 21")
print("  Static features: Should be 13")
print(f"\nValidation data:")
print(f"  Dynamics features shape: {X_spatial.shape}")
print(f"  Static features shape: {static_features.shape}")
print(f"  Expected input to GAT1: {21*2 + static_features.shape[1]}")
print(f"  Model GAT1 expects: 55")

In [None]:
# Instantiate + load checkpoint
device = torch.device('cpu')

n_embd = 32
n_heads = 4
dropout = 0.1

# Update model instantiation (remove graph data from constructor):
model = DynamicNodeGATZINB(
    dynamic_node_dim=21,
    static_node_dim=static_features.shape[1],
    edge_dim=edge_attr.shape[1],
    n_embd=n_embd,
    n_heads=n_heads,
    dropout_rate=dropout,
).to(device)

# Load checkpoint - this will work if layer dimensions match
if CKPT_PATH.exists():
    state = torch.load(CKPT_PATH, map_location=device, weights_only=False)
    missing, unexpected = model.load_state_dict(state, strict=False)
    print('Loaded checkpoint:', CKPT_PATH)

model.eval()

In [None]:
# Attention aggregation utilities
from collections import defaultdict
import pandas as pd
import numpy as np

def _alpha_mean_per_edge(alpha: torch.Tensor) -> torch.Tensor:
    """Return per-edge attention scalar by averaging over heads."""
    # Expected shapes: [E, heads] or [E] (rare)
    if alpha.dim() == 2:
        return alpha.mean(dim=1)
    return alpha.view(alpha.shape[0], -1).mean(dim=1)

@torch.no_grad()
def collect_attention_stats(loader, max_batches=10, layer='layer2', device='cpu'):
    """Collect averaged attention statistics across batches."""
    device = torch.device(device)
    model.eval()
    sums = defaultdict(float)
    counts = defaultdict(int)

    n_batches = 0
    for X_spatial, _, y_batch in loader:
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()

        # Pass graph data explicitly:
        _, _, _, _, extra = model(
            X_batch=X_batch,
            targets=y_raw_int,
            node_mask=None,
            edge_index=edge_index,
            edge_attr=edge_attr,
            static_node_features=static_features,
            return_attention=True
        )

        attn_dict = extra.get('attn', {})
        pairs = attn_dict.get(layer, [])
        # pairs is a list over samples in the batch (because model loops over B)
        for ei, alpha in pairs:
            ei = ei.detach().cpu()
            alpha = alpha.detach().cpu()
            w = _alpha_mean_per_edge(alpha)
            src = ei[0].to(torch.long)
            dst = ei[1].to(torch.long)
            for s, d, ww in zip(src.tolist(), dst.tolist(), w.tolist()):
                key = (int(s), int(d))
                sums[key] += float(ww)
                counts[key] += 1
        n_batches += 1
        if n_batches >= max_batches:
            break
    rows = []
    for (s, d), total in sums.items():
        c = counts[(s, d)]
        rows.append({
            'src': s,
            'dst': d,
            'attn_mean': total / max(c, 1),
            'count': c,
            'src_name': id_to_name.get(s, str(s)),
            'dst_name': id_to_name.get(d, str(d)),
        })
    df = pd.DataFrame(rows)
    if len(df) == 0:
        return df
    df = df.sort_values(['attn_mean', 'count'], ascending=[False, False]).reset_index(drop=True)
    return df

@torch.no_grad()
def collect_temporal_attention(loader, sample_indices, layer='layer2', device='cpu'):
    """
    Collect attention scores for specific samples (temporal snapshots).

    Args:
        loader: DataLoader
        sample_indices: List of sample indices to extract
        layer: Which attention layer to extract
        device: Device to run on

    Returns:
        List of dicts, each containing attention for one sample
    """
    device = torch.device(device)
    model.eval()

    temporal_data = []
    current_sample_idx = 0

    for X_spatial, _, y_batch in loader:
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()

        batch_size = X_batch.shape[0]

        # Pass graph data explicitly:
        _, _, _, _, extra = model(
            X_batch=X_batch,
            targets=y_raw_int,
            node_mask=None,
            edge_index=edge_index,
            edge_attr=edge_attr,
            static_node_features=static_features,
            return_attention=True
        )

        attn_dict = extra.get('attn', {})
        pairs = attn_dict.get(layer, [])

        # Process each sample in the batch
        for batch_idx, (ei, alpha) in enumerate(pairs):
            sample_idx = current_sample_idx + batch_idx

            # Only save if this sample is in our selected indices
            if sample_idx in sample_indices:
                ei = ei.detach().cpu()
                alpha = alpha.detach().cpu()
                w = _alpha_mean_per_edge(alpha)

                # Build edge dictionary for this sample
                edges = {}
                src = ei[0].to(torch.long)
                dst = ei[1].to(torch.long)

                for s, d, ww in zip(src.tolist(), dst.tolist(), w.tolist()):
                    edge_key = f"{int(s)}_{int(d)}"
                    edges[edge_key] = {
                        'source': int(s),
                        'target': int(d),
                        'source_name': id_to_name.get(int(s), str(s)),
                        'target_name': id_to_name.get(int(d), str(d)),
                        'score': float(ww)
                    }

                temporal_data.append({
                    'sample_idx': sample_idx,
                    'edges': edges
                })

        current_sample_idx += batch_size

        # Stop if we've passed all selected samples
        if current_sample_idx > max(sample_indices):
            break

    return temporal_data

def top_incoming(df: pd.DataFrame, node_idx: int, k: int = 15) -> pd.DataFrame:
    """Top-k incoming edges by attention into node_idx."""
    out = df[df['dst'] == int(node_idx)].copy()
    return out.sort_values('attn_mean', ascending=False).head(k).reset_index(drop=True)

def top_outgoing(df: pd.DataFrame, node_idx: int, k: int = 15) -> pd.DataFrame:
    """Top-k outgoing edges by attention from node_idx."""
    out = df[df['src'] == int(node_idx)].copy()
    return out.sort_values('attn_mean', ascending=False).head(k).reset_index(drop=True)

print('Ready to collect attention statistics.')

In [None]:
# Collect averaged attention statistics (for static view)
MAX_BATCHES = 25

attn_l1 = collect_attention_stats(test_h, max_batches=MAX_BATCHES, layer='layer1', device=device)
attn_l2 = collect_attention_stats(test_h, max_batches=MAX_BATCHES, layer='layer2', device=device)

print('layer1 edges:', len(attn_l1), '| layer2 edges:', len(attn_l2))

# Top edges globally
display(attn_l1.head(20))
display(attn_l2.head(20))

# Pick a node to inspect (0 is fine, or change to any sensor id)
node_idx = 0
print('Node:', node_idx, id_to_name.get(int(node_idx), str(node_idx)))

print('\nTop incoming (layer2):')
display(top_incoming(attn_l2, node_idx=node_idx, k=20))

print('\nTop outgoing (layer2):')
display(top_outgoing(attn_l2, node_idx=node_idx, k=20))

In [None]:
# Collect temporal attention for selected days
# Sample every 4 timesteps (1 hour intervals if data is 15-min)
SAMPLE_INTERVAL = 4

selected_sample_indices = []
for start, end in selected_day_ranges:
    day_samples = list(range(start, end, SAMPLE_INTERVAL))
    selected_sample_indices.extend(day_samples)

print(f"Collecting attention for {len(selected_sample_indices)} samples")
print(f"Sample indices range: {min(selected_sample_indices)} to {max(selected_sample_indices)}")

# Collect temporal attention for both layers
print("\nCollecting Layer 1 temporal attention...")
temporal_l1 = collect_temporal_attention(test_h, selected_sample_indices, layer='layer1', device=device)
print(f"Collected {len(temporal_l1)} temporal snapshots for layer1")

print("\nCollecting Layer 2 temporal attention...")
temporal_l2 = collect_temporal_attention(test_h, selected_sample_indices, layer='layer2', device=device)
print(f"Collected {len(temporal_l2)} temporal snapshots for layer2")

# Sort by sample index
temporal_l1 = sorted(temporal_l1, key=lambda x: x['sample_idx'])
temporal_l2 = sorted(temporal_l2, key=lambda x: x['sample_idx'])

## Collect Temporal Attention Scores

Collect attention scores at different time points for temporal visualization.


In [None]:
# Evaluate model with all metrics (ZINB NLL, MSE, and Huber Loss)
@torch.no_grad()
def detailed_evaluation(model, data_loader, device, split_name="Val"):
    """Evaluate model with ZINB NLL, MSE, and Huber Loss."""
    device = torch.device(device)
    model.eval()

    num_nodes = None
    node_valid_counts = None
    node_total_counts = None

    total_zinb_nll = 0.0
    total_mse = 0.0
    total_huber = 0.0
    num_batches = 0

    all_preds = []
    all_params = []

    for X_spatial, _, y_batch in data_loader:
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)

        if num_nodes is None:
                num_nodes = y_batch.shape[2]
                node_valid_counts = torch.zeros(num_nodes)
                node_total_counts = torch.zeros(num_nodes)

        # 4. Unnormalize the target variable
        # y_raw = (y_batch_normalized * sigma) + mu
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name

        # 5. Round to nearest integer and cast to long
        # This is ESSENTIAL for the ZINB loss function
        y_raw_int = torch.round(y_raw).long()

        preds, zinb_nll, mse, huber, params = model(
            X_batch=X_batch,
            targets=y_raw_int,
            node_mask=None,
            edge_index=edge_index,      # ← Add this
            edge_attr=edge_attr,        # ← Add this
            static_node_features=static_features,  # ← Add this
            return_attention=True
        )

        all_preds.append(preds.cpu())
        all_params.append(params)

        if zinb_nll is not None:
            total_zinb_nll += zinb_nll.item()
            total_mse += mse.item()
            total_huber += huber.item()
            num_batches += 1

            nan_mask = ~torch.isnan(y_batch)
            node_valid_counts += nan_mask.sum(dim=(0, 1, 3)).cpu()
            node_total_counts += torch.ones_like(node_valid_counts) * y_batch.shape[0]

    if num_batches > 0:
        avg_zinb_nll = total_zinb_nll / num_batches
        avg_mse = total_mse / num_batches
        avg_huber = total_huber / num_batches

        print(f'\n{split_name} Set Metrics:')
        print(f'  ZINB NLL:   {avg_zinb_nll:.4f}')
        print(f'  MSE:        {avg_mse:.4f}')
        print(f'  Huber Loss: {avg_huber:.4f}')
        print(f'  Batches:    {num_batches}')

        # Per-node statistics
        node_valid_pct = (node_valid_counts / node_total_counts * 100)
        print(f"\n  Node Validity Statistics:")
        print(f"    Min: {node_valid_pct.min():.1f}%")
        print(f"    Max: {node_valid_pct.max():.1f}%")
        print(f"    Mean: {node_valid_pct.mean():.1f}%")
        print(f"    Nodes with 100% valid: {(node_valid_pct == 100).sum().item()}/{num_nodes}")
        print(f"    Nodes with <50% valid: {(node_valid_pct < 50).sum().item()}/{num_nodes}")

        return all_preds, all_params

# Run evaluation on all splits
train_metrics = detailed_evaluation(model, train_h, device=device, split_name='Train')
val_metrics = detailed_evaluation(model, val_h, device=device, split_name='Validation')
test_metrics = detailed_evaluation(model, test_h, device=device, split_name='Test')

In [None]:
# Save attention scores to JSON for visualization
output_data = {
    'layer1': {},
    'layer2': {},
    'nodes': id_to_name,
    'temporal': {
        'layer1': temporal_l1,
        'layer2': temporal_l2,
        'samples_per_day': SAMPLES_PER_DAY,
        'sample_interval': SAMPLE_INTERVAL,
        'selected_days': selected_days
    }
}

# Convert layer1 attention to JSON format (averaged)
for _, row in attn_l1.iterrows():
    edge_key = f"{row['src']}_{row['dst']}"
    output_data['layer1'][edge_key] = {
        'source': int(row['src']),
        'target': int(row['dst']),
        'source_name': row['src_name'],
        'target_name': row['dst_name'],
        'score': float(row['attn_mean'])
    }

# Convert layer2 attention to JSON format (averaged)
for _, row in attn_l2.iterrows():
    edge_key = f"{row['src']}_{row['dst']}"
    output_data['layer2'][edge_key] = {
        'source': int(row['src']),
        'target': int(row['dst']),
        'source_name': row['src_name'],
        'target_name': row['dst_name'],
        'score': float(row['attn_mean'])
    }

# Save to JSON file
with open('attention_scores.json', 'w') as f:
    json.dump(output_data, f, indent=2)

print(f'Attention scores saved to: attention_scores.json')
print(f'Layer 1 edges (averaged): {len(output_data["layer1"])}')
print(f'Layer 2 edges (averaged): {len(output_data["layer2"])}')
print(f'Temporal snapshots layer 1: {len(temporal_l1)}')
print(f'Temporal snapshots layer 2: {len(temporal_l2)}')
print(f'Nodes: {len(output_data["nodes"])}')

In [None]:
# Visualize attention scores on graph
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

def plot_attention_graph(attn_df, layer_name='Layer 2', top_k=100, figsize=(20, 16)):
    """
    Plot a graph visualization with edges colored by attention scores.

    Parameters:
    - attn_df: DataFrame with columns src, dst, attn_mean, src_name, dst_name
    - layer_name: Name for the plot title
    - top_k: Number of top edges to display (to avoid clutter)
    - figsize: Figure size tuple
    """
    # Take top_k edges by attention score
    df_plot = attn_df.head(top_k).copy()

    # Create directed graph
    G = nx.DiGraph()

    # Add nodes with names
    all_nodes = set(df_plot['src'].tolist() + df_plot['dst'].tolist())
    for node_id in all_nodes:
        node_name = id_to_name.get(int(node_id), str(node_id))
        G.add_node(node_id, label=node_name)

    # Add edges with attention weights
    edge_weights = []
    for _, row in df_plot.iterrows():
        G.add_edge(row['src'], row['dst'], weight=row['attn_mean'])
        edge_weights.append(row['attn_mean'])

    # Setup figure
    fig, ax = plt.subplots(figsize=figsize)

    # Layout - using spring layout for better visualization
    # You can try: spring_layout, kamada_kawai_layout, circular_layout
    pos = nx.spring_layout(G, k=2, iterations=50, seed=42)

    # Normalize edge weights for coloring
    norm = Normalize(vmin=min(edge_weights), vmax=max(edge_weights))
    cmap = plt.cm.YlOrRd  # Yellow to Red colormap

    # Draw nodes
    nx.draw_networkx_nodes(
        G, pos,
        node_color='lightblue',
        node_size=800,
        alpha=0.9,
        ax=ax
    )

    # Draw edges with colors based on attention scores
    edges = G.edges()
    colors = [G[u][v]['weight'] for u, v in edges]

    nx.draw_networkx_edges(
        G, pos,
        edgelist=edges,
        edge_color=colors,
        edge_cmap=cmap,
        edge_vmin=min(edge_weights),
        edge_vmax=max(edge_weights),
        width=2,
        alpha=0.6,
        arrows=True,
        arrowsize=15,
        arrowstyle='->',
        connectionstyle='arc3,rad=0.1',
        ax=ax
    )

    # Draw labels
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(
        G, pos,
        labels=labels,
        font_size=8,
        font_weight='bold',
        ax=ax
    )

    # Add colorbar
    sm = ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Attention Score', rotation=270, labelpad=20, fontsize=12)

    ax.set_title(f'{layer_name} Attention Scores (Top {top_k} Edges)',
                 fontsize=16, fontweight='bold', pad=20)
    ax.axis('off')
    plt.tight_layout()

    return fig, ax, G

# Plot both layers
print('Plotting Layer 1 attention...')
fig1, ax1, G1 = plot_attention_graph(attn_l1, layer_name='Layer 1', top_k=100)
plt.savefig('attention_layer1_graph.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nPlotting Layer 2 attention...')
fig2, ax2, G2 = plot_attention_graph(attn_l2, layer_name='Layer 2', top_k=100)
plt.savefig('attention_layer2_graph.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\nGraphs saved as attention_layer1_graph.png and attention_layer2_graph.png')
print(f'Layer 1: {G1.number_of_nodes()} nodes, {G1.number_of_edges()} edges')
print(f'Layer 2: {G2.number_of_nodes()} nodes, {G2.number_of_edges()} edges')

In [None]:
def format_params(params):
    params_cat = {"mu":[], "theta":[], "pi": []}
    for batch in params:
        for key in params_cat.keys():
            value = batch[key]
            params_cat[key].append(value)

    for key in params_cat.keys():
        params_cat[key] = torch.cat(params_cat[key], dim=0)

    # Return as tuple for easy unpacking
    return params_cat["mu"], params_cat["theta"], params_cat["pi"]