In [1]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

manual_seed()


# Data

In [2]:
from utils import load_data, load_edge

train_folds = load_data(True)
test_fold = load_data(False)[0]

In [3]:
import torch
from torch.utils.data import DataLoader, Dataset

class SimpleStockDataset(Dataset):
    def __init__(self, data, ws=128):
        self.data = data
        self.ws = ws
        self.samples = []
        
        self.n_tickers, self.n_days, self.n_features = self.data.shape
        
        for start in range(self.n_days - self.ws + 1):
            self.samples.append(start)
            
    def __len__(self):
        return len(self.samples)
      
    def __getitem__(self, idx):
        start = self.samples[idx]
        x = torch.tensor(self.data[:, start:start + self.ws], dtype=torch.float32)
        return x
    

# Model

### LSTM

In [4]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyRNN(nn.Module):
    def __init__(self, n_nodes: int, n_feats: int, args, node_emb=None):
        super().__init__()
        self.n_nodes = n_nodes
        self.rank_A = 32
        self.n_feats = n_feats
        self.args = args
        self.d_latent = getattr(args, "d_latent", 128)
        self.dropout = getattr(args, "dropout", 0.0)
        self.n_layers = getattr(args, "n_layers", 1)
        self.d_node = getattr(args, "d_node", 32)
        
        

        if node_emb is None:
            self.static_node_features = nn.Parameter(torch.randn(n_nodes, self.d_latent) * 0.1)
            self.fc_node_features = nn.Identity()
        else:
            self.static_node_features = nn.Parameter(node_emb, requires_grad=False)
            self.d_node = self.static_node_features.shape[-1]
            self.fc_node_features = nn.Linear(self.d_node, self.d_latent)
        
        self.cells = nn.ModuleList([
            nn.LSTMCell(n_feats if i == 0 else self.d_latent, self.d_latent)
            for i in range(self.n_layers)
        ])
        
        self.readout = nn.Sequential(
            nn.Linear(self.d_latent, self.d_latent), nn.ReLU(),
            nn.Linear(self.d_latent, n_feats)
        )
        
        self.enc_in = None

    def _init_states(
        self, X_0: torch.Tensor, H_0: Optional[torch.Tensor]
    ):
        """
        Trả về list (h_list, c_list) cho từng layer.
        - Nếu H_0 cung cấp (N, d), đặt h_top = H_0; ngược lại nếu enc_in != None thì h_top = enc(X_0); còn lại h=0.
        - c luôn = 0.
        """
        device = X_0.device
        dtype = X_0.dtype
        N = X_0.shape[0]

        h_list = []
        c_list = []
        for _ in range(self.n_layers):
            h_list.append(self.node_features.clone())
            c_list.append(self.node_features.clone())

        if H_0 is not None:
            h_list[-1] = H_0.to(device=device, dtype=dtype)
        elif self.enc_in is not None:
            h_list[-1] = self.enc_in(X_0)

        return h_list, c_list

    def _step(self, x_t: torch.Tensor, h_list, c_list):
        """
        Chạy 1 bước LSTM qua tất cả các layer.
        Input layer nhận x_t; các layer sau nhận h của layer trước ở cùng time-step.
        Trả về (h_list_new, c_list_new, h_top).
        """
        inp = x_t
        new_h, new_c = [], []
        for l, cell in enumerate(self.cells):
            h_t, c_t = cell(inp, (h_list[l], c_list[l]))
            new_h.append(h_t)
            new_c.append(c_t)
            inp = h_t
        h_top = new_h[-1]
        if self.training and self.dropout > 0:
            h_top = F.dropout(h_top, p=self.dropout, training=True)
        return new_h, new_c, h_top

    def forecast(self, X, H_0=None, horizon=0):
        y, *_ = self.forward(X, H_0, horizon)
        return y

    def forward(self, X: torch.Tensor,
                H_0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Run deterministic inference.
        Args:
            X: (N, T, F) input features at each t
            H_0: (N, d) optional initial latent state (default zeros)
            Atilde, L: if precomputed; otherwise build from X
        Returns:
            y_pred: (N, T-1) predicted next-day log-return for each node
            H_all: (N, T, d) latent states (including H_0 as first)
        """
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        n_nodes, n_steps, n_feats = X.shape
        d_latent = self.d_latent
        
        self.node_features = self.fc_node_features(self.static_node_features)

        X_0 = X[:, 0, :] # (N, F)
        h_list, c_list = self._init_states(X_0, H_0)

        H_all = torch.zeros(n_nodes, n_steps + max(horizon, 0), d_latent, device=device, dtype=X.dtype)
        r_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)
        y_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)


        H_all[:, 0] = h_list[-1]
        for t in range(n_steps - 1):
            x_t = X[:, t, :]  # (N, F)
            h_list, c_list, h_top = self._step(x_t, h_list, c_list)
            r_t = self.readout(h_top)
            y_t = x_t + r_t

            H_all[:, t + 1] = h_top
            r_all[:, t] = r_t
            y_all[:, t] = y_t

        if horizon > 0:
            x_t = X[:, -1, :]
            for s in range(horizon):
                h_list, c_list, h_top = self._step(x_t, h_list, c_list)
                r_t = self.readout(h_top)
                y_t = x_t + r_t

                idx = (n_steps - 1) + s
                y_all[:, idx] = y_t
                r_all[:, idx] = r_t
                H_all[:, n_steps + s] = h_top

                x_t = y_t

        return y_all, r_all, H_all

    def forward_loss(
        self,
        X: torch.Tensor,
        H_0: Optional[torch.Tensor] = None,
        horizon: int = 0,
    ):
        """
        Tính loss dự báo one-step + rollout horizon (autoregressive) trên chuỗi đầu vào.

        Args:
            X: (N, T, F) chuỗi gốc (chứa full ground-truth đến T-1)
            H_0: (N, d) latent init (tuỳ chọn)
            horizon: số bước rollout ngoài quan sát cuối cùng
                    (nếu >0, ta cắt input để tránh nhìn thấy tương lai)
            reduction: 'mean' | 'sum' | 'none' cho F.mse_loss

        Returns:
            loss: scalar tensor
            y_pred: (N, T-1, F) dự báo X_{t+1} cho toàn bộ t=0..T-2
        """
        device = next(self.parameters()).device
        X = X.to(device)
        N, T, Fdim = X.shape

        if horizon < 0:
            raise ValueError("horizon must be >= 0")
        if horizon >= T:
            raise ValueError(f"horizon={horizon} must be < sequence length T={T}")

        # Cắt input nếu rollout > 0 để giữ đúng số target (T-1)
        X_in = X[:, : T - horizon, :] if horizon > 0 else X

        # forward() trả (N, (T-horizon)-1 + horizon, F) = (N, T-1, F)
        y_ar, _, _ = self.forward(X_in, H_0=H_0, horizon=horizon)

        # Ground truth luôn là X_{1:T}
        target = X[:, 1:, :]  # (N, T-1, F)
        
        err_t = (y_ar - target).abs().mean(dim=-1).mean(dim=0)
        
        decay = 0.9
        coef_pre = 1.0
        coef_roll = 1.0
        
        len_pre = T - horizon - 1
        loss_pre = torch.tensor(0.0, device=device, dtype=err_t.dtype)
        loss_roll = torch.tensor(0.0, device=device, dtype=err_t.dtype)

        if len_pre > 0:
            idx = torch.arange(len_pre - 1, -1, -1, device=device, dtype=err_t.dtype)
            w = decay ** idx
            loss_pre = (w * err_t[:len_pre]).sum() / (w.sum() + 1e-12)

        if horizon > 0:
            # MSE đều cho đoạn rollout (chiều dài = horizon)
            loss_roll = err_t[len_pre:].mean()

        loss = coef_pre * loss_pre + coef_roll * loss_roll
        
        return loss

# Eval & Train

In [5]:
from sklearn.metrics import *

def eval_ensemble(args, model, training_data_np, testing_data_np, device='cuda', seq_lens=[64, 96, 128], verbose=False):

    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = np.log1p(training_data_np[:, :, used_features])
    testing_data_np = np.log1p(testing_data_np[:, :, used_features])
    
    labels = torch.tensor(testing_data_np).float().to(device)
    n_nodes, horizon, n_feats = labels.shape
    y_preds = np.zeros((len(seq_lens), n_nodes, horizon, n_feats-1)) # Bỏ Vol
    model.eval().to(device)
    for i, seq_len in enumerate(seq_lens):
        batch = torch.tensor(training_data_np[:, -seq_len:]).float().to(device)
        with torch.no_grad():
            y_all = model.forecast(batch, horizon=horizon)
        y_preds[i] = y_all[:, -horizon:, :n_feats-1].detach().cpu().numpy() # OHLC + Adj Close
    y_gt = testing_data_np[:, :, :n_feats-1].reshape(n_nodes, -1) # OHLC + Adj Close
    y_pred = y_preds.mean(axis=0).reshape(n_nodes, -1)
    
    if verbose:
        print("Max var:", np.var(np.expm1(y_preds), axis=0).max())
        print("Mean var:", np.var(np.expm1(y_preds), axis=0).mean())
    return {
        'rmse': root_mean_squared_error(y_gt, y_pred), 
        'raw_rmse': root_mean_squared_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'mae': mean_absolute_error(y_gt, y_pred), 
        'raw_mae': mean_absolute_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'r2': r2_score(y_gt.ravel(), y_pred.ravel()),
        'raw_r2': r2_score(np.expm1(y_gt).ravel(), np.expm1(y_pred).ravel()), 
    }

def eval(args, model, training_data_np, testing_data_np, device='cuda'):
    return eval_ensemble(args, model, training_data_np, testing_data_np, device, [args.seq_len])

In [6]:
from tqdm import tqdm
import time

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.avg == 0:
            self.avg = val
            return
        self.avg = 0.95 * self.avg + 0.05 * val

import copy
def train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np):
    # if args.amp:
    #     from apex import amp
    global best_loss, best_model
    test_losses = []
    end = time.time()

    best_model = copy.deepcopy(model)
    step = 0
    for epoch in range(args.epochs):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        stats01 = AverageMeter()
        stats02 = AverageMeter()
        p_bar = tqdm(train_loader)
        for batch_idx, samples in enumerate(p_bar):
            step += 1
            model.train().to(args.device)
          
            samples = samples[0].float().to(args.device)
            data_time.update(time.time() - end)

            loss = model.forward_loss(samples, horizon=args.horizon)
            
            loss.backward()
            
            max_norm = 5.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item()
            
            optimizer.step()
            scheduler.step()

            losses.update(loss.item())
            stats01.update(0.0)
            # stats01.update(power_iteration_lmax_sym(poly_laplacian(model.L, F.softplus(model.kappa))))

            batch_time.update(time.time() - end)
            end = time.time()
            # mask_probs.update(mask.mean().item())
            p_bar.set_description(
                "Ep: {epoch}/{epochs:3}. LR: {lr:.3e}. "
                "Loss: {loss:.4f}. Stats01: {stats01:.4f}".format(
                epoch=epoch + 1,
                epochs=args.epochs,
                lr=scheduler.get_last_lr()[0],
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                stats01=stats01.avg,
            ))
            p_bar.update()
            
            if (step + 1) % args.eval_steps == 0:
                test_model = model

                test_metrics = eval_ensemble(args, test_model, training_data_np, testing_data_np, args.device, [args.seq_len])
                print(test_metrics)
                test_loss = test_metrics['mae']

                is_best = test_loss < best_loss
                if test_loss < best_loss:
                    best_loss = test_loss
                    best_model = copy.deepcopy(test_model)


                test_losses.append(test_loss)
                print('Best loss: {:.3f}'.format(best_loss))
                print('Mean loss: {:.3f}\n'.format(
                    np.mean(test_losses[-20:])))


In [7]:
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import math

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

# Training

In [8]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

class Config:
    # Training
    epochs = 5
    eval_steps = 200
    lr = 1e-4
    wd = 1e-3
    warmup = 0
    
    n_nodes = 428
    n_feats = 6
    
    # Prediction
    seq_len = 96
    horizon = 32
    
    # LSTM
    d_node: int = None
    d_latent: int = 256
    device: str = "cuda"
    n_layers = 2
    dropout = 0.0
    seed = 42

args = Config()
manual_seed(args.seed)


In [9]:
node_emb_path="../input/static_node_emb.npy"
node_emb_matrix = torch.tensor(np.load(node_emb_path), dtype=torch.float32)


def train_fold(fold_idx):
    global training_data_np, testing_data_np
    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = train_folds[fold_idx][0][:, :, used_features]
    testing_data_np = train_folds[fold_idx][1][:, :, used_features]
    transformed_training_data_np = np.log1p(training_data_np)
    transformed_testing_data_np = np.log1p(testing_data_np)
    train_loader = DataLoader(SimpleStockDataset(transformed_training_data_np, args.seq_len + args.horizon), batch_size=1, shuffle=True)
    
    manual_seed(args.seed + fold_idx)
        
    model = MyRNN(args.n_nodes, args.n_feats, args, node_emb=node_emb_matrix)

    from torch.optim import AdamW

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    total_steps = args.epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup,
        num_training_steps=total_steps,
    )
    
    print(f'Model size: {get_model_size(model) * 1e3:.2f}K')
    
    # Sanity check
    print(eval_ensemble(args, model, training_data_np, testing_data_np, args.device, [2, 4]))
    
    global best_loss, best_model
    best_loss = 9999
    best_model = copy.deepcopy(model)

    train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np)
    
    return best_model, best_loss

In [10]:
best_model, best_loss = train_fold(3)

Model size: 868.87K
{'rmse': 0.526685557191528, 'raw_rmse': 358.97996240197466, 'mae': 0.52032424938255, 'raw_mae': 140.37756601242643, 'r2': 0.5457457876254339, 'raw_r2': 0.2382510587744412}


Ep: 1/  5. LR: 9.999e-05. Loss: 0.1968. Stats01: 0.0000:   6%|▌         | 219/3529 [00:51<18:43,  2.95it/s]

{'rmse': 0.07011977558370015, 'raw_rmse': 37.02251639268137, 'mae': 0.05099499445774277, 'raw_mae': 11.628148036166866, 'r2': 0.9934892686186625, 'raw_r2': 0.9935715864602711}
Best loss: 0.051
Mean loss: 0.051



Ep: 1/  5. LR: 9.995e-05. Loss: 0.1922. Stats01: 0.0000:  12%|█▏        | 427/3529 [01:58<17:20,  2.98it/s]

{'rmse': 0.07045433135045637, 'raw_rmse': 36.775431247762945, 'mae': 0.05215796826471413, 'raw_mae': 12.469212926188277, 'r2': 0.993728828773165, 'raw_r2': 0.9933478591116139}
Best loss: 0.051
Mean loss: 0.052



Ep: 1/  5. LR: 9.989e-05. Loss: 0.1859. Stats01: 0.0000:  18%|█▊        | 634/3529 [03:05<15:43,  3.07it/s]

{'rmse': 0.08248266371951614, 'raw_rmse': 46.2718140559178, 'mae': 0.06660596860616175, 'raw_mae': 15.669379600031263, 'r2': 0.9912192197447912, 'raw_r2': 0.9897429063888705}
Best loss: 0.051
Mean loss: 0.057



Ep: 1/  5. LR: 9.981e-05. Loss: 0.1810. Stats01: 0.0000:  24%|██▍       | 839/3529 [04:13<15:08,  2.96it/s]

{'rmse': 0.06982414940907204, 'raw_rmse': 34.097247111432495, 'mae': 0.050848532811842255, 'raw_mae': 12.12268675919503, 'r2': 0.9939218699460247, 'raw_r2': 0.9944620712736619}
Best loss: 0.051
Mean loss: 0.055



Ep: 1/  5. LR: 9.970e-05. Loss: 0.1800. Stats01: 0.0000:  30%|██▉       | 1044/3529 [05:20<13:46,  3.01it/s]

{'rmse': 0.06366462987249115, 'raw_rmse': 23.627357194783574, 'mae': 0.043760005337181634, 'raw_mae': 9.424491539206832, 'r2': 0.9950733598352693, 'raw_r2': 0.9977211015290087}
Best loss: 0.044
Mean loss: 0.053



Ep: 1/  5. LR: 9.956e-05. Loss: 0.1793. Stats01: 0.0000:  35%|███▌      | 1248/3529 [06:28<12:55,  2.94it/s]

{'rmse': 0.06464189721620203, 'raw_rmse': 29.140132680017683, 'mae': 0.04483558644853959, 'raw_mae': 10.500957189811011, 'r2': 0.9948502209245312, 'raw_r2': 0.996215652609144}
Best loss: 0.044
Mean loss: 0.052



Ep: 1/  5. LR: 9.941e-05. Loss: 0.1774. Stats01: 0.0000:  41%|████      | 1452/3529 [07:37<11:34,  2.99it/s]

{'rmse': 0.06275695795415678, 'raw_rmse': 27.409930278517496, 'mae': 0.0418943181132982, 'raw_mae': 9.994930132435876, 'r2': 0.9951930805635031, 'raw_r2': 0.9967631569549087}
Best loss: 0.042
Mean loss: 0.050



Ep: 1/  5. LR: 9.923e-05. Loss: 0.1793. Stats01: 0.0000:  47%|████▋     | 1656/3529 [08:45<10:14,  3.05it/s]

{'rmse': 0.06341422470674797, 'raw_rmse': 24.23012991284356, 'mae': 0.04328738982579832, 'raw_mae': 9.39109850116495, 'r2': 0.9951326142827361, 'raw_r2': 0.9975669708978976}
Best loss: 0.042
Mean loss: 0.049



Ep: 1/  5. LR: 9.902e-05. Loss: 0.1741. Stats01: 0.0000:  53%|█████▎    | 1859/3529 [09:54<09:33,  2.91it/s]

{'rmse': 0.06147218429903503, 'raw_rmse': 24.520426588548702, 'mae': 0.040823555282228, 'raw_mae': 9.261652815022709, 'r2': 0.9954496277374826, 'raw_r2': 0.9975243294943197}
Best loss: 0.041
Mean loss: 0.048



Ep: 1/  5. LR: 9.879e-05. Loss: 0.1737. Stats01: 0.0000:  58%|█████▊    | 2062/3529 [11:02<08:24,  2.91it/s]

{'rmse': 0.06432061349578824, 'raw_rmse': 28.188130835317324, 'mae': 0.04445902892365521, 'raw_mae': 10.296697113173222, 'r2': 0.9949664027003831, 'raw_r2': 0.9965779094113746}
Best loss: 0.041
Mean loss: 0.048



Ep: 1/  5. LR: 9.854e-05. Loss: 0.1713. Stats01: 0.0000:  64%|██████▍   | 2265/3529 [12:10<07:02,  2.99it/s]

{'rmse': 0.06362286818322385, 'raw_rmse': 23.94581378291462, 'mae': 0.04380499223542828, 'raw_mae': 9.418082654252498, 'r2': 0.9951049227425417, 'raw_r2': 0.9976342944965947}
Best loss: 0.041
Mean loss: 0.048



Ep: 1/  5. LR: 9.826e-05. Loss: 0.1792. Stats01: 0.0000:  70%|██████▉   | 2469/3529 [12:53<04:02,  4.36it/s]

{'rmse': 0.06196792061757559, 'raw_rmse': 24.810571259676482, 'mae': 0.04147479481283957, 'raw_mae': 9.311617514825079, 'r2': 0.9953710977066965, 'raw_r2': 0.9974450802381818}
Best loss: 0.041
Mean loss: 0.047



Ep: 1/  5. LR: 9.796e-05. Loss: 0.1741. Stats01: 0.0000:  76%|███████▌  | 2672/3529 [13:33<03:14,  4.40it/s]

{'rmse': 0.06416689614910526, 'raw_rmse': 26.58399943310809, 'mae': 0.04428335921916043, 'raw_mae': 10.040902358832971, 'r2': 0.9949795236717054, 'raw_r2': 0.9970060413490819}
Best loss: 0.041
Mean loss: 0.047



Ep: 1/  5. LR: 9.763e-05. Loss: 0.1831. Stats01: 0.0000:  81%|████████▏ | 2875/3529 [14:12<02:28,  4.41it/s]

{'rmse': 0.06152610819698643, 'raw_rmse': 24.525967439380775, 'mae': 0.04091526361647791, 'raw_mae': 9.319494109549092, 'r2': 0.9954356933825003, 'raw_r2': 0.9975230854513132}
Best loss: 0.041
Mean loss: 0.046



Ep: 1/  5. LR: 9.728e-05. Loss: 0.1714. Stats01: 0.0000:  87%|████████▋ | 3077/3529 [14:52<01:42,  4.41it/s]

{'rmse': 0.06828446943434152, 'raw_rmse': 30.19108617053557, 'mae': 0.049864636814258574, 'raw_mae': 11.12104107769387, 'r2': 0.9942346376768108, 'raw_r2': 0.9960058735018081}
Best loss: 0.041
Mean loss: 0.047



Ep: 1/  5. LR: 9.691e-05. Loss: 0.1738. Stats01: 0.0000:  93%|█████████▎| 3280/3529 [15:31<00:54,  4.56it/s]

{'rmse': 0.062330009063039336, 'raw_rmse': 24.202592632380018, 'mae': 0.04207871848375639, 'raw_mae': 9.333743357095091, 'r2': 0.9953074807706648, 'raw_r2': 0.9975828702768628}
Best loss: 0.041
Mean loss: 0.046



Ep: 1/  5. LR: 9.651e-05. Loss: 0.1745. Stats01: 0.0000:  99%|█████████▊| 3482/3529 [16:11<00:10,  4.42it/s]

{'rmse': 0.061726184723320913, 'raw_rmse': 24.7634377087457, 'mae': 0.041027258222769473, 'raw_mae': 9.430548897359984, 'r2': 0.9953886843945582, 'raw_r2': 0.9974535099996718}
Best loss: 0.041
Mean loss: 0.046



Ep: 1/  5. LR: 9.625e-05. Loss: 0.1722. Stats01: 0.0000: 100%|██████████| 3529/3529 [16:37<00:00,  3.54it/s]
Ep: 2/  5. LR: 9.609e-05. Loss: 0.1854. Stats01: 0.0000:   2%|▏         | 83/3529 [00:14<12:11,  4.71it/s]

{'rmse': 0.06304689040241312, 'raw_rmse': 24.854262756032405, 'mae': 0.04320472613421565, 'raw_mae': 9.449657760687188, 'r2': 0.9951966435947533, 'raw_r2': 0.9974098535162637}
Best loss: 0.041
Mean loss: 0.046



Ep: 2/  5. LR: 9.565e-05. Loss: 0.1763. Stats01: 0.0000:   8%|▊         | 294/3529 [00:53<11:49,  4.56it/s]

{'rmse': 0.062215555517011925, 'raw_rmse': 24.360522982511572, 'mae': 0.04214932832151803, 'raw_mae': 9.314115165612083, 'r2': 0.9953324342321209, 'raw_r2': 0.9975518615261296}
Best loss: 0.041
Mean loss: 0.046



Ep: 2/  5. LR: 9.519e-05. Loss: 0.1742. Stats01: 0.0000:  14%|█▍        | 501/3529 [01:32<08:53,  5.67it/s]

{'rmse': 0.06312956090917787, 'raw_rmse': 24.187383699045686, 'mae': 0.04298435595641012, 'raw_mae': 9.387617530159531, 'r2': 0.9951718768854069, 'raw_r2': 0.9975890014665603}
Best loss: 0.041
Mean loss: 0.046



Ep: 2/  5. LR: 9.470e-05. Loss: 0.1767. Stats01: 0.0000:  20%|██        | 708/3529 [02:11<09:59,  4.71it/s]

{'rmse': 0.06111642730685005, 'raw_rmse': 24.022758065611114, 'mae': 0.04064777011900165, 'raw_mae': 9.103441005450644, 'r2': 0.9955150012924986, 'raw_r2': 0.9976402736921546}
Best loss: 0.041
Mean loss: 0.045



Ep: 2/  5. LR: 9.418e-05. Loss: 0.1737. Stats01: 0.0000:  26%|██▌       | 913/3529 [02:50<09:46,  4.46it/s]

{'rmse': 0.06178218223951091, 'raw_rmse': 24.124943160344827, 'mae': 0.04124728457813137, 'raw_mae': 9.179306963523041, 'r2': 0.9954034967179518, 'raw_r2': 0.9976124748588028}
Best loss: 0.041
Mean loss: 0.045



Ep: 2/  5. LR: 9.365e-05. Loss: 0.1728. Stats01: 0.0000:  32%|███▏      | 1117/3529 [03:28<08:40,  4.64it/s]

{'rmse': 0.06301087282638816, 'raw_rmse': 23.947509942805723, 'mae': 0.04313863296263593, 'raw_mae': 9.345191704683767, 'r2': 0.9952090058731035, 'raw_r2': 0.9976449151259}
Best loss: 0.041
Mean loss: 0.043



Ep: 2/  5. LR: 9.309e-05. Loss: 0.1739. Stats01: 0.0000:  37%|███▋      | 1321/3529 [04:07<07:56,  4.63it/s]

{'rmse': 0.060489708932976526, 'raw_rmse': 23.579455542487736, 'mae': 0.03966052730330415, 'raw_mae': 9.002171380883892, 'r2': 0.9956058892864376, 'raw_r2': 0.9977401469761714}
Best loss: 0.040
Mean loss: 0.043



Ep: 2/  5. LR: 9.251e-05. Loss: 0.1732. Stats01: 0.0000:  43%|████▎     | 1525/3529 [04:44<07:04,  4.72it/s]

{'rmse': 0.061050795912525095, 'raw_rmse': 24.108093702911376, 'mae': 0.040412311394040924, 'raw_mae': 9.17150429076285, 'r2': 0.9955106368255383, 'raw_r2': 0.9976227982660469}
Best loss: 0.040
Mean loss: 0.043



Ep: 2/  5. LR: 9.191e-05. Loss: 0.1771. Stats01: 0.0000:  49%|████▉     | 1729/3529 [05:22<06:45,  4.44it/s]

{'rmse': 0.06089044282665034, 'raw_rmse': 23.874589277581105, 'mae': 0.04030471456741931, 'raw_mae': 9.117628030405289, 'r2': 0.9955405221992681, 'raw_r2': 0.9976718541675889}
Best loss: 0.040
Mean loss: 0.042



Ep: 2/  5. LR: 9.128e-05. Loss: 0.1719. Stats01: 0.0000:  55%|█████▍    | 1932/3529 [06:01<05:49,  4.58it/s]

{'rmse': 0.06024006827251474, 'raw_rmse': 23.684323261941593, 'mae': 0.03931262757447299, 'raw_mae': 8.971368329054513, 'r2': 0.9956447647479634, 'raw_r2': 0.9977198356395257}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 9.064e-05. Loss: 0.1791. Stats01: 0.0000:  60%|██████    | 2135/3529 [06:39<04:57,  4.68it/s]

{'rmse': 0.06210602684118411, 'raw_rmse': 24.44263134620046, 'mae': 0.04201771467235016, 'raw_mae': 9.309875009372297, 'r2': 0.9953524154394117, 'raw_r2': 0.9975348407600763}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.997e-05. Loss: 0.1707. Stats01: 0.0000:  66%|██████▋   | 2338/3529 [07:16<04:21,  4.55it/s]

{'rmse': 0.06255746225786289, 'raw_rmse': 24.93872479962948, 'mae': 0.04241930514873009, 'raw_mae': 9.430481441622772, 'r2': 0.9952790911492643, 'raw_r2': 0.9974185940230215}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.928e-05. Loss: 0.1744. Stats01: 0.0000:  72%|███████▏  | 2541/3529 [07:54<03:34,  4.61it/s]

{'rmse': 0.06050172052999731, 'raw_rmse': 23.417027686868376, 'mae': 0.03972881720229687, 'raw_mae': 8.915925690072154, 'r2': 0.9956094970991681, 'raw_r2': 0.9977700776377902}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.857e-05. Loss: 0.1749. Stats01: 0.0000:  78%|███████▊  | 2743/3529 [08:32<02:31,  5.20it/s]

{'rmse': 0.06040426128453392, 'raw_rmse': 23.42494205760126, 'mae': 0.03962390661149469, 'raw_mae': 8.926088959990723, 'r2': 0.9956192719323518, 'raw_r2': 0.99776237349039}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.783e-05. Loss: 0.1705. Stats01: 0.0000:  84%|████████▎ | 2947/3529 [09:11<02:09,  4.49it/s]

{'rmse': 0.06168429452349643, 'raw_rmse': 24.09224885057537, 'mae': 0.041155242756696554, 'raw_mae': 9.187806377132443, 'r2': 0.9954143428933407, 'raw_r2': 0.9976119698448583}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.707e-05. Loss: 0.1773. Stats01: 0.0000:  89%|████████▉ | 3149/3529 [09:50<01:24,  4.52it/s]

{'rmse': 0.06290077213171953, 'raw_rmse': 25.373582019529053, 'mae': 0.04286878376080836, 'raw_mae': 9.537863248118532, 'r2': 0.9952243034746576, 'raw_r2': 0.9973205525988856}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.630e-05. Loss: 0.1672. Stats01: 0.0000:  95%|█████████▍| 3352/3529 [10:29<00:39,  4.48it/s]

{'rmse': 0.0610215361783884, 'raw_rmse': 23.820775127406314, 'mae': 0.040326019451132, 'raw_mae': 9.045031466658045, 'r2': 0.9955232577650045, 'raw_r2': 0.9976728105787459}
Best loss: 0.039
Mean loss: 0.042



Ep: 2/  5. LR: 8.550e-05. Loss: 0.1618. Stats01: 0.0000: : 3554it [11:09,  4.36it/s]                        

{'rmse': 0.06169335300409308, 'raw_rmse': 24.460045174390086, 'mae': 0.041160111890988726, 'raw_mae': 9.289823360595237, 'r2': 0.9954104545512389, 'raw_r2': 0.9975265276350906}
Best loss: 0.039
Mean loss: 0.041



Ep: 2/  5. LR: 8.526e-05. Loss: 0.1735. Stats01: 0.0000: 100%|██████████| 3529/3529 [11:21<00:00,  5.18it/s]
Ep: 3/  5. LR: 8.468e-05. Loss: 0.1704. Stats01: 0.0000:   5%|▍         | 159/3529 [00:28<11:50,  4.75it/s]

{'rmse': 0.06159267123437039, 'raw_rmse': 23.781223332295205, 'mae': 0.04104777428582843, 'raw_mae': 9.126363688404696, 'r2': 0.9954272172640223, 'raw_r2': 0.9976832682233157}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 8.384e-05. Loss: 0.1757. Stats01: 0.0000:  10%|█         | 368/3529 [01:08<11:55,  4.42it/s]

{'rmse': 0.06038045069428087, 'raw_rmse': 23.35638982450719, 'mae': 0.0395391553244796, 'raw_mae': 8.922526016513904, 'r2': 0.9956233474499118, 'raw_r2': 0.9977872091338083}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 8.298e-05. Loss: 0.1687. Stats01: 0.0000:  16%|█▋        | 575/3529 [01:47<10:38,  4.63it/s]

{'rmse': 0.06088206172905238, 'raw_rmse': 24.969216839391084, 'mae': 0.039990545707070335, 'raw_mae': 9.1211752048071, 'r2': 0.9955478653978652, 'raw_r2': 0.9973887029692867}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 8.210e-05. Loss: 0.1673. Stats01: 0.0000:  22%|██▏       | 780/3529 [02:25<09:45,  4.70it/s]

{'rmse': 0.061463809004185625, 'raw_rmse': 24.837141254202482, 'mae': 0.040778041750034105, 'raw_mae': 9.378162319470105, 'r2': 0.9954380053543145, 'raw_r2': 0.9974585484391233}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 8.120e-05. Loss: 0.1732. Stats01: 0.0000:  28%|██▊       | 985/3529 [03:03<09:10,  4.62it/s]

{'rmse': 0.06147274062564401, 'raw_rmse': 24.13062165953456, 'mae': 0.041021937813805365, 'raw_mae': 9.180998164883782, 'r2': 0.9954554616685866, 'raw_r2': 0.9976053596146719}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 8.029e-05. Loss: 0.1706. Stats01: 0.0000:  34%|███▎      | 1190/3529 [03:40<08:17,  4.71it/s]

{'rmse': 0.06090869273284771, 'raw_rmse': 25.658092712512524, 'mae': 0.04007966256489574, 'raw_mae': 9.264904284969948, 'r2': 0.9955390014433338, 'raw_r2': 0.9972100068680974}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.935e-05. Loss: 0.1693. Stats01: 0.0000:  40%|███▉      | 1394/3529 [04:19<07:35,  4.69it/s]

{'rmse': 0.06171717224486795, 'raw_rmse': 29.30617762912351, 'mae': 0.04134474369293447, 'raw_mae': 9.551521482882746, 'r2': 0.9954210247751422, 'raw_r2': 0.9960314584302151}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.839e-05. Loss: 0.1661. Stats01: 0.0000:  45%|████▌     | 1598/3529 [04:56<05:15,  6.11it/s]

{'rmse': 0.06097235448446566, 'raw_rmse': 24.75127159732178, 'mae': 0.04030101732560693, 'raw_mae': 9.132311672744727, 'r2': 0.995538661471773, 'raw_r2': 0.9974766353642067}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.741e-05. Loss: 0.1634. Stats01: 0.0000:  51%|█████     | 1801/3529 [05:34<06:13,  4.63it/s]

{'rmse': 0.06219296147919238, 'raw_rmse': 31.916136001465027, 'mae': 0.041762856678391524, 'raw_mae': 10.09806727513858, 'r2': 0.9953355754721622, 'raw_r2': 0.9952818057453215}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.642e-05. Loss: 0.1627. Stats01: 0.0000:  57%|█████▋    | 2004/3529 [06:12<05:23,  4.71it/s]

{'rmse': 0.06145347304407352, 'raw_rmse': 26.681653215910185, 'mae': 0.04105747098596704, 'raw_mae': 9.310715329809337, 'r2': 0.995463905148839, 'raw_r2': 0.996917812514705}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.540e-05. Loss: 0.1614. Stats01: 0.0000:  63%|██████▎   | 2207/3529 [06:50<04:42,  4.68it/s]

{'rmse': 0.06078385098358259, 'raw_rmse': 23.913661442716442, 'mae': 0.040152934355339606, 'raw_mae': 9.041585274557987, 'r2': 0.9955604151337932, 'raw_r2': 0.9976547904089217}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.437e-05. Loss: 0.1619. Stats01: 0.0000:  68%|██████▊   | 2410/3529 [07:28<04:07,  4.52it/s]

{'rmse': 0.06079533218078779, 'raw_rmse': 23.57155471357603, 'mae': 0.04026002809546815, 'raw_mae': 8.978740035809738, 'r2': 0.9955630003033205, 'raw_r2': 0.9977288460338304}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.332e-05. Loss: 0.1550. Stats01: 0.0000:  74%|███████▍  | 2613/3529 [08:06<03:14,  4.71it/s]

{'rmse': 0.06165468852976447, 'raw_rmse': 23.376186980855362, 'mae': 0.041290906147719096, 'raw_mae': 9.059356250446466, 'r2': 0.9954295614141357, 'raw_r2': 0.997775588191438}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.225e-05. Loss: 0.1533. Stats01: 0.0000:  80%|███████▉  | 2816/3529 [08:44<02:36,  4.57it/s]

{'rmse': 0.06130434629418724, 'raw_rmse': 24.099818966801248, 'mae': 0.04065358795821664, 'raw_mae': 9.182221922260545, 'r2': 0.9954807541141244, 'raw_r2': 0.9976341642860783}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.117e-05. Loss: 0.1464. Stats01: 0.0000:  86%|████████▌ | 3019/3529 [09:22<01:50,  4.63it/s]

{'rmse': 0.06049618252473981, 'raw_rmse': 23.422108161807603, 'mae': 0.039658910259521486, 'raw_mae': 8.978703094858446, 'r2': 0.9956064426326425, 'raw_r2': 0.9977730585507413}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 7.006e-05. Loss: 0.1508. Stats01: 0.0000:  91%|█████████▏| 3221/3529 [10:00<01:05,  4.72it/s]

{'rmse': 0.06273849641541064, 'raw_rmse': 24.965717128908096, 'mae': 0.04286051090046133, 'raw_mae': 9.470771400235813, 'r2': 0.9952546639348764, 'raw_r2': 0.9974179770262155}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 6.894e-05. Loss: 0.1467. Stats01: 0.0000:  97%|█████████▋| 3424/3529 [10:37<00:22,  4.71it/s]

{'rmse': 0.060810621157123323, 'raw_rmse': 23.624569433142778, 'mae': 0.040028737011792215, 'raw_mae': 9.016327337345615, 'r2': 0.9955559329814543, 'raw_r2': 0.997728728795757}
Best loss: 0.039
Mean loss: 0.041



Ep: 3/  5. LR: 6.788e-05. Loss: 0.1680. Stats01: 0.0000: 100%|██████████| 3529/3529 [11:13<00:00,  5.24it/s]
Ep: 4/  5. LR: 6.781e-05. Loss: 0.1537. Stats01: 0.0000:   1%|          | 18/3529 [00:02<09:56,  5.88it/s]

{'rmse': 0.060842415076880436, 'raw_rmse': 23.458582485722935, 'mae': 0.040147601774333797, 'raw_mae': 8.937618110702203, 'r2': 0.9955593026951857, 'raw_r2': 0.9977621285641575}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 6.665e-05. Loss: 0.1495. Stats01: 0.0000:   7%|▋         | 234/3529 [00:40<10:12,  5.38it/s]

{'rmse': 0.06084481063392924, 'raw_rmse': 23.381823060272833, 'mae': 0.040294389902307895, 'raw_mae': 8.977709398007592, 'r2': 0.9955568369056407, 'raw_r2': 0.99777815236396}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 6.548e-05. Loss: 0.1468. Stats01: 0.0000:  13%|█▎        | 442/3529 [01:19<10:44,  4.79it/s]

{'rmse': 0.06136664330876476, 'raw_rmse': 23.8150806816414, 'mae': 0.040907354793328056, 'raw_mae': 9.074477131895538, 'r2': 0.9954691774928215, 'raw_r2': 0.9976733273671164}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 6.430e-05. Loss: 0.1504. Stats01: 0.0000:  18%|█▊        | 648/3529 [01:58<10:19,  4.65it/s]

{'rmse': 0.06088143386056666, 'raw_rmse': 23.669823946329586, 'mae': 0.040396312040284044, 'raw_mae': 8.979943615366215, 'r2': 0.9955480621782985, 'raw_r2': 0.9977053435459263}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 6.310e-05. Loss: 0.1488. Stats01: 0.0000:  24%|██▍       | 853/3529 [02:36<09:27,  4.72it/s]

{'rmse': 0.060685340675591955, 'raw_rmse': 23.655805926986535, 'mae': 0.03997430112848439, 'raw_mae': 9.046360752217398, 'r2': 0.9955730509722968, 'raw_r2': 0.9977194578095422}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 6.188e-05. Loss: 0.1483. Stats01: 0.0000:  30%|██▉       | 1058/3529 [03:14<08:58,  4.59it/s]

{'rmse': 0.0609080462584154, 'raw_rmse': 23.7689644504024, 'mae': 0.04034752447444726, 'raw_mae': 9.025252466861811, 'r2': 0.9955480330777225, 'raw_r2': 0.9976927081652475}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 6.065e-05. Loss: 0.1445. Stats01: 0.0000:  36%|███▌      | 1262/3529 [03:52<08:10,  4.63it/s]

{'rmse': 0.06246601927341428, 'raw_rmse': 24.47382734195868, 'mae': 0.042388236244521316, 'raw_mae': 9.369608466777574, 'r2': 0.9952936965226286, 'raw_r2': 0.9975330952281872}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.940e-05. Loss: 0.1508. Stats01: 0.0000:  42%|████▏     | 1466/3529 [04:30<07:31,  4.57it/s]

{'rmse': 0.06088352534800324, 'raw_rmse': 23.97983349012676, 'mae': 0.040083758629146454, 'raw_mae': 9.100867532379102, 'r2': 0.9955458781928311, 'raw_r2': 0.9976466693984479}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.814e-05. Loss: 0.1467. Stats01: 0.0000:  47%|████▋     | 1670/3529 [05:08<06:44,  4.60it/s]

{'rmse': 0.06166804505501992, 'raw_rmse': 23.548299901783516, 'mae': 0.04133980173662383, 'raw_mae': 9.082917379710059, 'r2': 0.9954240344504424, 'raw_r2': 0.9977382636550667}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.687e-05. Loss: 0.1469. Stats01: 0.0000:  53%|█████▎    | 1873/3529 [05:46<05:58,  4.61it/s]

{'rmse': 0.061964132907596225, 'raw_rmse': 23.724261900214334, 'mae': 0.041739308877865755, 'raw_mae': 9.137051862934445, 'r2': 0.9953846314745808, 'raw_r2': 0.997700155739075}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.558e-05. Loss: 0.1473. Stats01: 0.0000:  59%|█████▉    | 2076/3529 [06:24<05:13,  4.64it/s]

{'rmse': 0.06308499838889983, 'raw_rmse': 24.24191147135997, 'mae': 0.04324763454215799, 'raw_mae': 9.386584727184829, 'r2': 0.995198222771229, 'raw_r2': 0.9975805991064881}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.428e-05. Loss: 0.1530. Stats01: 0.0000:  65%|██████▍   | 2280/3529 [07:03<03:45,  5.53it/s]

{'rmse': 0.06065566241916097, 'raw_rmse': 24.029643678001197, 'mae': 0.03965713804848559, 'raw_mae': 9.128432628534053, 'r2': 0.9955753052399606, 'raw_r2': 0.9976433260342757}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.296e-05. Loss: 0.1551. Stats01: 0.0000:  70%|███████   | 2482/3529 [07:42<04:16,  4.09it/s]

{'rmse': 0.06058815021931019, 'raw_rmse': 23.242229414300112, 'mae': 0.03985432867981239, 'raw_mae': 8.905560031200508, 'r2': 0.995597638973871, 'raw_r2': 0.9978107701800294}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.164e-05. Loss: 0.1488. Stats01: 0.0000:  76%|███████▌  | 2685/3529 [08:21<03:06,  4.53it/s]

{'rmse': 0.06184960294772435, 'raw_rmse': 23.999313773502987, 'mae': 0.041517019383300734, 'raw_mae': 9.15399600207415, 'r2': 0.995399703318769, 'raw_r2': 0.9976409918461228}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 5.030e-05. Loss: 0.1490. Stats01: 0.0000:  82%|████████▏ | 2888/3529 [08:59<02:22,  4.51it/s]

{'rmse': 0.0612831245851907, 'raw_rmse': 23.480655919594025, 'mae': 0.04089104829633216, 'raw_mae': 9.037964376749688, 'r2': 0.995484203642583, 'raw_r2': 0.9977519144097263}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.894e-05. Loss: 0.1489. Stats01: 0.0000:  88%|████████▊ | 3091/3529 [09:37<01:35,  4.61it/s]

{'rmse': 0.061639267354480595, 'raw_rmse': 23.66998948347947, 'mae': 0.04131452931795142, 'raw_mae': 9.116828692185585, 'r2': 0.9954355037957969, 'raw_r2': 0.9977126993464902}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.758e-05. Loss: 0.1473. Stats01: 0.0000:  93%|█████████▎| 3293/3529 [10:16<00:53,  4.45it/s]

{'rmse': 0.06272339106539317, 'raw_rmse': 24.64647550193337, 'mae': 0.0428300231792218, 'raw_mae': 9.445568923251187, 'r2': 0.9952570396558726, 'raw_r2': 0.9974978616544501}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.620e-05. Loss: 0.1505. Stats01: 0.0000:  99%|█████████▉| 3496/3529 [10:55<00:07,  4.54it/s]

{'rmse': 0.06024467747250315, 'raw_rmse': 23.794086696975537, 'mae': 0.039226465788304486, 'raw_mae': 9.023921827062413, 'r2': 0.9956377511923159, 'raw_r2': 0.9976974200305981}
Best loss: 0.039
Mean loss: 0.041



Ep: 4/  5. LR: 4.540e-05. Loss: 0.1499. Stats01: 0.0000: 100%|██████████| 3529/3529 [11:17<00:00,  5.21it/s]
Ep: 5/  5. LR: 4.482e-05. Loss: 0.1508. Stats01: 0.0000:   3%|▎         | 97/3529 [00:16<12:38,  4.53it/s]

{'rmse': 0.06139721145423023, 'raw_rmse': 24.009141130759854, 'mae': 0.04099189366035496, 'raw_mae': 9.155987825309246, 'r2': 0.9954715906797035, 'raw_r2': 0.9976492254431081}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 4.342e-05. Loss: 0.1470. Stats01: 0.0000:   9%|▊         | 307/3529 [00:55<09:51,  5.45it/s]

{'rmse': 0.060736958365285705, 'raw_rmse': 23.47945387533652, 'mae': 0.040085573698454995, 'raw_mae': 8.965962615134714, 'r2': 0.9955702381467345, 'raw_r2': 0.9977588641548988}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 4.201e-05. Loss: 0.1519. Stats01: 0.0000:  15%|█▍        | 515/3529 [01:34<11:43,  4.28it/s]

{'rmse': 0.06132937484612308, 'raw_rmse': 23.96048474421718, 'mae': 0.04083349926104877, 'raw_mae': 9.09556116871783, 'r2': 0.9954822788204358, 'raw_r2': 0.9976545515444645}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 4.059e-05. Loss: 0.1508. Stats01: 0.0000:  20%|██        | 721/3529 [02:14<10:55,  4.28it/s]

{'rmse': 0.06065924787138635, 'raw_rmse': 23.494829552917945, 'mae': 0.03992788669905327, 'raw_mae': 8.968577569603887, 'r2': 0.9955849801606278, 'raw_r2': 0.9977567584507469}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.916e-05. Loss: 0.1527. Stats01: 0.0000:  26%|██▌       | 926/3529 [02:53<10:11,  4.26it/s]

{'rmse': 0.06033870552710087, 'raw_rmse': 23.22351772435729, 'mae': 0.03948376599915894, 'raw_mae': 8.901754258663807, 'r2': 0.9956343237745144, 'raw_r2': 0.9978134591237243}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.772e-05. Loss: 0.1521. Stats01: 0.0000:  32%|███▏      | 1131/3529 [03:33<07:50,  5.09it/s]

{'rmse': 0.060517098664459285, 'raw_rmse': 23.800163637810094, 'mae': 0.03979776098244496, 'raw_mae': 8.997834058963658, 'r2': 0.9956048418611526, 'raw_r2': 0.9976932144633872}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.628e-05. Loss: 0.1453. Stats01: 0.0000:  38%|███▊      | 1335/3529 [04:13<08:06,  4.51it/s]

{'rmse': 0.06033269502662217, 'raw_rmse': 23.765732789872892, 'mae': 0.039365502543929086, 'raw_mae': 9.048520883401071, 'r2': 0.9956249680227114, 'raw_r2': 0.997702690148928}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.482e-05. Loss: 0.1493. Stats01: 0.0000:  44%|████▎     | 1538/3529 [04:53<07:37,  4.35it/s]

{'rmse': 0.060600983140893065, 'raw_rmse': 23.151267152152045, 'mae': 0.03989168796545729, 'raw_mae': 8.89818612100162, 'r2': 0.9955969673393965, 'raw_r2': 0.9978281372107505}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.335e-05. Loss: 0.1480. Stats01: 0.0000:  49%|████▉     | 1742/3529 [05:32<06:35,  4.52it/s]

{'rmse': 0.06035681481930726, 'raw_rmse': 23.460083318061127, 'mae': 0.0394775657835021, 'raw_mae': 8.960568081603872, 'r2': 0.9956292893691229, 'raw_r2': 0.9977652315187279}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.188e-05. Loss: 0.1459. Stats01: 0.0000:  55%|█████▌    | 1945/3529 [06:13<06:08,  4.29it/s]

{'rmse': 0.0622574044467698, 'raw_rmse': 23.706735596125686, 'mae': 0.04216525649076626, 'raw_mae': 9.198712532494888, 'r2': 0.9953366921313207, 'raw_r2': 0.9977073058221967}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 3.040e-05. Loss: 0.1493. Stats01: 0.0000:  61%|██████    | 2149/3529 [06:53<04:42,  4.88it/s]

{'rmse': 0.060784094239131614, 'raw_rmse': 23.506126475474833, 'mae': 0.040082914467981186, 'raw_mae': 8.98366538846795, 'r2': 0.9955656241200421, 'raw_r2': 0.9977484821445859}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.891e-05. Loss: 0.1518. Stats01: 0.0000:  67%|██████▋   | 2352/3529 [07:34<04:12,  4.67it/s]

{'rmse': 0.062249868853407296, 'raw_rmse': 24.231328468474345, 'mae': 0.04218599838184246, 'raw_mae': 9.265673436849648, 'r2': 0.9953367531793297, 'raw_r2': 0.9975824413639695}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.742e-05. Loss: 0.1494. Stats01: 0.0000:  72%|███████▏  | 2554/3529 [08:14<03:41,  4.40it/s]

{'rmse': 0.06028885318054654, 'raw_rmse': 23.545071499624196, 'mae': 0.03941669586170345, 'raw_mae': 8.953738916495027, 'r2': 0.995638432339367, 'raw_r2': 0.9977499967093929}
Best loss: 0.039
Mean loss: 0.040



Ep: 5/  5. LR: 2.592e-05. Loss: 0.1512. Stats01: 0.0000:  78%|███████▊  | 2757/3529 [08:54<03:12,  4.00it/s]

{'rmse': 0.06175215044535258, 'raw_rmse': 23.66591121405165, 'mae': 0.041452253760460994, 'raw_mae': 9.101542818867742, 'r2': 0.9954207728519803, 'raw_r2': 0.9977125305912545}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.441e-05. Loss: 0.1480. Stats01: 0.0000:  84%|████████▍ | 2960/3529 [09:35<02:23,  3.97it/s]

{'rmse': 0.06101077091235794, 'raw_rmse': 23.042449952341588, 'mae': 0.04047321336583531, 'raw_mae': 8.932727847995963, 'r2': 0.9955309517461093, 'raw_r2': 0.9978503784320162}
Best loss: 0.039
Mean loss: 0.041



Ep: 5/  5. LR: 2.290e-05. Loss: 0.1464. Stats01: 0.0000:  90%|████████▉ | 3163/3529 [10:15<01:11,  5.11it/s]

{'rmse': 0.06034464127297432, 'raw_rmse': 23.28224368990248, 'mae': 0.03953489148515561, 'raw_mae': 8.913194524065725, 'r2': 0.9956314953157523, 'raw_r2': 0.9978072609586317}
Best loss: 0.039
Mean loss: 0.040



Ep: 5/  5. LR: 2.138e-05. Loss: 0.1478. Stats01: 0.0000:  95%|█████████▌| 3365/3529 [10:55<00:36,  4.44it/s]

{'rmse': 0.06115983806344767, 'raw_rmse': 23.324419924046822, 'mae': 0.04066168403942654, 'raw_mae': 8.990140150823585, 'r2': 0.995508665243521, 'raw_r2': 0.9977891107932272}
Best loss: 0.039
Mean loss: 0.040



Ep: 5/  5. LR: 1.985e-05. Loss: 0.1485. Stats01: 0.0000: : 3567it [11:35,  4.44it/s]                        

{'rmse': 0.06056229420831543, 'raw_rmse': 23.247154831373056, 'mae': 0.039937126165006664, 'raw_mae': 8.896663438596937, 'r2': 0.9956000246659883, 'raw_r2': 0.9978002346163533}
Best loss: 0.039
Mean loss: 0.040



Ep: 5/  5. LR: 1.951e-05. Loss: 0.1463. Stats01: 0.0000: 100%|██████████| 3529/3529 [11:44<00:00,  5.01it/s]


In [11]:
best_loss

0.039226465788304486

In [12]:
data, labels = train_folds[2][0], train_folds[2][1]
eval_ensemble(args, best_model, data, labels, args.device, [args.seq_len])

{'rmse': 0.0472313437872839,
 'raw_rmse': 17.395436049034412,
 'mae': 0.03443736801579683,
 'raw_mae': 7.459272183469157,
 'r2': 0.997348623413033,
 'raw_r2': 0.9985920600647545}

In [13]:
data, labels = train_folds[3][0], train_folds[3][1]
eval_ensemble(args, best_model, data, labels, args.device, [args.seq_len])

{'rmse': 0.06024467747250315,
 'raw_rmse': 23.794086696975537,
 'mae': 0.039226465788304486,
 'raw_mae': 9.023921827062413,
 'r2': 0.9956377511923159,
 'raw_r2': 0.9976974200305981}

In [14]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [args.seq_len])

{'rmse': 0.05925776160656928,
 'raw_rmse': 29.420537020318868,
 'mae': 0.04265291584034904,
 'raw_mae': 10.661486556444133,
 'r2': 0.9959974936382877,
 'raw_r2': 0.9968750435953792}

In [15]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [32, 64, 96])

{'rmse': 0.059266801908728205,
 'raw_rmse': 29.375503162923227,
 'mae': 0.042657423766604745,
 'raw_mae': 10.660737938296819,
 'r2': 0.9959959232576772,
 'raw_r2': 0.996883588096744}

In [16]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [64, 96, 128])

{'rmse': 0.05925928017253247,
 'raw_rmse': 29.410658576827814,
 'mae': 0.04265375605845798,
 'raw_mae': 10.66109554258588,
 'r2': 0.9959972308138646,
 'raw_r2': 0.996876998234142}

In [17]:
data, labels = test_fold[0], test_fold[1]
model = copy.deepcopy(best_model)
used_features = [0, 1, 2, 3, 4, 5]
data = np.log1p(data[:, :, used_features])
labels = np.log1p(labels[:, :, used_features])

seq_lens = [64, 96, 128]

labels = torch.tensor(labels).float().to(args.device)
n_nodes, horizon, n_feats = labels.shape
y_preds = np.zeros((len(seq_lens), n_nodes, horizon, n_feats-1)) # Bỏ Vol
model.eval().to(args.device)
for i, seq_len in enumerate(seq_lens):
    batch = torch.tensor(data[:, -seq_len:]).float().to(args.device)
    with torch.no_grad():
        y_all = model.forecast(batch, horizon=horizon)
    y_preds[i] = y_all[:, -horizon:, :n_feats-1].detach().cpu().numpy() # OHLC + Adj Close
y_gt = labels[:, :, :n_feats-1].reshape(n_nodes, -1).detach().cpu().numpy() # OHLC + Adj Close
y_pred = y_preds.mean(axis=0).reshape(n_nodes, -1)


In [18]:
mean_absolute_error(np.expm1(y_gt), np.expm1(y_pred))

10.661093666877838