In [1]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

manual_seed()


# Data

In [2]:
from utils import load_data, load_edge

train_folds = load_data(True)
test_fold = load_data(False)[0]

In [3]:
import torch
from torch.utils.data import DataLoader, Dataset

class SimpleStockDataset(Dataset):
    def __init__(self, data, ws=128):
        self.data = data
        self.ws = ws
        self.samples = []
        
        self.n_tickers, self.n_days, self.n_features = self.data.shape
        
        for start in range(self.n_days - self.ws + 1):
            self.samples.append(start)
            
    def __len__(self):
        return len(self.samples)
      
    def __getitem__(self, idx):
        start = self.samples[idx]
        x = torch.tensor(self.data[:, start:start + self.ws], dtype=torch.float32)
        return x
    

# Model

### TemporalGraphRefiner

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Literal


class MyTemporalEncoderLayerBiLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_layers: int = 1, dropout: float = 0.0, **kwargs):
        super().__init__()
        
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            n_layers,
            batch_first=True,
            dropout=(dropout if n_layers > 1 else 0.0),
            bidirectional=True)

        self.proj = nn.Linear(2 * hidden_size, hidden_size)

    def forward(
        self, x: torch.Tensor):
        x, _ = self.lstm(x)
        
        x = self.proj(x)
        return x

class GraphChebMix(nn.Module):
    def __init__(self, Pdeg: int, d_latent: int, safety: float = 0.99, dropout=0.0):
        super().__init__()
        self.Pdeg = Pdeg
        self.coeffs = nn.Parameter(torch.randn(Pdeg + 1) * 0.05)
        self.safety = safety
        self.dropout = nn.Dropout(dropout)

    def forward(self, X: torch.Tensor, L: torch.Tensor):
        """
        X: (N, H, d_latent); per-time graph mixing.
        """
        Y = self.apply_cheb_seq(L, X, self.coeffs, safety=self.safety)
        return self.dropout(Y)

    @staticmethod
    def apply_cheb_seq(L, X, coeffs, safety=0.99):
        I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
        S = I_N - L

        s = coeffs.abs().sum().clamp_min(1e-12)
        scale = min(1.0, float(safety)) / float(s)
        c = coeffs * scale

        T0 = X
        if c.numel() == 1:
            return c[0] * T0
        T1 = torch.einsum('ij,jhd->ihd', S, X)
        Y  = c[0] * T0 + c[1] * T1
        for j in range(2, c.numel()):
            T2 = 2 * torch.einsum('ij,jhd->ihd', S, T1) - T0
            Y  = Y + c[j] * T2
            T0, T1 = T1, T2
        return Y
      

class TemporalGraphLayer(nn.Module):
    def __init__(self, d_latent: int = 128,
                 Pdeg: int = 2, dropout: float = 0.1):
        super().__init__()
        self.d_latent = d_latent
        self.temporal_encoder = MyTemporalEncoderLayerBiLSTM(d_latent, d_latent)
        self.graph_encoder = GraphChebMix(Pdeg=Pdeg, d_latent=d_latent, dropout=dropout)
        self.ln_temporal = nn.LayerNorm(d_latent)
        self.ln_graph = nn.LayerNorm(d_latent)
        self.fuse = nn.Sequential(nn.Linear(2*d_latent, d_latent))
        self.norm_out = nn.LayerNorm(d_latent)
        

    def forward(self, x, L):
        """
        x:   (N, H, d_latent)
        node_features: (N, d_node)
        L: from your OUGCN
        """
        nn.TransformerEncoderLayer
        x_res = x
        # temporal encoder
        ht = self.ln_temporal(self.temporal_encoder(x)) # (N,H,d_latent)

        # graph encoder
        hg = self.ln_graph(self.graph_encoder(ht, L)) # (N,H,d_latent)

        # fuse
        h = self.fuse(torch.cat([ht, hg], dim=-1))                    # (N,H,d_latent)

        out = self.norm_out(x_res + h)
        
        return out
      

class TemporalGraphRefiner(nn.Module):
    def __init__(self, n_feats: int, d_node: int, d_latent: int = 128,
                 num_layers: int = 2, 
                 Pdeg: int = 2, dropout: float = 0.1):
        super().__init__()
        self.d_latent = d_latent
        self.num_layers = num_layers
        self.fc_in = nn.Linear(n_feats + d_node, d_latent)
        self.temporal_graph_block = nn.ModuleList([
            TemporalGraphLayer(self.d_latent, Pdeg, dropout)
            for _ in range(self.num_layers)
        ])
        self.head = nn.Sequential(
            nn.Linear(d_latent, d_latent), nn.GELU(),
            nn.Linear(d_latent, n_feats)
        )

    def forward(self, x, node_features, L):
        """
        x:   (N, H, F)
        node_features: (N, d_node)
        L: from your OUGCN
        """
        N, H, F = x.shape
        device = x.device
        
        node_broadcast = node_features.unsqueeze(1).expand(N, H, node_features.size(-1))
        x = self.fc_in(torch.cat([x, node_broadcast], dim=-1))  # (N,H,d_latent)

        for _, block in enumerate(self.temporal_graph_block):
            x = block(x, L)
        
        out = self.head(x)
        return out

### RNNBackbone

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Literal

MyRNNArchitectureType = Literal['lstm']

class MyLSTMBackbone(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_layers: int, dropout: float, **kwargs):
        super().__init__()
        self.d_in = input_size
        self.d_latent = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.cells = nn.ModuleList([
            nn.LSTMCell(self.d_in if i == 0 else self.d_latent, self.d_latent)
            for i in range(self.n_layers)
        ])

    def init_states(self, X_0: torch.Tensor, H_0: Optional[torch.Tensor] = None) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        X_0: (N, d_in), H_0: (N, d_latent)
        """
        device = X_0.device
        dtype = X_0.dtype
        N = X_0.shape[0]

        h_list = [torch.zeros(N, self.d_latent, device=device, dtype=dtype) for _ in range(self.n_layers)]
        c_list = [torch.zeros(N, self.d_latent, device=device, dtype=dtype) for _ in range(self.n_layers)]

        if H_0 is not None:
            if H_0.shape != (N, self.d_latent):
                raise ValueError(f"H_0 must be (N, {self.d_latent}), got {tuple(H_0.shape)}")
            h_list[-1] = H_0.to(device=device, dtype=dtype)

        return h_list, c_list

    def forward(
        self,
        input: torch.Tensor,  # (N, d_in)
        hx: Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]] = None
    ) -> Tuple[torch.Tensor, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
        if hx is None:
            hx = self.init_states(input)
        h_list, c_list = hx

        new_h, new_c = [], []
        inp = input
        for l, cell in enumerate(self.cells):
            h_t, c_t = cell(inp, (h_list[l], c_list[l]))
            if self.training and self.dropout > 0:
                h_t = F.dropout(h_t, p=self.dropout, training=True)
            new_h.append(h_t)
            new_c.append(c_t)
            inp = h_t

        h_top = new_h[-1]
        return h_top, (new_h, new_c)


class MyRNNCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, n_layers: int = 1, dropout: float = 0.0, **kwargs):
        super().__init__()
        arch_type = kwargs.get('arch_type', 'lstm')
        backbone_kwargs = kwargs.get('backbone_kwargs', {})
        assert arch_type in ['lstm'], f'Arch not supported: {arch_type}'

        self.rnn = MyLSTMBackbone(input_size, hidden_size, n_layers, dropout, **backbone_kwargs)

    @classmethod
    def make(cls, **kwargs) -> 'MyRNNCell':
        """
        Create Customized RNN Cell compatible with OUGCN.
        """
        default_arch_type: MyRNNArchitectureType = 'lstm'
        defaults = {
            'n_layers': 1,
            'dropout': 0.1,
            'arch_type': default_arch_type,
        }
        return MyRNNCell(**(defaults | kwargs))

    def init_states(self, X_0: torch.Tensor, H_0: Optional[torch.Tensor] = None):
        return self.rnn.init_states(X_0, H_0)

    def forward(
        self,
        input: torch.Tensor,
        hx: Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]] = None
    ):
        return self.rnn(input, hx)


### OUGCN

In [6]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------
# Graph utilities
# ------------------------------

def pearson_corr_matrix(series: torch.Tensor) -> torch.Tensor:
    """Compute Pearson correlation across time for N series.
    Args:
        series: (N, T) time-series (already aligned), not all-constant.
    Returns:
        corr: (N, N) in [-1,1]
    """
    N, T = series.shape
    x = series - series.mean(dim=1, keepdim=True)
    std = x.std(dim=1, unbiased=False, keepdim=True) + 1e-8
    x = x / std
    corr = (x @ x.t()) / T
    corr = corr.clamp(-1.0, 1.0)
    return corr


def build_adjacency_from_corr(corr: torch.Tensor,
                              keep_negative: bool = False,
                              knn: Optional[int] = None,
                              threshold: Optional[float] = None) -> torch.Tensor:
    """Turn correlation into nonnegative adjacency.
    - If keep_negative=False: use relu to zero-out negative correlations.
    - Optional: keep only top-k neighbors per node (excluding self) or apply threshold.
    Returns A with zero diagonal.
    """
    N = corr.shape[0]
    if not keep_negative:
        A = corr.relu()
    else:
        # shift to nonnegative range [0, 2] then rescale to [0,1]
        A = (corr + 1.0) / 2.0
        A = A.clamp(0.0, 1.0)
    eye = torch.eye(N, device=A.device, dtype=torch.bool)
    A = A.masked_fill(eye, 0.0)

    if threshold is not None:
        A = torch.where(A >= threshold, A, torch.zeros_like(A))

    if knn is not None and knn > 0 and knn < N:
        # retain top-k per row (excluding diagonal already 0)
        topk_vals, topk_idx = torch.topk(A, k=knn, dim=1)
        mask = torch.zeros_like(A, dtype=torch.bool)
        mask.scatter_(1, topk_idx, True)
        A = torch.where(mask, A, torch.zeros_like(A))
        # symmetrize by max
        A = torch.max(A, A.t())

    # final symmetry
    A = 0.5 * (A + A.t())
    A = A.masked_fill(eye, 0.0)
    return A

def normalized_adjacency(A):
    """Compute Atilde = Dtilde^{-1/2}(A+I)Dtilde^{-1/2}."""
    I = torch.eye(A.size(0), device=A.device, dtype=A.dtype)
    Ahat = A + I
    d = Ahat.sum(1) + 1e-8
    dinv = d.pow(-0.5)
    return dinv[:,None] * Ahat * dinv[None,:]

def normalized_laplacian(A):
    """Compute L = I - D^{-1/2} A D^{-1/2}."""
    d = A.sum(dim=1) + 1e-8
    dinv = d.pow(-0.5)
    S = dinv[:,None] * A * dinv[None,:]
    I = torch.eye(A.size(0), device=A.device, dtype=A.dtype)
    return I - S

def power_iteration_lmax_sym(M: torch.Tensor, n_iter: int = 25) -> float:
    """Estimate largest eigenvalue (spectral radius) of symmetric PSD matrix M."""
    N = M.shape[0]
    v = torch.randn(N, device=M.device, dtype=M.dtype)
    v = v / (v.norm() + 1e-8)
    for _ in range(n_iter):
        v = M @ v
        n = v.norm() + 1e-8
        v = v / n
    # Rayleigh quotient
    lmax = float((v @ (M @ v)).item())
    return max(lmax, 1e-12)
  
def spec_norm_2(A: torch.Tensor, n_iter: int = 25) -> torch.Tensor:
    """||A||_2 via power method on A^T A. Returns a scalar Tensor (has grad wrt A)."""
    AtA = A.T @ A
    v = torch.randn(AtA.shape[0], device=A.device, dtype=A.dtype)
    v = v / (v.norm() + 1e-8)
    for _ in range(n_iter):
        v = AtA @ v
        v = v / (v.norm() + 1e-8)
    # sqrt(v^T (A^T A) v) = ||A v||, but this Rayleigh ~ lambda_max; take sqrt.
    lam_max = v @ (AtA @ v)
    return lam_max.clamp_min(1e-12).sqrt()  # Tensor

def compute_poly(L, coeffs):
    I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
    return apply_poly_to_emb(L, I_N, coeffs)

def compute_cheb(L, coeffs, safety=0.99):
    I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
    K = apply_cheb_to_emb(L, I_N, coeffs, safety)
    
    # lam_max = spec_norm_2(K)
    # scale = min(1.0, safety / lam_max)
    # return K * scale
    
    return K

def apply_poly_to_emb(L, V, coeffs) -> torch.Tensor:
    """Compute K * V = sum_{i=0}^P coeffs_i * L^i * V.
    """
    out = coeffs[0] * V
    if coeffs.numel() == 1:
        return out
    LV = V
    for i in range(1, coeffs.numel()):
        LV = L @ LV          # O(N^2 d)
        out = out + coeffs[i] * LV
    return out

def apply_cheb_to_emb(L, V, coeffs, safety=0.99):
    """
    Tính y = sum_j c_j T_j(S) V
    S = I - L (phổ trong [-1,1]); V: (N,d) hoặc (N,F)
    """
    I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
    S = I_N - L

    c = clamp_l1(coeffs, safety)
    
    if c.numel() == 1:
        return c[0] * V
    T0 = V
    T1 = S @ V              # T1 = T_1(S)V
    y  = c[0] * T0 + c[1] * T1
    for j in range(2, c.numel()):
        T2 = 2 * (S @ T1) - T0
        y = y + c[j] * T2
        T0, T1 = T1, T2
    return y

def clamp_l1(coeffs, safety=0.99):
    s = coeffs.abs().sum().clamp_min(1e-12)
    scale = torch.minimum(torch.tensor(1.0, device=coeffs.device, dtype=coeffs.dtype), safety / s)
    return coeffs * scale

# ------------------------------
# Model definition
# ------------------------------

import math

class MyGCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation=None, bias=True, dropout=0):
        super(MyGCN, self).__init__()
        self.dropout = dropout
        self.activation = activation
        
        self.weight = nn.Parameter(torch.FloatTensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_feats))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def graph_convolve(self, x, adj):
        support = x @ self.weight
        output = adj @ support
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def forward(self, x, adj):
        x = self.graph_convolve(x, adj)
        if self.activation is not None:
            x = self.activation(x)
        if self.training and self.dropout > 0:
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x

class GatedCell(nn.Module):
    def __init__(self, d_latent):
        super().__init__()
        self.z_gate = nn.Linear(2*d_latent, d_latent)
        nn.init.constant_(self.z_gate.bias, -1.0)
        self.proj_A = nn.Sequential(
            nn.LayerNorm(d_latent),
            nn.Linear(d_latent, d_latent),
            nn.GELU(),
        )
        self.proj_B = nn.Sequential(
            nn.LayerNorm(d_latent),
            nn.Linear(d_latent, d_latent),
            nn.GELU(),
        )

    def forward(self, A, B):
        z = torch.sigmoid(self.z_gate(torch.cat([self.proj_A(A), self.proj_B(B)], dim=-1)))
        
        A = (1 - z) * A + z * B
        return A

class OUGCN_with_Refiner(nn.Module):
    def __init__(self, n_nodes: int, n_feats: int, args, node_emb=None):
        super().__init__()
        self.n_nodes = n_nodes
        self.rank_adj = args.rank_adj
        self.top_k_adj = args.top_k_adj
        self.n_feats = n_feats
        self.args = args
        d_latent = args.d_latent
        d_node = args.d_node
        in_dropout = args.in_dropout
        
        rnn_kwargs:dict = getattr(args, 'rnn_kwargs', {})
        
        rnn_arch_type = rnn_kwargs.setdefault('rnn_arch_type', 'lstm')
        rnn_n_layers = rnn_kwargs.setdefault('rnn_n_layers', 1)
        rnn_hidden_dim = rnn_kwargs.setdefault('rnn_hidden_dim', 128)
        rnn_dropout = rnn_kwargs.setdefault('rnn_dropout', 0.0)
        rnn_backbone_kwargs = rnn_kwargs.setdefault('rnn_backbone_kwargs', {})
        

        if node_emb is None:
            self.static_node_features = nn.Parameter(torch.randn(n_nodes, d_node) * 0.1)
        else:
            self.static_node_features = nn.Parameter(node_emb, requires_grad=False)
            d_node = self.static_node_features.shape[-1]
            
        
        self.fc_in = nn.Linear(n_feats + d_node, d_latent)
        self.gcn_in = MyGCN(d_latent, d_latent, F.relu, dropout=in_dropout)
        self.rnn_mean = MyRNNCell.make(
            input_size=n_feats, hidden_size=rnn_hidden_dim,
            n_layers=rnn_n_layers, dropout=rnn_dropout,
            arch_type=rnn_arch_type,
            backbone_kwargs=rnn_backbone_kwargs)
        self.gcn_mean = MyGCN(rnn_hidden_dim, d_latent, F.relu, dropout=in_dropout)
        self.fc_mean = nn.Linear(rnn_hidden_dim, d_latent)
        
        self.readout = nn.Sequential(
            nn.Linear(d_latent, d_latent), nn.GELU(),
            nn.Linear(d_latent, n_feats), nn.Tanh(),
        )
        self.res_scale = nn.Parameter(torch.ones(n_feats))
        
        self.kappa_H = nn.Parameter(torch.randn(args.Pdeg + 1) * 0.1)
        self.kappa_M = nn.Parameter(torch.randn(args.Pdeg + 1) * 0.1)
        
        self.H_mix_module = GatedCell(d_latent)
        self.M_mix_module = GatedCell(d_latent)
        
        self.fc_corr_features = nn.Sequential(
            nn.Linear(d_node, d_latent), nn.ReLU(),
            nn.Linear(d_latent, self.rank_adj)
        )
        # self.fc_node_features = nn.Sequential(
        #     nn.Linear(d_node, d_latent), nn.ReLU(),
        #     nn.Linear(d_latent, d_latent),
        # )
        
        self.refiner = TemporalGraphRefiner(n_feats, d_node, d_latent,
                                            **args.refiner_kwargs)
        
        self.corr_features = None
        self.node_features = None
        
    def build_graph(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Build A (adj), Atilde (norm adj for GCN), and L (normalized Laplacian) from data X.
        Right now we build a data-independent static learnable adj matrix.
        X: (N, T, F)
        Returns: (A, Atilde, L) all in device/dtype of X.
        """
        # n_nodes, n_steps, n_feats = X.shape
        if self.corr_features is None:
            self.corr_features = self.fc_corr_features(self.static_node_features)
        self.corr_features = F.normalize(self.corr_features, dim=-1)
        A = self.corr_features @ self.corr_features.T
        # A.fill_diagonal_(0.0)
        A = build_adjacency_from_corr(A, keep_negative=False, knn=self.top_k_adj)
        Atilde = normalized_adjacency(A)
        L = normalized_laplacian(A)
        return A, Atilde, L
      

    def _compute_graph_ops(self, X: torch.Tensor):
        """Helper: build Atilde, L, K, identity."""
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        _, Atilde, L = self.build_graph(X)
        Atilde = Atilde.to(device)
        L = L.to(device)
        
        return Atilde, L
    
    def _step(self, X_t, H_t, H_filter, M_filter, gcn_filter, rnn_mean_latent=None):
        Z_t_hist, rnn_mean_latent = self.rnn_mean.forward(X_t, rnn_mean_latent)
        
        M_t = self.fc_mean(Z_t_hist) + self.gcn_mean(Z_t_hist, gcn_filter) # mean embedding
        
        H_cand = self.H_mix_module(H_t, H_filter @ H_t)
        M_cand = self.M_mix_module(M_t, M_filter @ M_t)
        H_next = H_cand + M_cand
        r_t = self.readout(H_next) * self.res_scale  # (N, F)
        
        return r_t, H_next, rnn_mean_latent

    
    
    def forecast(self, X: torch.Tensor,
                H_0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        y_pred, *_ = self.forward(X, H_0, horizon)
        return y_pred
    
    def forward(self, X: torch.Tensor,
                H_0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run deterministic inference.
        Args:
            X: (N, T, F) input features at each t
            H_0: (N, d) optional initial latent state (default zeros)
            Atilde, L: if precomputed; otherwise build from X
        Returns:
            y_pred: (N, T-1) predicted next-day log-return for each node
            H_all: (N, T, d) latent states (including H_0 as first)
        """
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        n_nodes, n_steps, n_feats = X.shape
        d_latent = self.args.d_latent
        
        self.corr_features = self.fc_corr_features(self.static_node_features)
        # self.node_features = self.fc_node_features(self.static_node_features)
        

        Atilde, L = self._compute_graph_ops(X)
        I_N = torch.eye(L.shape[0], device=L.device, dtype=L.dtype)
        
        
        self.L = L

        X_0 = X[:, 0, :] # (N, F)
        if H_0 is None:
            x_in = self.fc_in(torch.cat([X_0, self.static_node_features], dim=-1)) # (N, F + d_node)
            H_t = self.gcn_in(x_in, Atilde)
        else:
            H_t = H_0.to(device)
        
        # H_t = H_t + self.node_features

        H_all = torch.zeros(n_nodes, n_steps + max(horizon, 0), d_latent, device=device, dtype=X.dtype)
        r_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)
        y_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)

        H_all[:, 0] = H_t
        
        H_filter = compute_cheb(L, self.kappa_H)
        M_filter = I_N - Atilde
        gcn_filter = compute_cheb(L, self.kappa_M)
        
        rnn_mean_latent = self.rnn_mean.init_states(X_0, H_t)
        
        for t in range(n_steps - 1):
            X_t = X[:, t, :] # (N, D)
            r_t, H_next, rnn_mean_latent = self._step(X_t, H_t, H_filter, M_filter, gcn_filter, rnn_mean_latent)

            H_all[:, t + 1] = H_next
            r_all[:, t] = r_t
            y_all[:, t] = X_t + r_t
            H_t = H_next
            
        if horizon > 0:
            X_t = X[:, -1, :]
            for s in range(horizon):
                r_t, H_next, rnn_mean_latent = self._step(X_t, H_t, H_filter, M_filter, gcn_filter, rnn_mean_latent)

                H_all[:, n_steps + s] = H_next
                r_all[:, n_steps - 1 + s] = r_t
                H_t = H_next
                
                y_all[:, n_steps - 1 + s] = X_t + r_t
                X_t = X_t + r_t
        
        r_refine = self.refiner.forward(y_all, self.static_node_features, L)
        y_refine = y_all + r_refine
        
        return y_refine, r_refine, y_all, r_all, H_all




    def forward_loss(
        self,
        X: torch.Tensor,
        H_0: Optional[torch.Tensor] = None,
        horizon: int = 0,
    ):
        """
        Tính loss dự báo one-step + rollout horizon (autoregressive) trên chuỗi đầu vào.

        Args:
            X: (N, T, F) chuỗi gốc (chứa full ground-truth đến T-1)
            H_0: (N, d) latent init (tuỳ chọn)
            horizon: số bước rollout ngoài quan sát cuối cùng
                    (nếu >0, ta cắt input để tránh nhìn thấy tương lai)
            reduction: 'mean' | 'sum' | 'none' cho F.mse_loss

        Returns:
            loss: scalar tensor
            y_pred: (N, T-1, F) dự báo X_{t+1} cho toàn bộ t=0..T-2
        """
        device = next(self.parameters()).device
        X = X.to(device)
        N, T, Fdim = X.shape

        if horizon < 0:
            raise ValueError("horizon must be >= 0")
        if horizon >= T:
            raise ValueError(f"horizon={horizon} must be < sequence length T={T}")

        # Cắt input nếu rollout > 0 để giữ đúng số target (T-1)
        X_in = X[:, : T - horizon, :] if horizon > 0 else X

        # forward() trả (N, (T-horizon)-1 + horizon, F) = (N, T-1, F)
        y_refine, _, y_ar, _, _ = self.forward(X_in, H_0=H_0, horizon=horizon)

        # Ground truth luôn là X_{1:T}
        target = X[:, 1:, :]  # (N, T-1, F)
        
        err_t = (y_ar - target).abs().mean(dim=-1).mean(dim=0)
        
        decay = 0.9
        coef_pre = 1.0
        coef_roll = 1.0
        
        len_pre = T - horizon - 1
        loss_pre = torch.tensor(0.0, device=device, dtype=err_t.dtype)
        loss_roll = torch.tensor(0.0, device=device, dtype=err_t.dtype)

        if len_pre > 0:
            idx = torch.arange(len_pre - 1, -1, -1, device=device, dtype=err_t.dtype)
            w = decay ** idx
            loss_pre = (w * err_t[:len_pre]).sum() / (w.sum() + 1e-12)

        if horizon > 0:
            # MSE đều cho đoạn rollout (chiều dài = horizon)
            loss_roll = err_t[len_pre:].mean()
        
        loss_refine = (y_refine - target).abs().mean()

        loss = loss_refine + coef_pre * loss_pre + coef_roll * loss_roll
        
        return loss

# Eval & Train

In [14]:
from sklearn.metrics import *

def eval_ensemble(args, model, training_data_np, testing_data_np, device='cuda', seq_lens=[64, 96, 128], verbose=False):

    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = np.log1p(training_data_np[:, :, used_features])
    testing_data_np = np.log1p(testing_data_np[:, :, used_features])
    
    labels = torch.tensor(testing_data_np).float().to(device)
    n_nodes, horizon, n_feats = labels.shape
    y_preds = np.zeros((len(seq_lens), n_nodes, horizon, n_feats-1)) # Bỏ Vol
    model.eval().to(device)
    for i, seq_len in enumerate(seq_lens):
        batch = torch.tensor(training_data_np[:, -seq_len:]).float().to(device)
        with torch.no_grad():
            y_all = model.forecast(batch, horizon=horizon)
        y_preds[i] = y_all[:, -horizon:, :n_feats-1].detach().cpu().numpy() # OHLC + Adj Close
    y_gt = testing_data_np[:, :, :n_feats-1].reshape(n_nodes, -1) # OHLC + Adj Close
    y_pred = y_preds.mean(axis=0).reshape(n_nodes, -1)
    
    if verbose:
        print("Max var:", np.var(np.expm1(y_preds), axis=0).max())
        print("Mean var:", np.var(np.expm1(y_preds), axis=0).mean())
    return {
        'rmse': root_mean_squared_error(y_gt, y_pred), 
        'raw_rmse': root_mean_squared_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'mae': mean_absolute_error(y_gt, y_pred), 
        'raw_mae': mean_absolute_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'r2': r2_score(y_gt.ravel(), y_pred.ravel()),
        'raw_r2': r2_score(np.expm1(y_gt).ravel(), np.expm1(y_pred).ravel()), 
    }

def eval(args, model, training_data_np, testing_data_np, device='cuda'):
    return eval_ensemble(args, model, training_data_np, testing_data_np, device, [args.seq_len])

In [8]:
from tqdm import tqdm
import time

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.avg == 0:
            self.avg = val
            return
        self.avg = 0.95 * self.avg + 0.05 * val

import copy
def train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np):
    # if args.amp:
    #     from apex import amp
    global best_loss, best_model
    test_losses = []
    end = time.time()

    best_model = copy.deepcopy(model)
    step = 0
    for epoch in range(args.epochs):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        stats01 = AverageMeter()
        stats02 = AverageMeter()
        p_bar = tqdm(train_loader)
        for batch_idx, samples in enumerate(p_bar):
            step += 1
            model.train().to(args.device)
          
            samples = samples[0].float().to(args.device)
            data_time.update(time.time() - end)

            loss = model.forward_loss(samples, horizon=args.horizon)
            
            loss.backward()
            
            max_norm = 5.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item()
            
            optimizer.step()
            scheduler.step()

            losses.update(loss.item())
            stats01.update(0.0)
            # stats01.update(power_iteration_lmax_sym(poly_laplacian(model.L, F.softplus(model.kappa))))

            batch_time.update(time.time() - end)
            end = time.time()
            # mask_probs.update(mask.mean().item())
            p_bar.set_description(
                "Ep: {epoch}/{epochs:3}. LR: {lr:.3e}. "
                "Loss: {loss:.4f}. Stats01: {stats01:.4f}".format(
                epoch=epoch + 1,
                epochs=args.epochs,
                lr=scheduler.get_last_lr()[0],
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                stats01=stats01.avg,
            ))
            p_bar.update()
            
            if (step + 1) % args.eval_steps == 0:
                test_model = model

                test_metrics = eval_ensemble(args, test_model, training_data_np, testing_data_np, args.device, [args.seq_len])
                print(test_metrics)
                test_loss = test_metrics['mae']

                is_best = test_loss < best_loss
                if test_loss < best_loss:
                    best_loss = test_loss
                    best_model = copy.deepcopy(test_model)


                test_losses.append(test_loss)
                print('Best loss: {:.3f}'.format(best_loss))
                print('Mean loss: {:.3f}\n'.format(
                    np.mean(test_losses[-20:])))


In [9]:
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

# Training

In [10]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

class Config:
    # Training
    epochs = 5
    eval_steps = 200
    lr = 1e-4
    wd = 1e-3
    warmup = 0
    
    n_nodes = 428
    n_feats = 6
    
    # Prediction
    seq_len = 96
    horizon = 32
    
    # OUGCN
    d_node: int = 32
    in_dropout: float = 0.0
    d_latent: int = 128
    Pdeg: int = 2                  # polynomial degree of Laplacian in K
    safety: float = 0.99           # stability safety factor
    device: str = "cuda"
    top_k_adj: int = 32
    rank_adj: int = 32
    
    # RNN kwargs
    rnn_kwargs = {
        'rnn_arch_type': 'lstm',
        'rnn_n_layers': 1,
        'rnn_dropout': 0.0,
        'rnn_hidden_dim': 128,
        'rnn_backbone_kwargs': { },
    }
    
    # Refiner kwargs
    refiner_kwargs = {
        'num_layers': 1,
        'Pdeg': 2,
        'dropout': 0.1,
    }
    seed = 42

args = Config()
manual_seed(args.seed)


In [11]:
node_emb_path="../input/static_node_emb.npy"
node_emb_matrix = torch.tensor(np.load(node_emb_path), dtype=torch.float32)


def train_fold(fold_idx):
    global training_data_np, testing_data_np
    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = np.log1p(train_folds[fold_idx][0][:, :, used_features])
    testing_data_np = np.log1p(train_folds[fold_idx][1][:, :, used_features])
    train_loader = DataLoader(SimpleStockDataset(training_data_np, args.seq_len + args.horizon), batch_size=1, shuffle=True)
    
    manual_seed(args.seed + fold_idx)
        
    model = OUGCN_with_Refiner(args.n_nodes, args.n_feats, args, node_emb=node_emb_matrix)

    from torch.optim import AdamW

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    total_steps = args.epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup,
        num_training_steps=total_steps,
    )
    
    print(f'Model size: {get_model_size(model) * 1e3:.2f}K')
    
    # Sanity check
    print(eval_ensemble(args, model, training_data_np, testing_data_np, args.device, [2, 4]))
    
    global best_loss, best_model
    best_loss = 9999
    best_model = copy.deepcopy(model)

    train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np)
    
    return best_model, best_loss

In [12]:
best_model, best_loss = train_fold(3)

Model size: 630.33K
{'rmse': 0.380309881486751, 'raw_rmse': 579.2165059396674, 'mae': 0.3428282398365135, 'raw_mae': 129.83877709089793, 'r2': 0.775560176849138, 'raw_r2': -6.986756850649582}


Ep: 1/  5. LR: 9.999e-05. Loss: 0.3201. Stats01: 0.0000:   6%|▌         | 219/3529 [02:14<36:04,  1.53it/s]

{'rmse': 0.10707334198143661, 'raw_rmse': 43.942521752764236, 'mae': 0.09157104603100927, 'raw_mae': 19.235482651270896, 'r2': 0.984412341705579, 'raw_r2': 0.9903683831264545}
Best loss: 0.092
Mean loss: 0.092



Ep: 1/  5. LR: 9.995e-05. Loss: 0.3168. Stats01: 0.0000:  12%|█▏        | 427/3529 [04:37<48:12,  1.07it/s]

{'rmse': 0.09551845733961495, 'raw_rmse': 42.956224472408515, 'mae': 0.07985732003660735, 'raw_mae': 17.670086399113746, 'r2': 0.9878442000839696, 'raw_r2': 0.9911661504642908}
Best loss: 0.080
Mean loss: 0.086



Ep: 1/  5. LR: 9.989e-05. Loss: 0.3016. Stats01: 0.0000:  18%|█▊        | 634/3529 [07:44<39:12,  1.23it/s]

{'rmse': 0.08099990546139797, 'raw_rmse': 33.2661863122052, 'mae': 0.06252817886883644, 'raw_mae': 13.569174981319957, 'r2': 0.9906760188461735, 'raw_r2': 0.9948468407728508}
Best loss: 0.063
Mean loss: 0.078



Ep: 1/  5. LR: 9.981e-05. Loss: 0.2841. Stats01: 0.0000:  24%|██▍       | 839/3529 [10:51<41:23,  1.08it/s]

{'rmse': 0.08456490343212174, 'raw_rmse': 49.92634166349647, 'mae': 0.0671057458812477, 'raw_mae': 16.530334689194426, 'r2': 0.9910245504742486, 'raw_r2': 0.9873858297467437}
Best loss: 0.063
Mean loss: 0.075



Ep: 1/  5. LR: 9.970e-05. Loss: 0.2577. Stats01: 0.0000:  30%|██▉       | 1044/3529 [13:02<24:16,  1.71it/s]

{'rmse': 0.08524057165902361, 'raw_rmse': 40.57053167094432, 'mae': 0.06767289194614172, 'raw_mae': 15.411100935116227, 'r2': 0.989941747099362, 'raw_r2': 0.9926780250971825}
Best loss: 0.063
Mean loss: 0.074



Ep: 1/  5. LR: 9.956e-05. Loss: 0.2647. Stats01: 0.0000:  35%|███▌      | 1248/3529 [15:03<22:57,  1.66it/s]

{'rmse': 0.09094366000676074, 'raw_rmse': 39.57723923564668, 'mae': 0.07542725985521002, 'raw_mae': 16.270665450531553, 'r2': 0.9891697921924792, 'raw_r2': 0.9928097515168822}
Best loss: 0.063
Mean loss: 0.074



Ep: 1/  5. LR: 9.941e-05. Loss: 0.2565. Stats01: 0.0000:  41%|████      | 1452/3529 [17:09<20:53,  1.66it/s]

{'rmse': 0.08055922669108534, 'raw_rmse': 31.248443355673334, 'mae': 0.06365124810490805, 'raw_mae': 13.61684144595778, 'r2': 0.9920060462139344, 'raw_r2': 0.9958988468824975}
Best loss: 0.063
Mean loss: 0.073



Ep: 1/  5. LR: 9.923e-05. Loss: 0.2548. Stats01: 0.0000:  47%|████▋     | 1656/3529 [19:11<15:57,  1.96it/s]

{'rmse': 0.07570289287901687, 'raw_rmse': 35.35045638815908, 'mae': 0.05669932030978189, 'raw_mae': 13.69355228792736, 'r2': 0.9928167537044831, 'raw_r2': 0.9943104147224644}
Best loss: 0.057
Mean loss: 0.071



Ep: 1/  5. LR: 9.902e-05. Loss: 0.2525. Stats01: 0.0000:  53%|█████▎    | 1859/3529 [20:35<10:37,  2.62it/s]

{'rmse': 0.08005921902901356, 'raw_rmse': 38.05587132453037, 'mae': 0.06181343452692814, 'raw_mae': 15.1410943189822, 'r2': 0.9917530681535538, 'raw_r2': 0.9932329875317361}
Best loss: 0.057
Mean loss: 0.070



Ep: 1/  5. LR: 9.879e-05. Loss: 0.2528. Stats01: 0.0000:  58%|█████▊    | 2062/3529 [21:51<09:05,  2.69it/s]

{'rmse': 0.08168296434498619, 'raw_rmse': 34.43049290726285, 'mae': 0.06475566434186744, 'raw_mae': 13.763093221730387, 'r2': 0.9914631747490419, 'raw_r2': 0.994659063644358}
Best loss: 0.057
Mean loss: 0.069



Ep: 1/  5. LR: 9.854e-05. Loss: 0.2430. Stats01: 0.0000:  64%|██████▍   | 2265/3529 [23:05<08:03,  2.61it/s]

{'rmse': 0.06819730201955225, 'raw_rmse': 30.387762329695946, 'mae': 0.048148852271544765, 'raw_mae': 11.55273355966306, 'r2': 0.994307755490576, 'raw_r2': 0.9960795840563312}
Best loss: 0.048
Mean loss: 0.067



Ep: 1/  5. LR: 9.826e-05. Loss: 0.2346. Stats01: 0.0000:  70%|██████▉   | 2468/3529 [24:20<06:10,  2.86it/s]

{'rmse': 0.0683189156552521, 'raw_rmse': 31.00928358635717, 'mae': 0.04874954553041807, 'raw_mae': 11.571269449595885, 'r2': 0.9943701281333612, 'raw_r2': 0.9959865640167348}
Best loss: 0.048
Mean loss: 0.066



Ep: 1/  5. LR: 9.796e-05. Loss: 0.2288. Stats01: 0.0000:  76%|███████▌  | 2671/3529 [25:36<06:16,  2.28it/s]

{'rmse': 0.06908171595487672, 'raw_rmse': 27.435232711696454, 'mae': 0.050051448487294446, 'raw_mae': 10.909601313553745, 'r2': 0.9940658842138047, 'raw_r2': 0.9968265579039534}
Best loss: 0.048
Mean loss: 0.064



Ep: 1/  5. LR: 9.763e-05. Loss: 0.2411. Stats01: 0.0000:  81%|████████▏ | 2874/3529 [26:51<03:59,  2.73it/s]

{'rmse': 0.06268322629462868, 'raw_rmse': 24.145620001471997, 'mae': 0.04225775915488704, 'raw_mae': 9.46333882119759, 'r2': 0.9952085456703283, 'raw_r2': 0.9976153060461708}
Best loss: 0.042
Mean loss: 0.063



Ep: 1/  5. LR: 9.728e-05. Loss: 0.2299. Stats01: 0.0000:  87%|████████▋ | 3076/3529 [28:04<02:50,  2.66it/s]

{'rmse': 0.07091956739459968, 'raw_rmse': 29.18379479451576, 'mae': 0.05254021158094157, 'raw_mae': 11.458154786826121, 'r2': 0.9939375241794307, 'raw_r2': 0.996371666140348}
Best loss: 0.042
Mean loss: 0.062



Ep: 1/  5. LR: 9.691e-05. Loss: 0.2334. Stats01: 0.0000:  93%|█████████▎| 3279/3529 [29:17<01:31,  2.74it/s]

{'rmse': 0.06784772338760514, 'raw_rmse': 26.09471316135411, 'mae': 0.04930131923530237, 'raw_mae': 10.445645915927532, 'r2': 0.9943650895509516, 'raw_r2': 0.9970944598233477}
Best loss: 0.042
Mean loss: 0.061



Ep: 1/  5. LR: 9.652e-05. Loss: 0.2285. Stats01: 0.0000:  99%|█████████▊| 3481/3529 [31:06<00:28,  1.68it/s]

{'rmse': 0.06364333362041939, 'raw_rmse': 27.43092496480894, 'mae': 0.043276120206755725, 'raw_mae': 10.148611469073428, 'r2': 0.9950938315385451, 'raw_r2': 0.9966859709662272}
Best loss: 0.042
Mean loss: 0.060



Ep: 1/  5. LR: 9.625e-05. Loss: 0.2310. Stats01: 0.0000: 100%|██████████| 3529/3529 [32:25<00:00,  1.81it/s]
Ep: 2/  5. LR: 9.610e-05. Loss: 0.2374. Stats01: 0.0000:   2%|▏         | 82/3529 [00:42<30:08,  1.91it/s]

{'rmse': 0.06351557614675085, 'raw_rmse': 25.51855483452781, 'mae': 0.043586907225662684, 'raw_mae': 9.81726235108789, 'r2': 0.9951565594668645, 'raw_r2': 0.9973188702778734}
Best loss: 0.042
Mean loss: 0.059



Ep: 2/  5. LR: 9.565e-05. Loss: 0.2233. Stats01: 0.0000:   8%|▊         | 293/3529 [02:42<32:06,  1.68it/s]

{'rmse': 0.06545612705098239, 'raw_rmse': 26.251370202768065, 'mae': 0.04604500320880332, 'raw_mae': 10.231830863306996, 'r2': 0.9948325589409062, 'raw_r2': 0.997154896809997}
Best loss: 0.042
Mean loss: 0.059



Ep: 2/  5. LR: 9.519e-05. Loss: 0.2240. Stats01: 0.0000:  14%|█▍        | 501/3529 [04:42<27:13,  1.85it/s]

{'rmse': 0.06743763589508735, 'raw_rmse': 25.65401883728691, 'mae': 0.048642875389139736, 'raw_mae': 10.597848659380626, 'r2': 0.994543211036594, 'raw_r2': 0.9972953975464813}
Best loss: 0.042
Mean loss: 0.058



Ep: 2/  5. LR: 9.470e-05. Loss: 0.2134. Stats01: 0.0000:  20%|██        | 707/3529 [06:42<25:49,  1.82it/s]

{'rmse': 0.06602776332434596, 'raw_rmse': 27.114482463550534, 'mae': 0.04677288527096173, 'raw_mae': 10.420747412046172, 'r2': 0.9947054536393211, 'raw_r2': 0.9968930309794943}
Best loss: 0.042
Mean loss: 0.056



Ep: 2/  5. LR: 9.419e-05. Loss: 0.2230. Stats01: 0.0000:  26%|██▌       | 912/3529 [08:43<25:07,  1.74it/s]

{'rmse': 0.06426247003618266, 'raw_rmse': 27.19684210282654, 'mae': 0.044715821065258544, 'raw_mae': 10.197797333811685, 'r2': 0.9950491864614814, 'raw_r2': 0.9969338556093239}
Best loss: 0.042
Mean loss: 0.054



Ep: 2/  5. LR: 9.365e-05. Loss: 0.2162. Stats01: 0.0000:  32%|███▏      | 1116/3529 [10:42<23:45,  1.69it/s]

{'rmse': 0.06487649855911072, 'raw_rmse': 26.188328906448415, 'mae': 0.04560495540510056, 'raw_mae': 9.958073891978941, 'r2': 0.9949406448372057, 'raw_r2': 0.9971416717376018}
Best loss: 0.042
Mean loss: 0.053



Ep: 2/  5. LR: 9.309e-05. Loss: 0.2159. Stats01: 0.0000:  37%|███▋      | 1320/3529 [12:41<22:13,  1.66it/s]

{'rmse': 0.06399415069236114, 'raw_rmse': 23.967643448271172, 'mae': 0.044058512309272053, 'raw_mae': 9.581902786237716, 'r2': 0.995017250448533, 'raw_r2': 0.9975979038217526}
Best loss: 0.042
Mean loss: 0.052



Ep: 2/  5. LR: 9.251e-05. Loss: 0.2155. Stats01: 0.0000:  43%|████▎     | 1524/3529 [14:41<19:20,  1.73it/s]

{'rmse': 0.06583959623158427, 'raw_rmse': 28.8906459455353, 'mae': 0.04613842236005105, 'raw_mae': 10.780193665970174, 'r2': 0.9947857893486957, 'raw_r2': 0.9964814581182099}
Best loss: 0.042
Mean loss: 0.051



Ep: 2/  5. LR: 9.191e-05. Loss: 0.2288. Stats01: 0.0000:  49%|████▉     | 1728/3529 [16:40<17:32,  1.71it/s]

{'rmse': 0.06492081083059678, 'raw_rmse': 26.270400738194038, 'mae': 0.0454053153927422, 'raw_mae': 10.067063134110484, 'r2': 0.9948818577346853, 'raw_r2': 0.9970846640711205}
Best loss: 0.042
Mean loss: 0.050



Ep: 2/  5. LR: 9.129e-05. Loss: 0.2218. Stats01: 0.0000:  55%|█████▍    | 1931/3529 [18:39<15:48,  1.68it/s]

{'rmse': 0.06475804810409241, 'raw_rmse': 24.579120220927194, 'mae': 0.04482318221337671, 'raw_mae': 9.80748296999421, 'r2': 0.9950105658829087, 'raw_r2': 0.9975478528631317}
Best loss: 0.042
Mean loss: 0.049



Ep: 2/  5. LR: 9.064e-05. Loss: 0.2088. Stats01: 0.0000:  60%|██████    | 2134/3529 [20:38<13:33,  1.71it/s]

{'rmse': 0.06303512891765517, 'raw_rmse': 25.90440281342561, 'mae': 0.042439666902457235, 'raw_mae': 9.951043944956886, 'r2': 0.9951974541966951, 'raw_r2': 0.9971879881390819}
Best loss: 0.042
Mean loss: 0.048



Ep: 2/  5. LR: 8.997e-05. Loss: 0.2123. Stats01: 0.0000:  66%|██████▌   | 2337/3529 [22:38<11:52,  1.67it/s]

{'rmse': 0.06596912442239178, 'raw_rmse': 26.855125183410323, 'mae': 0.04654430103974664, 'raw_mae': 10.288904679830958, 'r2': 0.9947198379610859, 'raw_r2': 0.9969387594955343}
Best loss: 0.042
Mean loss: 0.047



Ep: 2/  5. LR: 8.928e-05. Loss: 0.2197. Stats01: 0.0000:  72%|███████▏  | 2540/3529 [24:38<09:54,  1.66it/s]

{'rmse': 0.06159936973078773, 'raw_rmse': 24.29618365919352, 'mae': 0.04094627187449216, 'raw_mae': 9.300656594919626, 'r2': 0.99547104180766, 'raw_r2': 0.9976131523748498}
Best loss: 0.041
Mean loss: 0.046



Ep: 2/  5. LR: 8.857e-05. Loss: 0.2107. Stats01: 0.0000:  78%|███████▊  | 2743/3529 [26:37<07:45,  1.69it/s]

{'rmse': 0.0610933691564088, 'raw_rmse': 24.867260167487412, 'mae': 0.040517179007530324, 'raw_mae': 9.198081969450175, 'r2': 0.9955033770544683, 'raw_r2': 0.9974742005360432}
Best loss: 0.041
Mean loss: 0.046



Ep: 2/  5. LR: 8.783e-05. Loss: 0.2074. Stats01: 0.0000:  83%|████████▎ | 2946/3529 [28:06<03:20,  2.90it/s]

{'rmse': 0.06327029240475768, 'raw_rmse': 25.39072603377994, 'mae': 0.04317254933733441, 'raw_mae': 9.774033591002434, 'r2': 0.9951783948715424, 'raw_r2': 0.9973290390331897}
Best loss: 0.041
Mean loss: 0.045



Ep: 2/  5. LR: 8.708e-05. Loss: 0.2145. Stats01: 0.0000:  89%|████████▉ | 3148/3529 [29:16<02:13,  2.85it/s]

{'rmse': 0.06764990527704269, 'raw_rmse': 27.588464913205225, 'mae': 0.04883791381816727, 'raw_mae': 10.649109857055878, 'r2': 0.9944805357128272, 'raw_r2': 0.9968148028836672}
Best loss: 0.041
Mean loss: 0.045



Ep: 2/  5. LR: 8.630e-05. Loss: 0.2093. Stats01: 0.0000:  95%|█████████▍| 3351/3529 [30:25<01:01,  2.88it/s]

{'rmse': 0.06393132470603416, 'raw_rmse': 25.46686671778081, 'mae': 0.04414216695314485, 'raw_mae': 9.791373102603462, 'r2': 0.9950960217921707, 'raw_r2': 0.9973255447210184}
Best loss: 0.041
Mean loss: 0.045



Ep: 2/  5. LR: 8.550e-05. Loss: 0.2112. Stats01: 0.0000: : 3553it [31:34,  2.88it/s]                        

{'rmse': 0.06577031048072976, 'raw_rmse': 26.82776050008454, 'mae': 0.04618431879310074, 'raw_mae': 10.158534320603628, 'r2': 0.9948307227853025, 'raw_r2': 0.9970940048466644}
Best loss: 0.041
Mean loss: 0.045



Ep: 2/  5. LR: 8.526e-05. Loss: 0.2069. Stats01: 0.0000: 100%|██████████| 3529/3529 [31:54<00:00,  1.84it/s]
Ep: 3/  5. LR: 8.469e-05. Loss: 0.2122. Stats01: 0.0000:   4%|▍         | 158/3529 [00:48<17:24,  3.23it/s]

{'rmse': 0.060209911893805736, 'raw_rmse': 23.416768061650977, 'mae': 0.03941983768828995, 'raw_mae': 8.986242174692988, 'r2': 0.9956352507468043, 'raw_r2': 0.9977420660686344}
Best loss: 0.039
Mean loss: 0.045



Ep: 3/  5. LR: 8.385e-05. Loss: 0.2119. Stats01: 0.0000:  10%|█         | 367/3529 [01:57<17:57,  2.93it/s]

{'rmse': 0.06530506363155264, 'raw_rmse': 27.6385944918017, 'mae': 0.045661969803434996, 'raw_mae': 10.302005148868018, 'r2': 0.9949008157396768, 'raw_r2': 0.9968331405620127}
Best loss: 0.039
Mean loss: 0.045



Ep: 3/  5. LR: 8.299e-05. Loss: 0.2118. Stats01: 0.0000:  16%|█▋        | 574/3529 [03:06<16:47,  2.93it/s]

{'rmse': 0.06195460733698544, 'raw_rmse': 23.276746783299767, 'mae': 0.04160712559272451, 'raw_mae': 9.209858358210255, 'r2': 0.9954215437325235, 'raw_r2': 0.9978168441372262}
Best loss: 0.039
Mean loss: 0.045



Ep: 3/  5. LR: 8.211e-05. Loss: 0.2042. Stats01: 0.0000:  22%|██▏       | 779/3529 [04:15<15:38,  2.93it/s]

{'rmse': 0.061564666258535246, 'raw_rmse': 23.69308567805731, 'mae': 0.04117529194954563, 'raw_mae': 9.156745447648584, 'r2': 0.9954623605665833, 'raw_r2': 0.9977399189615482}
Best loss: 0.039
Mean loss: 0.044



Ep: 3/  5. LR: 8.121e-05. Loss: 0.2129. Stats01: 0.0000:  28%|██▊       | 984/3529 [05:24<14:23,  2.95it/s]

{'rmse': 0.06432280895141727, 'raw_rmse': 23.914600383211507, 'mae': 0.04448748463477696, 'raw_mae': 9.635776563740803, 'r2': 0.9950878734945877, 'raw_r2': 0.997686885512267}
Best loss: 0.039
Mean loss: 0.044



Ep: 3/  5. LR: 8.029e-05. Loss: 0.2125. Stats01: 0.0000:  34%|███▎      | 1189/3529 [06:33<13:30,  2.89it/s]

{'rmse': 0.06077637873847408, 'raw_rmse': 23.443043673927292, 'mae': 0.040211361437779314, 'raw_mae': 8.979509771466901, 'r2': 0.9955651740759521, 'raw_r2': 0.9977459913600141}
Best loss: 0.039
Mean loss: 0.044



Ep: 3/  5. LR: 7.935e-05. Loss: 0.2059. Stats01: 0.0000:  39%|███▉      | 1393/3529 [07:43<12:21,  2.88it/s]

{'rmse': 0.06330665298822534, 'raw_rmse': 23.96242651984666, 'mae': 0.04345713313306927, 'raw_mae': 9.475557771101556, 'r2': 0.9951796735299892, 'raw_r2': 0.9976216371149771}
Best loss: 0.039
Mean loss: 0.044



Ep: 3/  5. LR: 7.839e-05. Loss: 0.2041. Stats01: 0.0000:  45%|████▌     | 1597/3529 [08:52<08:47,  3.66it/s]

{'rmse': 0.06570275779603182, 'raw_rmse': 27.141556235566444, 'mae': 0.04613949061985703, 'raw_mae': 10.47966454424214, 'r2': 0.9948253086281797, 'raw_r2': 0.9969798913979496}
Best loss: 0.039
Mean loss: 0.044



Ep: 3/  5. LR: 7.742e-05. Loss: 0.2116. Stats01: 0.0000:  51%|█████     | 1800/3529 [10:01<10:00,  2.88it/s]

{'rmse': 0.06328044201433426, 'raw_rmse': 25.34038184307984, 'mae': 0.04291543079224529, 'raw_mae': 9.62474859897042, 'r2': 0.9952206581274956, 'raw_r2': 0.9973842448966957}
Best loss: 0.039
Mean loss: 0.044



Ep: 3/  5. LR: 7.642e-05. Loss: 0.2003. Stats01: 0.0000:  57%|█████▋    | 2003/3529 [11:11<08:39,  2.94it/s]

{'rmse': 0.06191947831996331, 'raw_rmse': 23.87479148497513, 'mae': 0.041637430279805696, 'raw_mae': 9.221012503251787, 'r2': 0.9953987704300455, 'raw_r2': 0.9976750002368603}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 7.541e-05. Loss: 0.1950. Stats01: 0.0000:  63%|██████▎   | 2206/3529 [12:20<07:43,  2.85it/s]

{'rmse': 0.06024001752731339, 'raw_rmse': 22.990830147895736, 'mae': 0.03928743540381634, 'raw_mae': 8.923371479463178, 'r2': 0.9956370510431481, 'raw_r2': 0.9978623629752083}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 7.438e-05. Loss: 0.2042. Stats01: 0.0000:  68%|██████▊   | 2409/3529 [13:30<06:20,  2.95it/s]

{'rmse': 0.062264280524391716, 'raw_rmse': 24.61275730372577, 'mae': 0.04222389377221862, 'raw_mae': 9.312661642165388, 'r2': 0.9953169909070372, 'raw_r2': 0.9974745399220621}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 7.332e-05. Loss: 0.2034. Stats01: 0.0000:  74%|███████▍  | 2612/3529 [14:39<05:19,  2.87it/s]

{'rmse': 0.062337850349389747, 'raw_rmse': 24.74353753094578, 'mae': 0.0422596566419202, 'raw_mae': 9.477557841062275, 'r2': 0.9953468514792946, 'raw_r2': 0.9975297163038659}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 7.226e-05. Loss: 0.2010. Stats01: 0.0000:  80%|███████▉  | 2815/3529 [15:48<04:03,  2.93it/s]

{'rmse': 0.06427859165128784, 'raw_rmse': 25.623378186960114, 'mae': 0.04482093523921743, 'raw_mae': 9.797936753880547, 'r2': 0.9950615441950733, 'raw_r2': 0.9973191948075248}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 7.117e-05. Loss: 0.2155. Stats01: 0.0000:  86%|████████▌ | 3018/3529 [16:57<02:58,  2.87it/s]

{'rmse': 0.06434237584990046, 'raw_rmse': 24.62344658478556, 'mae': 0.04458541265317739, 'raw_mae': 9.890306118733818, 'r2': 0.9950843600760177, 'raw_r2': 0.9975431429226592}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 7.007e-05. Loss: 0.2026. Stats01: 0.0000:  91%|█████████ | 3220/3529 [18:07<01:49,  2.82it/s]

{'rmse': 0.061423706837213626, 'raw_rmse': 24.79389522510861, 'mae': 0.0406972083411852, 'raw_mae': 9.502426003240098, 'r2': 0.9954509898437165, 'raw_r2': 0.9974774378035864}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 6.895e-05. Loss: 0.2165. Stats01: 0.0000:  97%|█████████▋| 3423/3529 [19:16<00:35,  2.95it/s]

{'rmse': 0.06353563214264299, 'raw_rmse': 23.582198994934075, 'mae': 0.043721703409691316, 'raw_mae': 9.437254761435776, 'r2': 0.9951849142269544, 'raw_r2': 0.9977421015984111}
Best loss: 0.039
Mean loss: 0.043



Ep: 3/  5. LR: 6.788e-05. Loss: 0.1968. Stats01: 0.0000: 100%|██████████| 3529/3529 [20:20<00:00,  2.89it/s]
Ep: 4/  5. LR: 6.781e-05. Loss: 0.2021. Stats01: 0.0000:   0%|          | 17/3529 [00:04<14:04,  4.16it/s]

{'rmse': 0.06174225411135385, 'raw_rmse': 23.795768182065547, 'mae': 0.04130975244665015, 'raw_mae': 9.211913350706347, 'r2': 0.9954528860105516, 'raw_r2': 0.9977076937441544}
Best loss: 0.039
Mean loss: 0.043



Ep: 4/  5. LR: 6.666e-05. Loss: 0.1982. Stats01: 0.0000:   7%|▋         | 233/3529 [01:13<14:27,  3.80it/s]

{'rmse': 0.06167405180387255, 'raw_rmse': 22.86951449238331, 'mae': 0.04122486308016873, 'raw_mae': 9.043414834306496, 'r2': 0.9954640428838852, 'raw_r2': 0.9979042606647301}
Best loss: 0.039
Mean loss: 0.043



Ep: 4/  5. LR: 6.549e-05. Loss: 0.2053. Stats01: 0.0000:  12%|█▏        | 441/3529 [02:22<16:26,  3.13it/s]

{'rmse': 0.06301525601043761, 'raw_rmse': 24.066854616610225, 'mae': 0.043156231928602296, 'raw_mae': 9.414517132639595, 'r2': 0.9952589929399781, 'raw_r2': 0.9976385796757761}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 6.430e-05. Loss: 0.2022. Stats01: 0.0000:  18%|█▊        | 647/3529 [03:31<16:39,  2.88it/s]

{'rmse': 0.06218584405940501, 'raw_rmse': 23.96386992386893, 'mae': 0.041917885358400624, 'raw_mae': 9.330978106412553, 'r2': 0.9953741350091856, 'raw_r2': 0.9976715220188406}
Best loss: 0.039
Mean loss: 0.043



Ep: 4/  5. LR: 6.310e-05. Loss: 0.1937. Stats01: 0.0000:  24%|██▍       | 852/3529 [04:40<15:07,  2.95it/s]

{'rmse': 0.061909028589182374, 'raw_rmse': 23.513693499874794, 'mae': 0.04156618828020336, 'raw_mae': 9.190340559998289, 'r2': 0.9954318056499956, 'raw_r2': 0.9977744349556252}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 6.189e-05. Loss: 0.2007. Stats01: 0.0000:  30%|██▉       | 1057/3529 [05:49<13:54,  2.96it/s]

{'rmse': 0.0607440668353967, 'raw_rmse': 23.926853203098226, 'mae': 0.03980319359030482, 'raw_mae': 9.237777112296827, 'r2': 0.9955614319822569, 'raw_r2': 0.9976840511010582}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 6.066e-05. Loss: 0.1973. Stats01: 0.0000:  36%|███▌      | 1261/3529 [06:57<13:01,  2.90it/s]

{'rmse': 0.06067816145112281, 'raw_rmse': 23.448944844405855, 'mae': 0.04003026412576769, 'raw_mae': 8.974018560290862, 'r2': 0.9955844379479923, 'raw_r2': 0.9977739920150215}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.941e-05. Loss: 0.1932. Stats01: 0.0000:  42%|████▏     | 1465/3529 [08:06<11:45,  2.93it/s]

{'rmse': 0.06065520081590775, 'raw_rmse': 23.126285955826706, 'mae': 0.03990728089207964, 'raw_mae': 8.966914725251803, 'r2': 0.9955991358741164, 'raw_r2': 0.997850379841403}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.815e-05. Loss: 0.2015. Stats01: 0.0000:  47%|████▋     | 1669/3529 [09:15<10:33,  2.94it/s]

{'rmse': 0.06194382027937731, 'raw_rmse': 23.986096220640466, 'mae': 0.04174531906106316, 'raw_mae': 9.328836726493359, 'r2': 0.9954293573800651, 'raw_r2': 0.997685434273461}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.687e-05. Loss: 0.1975. Stats01: 0.0000:  53%|█████▎    | 1872/3529 [10:24<09:30,  2.90it/s]

{'rmse': 0.061853949414176115, 'raw_rmse': 24.464516482823537, 'mae': 0.04143902747226208, 'raw_mae': 9.29875782098789, 'r2': 0.9954187696916322, 'raw_r2': 0.997547026382178}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.559e-05. Loss: 0.1948. Stats01: 0.0000:  59%|█████▉    | 2075/3529 [11:33<08:27,  2.86it/s]

{'rmse': 0.06144031544942015, 'raw_rmse': 23.07139076762953, 'mae': 0.041094130309245545, 'raw_mae': 9.02435918404969, 'r2': 0.9954896396303208, 'raw_r2': 0.9978557146235221}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.428e-05. Loss: 0.1947. Stats01: 0.0000:  65%|██████▍   | 2279/3529 [12:42<05:42,  3.65it/s]

{'rmse': 0.06046621850670241, 'raw_rmse': 22.687888623703415, 'mae': 0.03978377544596897, 'raw_mae': 8.866650865850621, 'r2': 0.9956223647170426, 'raw_r2': 0.9979155396313373}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.297e-05. Loss: 0.1945. Stats01: 0.0000:  70%|███████   | 2481/3529 [13:51<06:01,  2.90it/s]

{'rmse': 0.061302763170063335, 'raw_rmse': 23.33794744196101, 'mae': 0.0410048216530292, 'raw_mae': 9.017925562325962, 'r2': 0.995497584102999, 'raw_r2': 0.9977955686479524}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.164e-05. Loss: 0.1929. Stats01: 0.0000:  76%|███████▌  | 2684/3529 [15:00<04:47,  2.94it/s]

{'rmse': 0.061429747605946605, 'raw_rmse': 22.931782275894285, 'mae': 0.04104594313585758, 'raw_mae': 9.020374794841779, 'r2': 0.9954790353913792, 'raw_r2': 0.9978761266693691}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 5.030e-05. Loss: 0.2007. Stats01: 0.0000:  82%|████████▏ | 2887/3529 [16:09<03:40,  2.91it/s]

{'rmse': 0.06135315940789392, 'raw_rmse': 23.641654313812523, 'mae': 0.04088880872520133, 'raw_mae': 9.083495844122403, 'r2': 0.9954953203660458, 'raw_r2': 0.9977300760020202}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 4.895e-05. Loss: 0.1883. Stats01: 0.0000:  88%|████████▊ | 3090/3529 [17:19<02:24,  3.04it/s]

{'rmse': 0.060259447638782854, 'raw_rmse': 22.763921043482767, 'mae': 0.03952965456019184, 'raw_mae': 8.86468232947878, 'r2': 0.9956381715669421, 'raw_r2': 0.9979016367464661}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.759e-05. Loss: 0.2002. Stats01: 0.0000:  93%|█████████▎| 3292/3529 [18:28<01:21,  2.92it/s]

{'rmse': 0.06314311001313282, 'raw_rmse': 25.062827251847143, 'mae': 0.04335894527191716, 'raw_mae': 9.463407561030062, 'r2': 0.9952640725933426, 'raw_r2': 0.997456503069185}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.621e-05. Loss: 0.1977. Stats01: 0.0000:  99%|█████████▉| 3495/3529 [19:36<00:11,  3.03it/s]

{'rmse': 0.06174724094994275, 'raw_rmse': 22.742889232346705, 'mae': 0.04155098076310156, 'raw_mae': 9.088293210119652, 'r2': 0.9954853239509583, 'raw_r2': 0.997966416246402}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.540e-05. Loss: 0.1910. Stats01: 0.0000: 100%|██████████| 3529/3529 [20:17<00:00,  2.90it/s]
Ep: 5/  5. LR: 4.482e-05. Loss: 0.1882. Stats01: 0.0000:   3%|▎         | 96/3529 [00:28<18:07,  3.16it/s]

{'rmse': 0.06183348028578926, 'raw_rmse': 23.582974871289725, 'mae': 0.041492978202569676, 'raw_mae': 9.169627661032658, 'r2': 0.995435352334347, 'raw_r2': 0.9977521707710955}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 4.342e-05. Loss: 0.2025. Stats01: 0.0000:   9%|▊         | 307/3529 [01:37<17:32,  3.06it/s]

{'rmse': 0.0606165422296335, 'raw_rmse': 23.204757987616876, 'mae': 0.040058611147977596, 'raw_mae': 8.918229792331244, 'r2': 0.9955854736515287, 'raw_r2': 0.9978002114332915}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 4.202e-05. Loss: 0.1941. Stats01: 0.0000:  15%|█▍        | 514/3529 [02:46<17:11,  2.92it/s]

{'rmse': 0.060951960856928056, 'raw_rmse': 23.144469730801042, 'mae': 0.04044513789661137, 'raw_mae': 8.991662937247826, 'r2': 0.9955388490945836, 'raw_r2': 0.9978177869656073}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 4.060e-05. Loss: 0.1913. Stats01: 0.0000:  20%|██        | 720/3529 [03:54<16:09,  2.90it/s]

{'rmse': 0.06116795138207682, 'raw_rmse': 23.71094133779235, 'mae': 0.0404544420550034, 'raw_mae': 9.153013806911458, 'r2': 0.9955349323175343, 'raw_r2': 0.9977352875968928}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.917e-05. Loss: 0.1912. Stats01: 0.0000:  26%|██▌       | 925/3529 [05:04<14:51,  2.92it/s]

{'rmse': 0.06007816147003955, 'raw_rmse': 22.44387848038696, 'mae': 0.03942147681487161, 'raw_mae': 8.827656683241802, 'r2': 0.9956619050075604, 'raw_r2': 0.9979406528703525}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.773e-05. Loss: 0.1958. Stats01: 0.0000:  32%|███▏      | 1130/3529 [06:14<10:43,  3.73it/s]

{'rmse': 0.06129286442530055, 'raw_rmse': 23.4850070233405, 'mae': 0.040759817801042805, 'raw_mae': 9.09230375385016, 'r2': 0.9955042079722414, 'raw_r2': 0.9977588214531106}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.628e-05. Loss: 0.1905. Stats01: 0.0000:  38%|███▊      | 1334/3529 [07:23<12:12,  3.00it/s]

{'rmse': 0.06115971940761028, 'raw_rmse': 22.744299155611813, 'mae': 0.04067276955497418, 'raw_mae': 8.973430489028026, 'r2': 0.9955176241386868, 'raw_r2': 0.9979237876673531}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.483e-05. Loss: 0.1863. Stats01: 0.0000:  44%|████▎     | 1537/3529 [08:32<11:33,  2.87it/s]

{'rmse': 0.060551433122104446, 'raw_rmse': 22.833873520696557, 'mae': 0.039857363230296664, 'raw_mae': 8.879040451615387, 'r2': 0.9956025398357953, 'raw_r2': 0.9978959763755862}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.336e-05. Loss: 0.1862. Stats01: 0.0000:  49%|████▉     | 1741/3529 [09:42<10:16,  2.90it/s]

{'rmse': 0.06080802158183705, 'raw_rmse': 22.450706437953894, 'mae': 0.04021759619074875, 'raw_mae': 8.857452490859071, 'r2': 0.9955701015218394, 'raw_r2': 0.9979743094521727}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.189e-05. Loss: 0.1936. Stats01: 0.0000:  55%|█████▌    | 1944/3529 [10:51<09:07,  2.90it/s]

{'rmse': 0.06027920749340942, 'raw_rmse': 22.87945762993693, 'mae': 0.03953669251519566, 'raw_mae': 8.866923620147627, 'r2': 0.9956474349056189, 'raw_r2': 0.9978880602930779}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.041e-05. Loss: 0.1919. Stats01: 0.0000:  61%|██████    | 2148/3529 [12:00<06:47,  3.39it/s]

{'rmse': 0.06146035855513302, 'raw_rmse': 22.94237534695467, 'mae': 0.04109017334935127, 'raw_mae': 9.056744302154108, 'r2': 0.9954771244819629, 'raw_r2': 0.9978717282767272}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.892e-05. Loss: 0.1885. Stats01: 0.0000:  67%|██████▋   | 2351/3529 [13:10<06:14,  3.15it/s]

{'rmse': 0.06175272628498514, 'raw_rmse': 23.327660704786016, 'mae': 0.04160559949395584, 'raw_mae': 9.15561962238587, 'r2': 0.9954351792975865, 'raw_r2': 0.9978030419805194}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.743e-05. Loss: 0.1957. Stats01: 0.0000:  72%|███████▏  | 2553/3529 [14:20<05:38,  2.88it/s]

{'rmse': 0.06214643942018843, 'raw_rmse': 24.313893200608987, 'mae': 0.042083929722621864, 'raw_mae': 9.272185882744884, 'r2': 0.9953987310262146, 'raw_r2': 0.9976075731740203}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.593e-05. Loss: 0.1867. Stats01: 0.0000:  78%|███████▊  | 2756/3529 [15:30<04:23,  2.93it/s]

{'rmse': 0.0601087608885091, 'raw_rmse': 22.873094080152228, 'mae': 0.039353770948938924, 'raw_mae': 8.887618972206244, 'r2': 0.9956398200804802, 'raw_r2': 0.9978664593875438}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.442e-05. Loss: 0.1891. Stats01: 0.0000:  84%|████████▍ | 2959/3529 [16:39<03:19,  2.85it/s]

{'rmse': 0.0613876019694867, 'raw_rmse': 23.097690011249973, 'mae': 0.04109188874181451, 'raw_mae': 9.016372635961519, 'r2': 0.9954946141132758, 'raw_r2': 0.9978393546455055}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.290e-05. Loss: 0.1884. Stats01: 0.0000:  90%|████████▉ | 3162/3529 [17:49<01:37,  3.78it/s]

{'rmse': 0.06054193440170199, 'raw_rmse': 22.715989195932607, 'mae': 0.039974858973833974, 'raw_mae': 8.849714885885303, 'r2': 0.9956066811691276, 'raw_r2': 0.9979078076187528}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.138e-05. Loss: 0.1901. Stats01: 0.0000:  95%|█████████▌| 3364/3529 [18:58<00:57,  2.86it/s]

{'rmse': 0.060956687276683796, 'raw_rmse': 22.946781955097716, 'mae': 0.040474125556018564, 'raw_mae': 8.93738502757644, 'r2': 0.995548036239072, 'raw_r2': 0.9978610648075709}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 1.986e-05. Loss: 0.1872. Stats01: 0.0000: : 3566it [20:07,  2.90it/s]                        

{'rmse': 0.06089760425634423, 'raw_rmse': 23.046491518789182, 'mae': 0.04038734056244652, 'raw_mae': 8.94984586887213, 'r2': 0.9955575926450944, 'raw_r2': 0.9978499637623406}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 1.951e-05. Loss: 0.1865. Stats01: 0.0000: 100%|██████████| 3529/3529 [20:24<00:00,  2.88it/s]


In [16]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [args.seq_len])

{'rmse': 0.05892767556592564,
 'raw_rmse': 29.64633468760002,
 'mae': 0.042399002905476824,
 'raw_mae': 10.570864551249285,
 'r2': 0.996062632975289,
 'raw_r2': 0.9968318397948741}

In [17]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [32, 64, 96])

{'rmse': 0.05891323298691096,
 'raw_rmse': 29.813343638095866,
 'mae': 0.04238265815226493,
 'raw_mae': 10.580564301318097,
 'r2': 0.9960644855150232,
 'raw_r2': 0.9967942889802358}

In [18]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [64, 96, 128])

{'rmse': 0.05892493875372378,
 'raw_rmse': 29.666546527718307,
 'mae': 0.04239628013934455,
 'raw_mae': 10.571961484432132,
 'r2': 0.9960629972134678,
 'raw_r2': 0.9968272838327025}