In [1]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

manual_seed()


# Data

In [2]:
from utils import load_data, load_edge

train_folds = load_data(True)
test_fold = load_data(False)[0]

In [3]:
import torch
from torch.utils.data import DataLoader, Dataset

class SimpleStockDataset(Dataset):
    def __init__(self, data, ws=128):
        self.data = data
        self.ws = ws
        self.samples = []
        
        self.n_tickers, self.n_days, self.n_features = self.data.shape
        
        for start in range(self.n_days - self.ws + 1):
            self.samples.append(start)
            
    def __len__(self):
        return len(self.samples)
      
    def __getitem__(self, idx):
        start = self.samples[idx]
        x = torch.tensor(self.data[:, start:start + self.ws], dtype=torch.float32)
        return x
    

# Model

### TemporalGraphRefiner

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Literal


class MyTemporalEncoderLayerBiLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_layers: int = 1, dropout: float = 0.0, **kwargs):
        super().__init__()
        
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            n_layers,
            batch_first=True,
            dropout=(dropout if n_layers > 1 else 0.0),
            bidirectional=True)

        self.proj = nn.Linear(2 * hidden_size, hidden_size)

    def forward(
        self, x: torch.Tensor):
        x, _ = self.lstm(x)
        
        x = self.proj(x)
        return x

class GraphChebMix(nn.Module):
    def __init__(self, Pdeg: int, d_latent: int, safety: float = 0.99, dropout=0.0):
        super().__init__()
        self.Pdeg = Pdeg
        self.coeffs = nn.Parameter(torch.randn(Pdeg + 1) * 0.05)
        self.safety = safety
        self.dropout = nn.Dropout(dropout)

    def forward(self, X: torch.Tensor, L: torch.Tensor):
        """
        X: (N, H, d_latent); per-time graph mixing.
        """
        Y = self.apply_cheb_seq(L, X, self.coeffs, safety=self.safety)
        return self.dropout(Y)

    @staticmethod
    def apply_cheb_seq(L, X, coeffs, safety=0.99):
        I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
        S = I_N - L

        s = coeffs.abs().sum().clamp_min(1e-12)
        scale = min(1.0, float(safety)) / float(s)
        c = coeffs * scale

        T0 = X
        if c.numel() == 1:
            return c[0] * T0
        T1 = torch.einsum('ij,jhd->ihd', S, X)
        Y  = c[0] * T0 + c[1] * T1
        for j in range(2, c.numel()):
            T2 = 2 * torch.einsum('ij,jhd->ihd', S, T1) - T0
            Y  = Y + c[j] * T2
            T0, T1 = T1, T2
        return Y
      

class TemporalGraphLayer(nn.Module):
    def __init__(self, d_latent: int = 128,
                 Pdeg: int = 2, dropout: float = 0.1):
        super().__init__()
        self.d_latent = d_latent
        self.temporal_encoder = MyTemporalEncoderLayerBiLSTM(d_latent, d_latent)
        self.graph_encoder = GraphChebMix(Pdeg=Pdeg, d_latent=d_latent, dropout=dropout)
        self.ln_temporal = nn.LayerNorm(d_latent)
        self.ln_graph = nn.LayerNorm(d_latent)
        self.fuse = nn.Sequential(nn.Linear(2*d_latent, d_latent))
        self.norm_out = nn.LayerNorm(d_latent)
        

    def forward(self, x, L):
        """
        x:   (N, H, d_latent)
        node_features: (N, d_node)
        L: from your OUGCN
        """
        nn.TransformerEncoderLayer
        x_res = x
        # temporal encoder
        ht = self.ln_temporal(self.temporal_encoder(x)) # (N,H,d_latent)

        # graph encoder
        hg = self.ln_graph(self.graph_encoder(ht, L)) # (N,H,d_latent)

        # fuse
        h = self.fuse(torch.cat([ht, hg], dim=-1))                    # (N,H,d_latent)

        out = self.norm_out(x_res + h)
        
        return out
      

class TemporalGraphRefiner(nn.Module):
    def __init__(self, n_feats: int, d_node: int, d_latent: int = 128,
                 num_layers: int = 2, 
                 Pdeg: int = 2, dropout: float = 0.1):
        super().__init__()
        self.d_latent = d_latent
        self.num_layers = num_layers
        self.fc_in = nn.Linear(n_feats + d_node, d_latent)
        self.temporal_graph_block = nn.ModuleList([
            TemporalGraphLayer(self.d_latent, Pdeg, dropout)
            for _ in range(self.num_layers)
        ])
        self.head = nn.Sequential(
            nn.Linear(d_latent, d_latent), nn.GELU(),
            nn.Linear(d_latent, n_feats)
        )

    def forward(self, x, node_features, L):
        """
        x:   (N, H, F)
        node_features: (N, d_node)
        L: from your OUGCN
        """
        N, H, F = x.shape
        device = x.device
        
        node_broadcast = node_features.unsqueeze(1).expand(N, H, node_features.size(-1))
        x = self.fc_in(torch.cat([x, node_broadcast], dim=-1))  # (N,H,d_latent)

        for _, block in enumerate(self.temporal_graph_block):
            x = block(x, L)
        
        out = self.head(x)
        return out

### RNNBackbone

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Literal

MyRNNArchitectureType = Literal['lstm']

class MyLSTMBackbone(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_layers: int, dropout: float, **kwargs):
        super().__init__()
        self.d_in = input_size
        self.d_latent = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.cells = nn.ModuleList([
            nn.LSTMCell(self.d_in if i == 0 else self.d_latent, self.d_latent)
            for i in range(self.n_layers)
        ])

    def init_states(self, X_0: torch.Tensor, H_0: Optional[torch.Tensor] = None) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        X_0: (N, d_in), H_0: (N, d_latent)
        """
        device = X_0.device
        dtype = X_0.dtype
        N = X_0.shape[0]

        h_list = [torch.zeros(N, self.d_latent, device=device, dtype=dtype) for _ in range(self.n_layers)]
        c_list = [torch.zeros(N, self.d_latent, device=device, dtype=dtype) for _ in range(self.n_layers)]

        if H_0 is not None:
            if H_0.shape != (N, self.d_latent):
                raise ValueError(f"H_0 must be (N, {self.d_latent}), got {tuple(H_0.shape)}")
            h_list[-1] = H_0.to(device=device, dtype=dtype)

        return h_list, c_list

    def forward(
        self,
        input: torch.Tensor,  # (N, d_in)
        hx: Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]] = None
    ) -> Tuple[torch.Tensor, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        if hx is None:
            hx = self.init_states(input)
        h_list, c_list = hx

        new_h, new_c = [], []
        inp = input
        for l, cell in enumerate(self.cells):
            h_t, c_t = cell(inp, (h_list[l], c_list[l]))
            if self.training and self.dropout > 0:
                h_t = F.dropout(h_t, p=self.dropout, training=True)
            new_h.append(h_t)
            new_c.append(c_t)
            inp = h_t

        h_top = new_h[-1]
        return h_top, (new_h, new_c)


class MyRNNCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_layers: int = 1, dropout: float = 0.0, **kwargs):
        super().__init__()
        arch_type = kwargs.get('arch_type', 'lstm')
        backbone_kwargs = kwargs.get('backbone_kwargs', {})
        assert arch_type in ['lstm'], f'Arch not supported: {arch_type}'

        self.rnn = MyLSTMBackbone(input_size, hidden_size, n_layers, dropout, **backbone_kwargs)

    @classmethod
    def make(cls, **kwargs) -> 'MyRNNCell':
        """
        Create Customized RNN Cell compatible with OUGCN.
        """
        default_arch_type: MyRNNArchitectureType = 'lstm'
        defaults = {
            'n_layers': 1,
            'dropout': 0.1,
            'arch_type': default_arch_type,
        }
        return MyRNNCell(**(defaults | kwargs))

    def init_states(self, X_0: torch.Tensor, H_0: Optional[torch.Tensor] = None):
        return self.rnn.init_states(X_0, H_0)

    def forward(
        self,
        input: torch.Tensor,
        hx: Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]] = None
    ):
        return self.rnn(input, hx)


### OUGCN

In [6]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------
# Graph utilities
# ------------------------------

def pearson_corr_matrix(series: torch.Tensor) -> torch.Tensor:
    """Compute Pearson correlation across time for N series.
    Args:
        series: (N, T) time-series (already aligned), not all-constant.
    Returns:
        corr: (N, N) in [-1,1]
    """
    N, T = series.shape
    x = series - series.mean(dim=1, keepdim=True)
    std = x.std(dim=1, unbiased=False, keepdim=True) + 1e-8
    x = x / std
    corr = (x @ x.t()) / T
    corr = corr.clamp(-1.0, 1.0)
    return corr


def build_adjacency_from_corr(corr: torch.Tensor,
                              keep_negative: bool = False,
                              knn: Optional[int] = None,
                              threshold: Optional[float] = None) -> torch.Tensor:
    """Turn correlation into nonnegative adjacency.
    - If keep_negative=False: use relu to zero-out negative correlations.
    - Optional: keep only top-k neighbors per node (excluding self) or apply threshold.
    Returns A with zero diagonal.
    """
    N = corr.shape[0]
    if not keep_negative:
        A = corr.relu()
    else:
        # shift to nonnegative range [0, 2] then rescale to [0,1]
        A = (corr + 1.0) / 2.0
        A = A.clamp(0.0, 1.0)
    eye = torch.eye(N, device=A.device, dtype=torch.bool)
    A = A.masked_fill(eye, 0.0)

    if threshold is not None:
        A = torch.where(A >= threshold, A, torch.zeros_like(A))

    if knn is not None and knn > 0 and knn < N:
        # retain top-k per row (excluding diagonal already 0)
        topk_vals, topk_idx = torch.topk(A, k=knn, dim=1)
        mask = torch.zeros_like(A, dtype=torch.bool)
        mask.scatter_(1, topk_idx, True)
        A = torch.where(mask, A, torch.zeros_like(A))
        # symmetrize by max
        A = torch.max(A, A.t())

    # final symmetry
    A = 0.5 * (A + A.t())
    A = A.masked_fill(eye, 0.0)
    return A

def normalized_adjacency(A):
    """Compute Atilde = Dtilde^{-1/2}(A+I)Dtilde^{-1/2}."""
    I = torch.eye(A.size(0), device=A.device, dtype=A.dtype)
    Ahat = A + I
    d = Ahat.sum(1) + 1e-8
    dinv = d.pow(-0.5)
    return dinv[:,None] * Ahat * dinv[None,:]

def normalized_laplacian(A):
    """Compute L = I - D^{-1/2} A D^{-1/2}."""
    d = A.sum(dim=1) + 1e-8
    dinv = d.pow(-0.5)
    S = dinv[:,None] * A * dinv[None,:]
    I = torch.eye(A.size(0), device=A.device, dtype=A.dtype)
    return I - S

def power_iteration_lmax_sym(M: torch.Tensor, n_iter: int = 25) -> float:
    """Estimate largest eigenvalue (spectral radius) of symmetric PSD matrix M."""
    N = M.shape[0]
    v = torch.randn(N, device=M.device, dtype=M.dtype)
    v = v / (v.norm() + 1e-8)
    for _ in range(n_iter):
        v = M @ v
        n = v.norm() + 1e-8
        v = v / n
    # Rayleigh quotient
    lmax = float((v @ (M @ v)).item())
    return max(lmax, 1e-12)
  
def spec_norm_2(A: torch.Tensor, n_iter: int = 25) -> torch.Tensor:
    """||A||_2 via power method on A^T A. Returns a scalar Tensor (has grad wrt A)."""
    AtA = A.T @ A
    v = torch.randn(AtA.shape[0], device=A.device, dtype=A.dtype)
    v = v / (v.norm() + 1e-8)
    for _ in range(n_iter):
        v = AtA @ v
        v = v / (v.norm() + 1e-8)
    # sqrt(v^T (A^T A) v) = ||A v||, but this Rayleigh ~ lambda_max; take sqrt.
    lam_max = v @ (AtA @ v)
    return lam_max.clamp_min(1e-12).sqrt()  # Tensor

def compute_poly(L, coeffs):
    I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
    return apply_poly_to_emb(L, I_N, coeffs)

def compute_cheb(L, coeffs, safety=0.99):
    I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
    K = apply_cheb_to_emb(L, I_N, coeffs, safety)
    
    # lam_max = spec_norm_2(K)
    # scale = min(1.0, safety / lam_max)
    # return K * scale
    
    return K

def apply_poly_to_emb(L, V, coeffs) -> torch.Tensor:
    """Compute K * V = sum_{i=0}^P coeffs_i * L^i * V.
    """
    out = coeffs[0] * V
    if coeffs.numel() == 1:
        return out
    LV = V
    for i in range(1, coeffs.numel()):
        LV = L @ LV          # O(N^2 d)
        out = out + coeffs[i] * LV
    return out

def apply_cheb_to_emb(L, V, coeffs, safety=0.99):
    """
    Tính y = sum_j c_j T_j(S) V
    S = I - L (phổ trong [-1,1]); V: (N,d) hoặc (N,F)
    """
    I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
    S = I_N - L

    c = clamp_l1(coeffs, safety)
    
    if c.numel() == 1:
        return c[0] * V
    T0 = V
    T1 = S @ V              # T1 = T_1(S)V
    y  = c[0] * T0 + c[1] * T1
    for j in range(2, c.numel()):
        T2 = 2 * (S @ T1) - T0
        y = y + c[j] * T2
        T0, T1 = T1, T2
    return y

def clamp_l1(coeffs, safety=0.99):
    s = coeffs.abs().sum().clamp_min(1e-12)
    scale = torch.minimum(torch.tensor(1.0, device=coeffs.device, dtype=coeffs.dtype), safety / s)
    return coeffs * scale

# ------------------------------
# Model definition
# ------------------------------

import math

class MyGCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation=None, bias=True, dropout=0):
        super(MyGCN, self).__init__()
        self.dropout = dropout
        self.activation = activation
        
        self.weight = nn.Parameter(torch.FloatTensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_feats))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def graph_convolve(self, x, adj):
        support = x @ self.weight
        output = adj @ support
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def forward(self, x, adj):
        x = self.graph_convolve(x, adj)
        if self.activation is not None:
            x = self.activation(x)
        if self.training and self.dropout > 0:
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x

class GatedCell(nn.Module):
    def __init__(self, d_latent):
        super().__init__()
        self.z_gate = nn.Linear(2*d_latent, d_latent)
        nn.init.constant_(self.z_gate.bias, -1.0)
        self.proj_A = nn.Sequential(
            nn.LayerNorm(d_latent),
            nn.Linear(d_latent, d_latent),
            nn.GELU(),
        )
        self.proj_B = nn.Sequential(
            nn.LayerNorm(d_latent),
            nn.Linear(d_latent, d_latent),
            nn.GELU(),
        )

    def forward(self, A, B):
        z = torch.sigmoid(self.z_gate(torch.cat([self.proj_A(A), self.proj_B(B)], dim=-1)))
        
        A = (1 - z) * A + z * B
        return A

class OUGCN_with_Refiner(nn.Module):
    def __init__(self, n_nodes: int, n_feats: int, args, node_emb=None):
        super().__init__()
        self.n_nodes = n_nodes
        self.rank_adj = args.rank_adj
        self.top_k_adj = args.top_k_adj
        self.n_feats = n_feats
        self.args = args
        d_latent = args.d_latent
        d_node = args.d_node
        in_dropout = args.in_dropout
        
        rnn_kwargs:dict = getattr(args, 'rnn_kwargs', {})
        
        rnn_arch_type = rnn_kwargs.setdefault('rnn_arch_type', 'lstm')
        rnn_n_layers = rnn_kwargs.setdefault('rnn_n_layers', 1)
        rnn_hidden_dim = rnn_kwargs.setdefault('rnn_hidden_dim', 128)
        rnn_dropout = rnn_kwargs.setdefault('rnn_dropout', 0.0)
        rnn_backbone_kwargs = rnn_kwargs.setdefault('rnn_backbone_kwargs', {})
        

        if node_emb is None:
            self.static_node_features = nn.Parameter(torch.randn(n_nodes, d_node) * 0.1)
        else:
            self.static_node_features = nn.Parameter(node_emb, requires_grad=False)
            d_node = self.static_node_features.shape[-1]
            
        
        self.fc_in = nn.Linear(n_feats + d_node, d_latent)
        self.gcn_in = MyGCN(d_latent, d_latent, F.relu, dropout=in_dropout)
        self.rnn_mean = MyRNNCell.make(
            input_size=n_feats, hidden_size=rnn_hidden_dim,
            n_layers=rnn_n_layers, dropout=rnn_dropout,
            arch_type=rnn_arch_type,
            backbone_kwargs=rnn_backbone_kwargs)
        self.gcn_mean = MyGCN(rnn_hidden_dim, d_latent, F.relu, dropout=in_dropout)
        self.fc_mean = nn.Linear(rnn_hidden_dim, d_latent)
        
        self.readout = nn.Sequential(
            nn.Linear(d_latent, d_latent), nn.GELU(),
            nn.Linear(d_latent, n_feats), nn.Tanh(),
        )
        self.res_scale = nn.Parameter(torch.ones(n_feats))
        
        self.kappa_H = nn.Parameter(torch.randn(args.Pdeg + 1) * 0.1)
        self.kappa_M = nn.Parameter(torch.randn(args.Pdeg + 1) * 0.1)
        
        self.H_mix_module = GatedCell(d_latent)
        self.M_mix_module = GatedCell(d_latent)
        
        self.fc_corr_features = nn.Sequential(
            nn.Linear(d_node, d_latent), nn.ReLU(),
            nn.Linear(d_latent, self.rank_adj)
        )
        # self.fc_node_features = nn.Sequential(
        #     nn.Linear(d_node, d_latent), nn.ReLU(),
        #     nn.Linear(d_latent, d_latent),
        # )
        
        self.refiner = TemporalGraphRefiner(n_feats, d_node, d_latent,
                                            **args.refiner_kwargs)
        
        self.corr_features = None
        self.node_features = None
        
    def build_graph(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Build A (adj), Atilde (norm adj for GCN), and L (normalized Laplacian) from data X.
        Right now we build a data-independent static learnable adj matrix.
        X: (N, T, F)
        Returns: (A, Atilde, L) all in device/dtype of X.
        """
        # n_nodes, n_steps, n_feats = X.shape
        if self.corr_features is None:
            self.corr_features = self.fc_corr_features(self.static_node_features)
        self.corr_features = F.normalize(self.corr_features, dim=-1)
        A = self.corr_features @ self.corr_features.T
        # A.fill_diagonal_(0.0)
        A = build_adjacency_from_corr(A, keep_negative=False, knn=self.top_k_adj)
        Atilde = normalized_adjacency(A)
        L = normalized_laplacian(A)
        return A, Atilde, L
      

    def _compute_graph_ops(self, X: torch.Tensor):
        """Helper: build Atilde, L, K, identity."""
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        _, Atilde, L = self.build_graph(X)
        Atilde = Atilde.to(device)
        L = L.to(device)
        
        return Atilde, L
    
    def _step(self, X_t, H_t, H_filter, M_filter, gcn_filter, rnn_mean_latent=None):
        Z_t_hist, rnn_mean_latent = self.rnn_mean.forward(X_t, rnn_mean_latent)
        
        M_t = self.fc_mean(Z_t_hist) + self.gcn_mean(Z_t_hist, gcn_filter) # mean embedding
        
        H_cand = self.H_mix_module(H_t, H_filter @ H_t)
        M_cand = self.M_mix_module(M_t, M_filter @ M_t)
        H_next = H_cand + M_cand
        r_t = self.readout(H_next) * self.res_scale  # (N, F)
        
        return r_t, H_next, rnn_mean_latent

    
    
    def forecast(self, X: torch.Tensor,
                H_0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        y_pred, *_ = self.forward(X, H_0, horizon)
        return y_pred
    
    def forward(self, X: torch.Tensor,
                H_0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run deterministic inference.
        Args:
            X: (N, T, F) input features at each t
            H_0: (N, d) optional initial latent state (default zeros)
            Atilde, L: if precomputed; otherwise build from X
        Returns:
            y_pred: (N, T-1) predicted next-day log-return for each node
            H_all: (N, T, d) latent states (including H_0 as first)
        """
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        n_nodes, n_steps, n_feats = X.shape
        d_latent = self.args.d_latent
        
        self.corr_features = self.fc_corr_features(self.static_node_features)
        # self.node_features = self.fc_node_features(self.static_node_features)
        

        Atilde, L = self._compute_graph_ops(X)
        I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
        
        
        self.L = L

        X_0 = X[:, 0, :] # (N, F)
        if H_0 is None:
            x_in = self.fc_in(torch.cat([X_0, self.static_node_features], dim=-1)) # (N, F + d_node)
            H_t = self.gcn_in(x_in, Atilde)
        else:
            H_t = H_0.to(device)
        
        # H_t = H_t + self.node_features

        H_all = torch.zeros(n_nodes, n_steps + max(horizon, 0), d_latent, device=device, dtype=X.dtype)
        r_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)
        y_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)

        H_all[:, 0] = H_t
        
        H_filter = compute_cheb(L, self.kappa_H)
        M_filter = I_N - Atilde
        gcn_filter = compute_cheb(L, self.kappa_M)
        
        rnn_mean_latent = self.rnn_mean.init_states(X_0, H_t)
        
        for t in range(n_steps - 1):
            X_t = X[:, t, :] # (N, D)
            r_t, H_next, rnn_mean_latent = self._step(X_t, H_t, H_filter, M_filter, gcn_filter, rnn_mean_latent)

            H_all[:, t + 1] = H_next
            r_all[:, t] = r_t
            y_all[:, t] = X_t + r_t
            H_t = H_next
            
        if horizon > 0:
            X_t = X[:, -1, :]
            for s in range(horizon):
                r_t, H_next, rnn_mean_latent = self._step(X_t, H_t, H_filter, M_filter, gcn_filter, rnn_mean_latent)

                H_all[:, n_steps + s] = H_next
                r_all[:, n_steps - 1 + s] = r_t
                H_t = H_next
                
                y_all[:, n_steps - 1 + s] = X_t + r_t
                X_t = X_t + r_t
        
        r_refine = self.refiner.forward(y_all, self.static_node_features, L)
        y_refine = y_all + r_refine
        
        return y_refine, r_refine, y_all, r_all, H_all




    def forward_loss(
        self,
        X: torch.Tensor,
        H_0: Optional[torch.Tensor] = None,
        horizon: int = 0,
    ):
        """
        Tính loss dự báo one-step + rollout horizon (autoregressive) trên chuỗi đầu vào.

        Args:
            X: (N, T, F) chuỗi gốc (chứa full ground-truth đến T-1)
            H_0: (N, d) latent init (tuỳ chọn)
            horizon: số bước rollout ngoài quan sát cuối cùng
                    (nếu >0, ta cắt input để tránh nhìn thấy tương lai)
            reduction: 'mean' | 'sum' | 'none' cho F.mse_loss

        Returns:
            loss: scalar tensor
            y_pred: (N, T-1, F) dự báo X_{t+1} cho toàn bộ t=0..T-2
        """
        device = next(self.parameters()).device
        X = X.to(device)
        N, T, Fdim = X.shape

        if horizon < 0:
            raise ValueError("horizon must be >= 0")
        if horizon >= T:
            raise ValueError(f"horizon={horizon} must be < sequence length T={T}")

        # Cắt input nếu rollout > 0 để giữ đúng số target (T-1)
        X_in = X[:, : T - horizon, :] if horizon > 0 else X

        # forward() trả (N, (T-horizon)-1 + horizon, F) = (N, T-1, F)
        y_refine, _, y_ar, _, _ = self.forward(X_in, H_0=H_0, horizon=horizon)

        # Ground truth luôn là X_{1:T}
        target = X[:, 1:, :]  # (N, T-1, F)
        
        err_t = (y_ar - target).abs().mean(dim=-1).mean(dim=0)
        
        decay = 0.9
        coef_pre = 1.0
        coef_roll = 1.0
        
        len_pre = T - horizon - 1
        loss_pre = torch.tensor(0.0, device=device, dtype=err_t.dtype)
        loss_roll = torch.tensor(0.0, device=device, dtype=err_t.dtype)

        if len_pre > 0:
            idx = torch.arange(len_pre - 1, -1, -1, device=device, dtype=err_t.dtype)
            w = decay ** idx
            loss_pre = (w * err_t[:len_pre]).sum() / (w.sum() + 1e-12)

        if horizon > 0:
            # MSE đều cho đoạn rollout (chiều dài = horizon)
            loss_roll = err_t[len_pre:].mean()
        
        loss_refine = (y_refine - target).abs().mean()

        loss = loss_refine + coef_pre * loss_pre + coef_roll * loss_roll
        
        return loss

# Eval & Train

In [7]:
from sklearn.metrics import *

def eval_ensemble(args, model, training_data_np, testing_data_np, device='cuda', seq_lens=[64, 96, 128], verbose=False):

    labels = torch.tensor(testing_data_np).float().to(device)
    n_nodes, horizon, n_feats = labels.shape
    y_preds = np.zeros((len(seq_lens), n_nodes, horizon, n_feats-1)) # Bỏ Vol
    model.eval().to(device)
    for i, seq_len in enumerate(seq_lens):
        batch = torch.tensor(training_data_np[:, -seq_len:]).float().to(device)
        with torch.no_grad():
            y_all = model.forecast(batch, horizon=horizon)
        y_preds[i] = y_all[:, -horizon:, :n_feats-1].detach().cpu().numpy() # OHLC + Adj Close
    y_gt = testing_data_np[:, :, :n_feats-1].reshape(n_nodes, -1) # OHLC + Adj Close
    y_pred = y_preds.mean(axis=0).reshape(n_nodes, -1)
    
    if verbose:
        print("Max var:", np.var(np.expm1(y_preds), axis=0).max())
        print("Mean var:", np.var(np.expm1(y_preds), axis=0).mean())
    return {
        'rmse': root_mean_squared_error(y_gt, y_pred), 
        'raw_rmse': root_mean_squared_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'mae': mean_absolute_error(y_gt, y_pred), 
        'raw_mae': mean_absolute_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'r2': r2_score(y_gt.ravel(), y_pred.ravel()),
        'raw_r2': r2_score(np.expm1(y_gt).ravel(), np.expm1(y_pred).ravel()), 
    }

def eval(args, model, training_data_np, testing_data_np, device='cuda'):
    return eval_ensemble(args, model, training_data_np, testing_data_np, device, [args.seq_len])

In [8]:
from tqdm import tqdm
import time

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.avg == 0:
            self.avg = val
            return
        self.avg = 0.95 * self.avg + 0.05 * val

import copy
def train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np):
    # if args.amp:
    #     from apex import amp
    global best_loss, best_model
    test_losses = []
    end = time.time()

    best_model = copy.deepcopy(model)
    step = 0
    for epoch in range(args.epochs):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        stats01 = AverageMeter()
        stats02 = AverageMeter()
        p_bar = tqdm(train_loader)
        for batch_idx, samples in enumerate(p_bar):
            step += 1
            model.train().to(args.device)
          
            samples = samples[0].float().to(args.device)
            data_time.update(time.time() - end)

            loss = model.forward_loss(samples, horizon=args.horizon)
            
            loss.backward()
            
            max_norm = 5.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item()
            
            optimizer.step()
            scheduler.step()

            losses.update(loss.item())
            stats01.update(0.0)
            # stats01.update(power_iteration_lmax_sym(poly_laplacian(model.L, F.softplus(model.kappa))))

            batch_time.update(time.time() - end)
            end = time.time()
            # mask_probs.update(mask.mean().item())
            p_bar.set_description(
                "Ep: {epoch}/{epochs:3}. LR: {lr:.3e}. "
                "Loss: {loss:.4f}. Stats01: {stats01:.4f}".format(
                epoch=epoch + 1,
                epochs=args.epochs,
                lr=scheduler.get_last_lr()[0],
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                stats01=stats01.avg,
            ))
            p_bar.update()
            
            if (step + 1) % args.eval_steps == 0:
                test_model = model

                test_metrics = eval_ensemble(args, test_model, training_data_np, testing_data_np, args.device, [args.seq_len])
                print(test_metrics)
                test_loss = test_metrics['mae']

                is_best = test_loss < best_loss
                if test_loss < best_loss:
                    best_loss = test_loss
                    best_model = copy.deepcopy(test_model)


                test_losses.append(test_loss)
                print('Best loss: {:.3f}'.format(best_loss))
                print('Mean loss: {:.3f}\n'.format(
                    np.mean(test_losses[-20:])))


In [9]:
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

# Training

In [10]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

class Config:
    # Training
    epochs = 5
    eval_steps = 200
    lr = 1e-4
    wd = 1e-3
    warmup = 0
    
    n_nodes = 428
    n_feats = 6
    
    # Prediction
    seq_len = 96
    horizon = 32
    
    # OUGCN
    d_node: int = 32
    in_dropout: float = 0.0
    d_latent: int = 128
    Pdeg: int = 2                  # polynomial degree of Laplacian in K
    safety: float = 0.99           # stability safety factor
    device: str = "cuda"
    top_k_adj: int = 32
    rank_adj: int = 32
    
    # RNN kwargs
    rnn_kwargs = {
        'rnn_arch_type': 'lstm',
        'rnn_n_layers': 1,
        'rnn_dropout': 0.0,
        'rnn_hidden_dim': 128,
        'rnn_backbone_kwargs': { },
    }
    
    # Refiner kwargs
    refiner_kwargs = {
        'num_layers': 1,
        'Pdeg': 2,
        'dropout': 0.1,
    }
    seed = 42

args = Config()
manual_seed(args.seed)


In [11]:
node_emb_path="../input/static_node_emb.npy"
node_emb_matrix = torch.tensor(np.load(node_emb_path), dtype=torch.float32)


def train_fold(fold_idx):
    global training_data_np, testing_data_np
    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = np.log1p(train_folds[fold_idx][0][:, :, used_features])
    testing_data_np = np.log1p(train_folds[fold_idx][1][:, :, used_features])
    train_loader = DataLoader(SimpleStockDataset(training_data_np, args.seq_len + args.horizon), batch_size=1, shuffle=True)
    
    manual_seed(args.seed + fold_idx)
        
    model = OUGCN_with_Refiner(args.n_nodes, args.n_feats, args, node_emb=node_emb_matrix)

    from torch.optim import AdamW

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    total_steps = args.epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup,
        num_training_steps=total_steps,
    )
    
    print(f'Model size: {get_model_size(model) * 1e3:.2f}K')
    
    # Sanity check
    print(eval_ensemble(args, model, training_data_np, testing_data_np, args.device, [2, 4]))
    
    global best_loss, best_model
    best_loss = 9999
    best_model = copy.deepcopy(model)

    train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np)
    
    return best_model, best_loss

In [None]:
best_model, best_loss = train_fold(3)

Model size: 630.33K
{'rmse': 0.380309881486751, 'raw_rmse': 579.2165059396674, 'mae': 0.3428282398365135, 'raw_mae': 129.83877709089793, 'r2': 0.775560176849138, 'raw_r2': -6.986756850649582}


Ep: 1/  5. LR: 9.999e-05. Loss: 0.3201. Stats01: 0.0000:   6%|▌         | 219/3529 [02:14<36:04,  1.53it/s]

{'rmse': 0.10707334198143661, 'raw_rmse': 43.942521752764236, 'mae': 0.09157104603100927, 'raw_mae': 19.235482651270896, 'r2': 0.984412341705579, 'raw_r2': 0.9903683831264545}
Best loss: 0.092
Mean loss: 0.092



Ep: 1/  5. LR: 9.995e-05. Loss: 0.3168. Stats01: 0.0000:  12%|█▏        | 427/3529 [04:37<48:12,  1.07it/s]

{'rmse': 0.09551845733961495, 'raw_rmse': 42.956224472408515, 'mae': 0.07985732003660735, 'raw_mae': 17.670086399113746, 'r2': 0.9878442000839696, 'raw_r2': 0.9911661504642908}
Best loss: 0.080
Mean loss: 0.086



Ep: 1/  5. LR: 9.989e-05. Loss: 0.3016. Stats01: 0.0000:  18%|█▊        | 634/3529 [07:44<39:12,  1.23it/s]

{'rmse': 0.08099990546139797, 'raw_rmse': 33.2661863122052, 'mae': 0.06252817886883644, 'raw_mae': 13.569174981319957, 'r2': 0.9906760188461735, 'raw_r2': 0.9948468407728508}
Best loss: 0.063
Mean loss: 0.078



Ep: 1/  5. LR: 9.981e-05. Loss: 0.2841. Stats01: 0.0000:  24%|██▍       | 839/3529 [10:51<41:23,  1.08it/s]

{'rmse': 0.08456490343212174, 'raw_rmse': 49.92634166349647, 'mae': 0.0671057458812477, 'raw_mae': 16.530334689194426, 'r2': 0.9910245504742486, 'raw_r2': 0.9873858297467437}
Best loss: 0.063
Mean loss: 0.075



Ep: 1/  5. LR: 9.970e-05. Loss: 0.2577. Stats01: 0.0000:  30%|██▉       | 1044/3529 [13:02<24:16,  1.71it/s]

{'rmse': 0.08524057165902361, 'raw_rmse': 40.57053167094432, 'mae': 0.06767289194614172, 'raw_mae': 15.411100935116227, 'r2': 0.989941747099362, 'raw_r2': 0.9926780250971825}
Best loss: 0.063
Mean loss: 0.074



Ep: 1/  5. LR: 9.956e-05. Loss: 0.2647. Stats01: 0.0000:  35%|███▌      | 1248/3529 [15:03<22:57,  1.66it/s]

{'rmse': 0.09094366000676074, 'raw_rmse': 39.57723923564668, 'mae': 0.07542725985521002, 'raw_mae': 16.270665450531553, 'r2': 0.9891697921924792, 'raw_r2': 0.9928097515168822}
Best loss: 0.063
Mean loss: 0.074



Ep: 1/  5. LR: 9.941e-05. Loss: 0.2565. Stats01: 0.0000:  41%|████      | 1452/3529 [17:09<20:53,  1.66it/s]

{'rmse': 0.08055922669108534, 'raw_rmse': 31.248443355673334, 'mae': 0.06365124810490805, 'raw_mae': 13.61684144595778, 'r2': 0.9920060462139344, 'raw_r2': 0.9958988468824975}
Best loss: 0.063
Mean loss: 0.073



Ep: 1/  5. LR: 9.923e-05. Loss: 0.2548. Stats01: 0.0000:  47%|████▋     | 1656/3529 [19:11<15:57,  1.96it/s]

{'rmse': 0.07570289287901687, 'raw_rmse': 35.35045638815908, 'mae': 0.05669932030978189, 'raw_mae': 13.69355228792736, 'r2': 0.9928167537044831, 'raw_r2': 0.9943104147224644}
Best loss: 0.057
Mean loss: 0.071



Ep: 1/  5. LR: 9.914e-05. Loss: 0.2466. Stats01: 0.0000:  49%|████▉     | 1745/3529 [19:49<12:37,  2.36it/s]

In [None]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, args.seq_len)

In [None]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [32, 64, 96])

In [None]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [64, 96, 128])