In [1]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

manual_seed()


# Data

In [2]:
from utils import load_data, load_edge

train_folds = load_data(True)
test_fold = load_data(False)[0]

In [3]:
import torch
from torch.utils.data import DataLoader, Dataset

class SimpleStockDataset(Dataset):
    def __init__(self, data, ws=128):
        self.data = data
        self.ws = ws
        self.samples = []
        
        self.n_tickers, self.n_days, self.n_features = self.data.shape
        
        for start in range(self.n_days - self.ws + 1):
            self.samples.append(start)
            
    def __len__(self):
        return len(self.samples)
      
    def __getitem__(self, idx):
        start = self.samples[idx]
        x = torch.tensor(self.data[:, start:start + self.ws], dtype=torch.float32)
        return x
    

# Model

### LSTM

In [36]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyRNN(nn.Module):
    def __init__(self, n_nodes: int, n_feats: int, args, node_emb=None):
        super().__init__()
        self.n_nodes = n_nodes
        self.rank_A = 32
        self.n_feats = n_feats
        self.args = args
        self.d_latent = getattr(args, "d_latent", 128)
        self.dropout = getattr(args, "dropout", 0.0)
        self.n_layers = getattr(args, "n_layers", 1)
        self.d_node = getattr(args, "d_node", 32)
        
        

        if node_emb is None:
            self.static_node_features = nn.Parameter(torch.randn(n_nodes, self.d_latent) * 0.1)
            self.fc_node_features = nn.Identity()
        else:
            self.static_node_features = nn.Parameter(node_emb, requires_grad=False)
            self.d_node = self.static_node_features.shape[-1]
            self.fc_node_features = nn.Linear(self.d_node, self.d_latent)
        
        self.cells = nn.ModuleList([
            nn.LSTMCell(n_feats if i == 0 else self.d_latent, self.d_latent)
            for i in range(self.n_layers)
        ])
        
        self.readout = nn.Sequential(
            nn.Linear(self.d_latent, self.d_latent), nn.ReLU(),
            nn.Linear(self.d_latent, n_feats)
        )
        
        self.enc_in = None

    def _init_states(
        self, X_0: torch.Tensor, H_0: Optional[torch.Tensor]
    ):
        """
        Trả về list (h_list, c_list) cho từng layer.
        - Nếu H_0 cung cấp (N, d), đặt h_top = H_0; ngược lại nếu enc_in != None thì h_top = enc(X_0); còn lại h=0.
        - c luôn = 0.
        """
        device = X_0.device
        dtype = X_0.dtype
        N = X_0.shape[0]

        h_list = []
        c_list = []
        for _ in range(self.n_layers):
            h_list.append(self.node_features)
            c_list.append(self.node_features)

        if H_0 is not None:
            h_list[-1] = H_0.to(device=device, dtype=dtype)
        elif self.enc_in is not None:
            h_list[-1] = self.enc_in(X_0)

        return h_list, c_list

    def _step(self, x_t: torch.Tensor, h_list, c_list):
        """
        Chạy 1 bước LSTM qua tất cả các layer.
        Input layer nhận x_t; các layer sau nhận h của layer trước ở cùng time-step.
        Trả về (h_list_new, c_list_new, h_top).
        """
        inp = x_t
        new_h, new_c = [], []
        for l, cell in enumerate(self.cells):
            h_t, c_t = cell(inp, (h_list[l], c_list[l]))
            new_h.append(h_t)
            new_c.append(c_t)
            inp = h_t
        h_top = new_h[-1]
        if self.training and self.dropout > 0:
            h_top = F.dropout(h_top, p=self.dropout, training=True)
        return new_h, new_c, h_top

    def forecast(self, X, H_0=None, horizon=0):
        y, *_ = self.forward(X, H_0, horizon)
        return y

    def forward(self, X: torch.Tensor,
                H_0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Run deterministic inference.
        Args:
            X: (N, T, F) input features at each t
            H_0: (N, d) optional initial latent state (default zeros)
            Atilde, L: if precomputed; otherwise build from X
        Returns:
            y_pred: (N, T-1) predicted next-day log-return for each node
            H_all: (N, T, d) latent states (including H_0 as first)
        """
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        n_nodes, n_steps, n_feats = X.shape
        d_latent = self.d_latent
        
        self.node_features = self.fc_node_features(self.static_node_features)

        X_0 = X[:, 0, :] # (N, F)
        h_list, c_list = self._init_states(X_0, H_0)

        H_all = torch.zeros(n_nodes, n_steps + max(horizon, 0), d_latent, device=device, dtype=X.dtype)
        r_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)
        y_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)


        H_all[:, 0] = h_list[-1]
        for t in range(n_steps - 1):
            x_t = X[:, t, :]  # (N, F)
            h_list, c_list, h_top = self._step(x_t, h_list, c_list)
            r_t = self.readout(h_top)
            y_t = x_t + r_t

            H_all[:, t + 1] = h_top
            r_all[:, t] = r_t
            y_all[:, t] = y_t

        if horizon > 0:
            x_t = X[:, -1, :]
            for s in range(horizon):
                h_list, c_list, h_top = self._step(x_t, h_list, c_list)
                r_t = self.readout(h_top)
                y_t = x_t + r_t

                idx = (n_steps - 1) + s
                y_all[:, idx] = y_t
                r_all[:, idx] = r_t
                H_all[:, n_steps + s] = h_top

                x_t = y_t

        return y_all, r_all, H_all

    def forward_loss(
        self,
        X: torch.Tensor,
        H_0: Optional[torch.Tensor] = None,
        horizon: int = 0,
    ):
        """
        Tính loss dự báo one-step + rollout horizon (autoregressive) trên chuỗi đầu vào.

        Args:
            X: (N, T, F) chuỗi gốc (chứa full ground-truth đến T-1)
            H_0: (N, d) latent init (tuỳ chọn)
            horizon: số bước rollout ngoài quan sát cuối cùng
                    (nếu >0, ta cắt input để tránh nhìn thấy tương lai)
            reduction: 'mean' | 'sum' | 'none' cho F.mse_loss

        Returns:
            loss: scalar tensor
            y_pred: (N, T-1, F) dự báo X_{t+1} cho toàn bộ t=0..T-2
        """
        device = next(self.parameters()).device
        X = X.to(device)
        N, T, Fdim = X.shape

        if horizon < 0:
            raise ValueError("horizon must be >= 0")
        if horizon >= T:
            raise ValueError(f"horizon={horizon} must be < sequence length T={T}")

        # Cắt input nếu rollout > 0 để giữ đúng số target (T-1)
        X_in = X[:, : T - horizon, :] if horizon > 0 else X

        # forward() trả (N, (T-horizon)-1 + horizon, F) = (N, T-1, F)
        y_ar, _, _ = self.forward(X_in, H_0=H_0, horizon=horizon)

        # Ground truth luôn là X_{1:T}
        target = X[:, 1:, :]  # (N, T-1, F)
        
        err_t = (y_ar - target).abs().mean(dim=-1).mean(dim=0)
        
        decay = 0.9
        coef_pre = 1.0
        coef_roll = 1.0
        
        len_pre = T - horizon - 1
        loss_pre = torch.tensor(0.0, device=device, dtype=err_t.dtype)
        loss_roll = torch.tensor(0.0, device=device, dtype=err_t.dtype)

        if len_pre > 0:
            idx = torch.arange(len_pre - 1, -1, -1, device=device, dtype=err_t.dtype)
            w = decay ** idx
            loss_pre = (w * err_t[:len_pre]).sum() / (w.sum() + 1e-12)

        if horizon > 0:
            # MSE đều cho đoạn rollout (chiều dài = horizon)
            loss_roll = err_t[len_pre:].mean()

        loss = coef_pre * loss_pre + coef_roll * loss_roll
        
        return loss

# Eval & Train

In [45]:
from sklearn.metrics import *

def eval_ensemble(args, model, training_data_np, testing_data_np, device='cuda', seq_lens=[64, 96, 128], verbose=False):

    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = np.log1p(training_data_np[:, :, used_features])
    testing_data_np = np.log1p(testing_data_np[:, :, used_features])
    
    labels = torch.tensor(testing_data_np).float().to(device)
    n_nodes, horizon, n_feats = labels.shape
    y_preds = np.zeros((len(seq_lens), n_nodes, horizon, n_feats-1)) # Bỏ Vol
    model.eval().to(device)
    for i, seq_len in enumerate(seq_lens):
        batch = torch.tensor(training_data_np[:, -seq_len:]).float().to(device)
        with torch.no_grad():
            y_all = model.forecast(batch, horizon=horizon)
        y_preds[i] = y_all[:, -horizon:, :n_feats-1].detach().cpu().numpy() # OHLC + Adj Close
    y_gt = testing_data_np[:, :, :n_feats-1].reshape(n_nodes, -1) # OHLC + Adj Close
    y_pred = y_preds.mean(axis=0).reshape(n_nodes, -1)
    
    if verbose:
        print("Max var:", np.var(np.expm1(y_preds), axis=0).max())
        print("Mean var:", np.var(np.expm1(y_preds), axis=0).mean())
    return {
        'rmse': root_mean_squared_error(y_gt, y_pred), 
        'raw_rmse': root_mean_squared_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'mae': mean_absolute_error(y_gt, y_pred), 
        'raw_mae': mean_absolute_error(np.expm1(y_gt), np.expm1(y_pred)), 
        'r2': r2_score(y_gt.ravel(), y_pred.ravel()),
        'raw_r2': r2_score(np.expm1(y_gt).ravel(), np.expm1(y_pred).ravel()), 
    }

def eval(args, model, training_data_np, testing_data_np, device='cuda'):
    return eval_ensemble(args, model, training_data_np, testing_data_np, device, [args.seq_len])

In [46]:
from tqdm import tqdm
import time

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.avg == 0:
            self.avg = val
            return
        self.avg = 0.95 * self.avg + 0.05 * val

import copy
def train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np):
    # if args.amp:
    #     from apex import amp
    global best_loss, best_model
    test_losses = []
    end = time.time()

    best_model = copy.deepcopy(model)
    step = 0
    for epoch in range(args.epochs):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        stats01 = AverageMeter()
        stats02 = AverageMeter()
        p_bar = tqdm(train_loader)
        for batch_idx, samples in enumerate(p_bar):
            step += 1
            model.train().to(args.device)
          
            samples = samples[0].float().to(args.device)
            data_time.update(time.time() - end)

            loss = model.forward_loss(samples, horizon=args.horizon)
            
            loss.backward()
            
            max_norm = 5.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item()
            
            optimizer.step()
            scheduler.step()

            losses.update(loss.item())
            stats01.update(0.0)
            # stats01.update(power_iteration_lmax_sym(poly_laplacian(model.L, F.softplus(model.kappa))))

            batch_time.update(time.time() - end)
            end = time.time()
            # mask_probs.update(mask.mean().item())
            p_bar.set_description(
                "Ep: {epoch}/{epochs:3}. LR: {lr:.3e}. "
                "Loss: {loss:.4f}. Stats01: {stats01:.4f}".format(
                epoch=epoch + 1,
                epochs=args.epochs,
                lr=scheduler.get_last_lr()[0],
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
                stats01=stats01.avg,
            ))
            p_bar.update()
            
            if (step + 1) % args.eval_steps == 0:
                test_model = model

                test_metrics = eval_ensemble(args, test_model, training_data_np, testing_data_np, args.device, [args.seq_len])
                print(test_metrics)
                test_loss = test_metrics['mae']

                is_best = test_loss < best_loss
                if test_loss < best_loss:
                    best_loss = test_loss
                    best_model = copy.deepcopy(test_model)


                test_losses.append(test_loss)
                print('Best loss: {:.3f}'.format(best_loss))
                print('Mean loss: {:.3f}\n'.format(
                    np.mean(test_losses[-20:])))


In [39]:
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import math

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

# Training

In [40]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

class Config:
    # Training
    epochs = 5
    eval_steps = 200
    lr = 1e-4
    wd = 1e-3
    warmup = 0
    
    n_nodes = 428
    n_feats = 6
    
    # Prediction
    seq_len = 96
    horizon = 32
    
    # LSTM
    d_node: int = None
    d_latent: int = 128
    device: str = "cuda"
    n_layers = 1
    dropout = 0.0
    seed = 42

args = Config()
manual_seed(args.seed)


In [41]:
node_emb_path="../input/static_node_emb.npy"
node_emb_matrix = torch.tensor(np.load(node_emb_path), dtype=torch.float32)


def train_fold(fold_idx):
    global training_data_np, testing_data_np
    used_features = [0, 1, 2, 3, 4, 5]
    training_data_np = np.log1p(train_folds[fold_idx][0][:, :, used_features])
    testing_data_np = np.log1p(train_folds[fold_idx][1][:, :, used_features])
    train_loader = DataLoader(SimpleStockDataset(training_data_np, args.seq_len + args.horizon), batch_size=1, shuffle=True)
    
    manual_seed(args.seed + fold_idx)
        
    model = MyRNN(args.n_nodes, args.n_feats, args, node_emb=node_emb_matrix)

    from torch.optim import AdamW

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    total_steps = args.epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup,
        num_training_steps=total_steps,
    )
    
    print(f'Model size: {get_model_size(model) * 1e3:.2f}K')
    
    # Sanity check
    print(eval_ensemble(args, model, training_data_np, testing_data_np, args.device, [2, 4]))
    
    global best_loss, best_model
    best_loss = 9999
    best_model = copy.deepcopy(model)

    train(args, train_loader, model, optimizer, scheduler, training_data_np, testing_data_np)
    
    return best_model, best_loss

In [42]:
best_model, best_loss = train_fold(3)

Model size: 89.35K
{'rmse': 0.46888273229933636, 'raw_rmse': 228.89217237270444, 'mae': 0.45586857612719023, 'raw_mae': 109.66308902441565, 'r2': 0.6031935891446447, 'raw_r2': 0.5531780766624103}


Ep: 1/  5. LR: 9.999e-05. Loss: 0.1940. Stats01: 0.0000:   6%|▋         | 224/3529 [00:20<05:27, 10.10it/s]

{'rmse': 0.07182815531119051, 'raw_rmse': 43.07875386723515, 'mae': 0.05268500006172838, 'raw_mae': 12.965463007756508, 'r2': 0.9934916353456135, 'raw_r2': 0.9896192413943248}
Best loss: 0.053
Mean loss: 0.053



Ep: 1/  5. LR: 9.995e-05. Loss: 0.1959. Stats01: 0.0000:  12%|█▏        | 432/3529 [00:41<05:53,  8.75it/s]

{'rmse': 0.06960421044365066, 'raw_rmse': 31.528973768120498, 'mae': 0.050292261478409754, 'raw_mae': 11.449900379297675, 'r2': 0.9938840674294743, 'raw_r2': 0.9956306745161979}
Best loss: 0.050
Mean loss: 0.051



Ep: 1/  5. LR: 9.989e-05. Loss: 0.2065. Stats01: 0.0000:  18%|█▊        | 638/3529 [01:02<05:05,  9.47it/s]

{'rmse': 0.0895116062239669, 'raw_rmse': 40.79189859017345, 'mae': 0.07385002403355073, 'raw_mae': 16.027142400354812, 'r2': 0.9895880771855532, 'raw_r2': 0.9921929731483442}
Best loss: 0.050
Mean loss: 0.059



Ep: 1/  5. LR: 9.981e-05. Loss: 0.1943. Stats01: 0.0000:  24%|██▍       | 842/3529 [01:23<04:31,  9.89it/s]

{'rmse': 0.06606322588027104, 'raw_rmse': 28.254643985471816, 'mae': 0.04611010075352794, 'raw_mae': 10.654177294619343, 'r2': 0.9946185041136574, 'raw_r2': 0.9964603961014625}
Best loss: 0.046
Mean loss: 0.056



Ep: 1/  5. LR: 9.970e-05. Loss: 0.1868. Stats01: 0.0000:  30%|██▉       | 1047/3529 [01:44<04:37,  8.94it/s]

{'rmse': 0.06757592217714181, 'raw_rmse': 28.548101193226664, 'mae': 0.047891619504766066, 'raw_mae': 11.021126162367064, 'r2': 0.9942503854910125, 'raw_r2': 0.996072881061458}
Best loss: 0.046
Mean loss: 0.054



Ep: 1/  5. LR: 9.956e-05. Loss: 0.1920. Stats01: 0.0000:  35%|███▌      | 1252/3529 [02:05<04:08,  9.16it/s]

{'rmse': 0.07281645929251852, 'raw_rmse': 25.47544471629287, 'mae': 0.05410289704748475, 'raw_mae': 11.017958745873706, 'r2': 0.9932830867414743, 'raw_r2': 0.9972830649436247}
Best loss: 0.046
Mean loss: 0.054



Ep: 1/  5. LR: 9.941e-05. Loss: 0.1931. Stats01: 0.0000:  41%|████▏     | 1456/3529 [02:26<03:37,  9.54it/s]

{'rmse': 0.07505662759708727, 'raw_rmse': 39.42729336443518, 'mae': 0.05542733154847641, 'raw_mae': 14.283787472894575, 'r2': 0.9925868943761603, 'raw_r2': 0.9917311476194499}
Best loss: 0.046
Mean loss: 0.054



Ep: 1/  5. LR: 9.922e-05. Loss: 0.1941. Stats01: 0.0000:  47%|████▋     | 1660/3529 [02:47<03:13,  9.66it/s]

{'rmse': 0.08569452364261577, 'raw_rmse': 37.20953734883838, 'mae': 0.06916233090228643, 'raw_mae': 15.99510777747693, 'r2': 0.9901264940284193, 'raw_r2': 0.9937051560155201}
Best loss: 0.046
Mean loss: 0.056



Ep: 1/  5. LR: 9.902e-05. Loss: 0.1896. Stats01: 0.0000:  53%|█████▎    | 1862/3529 [03:08<03:10,  8.75it/s]

{'rmse': 0.07537168787415256, 'raw_rmse': 32.40798133275076, 'mae': 0.056214175099790924, 'raw_mae': 12.97155915304745, 'r2': 0.9925159363227676, 'raw_r2': 0.9951352973517237}
Best loss: 0.046
Mean loss: 0.056



Ep: 1/  5. LR: 9.879e-05. Loss: 0.1987. Stats01: 0.0000:  59%|█████▊    | 2065/3529 [03:28<02:38,  9.26it/s]

{'rmse': 0.06539736744988048, 'raw_rmse': 25.731904883814607, 'mae': 0.04541857600750157, 'raw_mae': 10.207572381040961, 'r2': 0.9947569524121149, 'raw_r2': 0.9971349876194906}
Best loss: 0.045
Mean loss: 0.055



Ep: 1/  5. LR: 9.853e-05. Loss: 0.1903. Stats01: 0.0000:  64%|██████▍   | 2268/3529 [03:50<02:25,  8.64it/s]

{'rmse': 0.07216026754008478, 'raw_rmse': 28.898813224422533, 'mae': 0.05378889895218163, 'raw_mae': 11.799192815293818, 'r2': 0.9933534365656842, 'raw_r2': 0.9964415926291599}
Best loss: 0.045
Mean loss: 0.055



Ep: 1/  5. LR: 9.826e-05. Loss: 0.1845. Stats01: 0.0000:  70%|███████   | 2472/3529 [04:10<01:53,  9.28it/s]

{'rmse': 0.0679176449187125, 'raw_rmse': 27.04868781298881, 'mae': 0.04817647742500004, 'raw_mae': 11.139821469884213, 'r2': 0.9942392088038022, 'raw_r2': 0.9967892987905386}
Best loss: 0.045
Mean loss: 0.054



Ep: 1/  5. LR: 9.796e-05. Loss: 0.1807. Stats01: 0.0000:  76%|███████▌  | 2674/3529 [04:31<01:36,  8.82it/s]

{'rmse': 0.07420412681254615, 'raw_rmse': 32.01015164359165, 'mae': 0.055959130472865895, 'raw_mae': 12.360763907251894, 'r2': 0.9930199028812314, 'raw_r2': 0.9954179700374993}
Best loss: 0.045
Mean loss: 0.055



Ep: 1/  5. LR: 9.763e-05. Loss: 0.1882. Stats01: 0.0000:  82%|████████▏ | 2877/3529 [04:52<01:10,  9.28it/s]

{'rmse': 0.07215744238265644, 'raw_rmse': 34.714527384118874, 'mae': 0.053712357228393584, 'raw_mae': 12.205771444555886, 'r2': 0.9935105620678514, 'raw_r2': 0.9943843431402581}
Best loss: 0.045
Mean loss: 0.054



Ep: 1/  5. LR: 9.728e-05. Loss: 0.1838. Stats01: 0.0000:  87%|████████▋ | 3079/3529 [05:12<00:50,  8.92it/s]

{'rmse': 0.06570658249689662, 'raw_rmse': 31.155722802474674, 'mae': 0.04636818585702957, 'raw_mae': 11.10525020201129, 'r2': 0.994703300744449, 'raw_r2': 0.9956546885008333}
Best loss: 0.045
Mean loss: 0.054



Ep: 1/  5. LR: 9.691e-05. Loss: 0.1825. Stats01: 0.0000:  93%|█████████▎| 3283/3529 [05:33<00:23, 10.51it/s]

{'rmse': 0.07480293241231742, 'raw_rmse': 34.6666289406968, 'mae': 0.057218444861610186, 'raw_mae': 13.303649259301103, 'r2': 0.992950670703965, 'raw_r2': 0.9946181582422198}
Best loss: 0.045
Mean loss: 0.054



Ep: 1/  5. LR: 9.651e-05. Loss: 0.1810. Stats01: 0.0000:  99%|█████████▊| 3484/3529 [05:54<00:05,  8.80it/s]

{'rmse': 0.06607588682952018, 'raw_rmse': 27.39070608292895, 'mae': 0.04707362527295211, 'raw_mae': 10.571919422127756, 'r2': 0.9946676086220448, 'raw_r2': 0.9967712759971954}
Best loss: 0.045
Mean loss: 0.054



Ep: 1/  5. LR: 9.625e-05. Loss: 0.1790. Stats01: 0.0000: 100%|██████████| 3529/3529 [06:07<00:00,  9.60it/s]
Ep: 2/  5. LR: 9.609e-05. Loss: 0.1839. Stats01: 0.0000:   2%|▏         | 87/3529 [00:07<04:46, 12.00it/s]

{'rmse': 0.0681475728904752, 'raw_rmse': 30.559121598782355, 'mae': 0.04789303450735533, 'raw_mae': 11.450149859250086, 'r2': 0.9941534861345805, 'raw_r2': 0.9958635625507001}
Best loss: 0.045
Mean loss: 0.053



Ep: 2/  5. LR: 9.565e-05. Loss: 0.1869. Stats01: 0.0000:   9%|▊         | 301/3529 [00:27<05:29,  9.81it/s]

{'rmse': 0.06502119871411208, 'raw_rmse': 29.273921055567325, 'mae': 0.04458178709233409, 'raw_mae': 10.596522046716471, 'r2': 0.9947942449292665, 'raw_r2': 0.9960881726409626}
Best loss: 0.045
Mean loss: 0.053



Ep: 2/  5. LR: 9.519e-05. Loss: 0.1796. Stats01: 0.0000:  14%|█▍        | 507/3529 [00:48<05:11,  9.71it/s]

{'rmse': 0.06621418138511373, 'raw_rmse': 31.91945042779795, 'mae': 0.04604396728512461, 'raw_mae': 10.84715639838187, 'r2': 0.9945373487069629, 'raw_r2': 0.9954583826436344}
Best loss: 0.045
Mean loss: 0.053



Ep: 2/  5. LR: 9.469e-05. Loss: 0.1746. Stats01: 0.0000:  20%|██        | 713/3529 [01:08<04:58,  9.44it/s]

{'rmse': 0.06968503178409295, 'raw_rmse': 29.652518403361135, 'mae': 0.0509421463855416, 'raw_mae': 11.45874583972448, 'r2': 0.9939192357311579, 'raw_r2': 0.9961350178001126}
Best loss: 0.045
Mean loss: 0.053



Ep: 2/  5. LR: 9.418e-05. Loss: 0.1857. Stats01: 0.0000:  26%|██▌       | 918/3529 [01:28<04:30,  9.67it/s]

{'rmse': 0.06300497246845466, 'raw_rmse': 27.53208881306075, 'mae': 0.0423877932158334, 'raw_mae': 9.881130815071483, 'r2': 0.9951500132146133, 'raw_r2': 0.9967115233310084}
Best loss: 0.042
Mean loss: 0.052



Ep: 2/  5. LR: 9.365e-05. Loss: 0.1821. Stats01: 0.0000:  32%|███▏      | 1123/3529 [01:49<04:22,  9.15it/s]

{'rmse': 0.07457768952490366, 'raw_rmse': 30.84793352139596, 'mae': 0.05674307044950739, 'raw_mae': 12.561631498479516, 'r2': 0.9929243127493141, 'raw_r2': 0.9957756780725653}
Best loss: 0.042
Mean loss: 0.051



Ep: 2/  5. LR: 9.309e-05. Loss: 0.1831. Stats01: 0.0000:  38%|███▊      | 1328/3529 [02:09<03:53,  9.43it/s]

{'rmse': 0.06311708177860069, 'raw_rmse': 25.395147015521275, 'mae': 0.04280982926157225, 'raw_mae': 9.700467053804115, 'r2': 0.995172198028542, 'raw_r2': 0.997277765804382}
Best loss: 0.042
Mean loss: 0.051



Ep: 2/  5. LR: 9.251e-05. Loss: 0.1761. Stats01: 0.0000:  43%|████▎     | 1531/3529 [02:30<03:32,  9.41it/s]

{'rmse': 0.06406214422656203, 'raw_rmse': 26.653403823895246, 'mae': 0.04420211851499171, 'raw_mae': 9.88962752671171, 'r2': 0.99501077713638, 'raw_r2': 0.9969947402024087}
Best loss: 0.042
Mean loss: 0.051



Ep: 2/  5. LR: 9.191e-05. Loss: 0.1775. Stats01: 0.0000:  49%|████▉     | 1734/3529 [02:50<03:09,  9.48it/s]

{'rmse': 0.06790479834681379, 'raw_rmse': 29.545402622291004, 'mae': 0.04899542485119076, 'raw_mae': 10.846344701925892, 'r2': 0.9943241658011849, 'raw_r2': 0.9960656175677564}
Best loss: 0.042
Mean loss: 0.051



Ep: 2/  5. LR: 9.128e-05. Loss: 0.1797. Stats01: 0.0000:  55%|█████▍    | 1937/3529 [03:11<03:22,  7.85it/s]

{'rmse': 0.06280794077617606, 'raw_rmse': 26.00546573702209, 'mae': 0.04258383682861589, 'raw_mae': 9.640294886641422, 'r2': 0.9952207623248456, 'raw_r2': 0.9971295853466594}
Best loss: 0.042
Mean loss: 0.050



Ep: 2/  5. LR: 9.063e-05. Loss: 0.1748. Stats01: 0.0000:  61%|██████    | 2141/3529 [03:33<02:35,  8.92it/s]

{'rmse': 0.06815637236708481, 'raw_rmse': 26.1799408246687, 'mae': 0.049126134692770176, 'raw_mae': 10.478586361095287, 'r2': 0.9943000695612925, 'raw_r2': 0.9971336501731405}
Best loss: 0.042
Mean loss: 0.049



Ep: 2/  5. LR: 8.997e-05. Loss: 0.1752. Stats01: 0.0000:  66%|██████▋   | 2343/3529 [03:54<02:06,  9.36it/s]

{'rmse': 0.07266797325000798, 'raw_rmse': 28.630020068619817, 'mae': 0.054341267288550696, 'raw_mae': 11.26623801160211, 'r2': 0.9933415740676048, 'raw_r2': 0.9963871555023043}
Best loss: 0.042
Mean loss: 0.049



Ep: 2/  5. LR: 8.928e-05. Loss: 0.1846. Stats01: 0.0000:  72%|███████▏  | 2547/3529 [04:14<01:49,  9.00it/s]

{'rmse': 0.061466563637096766, 'raw_rmse': 24.25753784723518, 'mae': 0.04107032924894338, 'raw_mae': 9.174284662106144, 'r2': 0.9954565450267501, 'raw_r2': 0.9975637589464458}
Best loss: 0.041
Mean loss: 0.049



Ep: 2/  5. LR: 8.856e-05. Loss: 0.1733. Stats01: 0.0000:  78%|███████▊  | 2750/3529 [04:35<01:23,  9.37it/s]

{'rmse': 0.06316513281269423, 'raw_rmse': 23.879565938470297, 'mae': 0.04296508515973835, 'raw_mae': 9.453200503367887, 'r2': 0.9951673316234383, 'raw_r2': 0.9976708793288391}
Best loss: 0.041
Mean loss: 0.048



Ep: 2/  5. LR: 8.783e-05. Loss: 0.1752. Stats01: 0.0000:  84%|████████▎ | 2954/3529 [04:56<00:59,  9.59it/s]

{'rmse': 0.06235119581900114, 'raw_rmse': 24.21500012952924, 'mae': 0.041944089614274316, 'raw_mae': 9.44842400816471, 'r2': 0.9952900362642512, 'raw_r2': 0.9975933259135399}
Best loss: 0.041
Mean loss: 0.048



Ep: 2/  5. LR: 8.707e-05. Loss: 0.1786. Stats01: 0.0000:  89%|████████▉ | 3157/3529 [05:16<00:38,  9.55it/s]

{'rmse': 0.0700198748112538, 'raw_rmse': 28.15591944880646, 'mae': 0.05165585426761565, 'raw_mae': 11.167323570992814, 'r2': 0.9939368707529498, 'raw_r2': 0.9966104126197469}
Best loss: 0.041
Mean loss: 0.048



Ep: 2/  5. LR: 8.630e-05. Loss: 0.1741. Stats01: 0.0000:  95%|█████████▌| 3358/3529 [05:36<00:16, 10.23it/s]

{'rmse': 0.06185774227434617, 'raw_rmse': 25.09537421748471, 'mae': 0.04101960764237509, 'raw_mae': 9.523676764352846, 'r2': 0.9953550954033322, 'raw_r2': 0.9973401045582203}
Best loss: 0.041
Mean loss: 0.047



Ep: 2/  5. LR: 8.550e-05. Loss: 0.1759. Stats01: 0.0000: : 3562it [05:57,  9.65it/s]                        

{'rmse': 0.0634006216393449, 'raw_rmse': 25.842284526369525, 'mae': 0.043711072060538925, 'raw_mae': 9.827631706612655, 'r2': 0.9951240764319973, 'raw_r2': 0.9971311235893681}
Best loss: 0.041
Mean loss: 0.047



Ep: 2/  5. LR: 8.526e-05. Loss: 0.1839. Stats01: 0.0000: 100%|██████████| 3529/3529 [06:03<00:00,  9.72it/s]
Ep: 3/  5. LR: 8.468e-05. Loss: 0.1730. Stats01: 0.0000:   5%|▍         | 163/3529 [00:14<05:44,  9.76it/s]

{'rmse': 0.07009912163342712, 'raw_rmse': 26.914183973830603, 'mae': 0.0517832106173594, 'raw_mae': 11.032942298992698, 'r2': 0.9938718376641928, 'raw_r2': 0.996827847767005}
Best loss: 0.041
Mean loss: 0.047



Ep: 3/  5. LR: 8.384e-05. Loss: 0.1747. Stats01: 0.0000:  11%|█         | 374/3529 [00:35<05:31,  9.51it/s]

{'rmse': 0.06267996250793478, 'raw_rmse': 23.443896415205963, 'mae': 0.04259126515050941, 'raw_mae': 9.261732334346158, 'r2': 0.9952613375362, 'raw_r2': 0.9977517032417466}
Best loss: 0.041
Mean loss: 0.046



Ep: 3/  5. LR: 8.298e-05. Loss: 0.1758. Stats01: 0.0000:  16%|█▋        | 580/3529 [00:55<05:05,  9.66it/s]

{'rmse': 0.06584557957076827, 'raw_rmse': 25.88627578856479, 'mae': 0.04649663947692517, 'raw_mae': 10.190167907371954, 'r2': 0.9947164836530367, 'raw_r2': 0.997156218960028}
Best loss: 0.041
Mean loss: 0.046



Ep: 3/  5. LR: 8.210e-05. Loss: 0.1783. Stats01: 0.0000:  22%|██▏       | 785/3529 [01:15<04:52,  9.40it/s]

{'rmse': 0.06461851086344916, 'raw_rmse': 24.203944932463163, 'mae': 0.04497674947884259, 'raw_mae': 9.67876168938828, 'r2': 0.9949245489917745, 'raw_r2': 0.9975799893403595}
Best loss: 0.041
Mean loss: 0.046



Ep: 3/  5. LR: 8.120e-05. Loss: 0.1831. Stats01: 0.0000:  28%|██▊       | 990/3529 [01:36<04:34,  9.26it/s]

{'rmse': 0.06485066712090633, 'raw_rmse': 24.66091374777135, 'mae': 0.04541424277403055, 'raw_mae': 9.696084474329455, 'r2': 0.9949052120212655, 'raw_r2': 0.9974609835465584}
Best loss: 0.041
Mean loss: 0.046



Ep: 3/  5. LR: 8.029e-05. Loss: 0.1718. Stats01: 0.0000:  34%|███▍      | 1193/3529 [01:56<04:14,  9.16it/s]

{'rmse': 0.06323105253074687, 'raw_rmse': 24.771415760860013, 'mae': 0.04328212607732518, 'raw_mae': 9.490706505289943, 'r2': 0.9951611417732281, 'raw_r2': 0.9974471317288539}
Best loss: 0.041
Mean loss: 0.046



Ep: 3/  5. LR: 7.935e-05. Loss: 0.1788. Stats01: 0.0000:  40%|███▉      | 1397/3529 [02:17<03:55,  9.06it/s]

{'rmse': 0.060916649158969695, 'raw_rmse': 24.235027426849054, 'mae': 0.04033889980599072, 'raw_mae': 9.169913246939494, 'r2': 0.9955258602777388, 'raw_r2': 0.9975834038808692}
Best loss: 0.040
Mean loss: 0.046



Ep: 3/  5. LR: 7.838e-05. Loss: 0.1691. Stats01: 0.0000:  45%|████▌     | 1601/3529 [02:37<03:22,  9.50it/s]

{'rmse': 0.06513657721799594, 'raw_rmse': 25.223154832397103, 'mae': 0.04564869374685504, 'raw_mae': 9.985024596465918, 'r2': 0.9948099721814867, 'raw_r2': 0.9973172245931581}
Best loss: 0.040
Mean loss: 0.045



Ep: 3/  5. LR: 7.741e-05. Loss: 0.1714. Stats01: 0.0000:  51%|█████     | 1806/3529 [02:58<03:03,  9.37it/s]

{'rmse': 0.06711717847160813, 'raw_rmse': 25.828912746463253, 'mae': 0.048005639691883525, 'raw_mae': 10.218463560871182, 'r2': 0.9944842175656321, 'raw_r2': 0.9971968804579986}
Best loss: 0.040
Mean loss: 0.046



Ep: 3/  5. LR: 7.642e-05. Loss: 0.1803. Stats01: 0.0000:  57%|█████▋    | 2008/3529 [03:18<02:31, 10.05it/s]

{'rmse': 0.06720193008660434, 'raw_rmse': 25.6631338077466, 'mae': 0.048480299560992465, 'raw_mae': 10.372270459636441, 'r2': 0.9944906041222208, 'raw_r2': 0.997255786028432}
Best loss: 0.040
Mean loss: 0.046



Ep: 3/  5. LR: 7.540e-05. Loss: 0.1752. Stats01: 0.0000:  63%|██████▎   | 2211/3529 [03:39<02:10, 10.08it/s]

{'rmse': 0.06457468831243705, 'raw_rmse': 27.131062003456826, 'mae': 0.04475439157231946, 'raw_mae': 10.044422751317448, 'r2': 0.994919753743478, 'raw_r2': 0.9968717385483661}
Best loss: 0.040
Mean loss: 0.046



Ep: 3/  5. LR: 7.437e-05. Loss: 0.1710. Stats01: 0.0000:  68%|██████▊   | 2415/3529 [03:59<01:55,  9.64it/s]

{'rmse': 0.06570499173660063, 'raw_rmse': 26.671612552710368, 'mae': 0.046685202304187545, 'raw_mae': 10.299866792928016, 'r2': 0.9947456303375705, 'raw_r2': 0.9969937240443179}
Best loss: 0.040
Mean loss: 0.046



Ep: 3/  5. LR: 7.331e-05. Loss: 0.1833. Stats01: 0.0000:  74%|███████▍  | 2618/3529 [04:19<01:41,  9.01it/s]

{'rmse': 0.06110851750442208, 'raw_rmse': 24.697039062843064, 'mae': 0.04070837345372886, 'raw_mae': 9.227095394811165, 'r2': 0.9954965400552839, 'raw_r2': 0.9974655340346362}
Best loss: 0.040
Mean loss: 0.045



Ep: 3/  5. LR: 7.225e-05. Loss: 0.1800. Stats01: 0.0000:  80%|███████▉  | 2822/3529 [04:40<01:15,  9.33it/s]

{'rmse': 0.06573193885811818, 'raw_rmse': 27.021490709584807, 'mae': 0.04637362110231465, 'raw_mae': 10.270603784810207, 'r2': 0.9947228434493898, 'raw_r2': 0.9969084420931636}
Best loss: 0.040
Mean loss: 0.045



Ep: 3/  5. LR: 7.117e-05. Loss: 0.1718. Stats01: 0.0000:  86%|████████▌ | 3024/3529 [05:00<00:57,  8.76it/s]

{'rmse': 0.06406120774143344, 'raw_rmse': 28.605773524677677, 'mae': 0.04430228802052548, 'raw_mae': 10.108555404669737, 'r2': 0.9950149527879081, 'raw_r2': 0.9964336163563161}
Best loss: 0.040
Mean loss: 0.045



Ep: 3/  5. LR: 7.006e-05. Loss: 0.1749. Stats01: 0.0000:  91%|█████████▏| 3228/3529 [05:21<00:31,  9.57it/s]

{'rmse': 0.06331939232967457, 'raw_rmse': 24.751623554531548, 'mae': 0.043607426527724064, 'raw_mae': 9.566431503088435, 'r2': 0.9951554966806735, 'raw_r2': 0.9974713302676975}
Best loss: 0.040
Mean loss: 0.045



Ep: 3/  5. LR: 6.894e-05. Loss: 0.1766. Stats01: 0.0000:  97%|█████████▋| 3429/3529 [05:41<00:11,  9.05it/s]

{'rmse': 0.0628146943243282, 'raw_rmse': 26.047951467485653, 'mae': 0.04301828496849943, 'raw_mae': 9.637272655191842, 'r2': 0.9952338738344673, 'raw_r2': 0.9971523540746909}
Best loss: 0.040
Mean loss: 0.045



Ep: 3/  5. LR: 6.788e-05. Loss: 0.1771. Stats01: 0.0000: 100%|██████████| 3529/3529 [06:00<00:00,  9.78it/s]
Ep: 4/  5. LR: 6.780e-05. Loss: 0.1642. Stats01: 0.0000:   1%|          | 20/3529 [00:01<05:15, 11.12it/s]

{'rmse': 0.062165958432898905, 'raw_rmse': 26.322773214216294, 'mae': 0.04136134074650125, 'raw_mae': 9.656940394031194, 'r2': 0.9953046220923384, 'raw_r2': 0.9970571536486934}
Best loss: 0.040
Mean loss: 0.045



Ep: 4/  5. LR: 6.665e-05. Loss: 0.1702. Stats01: 0.0000:   7%|▋         | 237/3529 [00:21<05:20, 10.28it/s]

{'rmse': 0.06285477953331534, 'raw_rmse': 27.084066231087732, 'mae': 0.043011884494214773, 'raw_mae': 9.748692992835563, 'r2': 0.9952342991075767, 'raw_r2': 0.9969095679191158}
Best loss: 0.040
Mean loss: 0.045



Ep: 4/  5. LR: 6.548e-05. Loss: 0.1707. Stats01: 0.0000:  13%|█▎        | 445/3529 [00:42<04:38, 11.06it/s]

{'rmse': 0.06149555392906725, 'raw_rmse': 24.346487867747687, 'mae': 0.04106547445237625, 'raw_mae': 9.2219532427254, 'r2': 0.9954517671481279, 'raw_r2': 0.9975635445658072}
Best loss: 0.040
Mean loss: 0.045



Ep: 4/  5. LR: 6.430e-05. Loss: 0.1741. Stats01: 0.0000:  19%|█▊        | 653/3529 [01:02<05:11,  9.24it/s]

{'rmse': 0.06321489603944967, 'raw_rmse': 25.235005038659416, 'mae': 0.04341380676481602, 'raw_mae': 9.517542258840761, 'r2': 0.9951682818742267, 'raw_r2': 0.9973368096398167}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 6.310e-05. Loss: 0.1764. Stats01: 0.0000:  24%|██▍       | 858/3529 [01:23<04:48,  9.27it/s]

{'rmse': 0.064166616813629, 'raw_rmse': 25.82472537624818, 'mae': 0.04465255593802571, 'raw_mae': 9.83154926539773, 'r2': 0.9950083035104429, 'raw_r2': 0.9971922400929945}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 6.188e-05. Loss: 0.1721. Stats01: 0.0000:  30%|███       | 1061/3529 [01:43<04:04, 10.09it/s]

{'rmse': 0.06338319981934133, 'raw_rmse': 24.834320952321054, 'mae': 0.04360158153260276, 'raw_mae': 9.463369043796765, 'r2': 0.995152252611687, 'raw_r2': 0.9974316968892443}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 6.064e-05. Loss: 0.1727. Stats01: 0.0000:  36%|███▌      | 1266/3529 [02:04<04:04,  9.27it/s]

{'rmse': 0.06364060254171078, 'raw_rmse': 24.640617446642178, 'mae': 0.043806986976153636, 'raw_mae': 9.636655877836347, 'r2': 0.9950648629394548, 'raw_r2': 0.9974920493598375}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 5.940e-05. Loss: 0.1699. Stats01: 0.0000:  42%|████▏     | 1470/3529 [02:24<03:38,  9.43it/s]

{'rmse': 0.06312789153946975, 'raw_rmse': 25.65356626877332, 'mae': 0.04334210485845336, 'raw_mae': 9.566736975337179, 'r2': 0.9951944543186309, 'raw_r2': 0.9972443061340306}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 5.814e-05. Loss: 0.1701. Stats01: 0.0000:  47%|████▋     | 1673/3529 [02:45<03:17,  9.38it/s]

{'rmse': 0.06079595463354393, 'raw_rmse': 25.374858838387233, 'mae': 0.04010360244961854, 'raw_mae': 9.156889652823578, 'r2': 0.9955610743032062, 'raw_r2': 0.9973238622786018}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 5.687e-05. Loss: 0.1752. Stats01: 0.0000:  53%|█████▎    | 1877/3529 [03:05<02:42, 10.17it/s]

{'rmse': 0.06315678240379446, 'raw_rmse': 26.179702287143375, 'mae': 0.04332944020056591, 'raw_mae': 9.608582515148303, 'r2': 0.9951884510136505, 'raw_r2': 0.9971192558282624}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 5.557e-05. Loss: 0.1728. Stats01: 0.0000:  59%|█████▉    | 2081/3529 [03:26<02:29,  9.66it/s]

{'rmse': 0.061130209064855195, 'raw_rmse': 24.994496638217484, 'mae': 0.04057824050607045, 'raw_mae': 9.24414261194734, 'r2': 0.9955048965908828, 'raw_r2': 0.9974127192466395}
Best loss: 0.040
Mean loss: 0.044



Ep: 4/  5. LR: 5.427e-05. Loss: 0.1747. Stats01: 0.0000:  65%|██████▍   | 2284/3529 [03:46<02:06,  9.81it/s]

{'rmse': 0.061410865766198044, 'raw_rmse': 24.744205105846508, 'mae': 0.040877427629533426, 'raw_mae': 9.192351539384175, 'r2': 0.9954617314652903, 'raw_r2': 0.9974648990085461}
Best loss: 0.040
Mean loss: 0.043



Ep: 4/  5. LR: 5.296e-05. Loss: 0.1723. Stats01: 0.0000:  70%|███████   | 2485/3529 [04:07<02:00,  8.67it/s]

{'rmse': 0.06168495696552345, 'raw_rmse': 23.804729314157427, 'mae': 0.04133586117235438, 'raw_mae': 9.09518482858007, 'r2': 0.9954244082730787, 'raw_r2': 0.9976771718494477}
Best loss: 0.040
Mean loss: 0.043



Ep: 4/  5. LR: 5.163e-05. Loss: 0.1768. Stats01: 0.0000:  76%|███████▌  | 2689/3529 [04:27<01:27,  9.55it/s]

{'rmse': 0.06329374436911796, 'raw_rmse': 25.73139631356908, 'mae': 0.04364155269657608, 'raw_mae': 9.692603897505471, 'r2': 0.9951556717343135, 'raw_r2': 0.9972354609068715}
Best loss: 0.040
Mean loss: 0.043



Ep: 4/  5. LR: 5.030e-05. Loss: 0.1690. Stats01: 0.0000:  82%|████████▏ | 2892/3529 [04:47<01:08,  9.27it/s]

{'rmse': 0.06031128410400191, 'raw_rmse': 24.308592196810366, 'mae': 0.03947706519493493, 'raw_mae': 9.109006424725942, 'r2': 0.9956281071936414, 'raw_r2': 0.9975767837997832}
Best loss: 0.039
Mean loss: 0.043



Ep: 4/  5. LR: 4.894e-05. Loss: 0.1729. Stats01: 0.0000:  88%|████████▊ | 3094/3529 [05:08<00:46,  9.33it/s]

{'rmse': 0.06166111594783191, 'raw_rmse': 24.175460327952727, 'mae': 0.04128873499127106, 'raw_mae': 9.24031392584859, 'r2': 0.9954276266836846, 'raw_r2': 0.9976066921898586}
Best loss: 0.039
Mean loss: 0.043



Ep: 4/  5. LR: 4.758e-05. Loss: 0.1762. Stats01: 0.0000:  93%|█████████▎| 3297/3529 [05:28<00:24,  9.47it/s]

{'rmse': 0.06648241071893538, 'raw_rmse': 24.802011491251488, 'mae': 0.047422250994012774, 'raw_mae': 10.029052565790565, 'r2': 0.9946077804152686, 'raw_r2': 0.9974578968547851}
Best loss: 0.039
Mean loss: 0.043



Ep: 4/  5. LR: 4.620e-05. Loss: 0.1706. Stats01: 0.0000:  99%|█████████▉| 3500/3529 [05:49<00:03,  9.14it/s]

{'rmse': 0.06081899674786813, 'raw_rmse': 24.439428885448493, 'mae': 0.04016048268027159, 'raw_mae': 9.139370973396247, 'r2': 0.9955570853767413, 'raw_r2': 0.99754544741128}
Best loss: 0.039
Mean loss: 0.042



Ep: 4/  5. LR: 4.540e-05. Loss: 0.1652. Stats01: 0.0000: 100%|██████████| 3529/3529 [06:01<00:00,  9.77it/s]
Ep: 5/  5. LR: 4.481e-05. Loss: 0.1717. Stats01: 0.0000:   3%|▎         | 102/3529 [00:08<05:54,  9.67it/s]

{'rmse': 0.06262212091374957, 'raw_rmse': 25.40177832563032, 'mae': 0.042505349225569256, 'raw_mae': 9.438803615422282, 'r2': 0.9952667190194306, 'raw_r2': 0.9973014975116872}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 4.341e-05. Loss: 0.1726. Stats01: 0.0000:   9%|▉         | 312/3529 [00:29<05:37,  9.54it/s]

{'rmse': 0.06346389144311841, 'raw_rmse': 25.385730400625285, 'mae': 0.04384731098420688, 'raw_mae': 9.59385501960621, 'r2': 0.9951330039795506, 'raw_r2': 0.9973087400809059}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 4.201e-05. Loss: 0.1768. Stats01: 0.0000:  15%|█▍        | 520/3529 [00:49<05:26,  9.23it/s]

{'rmse': 0.061816156110641715, 'raw_rmse': 24.200928117894243, 'mae': 0.04153152298331495, 'raw_mae': 9.118045184308771, 'r2': 0.9954066406809312, 'raw_r2': 0.9975904526672074}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 4.059e-05. Loss: 0.1732. Stats01: 0.0000:  21%|██        | 725/3529 [01:09<03:48, 12.25it/s]

{'rmse': 0.06111262042821338, 'raw_rmse': 24.298292090622926, 'mae': 0.04069140530210883, 'raw_mae': 9.141767117645179, 'r2': 0.9955157402107448, 'raw_r2': 0.9975709666726247}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.915e-05. Loss: 0.1778. Stats01: 0.0000:  26%|██▋       | 933/3529 [01:30<04:39,  9.30it/s]

{'rmse': 0.06059098614267052, 'raw_rmse': 23.48130389846237, 'mae': 0.039913471061842905, 'raw_mae': 9.028656819047415, 'r2': 0.9955884181261185, 'raw_r2': 0.9977562818483026}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.772e-05. Loss: 0.1814. Stats01: 0.0000:  32%|███▏      | 1136/3529 [01:50<04:13,  9.42it/s]

{'rmse': 0.062073554566602324, 'raw_rmse': 24.859997801368035, 'mae': 0.041809171172939905, 'raw_mae': 9.251372694472195, 'r2': 0.9953635655674121, 'raw_r2': 0.9974332208569727}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.628e-05. Loss: 0.1701. Stats01: 0.0000:  38%|███▊      | 1341/3529 [02:10<03:52,  9.43it/s]

{'rmse': 0.0621775451851671, 'raw_rmse': 25.171021559115005, 'mae': 0.04210060529676112, 'raw_mae': 9.37730453214114, 'r2': 0.9953462195915881, 'raw_r2': 0.997350903587564}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.482e-05. Loss: 0.1708. Stats01: 0.0000:  44%|████▍     | 1544/3529 [02:31<03:29,  9.47it/s]

{'rmse': 0.060657821974485784, 'raw_rmse': 24.18284388609434, 'mae': 0.039721849536169196, 'raw_mae': 9.13734540430468, 'r2': 0.995576465517073, 'raw_r2': 0.9976039727394604}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.335e-05. Loss: 0.1668. Stats01: 0.0000:  50%|████▉     | 1747/3529 [02:51<03:39,  8.13it/s]

{'rmse': 0.0619656982831381, 'raw_rmse': 24.507213920534323, 'mae': 0.04162488875986261, 'raw_mae': 9.158525248645045, 'r2': 0.9953794737844432, 'raw_r2': 0.9975013044239647}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.187e-05. Loss: 0.1706. Stats01: 0.0000:  55%|█████▌    | 1952/3529 [03:12<02:54,  9.05it/s]

{'rmse': 0.061007340606340575, 'raw_rmse': 24.462791080010778, 'mae': 0.0403961961536134, 'raw_mae': 9.17465962977722, 'r2': 0.9955287137279283, 'raw_r2': 0.997544092395766}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 3.039e-05. Loss: 0.1723. Stats01: 0.0000:  61%|██████    | 2155/3529 [03:32<02:29,  9.18it/s]

{'rmse': 0.06448994164066574, 'raw_rmse': 25.817864892593377, 'mae': 0.045022310807333876, 'raw_mae': 9.819025006778633, 'r2': 0.9949548994719013, 'raw_r2': 0.9971590205942816}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 2.891e-05. Loss: 0.1813. Stats01: 0.0000:  67%|██████▋   | 2359/3529 [03:53<02:02,  9.58it/s]

{'rmse': 0.0626139029885198, 'raw_rmse': 26.867887772404075, 'mae': 0.04268565254657668, 'raw_mae': 9.599466976166669, 'r2': 0.9952767735460207, 'raw_r2': 0.9969198111954495}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 2.741e-05. Loss: 0.1676. Stats01: 0.0000:  73%|███████▎  | 2564/3529 [04:13<01:39,  9.73it/s]

{'rmse': 0.0610323320767118, 'raw_rmse': 25.940552269422895, 'mae': 0.04048082410905802, 'raw_mae': 9.220705535296254, 'r2': 0.9955232015385056, 'raw_r2': 0.9971615765106047}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 2.592e-05. Loss: 0.1614. Stats01: 0.0000:  78%|███████▊  | 2765/3529 [04:34<01:21,  9.35it/s]

{'rmse': 0.062284175894998434, 'raw_rmse': 25.187827831412427, 'mae': 0.04232364946422799, 'raw_mae': 9.427249436756064, 'r2': 0.9953291026950123, 'raw_r2': 0.9973404852125302}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 2.440e-05. Loss: 0.1731. Stats01: 0.0000:  84%|████████▍ | 2968/3529 [04:54<00:56,  9.91it/s]

{'rmse': 0.06079511357591352, 'raw_rmse': 23.270241442895415, 'mae': 0.04024489505258806, 'raw_mae': 9.012041019804085, 'r2': 0.9955642790664494, 'raw_r2': 0.997790380616076}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 2.290e-05. Loss: 0.1636. Stats01: 0.0000:  90%|████████▉ | 3170/3529 [05:14<00:39,  9.16it/s]

{'rmse': 0.060598306822876175, 'raw_rmse': 23.76077304673352, 'mae': 0.039789197793623145, 'raw_mae': 9.045799808574602, 'r2': 0.9955874475763122, 'raw_r2': 0.9976821248171278}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 2.138e-05. Loss: 0.1646. Stats01: 0.0000:  96%|█████████▌| 3372/3529 [05:35<00:16,  9.31it/s]

{'rmse': 0.06176268243058605, 'raw_rmse': 24.49956347040107, 'mae': 0.041565792828020254, 'raw_mae': 9.307965838945538, 'r2': 0.9954103700293022, 'raw_r2': 0.9975270534677507}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 1.985e-05. Loss: 0.1637. Stats01: 0.0000: : 3574it [05:55, 10.46it/s]                        

{'rmse': 0.06064967821959374, 'raw_rmse': 24.82553367428882, 'mae': 0.040173055591791074, 'raw_mae': 9.20318213095626, 'r2': 0.9955864115513692, 'raw_r2': 0.9974438350197694}
Best loss: 0.039
Mean loss: 0.042



Ep: 5/  5. LR: 1.951e-05. Loss: 0.1549. Stats01: 0.0000: 100%|██████████| 3529/3529 [06:00<00:00,  9.79it/s]


In [47]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [args.seq_len])

{'rmse': 0.05942646738274627,
 'raw_rmse': 31.23737607141963,
 'mae': 0.0428159117408845,
 'raw_mae': 10.791590338661385,
 'r2': 0.9959736732306238,
 'raw_r2': 0.9964505009281793}

In [48]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [32, 64, 96])

{'rmse': 0.059426540909795204,
 'raw_rmse': 31.235145484582624,
 'mae': 0.04281595703082084,
 'raw_mae': 10.791576876004012,
 'r2': 0.9959736616718091,
 'raw_r2': 0.9964511048969183}

In [49]:
data, labels = test_fold[0], test_fold[1]
eval_ensemble(args, best_model, data, labels, args.device, [64, 96, 128])

{'rmse': 0.05942647375817556,
 'raw_rmse': 31.237542590867935,
 'mae': 0.04281591624836971,
 'raw_mae': 10.791612406239738,
 'r2': 0.9959736722893658,
 'raw_r2': 0.9964504614533127}