In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
from torch_geometric_temporal.dataset import METRLADatasetLoader

In [None]:
def scale_metrla_data(data):
    """
    Standardize METR-LA data across time for each node and feature.

    Args:
        data: torch.Tensor or np.ndarray of shape [T, N, F]

    Returns:
        scaled_data: same shape, normalized
        mean: shape [N, F]
        std: shape [N, F]
    """
    if isinstance(data, torch.Tensor):
        data = data.numpy()

    # Mean and std over time axis, per node and feature
    mean = np.mean(data, axis=0)  # [N, F]
    std = np.std(data, axis=0) + 1e-6  # [N, F] avoid div by zero

    # Broadcast across time
    scaled = (data - mean) / std

    return torch.tensor(scaled, dtype=torch.float32), torch.tensor(mean), torch.tensor(std)


In [None]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.attn_layers = nn.ModuleList([
            nn.Linear(in_dim, out_dim) for _ in range(num_heads)
        ])
        self.attn_scores = nn.ParameterList([
            nn.Parameter(torch.Tensor(out_dim * 2)) for _ in range(num_heads)
        ])
        self.reset_parameters()

    def reset_parameters(self):
        for param in self.attn_scores:
            nn.init.xavier_uniform_(param.view(1, -1))

    def forward(self, x, adj):
        head_outputs = []
        N = x.size(0)

        for i in range(self.num_heads):
            h = self.attn_layers[i](x)  # [N, D]

            Wh1 = h.unsqueeze(1).repeat(1, N, 1)  # [N, N, D]
            Wh2 = h.unsqueeze(0).repeat(N, 1, 1)  # [N, N, D]
            a_input = torch.cat([Wh1, Wh2], dim=-1)  # [N, N, 2D]
            e = torch.einsum("ijk,k->ij", a_input, self.attn_scores[i])  # [N, N]

            mask = (adj > 0)
            e = e.masked_fill(~mask, float('-inf'))
            e = e.masked_fill(mask.sum(dim=1, keepdim=True) == 0, 0.0)
            attn = F.softmax(e, dim=1)

            head_output = torch.matmul(attn, h)  # [N, D]
            head_outputs.append(head_output)

        return head_outputs  # list of [N, D]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TimeVaryingKalmanFilter(nn.Module):
    def __init__(self, dim, hidden_dim=64):
        super().__init__()
        self.dim = dim

        # Learnable matrices (still shared across time)
        self.H = nn.Parameter(torch.eye(dim))            # Observation model
        self.Q = nn.Parameter(1e-3 * torch.eye(dim))      # Process noise
        self.R = nn.Parameter(1e-3 * torch.eye(dim))      # Observation noise

        # Time-varying A_t generator
        self.A_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim * dim)
        )

        # Attention over observations
        self.attn_proj = nn.Linear(dim, 1)

    def forward(self, z_seq):
        """
        z_seq: list of length T, each [N, dim]
        Returns final state estimate after T steps: [N, dim]
        """
        T = len(z_seq)
        N, D = z_seq[0].shape
        device = z_seq[0].device

        x = torch.zeros((N, D), device=device)
        P = torch.eye(D, device=device).unsqueeze(0).repeat(N, 1, 1)

        for t in range(T):
            z = z_seq[t]  # [N, D]

            # Generate time-varying A_t from z
            A_t_flat = self.A_gen(z)              # [N, D*D]
            A_t = A_t_flat.view(N, D, D)          # [N, D, D]

            # Predict
            x_pred = torch.bmm(A_t, x.unsqueeze(-1)).squeeze(-1)  # [N, D]
            P_pred = torch.bmm(A_t, torch.bmm(P, A_t.transpose(1, 2))) + self.Q  # [N, D, D]

            # Attention weighting on z
            attn = torch.sigmoid(self.attn_proj(z))  # [N, 1]
            z_attn = z * attn                        # [N, D]

            # Update
            H_batched = self.H.unsqueeze(0).expand(N, D, D)
            S = torch.bmm(H_batched, torch.bmm(P_pred, H_batched.transpose(1, 2))) + self.R
            K = torch.bmm(P_pred, torch.bmm(H_batched.transpose(1, 2), torch.linalg.inv(S)))

            y = (z_attn - torch.bmm(H_batched, x_pred.unsqueeze(-1)).squeeze(-1))  # innovation
            x = x_pred + torch.bmm(K, y.unsqueeze(-1)).squeeze(-1)
            P = torch.bmm((torch.eye(D, device=device).unsqueeze(0) - torch.bmm(K, H_batched)), P_pred)

        return x  # [N, D]


In [None]:
class AttentiveKalmanFilter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        # Learnable Kalman matrices
        self.A = nn.Parameter(torch.eye(dim))             # State transition
        self.H = nn.Parameter(torch.eye(dim))             # Observation matrix
        self.Q = nn.Parameter(1e-3 * torch.eye(dim))       # Process noise
        self.R = nn.Parameter(1e-3 * torch.eye(dim))       # Observation noise

        # Attention parameters
        self.attn_proj = nn.Linear(dim, 1)

    def forward(self, z_seq):
        """
        z_seq: list of length T, each tensor of shape [N, dim]
        Returns final state estimate after T steps: tensor of shape [N, dim]
        """
        T = len(z_seq)
        N, D = z_seq[0].shape
        device = z_seq[0].device

        x = torch.zeros((N, D), device=device)  # initial state
        P = torch.eye(D, device=device).expand(N, D, D)  # initial covariance

        for t in range(T):
            z = z_seq[t]  # [N, D]

            # Predict
            x_pred = (x @ self.A.T)  # [N, D]
            A_batched = self.A.unsqueeze(0).expand(N, D, D)
            P_pred = A_batched @ P @ A_batched.transpose(1, 2) + self.Q

            # Attention weight over z
            attn_weight = torch.sigmoid(self.attn_proj(z))  # [N, 1]
            z_weighted = z * attn_weight  # [N, D]

            # Update
            H_batched = self.H.unsqueeze(0).expand(N, D, D)
            S = H_batched @ P_pred @ H_batched.transpose(1, 2) + self.R  # [N, D, D]
            K = P_pred @ H_batched.transpose(1, 2) @ torch.linalg.inv(S)

            y = (z_weighted - (x_pred @ self.H.T))  # [N, D]
            x = x_pred + torch.bmm(K, y.unsqueeze(-1)).squeeze(-1)
            P = (torch.eye(D, device=device).unsqueeze(0) - K @ H_batched) @ P_pred

        return x


In [None]:
class KalmanFusion(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.R = nn.Parameter(torch.eye(d_model) * 0.1)
        self.H = torch.eye(d_model)

    def forward(self, z_heads):
        """
        z_heads: list of [N, D] tensors from each attention head
        returns: [N, D] tensor after fusion
        """
        x = torch.zeros_like(z_heads[0])  # [N, D]
        P = torch.eye(x.size(-1), device=x.device).unsqueeze(0).repeat(x.size(0), 1, 1)  # [N, D, D]

        H = self.H.to(x.device)  # [D, D]
        R = self.R.to(x.device)  # [D, D]

        for z in z_heads:
            z = z.unsqueeze(-1)  # [N, D, 1]
            x_ = x.unsqueeze(-1)  # [N, D, 1]

            S = H @ P @ H.T + R  # [N, D, D]
            K = P @ H.T @ torch.linalg.inv(S)  # [N, D, D]

            x = (x_ + K @ (z - H @ x_)).squeeze(-1)  # [N, D]
            P = (torch.eye(K.shape[1], device=x.device) - K @ H) @ P  # [N, D, D]

        return x


    '''
    def forward(self, z_heads):
        x = torch.zeros_like(z_heads[0])
        P = torch.eye(x.size(-1)).to(x.device)
        for z in z_heads:
            S = self.H @ P @ self.H.T + self.R
            K = P @ self.H.T @ torch.linalg.inv(S)
            x = x + (K @ (z - self.H @ x).unsqueeze(-1)).squeeze(-1)
            P = (torch.eye(K.shape[0]).to(K.device) - K @ self.H) @ P
        return x
'''

In [None]:
class KalmanOverTime(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.Q = nn.Parameter(torch.eye(d_model) * 0.1)
        self.R = nn.Parameter(torch.eye(d_model) * 0.1)
        self.A = nn.Parameter(torch.eye(d_model))
        self.H = torch.eye(d_model)

    def forward(self, z_seq):
        x = torch.zeros_like(z_seq[0])  # [N, D]
        P = torch.eye(x.size(-1), device=x.device).unsqueeze(0).repeat(x.size(0), 1, 1)  # [N, D, D]

        A = self.A.to(x.device)
        H = self.H.to(x.device)
        Q = self.Q.to(x.device)
        R = self.R.to(x.device)

        for z in z_seq:
            # Prediction step
            x_pred = x @ A.T  # [N, D]
            P_pred = A @ P @ A.T + Q  # [N, D, D]

            # Update step
            x_ = x_pred.unsqueeze(-1)  # [N, D, 1]
            z = z.unsqueeze(-1)        # [N, D, 1]

            S = H @ P_pred @ H.T + R  # [N, D, D]
            K = P_pred @ H.T @ torch.linalg.inv(S)  # [N, D, D]
            x = (x_ + K @ (z - H @ x_)).squeeze(-1)  # [N, D]
            P = (torch.eye(K.shape[1], device=x.device) - K @ H) @ P_pred  # [N, D, D]

        return x  # [N, D]


    '''
    def forward(self, z_seq):
        x = torch.zeros_like(z_seq[0])
        P = torch.eye(x.size(-1)).to(x.device)
        for z in z_seq:
            x_pred = self.A @ x
            P_pred = self.A @ P @ self.A.T + self.Q
            S = self.H @ P_pred @ self.H.T + self.R
            K = P_pred @ self.H.T @ torch.linalg.inv(S)
            x = x_pred + (K @ (z - self.H @ x_pred).unsqueeze(-1)).squeeze(-1)
            P = (torch.eye(K.shape[0]).to(K.device) - K @ self.H) @ P_pred
        return x
        '''


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FullyLearnedKalmanFilter(nn.Module):
    def __init__(self, dim, hidden_dim=64):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim

        # Matrix generators
        self.A_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim * dim)
        )
        self.H_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim * dim)
        )
        self.Q_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )
        self.R_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )

        # Optional attention mechanism over z
        self.attn_proj = nn.Linear(dim, 1)

    def forward(self, z_seq):
        """
        z_seq: list of T tensors, each [N, D]
        Returns final state estimate: [N, D]
        """
        T = len(z_seq)
        N, D = z_seq[0].shape
        device = z_seq[0].device

        x = torch.zeros((N, D), device=device)
        P = torch.eye(D, device=device).unsqueeze(0).expand(N, D, D)  # [N, D, D]

        for z in z_seq:
            # Time-varying matrix generation
            A_t = self.A_gen(z).view(N, D, D)
            H_t = self.H_gen(z).view(N, D, D)

            Q_diag = torch.exp(self.Q_gen(z))  # Ensure positive definiteness
            R_diag = torch.exp(self.R_gen(z))
            Q_t = torch.diag_embed(Q_diag)
            R_t = torch.diag_embed(R_diag)

            # Predict step
            x_pred = torch.bmm(A_t, x.unsqueeze(-1)).squeeze(-1)              # [N, D]
            P_pred = torch.bmm(A_t, torch.bmm(P, A_t.transpose(1, 2))) + Q_t  # [N, D, D]

            # Attention-weighted observation
            attn = torch.sigmoid(self.attn_proj(z))  # [N, 1]
            z_attn = z * attn                        # [N, D]

            # Update step
            S = torch.bmm(H_t, torch.bmm(P_pred, H_t.transpose(1, 2))) + R_t  # [N, D, D]
            S_inv = torch.linalg.inv(S)

            K = torch.bmm(P_pred, torch.bmm(H_t.transpose(1, 2), S_inv))      # [N, D, D]

            innovation = (z_attn - torch.bmm(H_t, x_pred.unsqueeze(-1)).squeeze(-1))  # [N, D]
            x = x_pred + torch.bmm(K, innovation.unsqueeze(-1)).squeeze(-1)
            I = torch.eye(D, device=device).unsqueeze(0).expand(N, D, D)
            P = torch.bmm(I - torch.bmm(K, H_t), P_pred)

        return x  # [N, D]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.query = nn.Linear(dim, dim)
        self.key   = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out   = nn.Linear(dim, dim)

    def forward(self, x):
        # x: [T, N, D]
        T, N, D = x.shape
        H = self.num_heads

        q = self.query(x).view(T, N, H, self.head_dim).permute(1, 2, 0, 3)  # [N, H, T, D_head]
        k = self.key(x).view(T, N, H, self.head_dim).permute(1, 2, 0, 3)
        v = self.value(x).view(T, N, H, self.head_dim).permute(1, 2, 0, 3)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [N, H, T, T]
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # [N, H, T, D_head]

        out = out.permute(2, 0, 1, 3).contiguous().view(T, N, D)  # [T, N, D]
        return self.out(out)  # [T, N, D]

class MultiHeadAttentionDrop(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.query = nn.Linear(dim, dim)
        self.key   = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out   = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        """
        x: [T, N, D]
        """
        residual = x
        T, N, D = x.shape
        H = self.num_heads

        q = self.query(x).view(T, N, H, self.head_dim).permute(1, 2, 0, 3)  # [N, H, T, D_head]
        k = self.key(x).view(T, N, H, self.head_dim).permute(1, 2, 0, 3)
        v = self.value(x).view(T, N, H, self.head_dim).permute(1, 2, 0, 3)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [N, H, T, T]
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)  # [N, H, T, D_head]
        out = out.permute(2, 0, 1, 3).contiguous().view(T, N, D)  # [T, N, D]
        out = self.out(out)
        out = self.dropout(out)

        return self.norm(out + residual)  # Add & Norm


class FullyLearnedKalmanWithAttention(nn.Module):
    def __init__(self, dim, hidden_dim=64, num_heads=4):
        super().__init__()
        self.dim = dim

        self.attn = MultiHeadAttention(dim, num_heads)

        self.A_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim * dim)
        )
        self.H_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim * dim)
        )
        self.Q_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim)
        )
        self.R_gen = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim)
        )

    def forward(self, z_seq):
        """
        z_seq: list of T tensors, each [N, D]
        Returns: [N, D]
        """
        T = len(z_seq)
        z_stack = torch.stack(z_seq, dim=0)  # [T, N, D]
        z_attn = self.attn(z_stack)          # [T, N, D]

        N, D = z_attn[0].shape
        device = z_attn.device
        x = torch.zeros((N, D), device=device)
        P = torch.eye(D, device=device).unsqueeze(0).expand(N, D, D)

        for t in range(T):
            z = z_attn[t]

            A_t = self.A_gen(z).view(N, D, D)
            H_t = self.H_gen(z).view(N, D, D)
            Q_t = torch.diag_embed(torch.exp(self.Q_gen(z)))
            R_t = torch.diag_embed(torch.exp(self.R_gen(z)))

            x_pred = torch.bmm(A_t, x.unsqueeze(-1)).squeeze(-1)
            P_pred = torch.bmm(A_t, torch.bmm(P, A_t.transpose(1, 2))) + Q_t

            S = torch.bmm(H_t, torch.bmm(P_pred, H_t.transpose(1, 2))) + R_t
            K = torch.bmm(P_pred, torch.bmm(H_t.transpose(1, 2), torch.linalg.inv(S)))

            innovation = z - torch.bmm(H_t, x_pred.unsqueeze(-1)).squeeze(-1)
            x = x_pred + torch.bmm(K, innovation.unsqueeze(-1)).squeeze(-1)
            I = torch.eye(D, device=device).unsqueeze(0).expand(N, D, D)
            P = torch.bmm(I - torch.bmm(K, H_t), P_pred)

        return x  # Final estimate


In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1, ff_hidden_mult=4):
        super().__init__()
        self.attn = MultiHeadAttention(dim, num_heads, dropout)

        self.ff = nn.Sequential(
            nn.Linear(dim, ff_hidden_mult * dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_hidden_mult * dim, dim),
            nn.Dropout(dropout)
        )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # x: [T, N, D]
        x = self.attn(x)               # Multi-head self-attention with Add & Norm inside
        x = x + self.ff(x)             # FeedForward + Residual
        x = self.norm(x)               # Final LayerNorm
        return x


class FullyLearnedKalmanWithTransformer(nn.Module):
    def __init__(self, dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.encoder = TransformerEncoderBlock(dim, num_heads, dropout)
        self.kalman_filter = FullyLearnedKalmanWithAttention(dim)

    def forward(self, z_seq):  # z_seq: [T, N, D]
        z_seq = self.encoder(z_seq)  # Transformer encoder
        x = self.kalman_filter(z_seq)
        return x  # [N, D]


In [None]:
class TransformerEncoderStack(nn.Module):
    def __init__(self, dim, num_heads, num_layers, dropout=0.1, ff_hidden_mult=4):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(dim, num_heads, dropout, ff_hidden_mult)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class KalmanWithTransformerStack(nn.Module):
    def __init__(self, dim, num_heads=4, num_layers=3, dropout=0.1):
        super().__init__()
        self.encoder_stack = TransformerEncoderStack(dim, num_heads, num_layers, dropout)
        self.kalman_filter = FullyLearnedKalmanWithAttention(dim)

    def forward(self, z_seq):  # z_seq: [T, N, D]
        z_seq = self.encoder_stack(z_seq)
        x = self.kalman_filter(z_seq)
        return x  # [N, D]


In [None]:
class KalmanGAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_heads, num_layers, out_dim):
        super().__init__()
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList([
            MultiHeadGATLayer(in_dim if i == 0 else hidden_dim, hidden_dim, num_heads)
            for i in range(num_layers)
        ])
        self.kalman_fusion = KalmanFusion(hidden_dim)
        self.kalman_filter = KalmanOverTime(hidden_dim)
        #self.kalman_filter = AttentiveKalmanFilter(dim=hidden_dim)
        #self.kalman_filter = TimeVaryingKalmanFilter(dim=hidden_dim)
        #self.kalman_filter = FullyLearnedKalmanFilter(dim=hidden_dim)
        #self.kalman_filter = FullyLearnedKalmanWithAttention(dim=hidden_dim)
        #self.kalman_filter = FullyLearnedKalmanWithTransformer(dim=hidden_dim)
        #self.kalman_filter = KalmanWithTransformerStack(dim=hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, out_dim)  # out_dim = 12


    def forward(self, x, adj):
        z_seq = []
        for layer in self.gat_layers:
            head_outputs = layer(x, adj)
            fused = self.kalman_fusion(head_outputs)
            z_seq.append(fused)
            x = fused  # next layer input
        #final_rep = self.kalman_filter(z_seq)
        final_rep = self.kalman_filter(z_seq)  # shape: [207, 64]
        output = self.output_proj(final_rep)   # shape: [207, 12]
        return output

        #return final_rep

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load dataset
loader = METRLADatasetLoader()
dataset = loader.get_dataset()

In [None]:
dataset[0].x.shape

torch.Size([207, 2, 12])

In [None]:
x_all = torch.stack([snap.x for snap in dataset])  # shape [T, N, F]
y_all = torch.stack([snap.y for snap in dataset])  # shape [T, N, F]
#x_scaled, mean, std = scale_metrla_data(x_all)
#y_scaled, mean, std = scale_metrla_data(y_all)
#x_data = []
#y_data = []
#x_data.append(torch.FloatTensor(x_scaled))
#y_data.append(torch.FloatTensor(y_scaled))
x_data, mean, std = scale_metrla_data(x_all)
y_data, mean, std = scale_metrla_data(y_all)
x_data = x_data.view(x_data.size(0), x_data.size(1), -1)
y_data = y_data.view(y_data.size(0), y_data.size(1), -1)
edge_index = dataset.edge_index
num_nodes = x_data[0].size(0)
adj = torch.zeros((num_nodes, num_nodes))
adj[edge_index[0], edge_index[1]] = 1.0

In [None]:
# Hyperparameters
#in_dim = x_data[0].shape[1]
in_dim = x_data.shape[-1]
hidden_dim = 64
out_dim = 12
num_heads = 4
num_layers = 3
epochs = 2
lr = 0.001

In [None]:
# Model
model = KalmanGAT(in_dim, hidden_dim, num_heads, num_layers, out_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

In [None]:
# Training loop
for epoch in range(epochs):
    model.train()
    losses = []
    if torch.isnan(x_data).any() or torch.isinf(x_data).any():
      print("NaN or Inf in input x!")

    if torch.isnan(adj).any() or torch.isinf(adj).any():
        print("NaN or Inf in input adj!")

    if torch.isnan(y_data).any() or torch.isinf(y_data).any():
        print("NaN or Inf in labels!")

    for t in range(len(x_data) - num_layers - 1):
        x_seq = x_data[t].to(device)
        y_true = y_data[t + num_layers].to(device)
        output = model(x_seq, adj.to(device))
        loss = criterion(output, y_true)
        #loss = criterion(output[:, :12], y_true)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {np.mean(losses):.4f}")

In [None]:
print("x_all.shape:", x_all.shape)
print("x_data.shape:", x_data.shape)
print("x_data[0].shape:", x_data[0].shape)
print("adj.shape:", adj.shape)
print("in_dim:", in_dim)


x_all.shape: torch.Size([34249, 207, 2, 12])
x_data.shape: torch.Size([34249, 207, 24])
x_data[0].shape: torch.Size([207, 24])
adj.shape: torch.Size([207, 207])
in_dim: 24


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error
from torch_geometric_temporal.dataset import METRLADatasetLoader
from torch_geometric_temporal.signal import StaticGraphTemporalSignal

# -----------------------------
# Configuration
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq_len = 12
pred_len = 12
in_dim = 2
hidden_dim = 64
epochs = 10
batch_size = 64
learning_rate = 1e-3

# -----------------------------
# Normalize function
# -----------------------------
def scale_metrla_data(data):
    mean = data.mean(dim=0, keepdim=True)
    std = data.std(dim=0, keepdim=True) + 1e-6
    data = (data - mean) / std
    return data, mean, std

# -----------------------------
# Dataset Class
# -----------------------------
class METRLADataset(Dataset):
    def __init__(self, snapshot_list, seq_len=12, pred_len=12, mean=None, std=None):
        self.seq_len = seq_len
        self.pred_len = pred_len

        features = torch.stack([torch.tensor(s.x, dtype=torch.float32) for s in snapshot_list])  # [T, N, F]
        self.mean = mean if mean is not None else features.mean(dim=0, keepdim=True)
        self.std = std if std is not None else features.std(dim=0, keepdim=True) + 1e-6
        self.features = (features - self.mean) / self.std  # Normalize
        self.targets = torch.stack([torch.tensor(s.y, dtype=torch.float32) for s in snapshot_list])  # [T, N]

    def __len__(self):
        return len(self.features) - self.seq_len - self.pred_len

    def __getitem__(self, idx):
        x = self.features[idx:idx + self.seq_len]  # [seq_len, N, F]
        y = self.targets[idx + self.seq_len:idx + self.seq_len + self.pred_len]  # [pred_len, N]
        return x, y

# -----------------------------
# Model Definitions (Same)
# -----------------------------
class KalmanNet(nn.Module):
    def __init__(self, num_nodes, in_dim, hidden_dim, pred_len):
        super(KalmanNet, self).__init__()
        self.rnn = nn.GRU(num_nodes * in_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_nodes * pred_len)

    def forward(self, x):
        print(f"x.shape = {x.shape}")
        if x.dim() > 4:
            x = x[..., 0]
        print(f"x.shape = {x.shape}")
        B, T, N, F = x.shape
        x = x.view(B, T, N * F)
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1])
        out = out.view(B, T, N, -1)
        return out

class DeepKF(KalmanNet):  # Same architecture as KalmanNet for now
    pass

class KalmanAttentionModel(nn.Module):
    def __init__(self, num_nodes, in_dim, hidden_dim, pred_len):
        super(KalmanAttentionModel, self).__init__()
        self.encoder = nn.Linear(in_dim, hidden_dim)
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads=4, dropout=0.1, batch_first=True)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.output = nn.Linear(hidden_dim, pred_len)

    def forward(self, x):
        print(f"x.shape = {x.shape}")
        if x.dim() > 4:
            x = x[..., 0]
        print(f"x.shape = {x.shape}")
        B, T, N, F = x.shape
        x = x.permute(0, 2, 1, 3).reshape(B * N, T, F)
        x = self.encoder(x)
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ff(x))
        x = self.output(x[:, -1])
        x = x.view(B, T, N, -1)
        return x

# -----------------------------
# Training and Evaluation
# -----------------------------
def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        print(y_pred.shape,y.shape)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y_pred = model(x)
            preds.append(y_pred.cpu().numpy())
            trues.append(y.numpy())
    preds = np.concatenate(preds, axis=0)
    trues = np.concatenate(trues, axis=0)
    mae = mean_absolute_error(trues.flatten(), preds.flatten())
    rmse = np.sqrt(mean_squared_error(trues.flatten(), preds.flatten()))
    mape = np.mean(np.abs((trues - preds) / (trues + 1e-5))) * 100
    return mae, rmse, mape

# -----------------------------
# Main Function
# -----------------------------
def main():
    # Load METR-LA
    loader = METRLADatasetLoader()
    dataset = list(loader.get_dataset())  # List of StaticGraphTemporalSignal
    num_nodes = dataset[0].x.shape[0]

    # Split
    n = len(dataset)
    train_raw = dataset[:int(n * 0.7)]
    val_raw = dataset[int(n * 0.7):int(n * 0.8)]
    test_raw = dataset[int(n * 0.8):]

    # Fit normalization on training
    train_feats = torch.stack([torch.tensor(s.x, dtype=torch.float32) for s in train_raw])
    mean = train_feats.mean(dim=0, keepdim=True)
    std = train_feats.std(dim=0, keepdim=True) + 1e-6

    # Build PyTorch Datasets
    train_dataset = METRLADataset(train_raw, seq_len, pred_len, mean, std)
    val_dataset = METRLADataset(val_raw, seq_len, pred_len, mean, std)
    test_dataset = METRLADataset(test_raw, seq_len, pred_len, mean, std)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    '''for x,y in train_loader:
        print(x.shape,y.shape)'''
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Define models
    models = {
        'KalmanNet': KalmanNet(num_nodes, in_dim, hidden_dim, pred_len),
        'DeepKF': DeepKF(num_nodes, in_dim, hidden_dim, pred_len),
        'KalmanAttention': KalmanAttentionModel(num_nodes, in_dim, hidden_dim, pred_len)
    }

    criterion = nn.MSELoss()
    results = {}

    for name, model in models.items():
        print(f"\nTraining {name}...")
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        train_losses, val_maes, val_rmses, val_mapes = [], [], [], []

        for epoch in range(epochs):
            loss = train(model, train_loader, optimizer, criterion)
            mae, rmse, mape = evaluate(model, val_loader)
            train_losses.append(loss)
            val_maes.append(mae)
            val_rmses.append(rmse)
            val_mapes.append(mape)
            print(f"Epoch {epoch+1}/{epochs} - Loss: {loss:.4f}, Val MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.2f}%")

        test_mae, test_rmse, test_mape = evaluate(model, test_loader)
        print(f"Test MAE: {test_mae:.4f}, RMSE: {test_rmse:.4f}, MAPE: {test_mape:.2f}%")

        results[name] = {
            'train_losses': train_losses,
            'val_maes': val_maes,
            'val_rmses': val_rmses,
            'val_mapes': val_mapes,
            'test_mae': test_mae,
            'test_rmse': test_rmse,
            'test_mape': test_mape
        }

    # Plot metrics
    for metric in ['train_losses', 'val_maes', 'val_rmses', 'val_mapes']:
        plt.figure()
        for name in models:
            plt.plot(results[name][metric], label=name)
        plt.title(metric.replace('_', ' ').title())
        plt.xlabel('Epoch')
        plt.ylabel(metric.split('_')[-1].upper())
        plt.legend()
        plt.grid(True)
        plt.savefig(f'{metric}.png')
        plt.show()

    # Final results
    print("\nFinal Test Metrics:")
    for name in models:
        r = results[name]
        print(f"{name}: MAE={r['test_mae']:.4f}, RMSE={r['test_rmse']:.4f}, MAPE={r['test_mape']:.2f}%")

if __name__ == '__main__':
    main()


  train_feats = torch.stack([torch.tensor(s.x, dtype=torch.float32) for s in train_raw])
  features = torch.stack([torch.tensor(s.x, dtype=torch.float32) for s in snapshot_list])  # [T, N, F]
  self.targets = torch.stack([torch.tensor(s.y, dtype=torch.float32) for s in snapshot_list])  # [T, N]



Training KalmanNet...
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])


  return F.mse_loss(input, target, reduction=self.reduction)


x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
torch.Size([64, 12, 207, 1]) torch.Size([64, 12, 207, 12])
x.shape = tor

  return F.mse_loss(input, target, reduction=self.reduction)


x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.Size([64, 12, 207, 2, 12])
x.shape = torch.Size([64, 12, 207, 2])
x.shape = torch.

ValueError: Found input variables with inconsistent numbers of samples: [101377008, 8448084]