In [1]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd


def manual_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def get_model_size(model):
	total_size = sum(param.numel() for param in model.parameters() if param.requires_grad)
	return total_size / 1e6

class Config:
    epochs = 5
    lr = 1e-4
    warmup = 0
    
    n_nodes = 255
    n_feats = 6
    
    eval_steps = 256
    
    seq_len = 128
    horizon = 32
  
    # myrnn
    d_latent: int = 32             # latent channels d
    delta_t: float = 0.5                # time step
    device: str = "cuda"
    
    seed = 42

args = Config()
manual_seed(args.seed)


In [2]:
import pandas as pd
import numpy as np


def make_tensor(df, feature_cols=None):
    if feature_cols is None:
        feature_cols = df.columns.drop(['Date', 'Ticker']).tolist()

    # 1) ép kiểu & sort
    df = df.copy()
    df['Date'] = pd.to_datetime(df['Date'])
    df = df.sort_values(['Ticker','Date'])

    # (tuỳ chọn) xử lý missing: forward-fill theo từng ticker
    # df[feature_cols] = df.groupby('Ticker', group_keys=False)[feature_cols].ffill()

    # 2) xác định thứ tự ticker & ngày
    tickers = sorted(df['Ticker'].unique())
    dates   = sorted(df['Date'].unique())

    # 3) pivot từng feature rồi stack vào trục cuối
    mats = []
    for feat in feature_cols:
        wide = (
            df.pivot_table(index='Ticker', columns='Date', values=feat, aggfunc='first')
              .reindex(index=tickers, columns=dates)
        )
        mats.append(wide.values)  # shape: (n_tickers, n_days)

    # 4) stack thành tensor: (n_tickers, n_days, n_features)
    X = np.stack(mats, axis=-1).astype('float32')

    return X, tickers, dates, feature_cols
  

def get_data_splits(path='../input/sp500_2009_2025/sp500_prices_clean.csv'):
    df = pd.read_csv(path)
    
    raw_data, tickers, dates, feature_cols = make_tensor(df)

    data = np.log(1 + raw_data)
    training_data = data[:, :-30, :]
    testing_data = data[:, -30:, :]
    
    return training_data, testing_data

In [3]:
import torch
from torch.utils.data import DataLoader, Dataset

class StockDataset(Dataset):
    def __init__(self, data, ws=128):
        self.data = data
        self.ws = ws
        self.samples = []
        
        self.n_tickers, self.n_days, self.n_features = self.data.shape
        
        for start in range(self.n_days - self.ws + 1):
            self.samples.append(start)
            
    def __len__(self):
        return len(self.samples)
      
    def __getitem__(self, idx):
        start = self.samples[idx]
        x = torch.tensor(self.data[:, start:start + self.ws], dtype=torch.float32)
        return x
    

In [4]:
training_data_np, testing_data_np = get_data_splits()
train_dataset = StockDataset(training_data_np, ws=128)
train_dataloader = DataLoader(train_dataset, batch_size=1)

# Model

In [5]:

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F


class MyRNN(nn.Module):
    def __init__(self, n_nodes: int, n_feats: int, args: Config):
        super().__init__()
        self.n_nodes = n_nodes
        self.rank_A = 32
        self.n_feats = n_feats
        self.args = args
        self.d_latent = getattr(args, "d_latent", 32)
        self.dropout = 0.0
        self.n_layers = 1
        
        
        self.cells = nn.ModuleList([
            nn.LSTMCell(n_feats if i == 0 else self.d_latent, self.d_latent)
            for i in range(self.n_layers)
        ])
        
        self.readout = nn.Sequential(
            nn.Linear(self.d_latent, self.d_latent), nn.ReLU(),
            nn.Linear(self.d_latent, n_feats)
        )
        
        self.enc_in = None
        self.static_node_features = nn.Parameter(torch.randn(n_nodes, self.d_latent) * 0.1)

    def _init_states(
        self, X0: torch.Tensor, H0: Optional[torch.Tensor]
    ):
        """
        Trả về list (h_list, c_list) cho từng layer.
        - Nếu H0 cung cấp (N, d), đặt h_top = H0; ngược lại nếu enc_in != None thì h_top = enc(X0); còn lại h=0.
        - c luôn = 0.
        """
        device = X0.device
        dtype = X0.dtype
        N = X0.shape[0]

        h_list = []
        c_list = []
        for _ in range(self.n_layers):
            h_list.append(self.static_node_features.clone())
            c_list.append(self.static_node_features.clone())

        if H0 is not None:
            h_list[-1] = H0.to(device=device, dtype=dtype)
        elif self.enc_in is not None:
            h_list[-1] = self.enc_in(X0)

        return h_list, c_list

    def _step(self, x_t: torch.Tensor, h_list, c_list):
        """
        Chạy 1 bước LSTM qua tất cả các layer.
        Input layer nhận x_t; các layer sau nhận h của layer trước ở cùng time-step.
        Trả về (h_list_new, c_list_new, h_top).
        """
        inp = x_t
        new_h, new_c = [], []
        for l, cell in enumerate(self.cells):
            h_t, c_t = cell(inp, (h_list[l], c_list[l]))
            new_h.append(h_t)
            new_c.append(c_t)
            inp = h_t
        h_top = new_h[-1]
        if self.training and self.dropout > 0:
            h_top = F.dropout(h_top, p=self.dropout, training=True)
        return new_h, new_c, h_top


    def forward(
        self,
        X: torch.Tensor,
        H0: Optional[torch.Tensor] = None,
        horizon: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Deterministic inference.
        Args:
            X: (N, T, F)
            H0: (N, d_latent) optional initial hidden for top layer
            horizon: số bước rollout sau quan sát cuối
        Returns:
            y_all: (N, T-1 + horizon, F)  # dự báo X_{t+1}
            r_all: (N, T-1 + horizon, F)  # residual
            H_all: (N, T + horizon, d)    # hidden top-layer (bao gồm bước khởi tạo tại index 0)
        """
        device = next(self.parameters()).device
        X = X.to(device)
        N, T, Fdim = X.shape

        # Khởi tạo trạng thái
        X0 = X[:, 0, :]
        h_list, c_list = self._init_states(X0, H0)

        # Khởi tạo output tensors
        ext = max(horizon, 0)
        H_all = torch.zeros(N, T + ext, self.d_latent, device=device, dtype=X.dtype)
        y_all = torch.zeros(N, (T - 1) + ext, Fdim, device=device, dtype=X.dtype)
        r_all = torch.zeros_like(y_all)

        # Lưu hidden ban đầu (tương tự H_t trước khi update lần đầu)
        H_all[:, 0] = h_list[-1]

        # Quan sát: t = 0..T-2 (dự báo X_{t+1})
        for t in range(T - 1):
            x_t = X[:, t, :]  # (N, F)
            h_list, c_list, h_top = self._step(x_t, h_list, c_list)
            r_t = self.readout(h_top)
            y_t = x_t + r_t

            H_all[:, t + 1] = h_top
            r_all[:, t] = r_t
            y_all[:, t] = y_t

        # Rollout autoregressive nếu cần
        if horizon > 0:
            x_t = X[:, -1, :]
            for s in range(horizon):
                h_list, c_list, h_top = self._step(x_t, h_list, c_list)
                r_t = self.readout(h_top)
                y_t = x_t + r_t

                idx = (T - 1) + s
                y_all[:, idx] = y_t
                r_all[:, idx] = r_t
                H_all[:, T + s] = h_top

                x_t = y_t  # autoregressive

        return y_all, r_all, H_all

    def forward(self, X: torch.Tensor,
                H0: Optional[torch.Tensor] = None,
                horizon: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Run deterministic inference.
        Args:
            X: (N, T, F) input features at each t
            H0: (N, d) optional initial latent state (default zeros)
            Atilde, L: if precomputed; otherwise build from X
        Returns:
            y_pred: (N, T-1) predicted next-day log-return for each node
            H_all: (N, T, d) latent states (including H0 as first)
        """
        device = self.args.device
        X = X.to(next(self.parameters()).device)
        n_nodes, n_steps, n_feats = X.shape
        d_latent = getattr(self.args, "d_latent", 64)
        

        X0 = X[:, 0, :] # (N, F)
        h_list, c_list = self._init_states(X0, H0)

        H_all = torch.zeros(n_nodes, n_steps + max(horizon, 0), d_latent, device=device, dtype=X.dtype)
        r_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)
        y_all = torch.zeros(n_nodes, n_steps - 1 + max(horizon, 0), n_feats, device=device, dtype=X.dtype)


        # Lưu hidden ban đầu (tương tự H_t trước khi update lần đầu)
        H_all[:, 0] = h_list[-1]
        for t in range(n_steps - 1):
            x_t = X[:, t, :]  # (N, F)
            h_list, c_list, h_top = self._step(x_t, h_list, c_list)
            r_t = self.readout(h_top)
            y_t = x_t + r_t

            H_all[:, t + 1] = h_top
            r_all[:, t] = r_t
            y_all[:, t] = y_t

        # Rollout autoregressive nếu cần
        if horizon > 0:
            x_t = X[:, -1, :]
            for s in range(horizon):
                h_list, c_list, h_top = self._step(x_t, h_list, c_list)
                r_t = self.readout(h_top)
                y_t = x_t + r_t

                idx = (n_steps - 1) + s
                y_all[:, idx] = y_t
                r_all[:, idx] = r_t
                H_all[:, n_steps + s] = h_top

                x_t = y_t  # autoregressive

        return y_all, r_all, H_all

    def forward_loss(
        self,
        X: torch.Tensor,
        H0: Optional[torch.Tensor] = None,
        horizon: int = 0,
        reduction: str = "mean",
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Tính loss dự báo one-step + rollout horizon (autoregressive) trên chuỗi đầu vào.

        Args:
            X: (N, T, F) chuỗi gốc (chứa full ground-truth đến T-1)
            H0: (N, d) latent init (tuỳ chọn)
            horizon: số bước rollout ngoài quan sát cuối cùng
                    (nếu >0, ta cắt input để tránh nhìn thấy tương lai)
            reduction: 'mean' | 'sum' | 'none' cho F.mse_loss

        Returns:
            loss: scalar tensor
            y_pred: (N, T-1, F) dự báo X_{t+1} cho toàn bộ t=0..T-2
        """
        device = next(self.parameters()).device
        X = X.to(device)
        N, T, Fdim = X.shape

        if horizon < 0:
            raise ValueError("horizon must be >= 0")
        if horizon >= T:
            raise ValueError(f"horizon={horizon} must be < sequence length T={T}")

        # Cắt input nếu rollout > 0 để giữ đúng số target (T-1)
        X_in = X[:, : T - horizon, :] if horizon > 0 else X

        # forward() trả (N, (T-horizon)-1 + horizon, F) = (N, T-1, F)
        y_pred, _, _ = self.forward(X_in, H0=H0, horizon=horizon)

        # Ground truth luôn là X_{1:T}
        target = X[:, 1:, :]  # (N, T-1, F)

        loss = F.mse_loss(y_pred, target, reduction=reduction)
        return loss

# Eval & Train

In [6]:
from sklearn.metrics import *

def eval(args, model, training_data_np, testing_data_np, device='cuda'):
    batch = torch.tensor(training_data_np[:, -args.seq_len:]).float().to(device)
    labels = torch.tensor(testing_data_np).float().to(device)
    horizon = labels.shape[1]
    model.eval().to(device)
    with torch.no_grad():
        y_all, r_all, H_all = model(batch, horizon=horizon)
    y_pred = y_all[:, -horizon:, 0].detach().cpu().numpy() # Adj Close
    y_gt = testing_data_np[:, :, 0] # Adj Close
    return {
        'mse': mean_squared_error(y_gt, y_pred), 
        'raw_mse': mean_squared_error(np.exp(y_gt), np.exp(y_pred)), 
        'mae': mean_absolute_error(y_gt, y_pred), 
        'raw_mae': mean_absolute_error(np.exp(y_gt), np.exp(y_pred)), 
        'r2': r2_score(y_gt, y_pred),
        'raw_r2': r2_score(np.exp(y_gt), np.exp(y_pred)), 
    }

In [7]:
myrnn = MyRNN(255, 6, args)
eval(args, myrnn, training_data_np, testing_data_np, args.device)

{'mse': 5.4592733,
 'raw_mse': 109073970.0,
 'mae': 2.0061402,
 'raw_mae': 2685.9885,
 'r2': -5.5758819580078125,
 'raw_r2': -1417.062255859375}

In [8]:
from tqdm import tqdm
import time

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.avg == 0:
            self.avg = val
            return
        self.avg = 0.95 * self.avg + 0.05 * val

def train(args, train_loader, model: myrnn, optimizer, scheduler):
    # if args.amp:
    #     from apex import amp
    global best_loss
    test_losses = []
    end = time.time()

    step = 0
    for epoch in range(args.epochs):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        stats01 = AverageMeter()
        stats02 = AverageMeter()
        p_bar = tqdm(train_loader)
        for batch_idx, samples in enumerate(p_bar):
            step += 1
            model.train().to(args.device)
          
            samples = samples[0].float().to(args.device)
            data_time.update(time.time() - end)

            loss = model.forward_loss(samples, horizon=args.horizon)
            
            loss.backward()
            optimizer.step()
            scheduler.step()

            losses.update(loss.item())

            batch_time.update(time.time() - end)
            end = time.time()
            # mask_probs.update(mask.mean().item())
            p_bar.set_description(
                "Ep: {epoch}/{epochs:3}. LR: {lr:.3e}. "
                "Loss: {loss:.4f}.".format(
                epoch=epoch + 1,
                epochs=args.epochs,
                lr=scheduler.get_last_lr()[0],
                data=data_time.avg,
                bt=batch_time.avg,
                loss=losses.avg,
            ))
            p_bar.update()
            
            if (step + 1) % args.eval_steps == 0:
                test_model = model

                test_metrics = eval(args, model, training_data_np, testing_data_np, args.device)
                print(test_metrics)
                test_loss = test_metrics['mse']

                is_best = test_loss < best_loss
                best_loss = min(test_loss, best_loss)


                test_losses.append(test_loss)
                print('Best loss: {:.3f}'.format(best_loss))
                print('Mean loss: {:.3f}\n'.format(
                    np.mean(test_losses[-20:])))


In [9]:
import math
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

In [10]:
training_data_np, testing_data_np = get_data_splits()
train_loader = DataLoader(StockDataset(training_data_np, args.seq_len), batch_size=1, shuffle=True)

myrnn = MyRNN(args.n_nodes, args.n_feats, args)

from torch.optim import Adam

optimizer = Adam(myrnn.parameters(), lr=args.lr)
total_steps = args.epochs * len(train_loader)
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=args.warmup,
    num_training_steps=total_steps,
)

In [11]:
print(f'{get_model_size(myrnn) * 1e3:.2f}K')

14.53K


In [12]:
best_loss = 9999

train(args, train_loader, myrnn, optimizer, scheduler)

Ep: 1/  5. LR: 9.998e-05. Loss: 0.0887.:   7%|▋         | 290/3999 [00:16<03:08, 19.71it/s]

{'mse': 0.012425347, 'raw_mse': 3213.3289, 'mae': 0.081988014, 'raw_mae': 18.606659, 'r2': 0.9850150942802429, 'raw_r2': 0.9576401114463806}
Best loss: 0.012
Mean loss: 0.012



Ep: 1/  5. LR: 9.994e-05. Loss: 0.0490.:  14%|█▍        | 557/3999 [00:34<04:43, 12.15it/s]

{'mse': 0.00756226, 'raw_mse': 1689.2867, 'mae': 0.06391531, 'raw_mae': 14.509393, 'r2': 0.9908967614173889, 'raw_r2': 0.9778246283531189}
Best loss: 0.008
Mean loss: 0.010



Ep: 1/  5. LR: 9.986e-05. Loss: 0.0609.:  21%|██        | 825/3999 [00:51<03:41, 14.35it/s]

{'mse': 0.09301702, 'raw_mse': 5648.8013, 'mae': 0.26063624, 'raw_mae': 39.852787, 'r2': 0.8879784941673279, 'raw_r2': 0.9256159067153931}
Best loss: 0.008
Mean loss: 0.038



Ep: 1/  5. LR: 9.975e-05. Loss: 0.0758.:  27%|██▋       | 1089/3999 [01:08<03:17, 14.72it/s]

{'mse': 0.55382043, 'raw_mse': 322355.88, 'mae': 0.6449068, 'raw_mae': 216.51666, 'r2': 0.3331925570964813, 'raw_r2': -3.2280080318450928}
Best loss: 0.008
Mean loss: 0.167



Ep: 1/  5. LR: 9.961e-05. Loss: 0.0757.:  34%|███▍      | 1354/3999 [01:26<02:37, 16.75it/s]

{'mse': 0.06562092, 'raw_mse': 13003.927, 'mae': 0.21701322, 'raw_mae': 51.312397, 'r2': 0.9210073351860046, 'raw_r2': 0.8283640146255493}
Best loss: 0.008
Mean loss: 0.146



Ep: 1/  5. LR: 9.944e-05. Loss: 0.0696.:  40%|████      | 1616/3999 [01:43<02:35, 15.31it/s]

{'mse': 0.067327864, 'raw_mse': 2438.8323, 'mae': 0.21829136, 'raw_mae': 30.319096, 'r2': 0.9189265966415405, 'raw_r2': 0.96790611743927}
Best loss: 0.008
Mean loss: 0.133



Ep: 1/  5. LR: 9.924e-05. Loss: 0.0558.:  47%|████▋     | 1878/3999 [01:59<02:39, 13.32it/s]

{'mse': 0.005381763, 'raw_mse': 673.49164, 'mae': 0.056596212, 'raw_mae': 11.196816, 'r2': 0.9935301542282104, 'raw_r2': 0.9910067319869995}
Best loss: 0.005
Mean loss: 0.115



Ep: 1/  5. LR: 9.901e-05. Loss: 0.0764.:  54%|█████▎    | 2140/3999 [02:16<02:10, 14.25it/s]

{'mse': 0.013541828, 'raw_mse': 963.4665, 'mae': 0.092348546, 'raw_mae': 15.566584, 'r2': 0.9836933612823486, 'raw_r2': 0.9873573184013367}
Best loss: 0.005
Mean loss: 0.102



Ep: 1/  5. LR: 9.875e-05. Loss: 0.0791.:  60%|██████    | 2402/3999 [02:33<01:46, 14.97it/s]

{'mse': 0.10808301, 'raw_mse': 7462.816, 'mae': 0.28338817, 'raw_mae': 43.850025, 'r2': 0.8698536157608032, 'raw_r2': 0.9016465544700623}
Best loss: 0.005
Mean loss: 0.103



Ep: 1/  5. LR: 9.845e-05. Loss: 0.0645.:  67%|██████▋   | 2663/3999 [02:51<01:36, 13.88it/s]

{'mse': 0.0036003368, 'raw_mse': 220.3935, 'mae': 0.041162465, 'raw_mae': 6.9516, 'r2': 0.9956684112548828, 'raw_r2': 0.9970970749855042}
Best loss: 0.004
Mean loss: 0.093



Ep: 1/  5. LR: 9.813e-05. Loss: 0.1119.:  73%|███████▎  | 2925/3999 [03:08<01:10, 15.15it/s]

{'mse': 0.25307846, 'raw_mse': 55545.984, 'mae': 0.43718734, 'raw_mae': 112.03955, 'r2': 0.6953317523002625, 'raw_r2': 0.26886123418807983}
Best loss: 0.004
Mean loss: 0.108



Ep: 1/  5. LR: 9.778e-05. Loss: 0.1349.:  80%|███████▉  | 3185/3999 [03:25<00:53, 15.35it/s]

{'mse': 0.4814615, 'raw_mse': 24259.645, 'mae': 0.60387295, 'raw_mae': 78.70506, 'r2': 0.42028385400772095, 'raw_r2': 0.6792904138565063}
Best loss: 0.004
Mean loss: 0.139



Ep: 1/  5. LR: 9.739e-05. Loss: 0.0893.:  86%|████████▌ | 3445/3999 [03:40<00:34, 15.88it/s]

{'mse': 0.019722965, 'raw_mse': 1590.6619, 'mae': 0.11639328, 'raw_mae': 22.762089, 'r2': 0.9762671589851379, 'raw_r2': 0.9789208769798279}
Best loss: 0.004
Mean loss: 0.130



Ep: 1/  5. LR: 9.698e-05. Loss: 0.1201.:  93%|█████████▎| 3706/3999 [03:57<00:18, 15.51it/s]

{'mse': 0.007733538, 'raw_mse': 1901.0477, 'mae': 0.06535445, 'raw_mae': 12.672098, 'r2': 0.9906889796257019, 'raw_r2': 0.9750630259513855}
Best loss: 0.004
Mean loss: 0.121



Ep: 1/  5. LR: 9.653e-05. Loss: 0.0831.:  99%|█████████▉| 3967/3999 [04:12<00:01, 16.37it/s]

{'mse': 0.0715648, 'raw_mse': 8757.043, 'mae': 0.22956583, 'raw_mae': 49.987785, 'r2': 0.9138584136962891, 'raw_r2': 0.8843795657157898}
Best loss: 0.004
Mean loss: 0.118



Ep: 1/  5. LR: 9.625e-05. Loss: 0.1940.: 100%|██████████| 3999/3999 [04:22<00:00, 15.24it/s]
Ep: 2/  5. LR: 9.606e-05. Loss: 0.1895.:   3%|▎         | 118/3999 [00:06<04:10, 15.51it/s]

{'mse': 0.86849177, 'raw_mse': 34501.45, 'mae': 0.8123272, 'raw_mae': 94.95246, 'r2': -0.04570375382900238, 'raw_r2': 0.5433076024055481}
Best loss: 0.004
Mean loss: 0.165



Ep: 2/  5. LR: 9.555e-05. Loss: 0.2087.:  10%|▉         | 393/3999 [00:21<03:51, 15.58it/s]

{'mse': 0.52329206, 'raw_mse': 156440.8, 'mae': 0.6303539, 'raw_mae': 187.41048, 'r2': 0.37001436948776245, 'raw_r2': -1.0549712181091309}
Best loss: 0.004
Mean loss: 0.186



Ep: 2/  5. LR: 9.502e-05. Loss: 0.1411.:  17%|█▋        | 661/3999 [00:38<03:46, 14.77it/s]

{'mse': 0.40597752, 'raw_mse': 21970.506, 'mae': 0.55570674, 'raw_mae': 74.31111, 'r2': 0.5112131237983704, 'raw_r2': 0.7096188068389893}
Best loss: 0.004
Mean loss: 0.198



Ep: 2/  5. LR: 9.446e-05. Loss: 0.1262.:  23%|██▎       | 926/3999 [00:55<03:28, 14.77it/s]

{'mse': 0.008745095, 'raw_mse': 970.03424, 'mae': 0.07331383, 'raw_mae': 15.299097, 'r2': 0.9894790053367615, 'raw_r2': 0.9870964288711548}
Best loss: 0.004
Mean loss: 0.188



Ep: 2/  5. LR: 9.387e-05. Loss: 0.1300.:  30%|██▉       | 1190/3999 [01:12<03:32, 13.19it/s]

{'mse': 0.696258, 'raw_mse': 275317.9, 'mae': 0.7274636, 'raw_mae': 236.3179, 'r2': 0.16177190840244293, 'raw_r2': -2.6129987239837646}
Best loss: 0.004
Mean loss: 0.213



Ep: 2/  5. LR: 9.325e-05. Loss: 0.0736.:  36%|███▋      | 1451/3999 [01:29<02:57, 14.36it/s]

{'mse': 0.20291829, 'raw_mse': 12171.97, 'mae': 0.39040917, 'raw_mae': 56.606453, 'r2': 0.7556638717651367, 'raw_r2': 0.8393834829330444}
Best loss: 0.004
Mean loss: 0.223



Ep: 2/  5. LR: 9.259e-05. Loss: 0.1357.:  43%|████▎     | 1717/3999 [01:46<02:33, 14.90it/s]

{'mse': 0.89530164, 'raw_mse': 33700.36, 'mae': 0.8245933, 'raw_mae': 95.43213, 'r2': -0.07799351215362549, 'raw_r2': 0.5539625883102417}
Best loss: 0.004
Mean loss: 0.267



Ep: 2/  5. LR: 9.192e-05. Loss: 0.0650.:  49%|████▉     | 1978/3999 [02:03<02:17, 14.68it/s]

{'mse': 0.027190676, 'raw_mse': 1948.4696, 'mae': 0.1368073, 'raw_mae': 22.44281, 'r2': 0.967257559299469, 'raw_r2': 0.9744001030921936}
Best loss: 0.004
Mean loss: 0.264



Ep: 2/  5. LR: 9.121e-05. Loss: 0.0603.:  56%|█████▌    | 2241/3999 [02:19<01:56, 15.03it/s]

{'mse': 0.2510552, 'raw_mse': 61684.184, 'mae': 0.4353998, 'raw_mae': 112.8463, 'r2': 0.6977683901786804, 'raw_r2': 0.18812669813632965}
Best loss: 0.004
Mean loss: 0.249



Ep: 2/  5. LR: 9.047e-05. Loss: 0.0808.:  63%|██████▎   | 2500/3999 [02:36<01:40, 14.88it/s]

{'mse': 0.33785397, 'raw_mse': 94390.39, 'mae': 0.5057379, 'raw_mae': 138.20291, 'r2': 0.5932685136795044, 'raw_r2': -0.24132700264453888}
Best loss: 0.004
Mean loss: 0.262



Ep: 2/  5. LR: 8.971e-05. Loss: 0.0844.:  69%|██████▉   | 2763/3999 [02:52<01:18, 15.84it/s]

{'mse': 0.119292185, 'raw_mse': 24870.738, 'mae': 0.29845372, 'raw_mae': 70.53333, 'r2': 0.856400191783905, 'raw_r2': 0.6720959544181824}
Best loss: 0.004
Mean loss: 0.265



Ep: 2/  5. LR: 8.892e-05. Loss: 0.0408.:  76%|███████▌  | 3021/3999 [03:08<01:08, 14.34it/s]

{'mse': 0.0030605257, 'raw_mse': 439.77164, 'mae': 0.040253606, 'raw_mae': 8.019064, 'r2': 0.996321439743042, 'raw_r2': 0.9941125512123108}
Best loss: 0.003
Mean loss: 0.265



Ep: 2/  5. LR: 8.810e-05. Loss: 0.0629.:  82%|████████▏ | 3285/3999 [03:25<00:49, 14.50it/s]

{'mse': 0.12265113, 'raw_mse': 7790.3296, 'mae': 0.3021482, 'raw_mae': 45.824528, 'r2': 0.8523115515708923, 'raw_r2': 0.8973146080970764}
Best loss: 0.003
Mean loss: 0.270



Ep: 2/  5. LR: 8.725e-05. Loss: 0.0782.:  89%|████████▊ | 3544/3999 [03:42<00:29, 15.18it/s]

{'mse': 0.2364427, 'raw_mse': 13459.532, 'mae': 0.4218528, 'raw_mae': 60.062588, 'r2': 0.7152984738349915, 'raw_r2': 0.8223581314086914}
Best loss: 0.003
Mean loss: 0.277



Ep: 2/  5. LR: 8.638e-05. Loss: 0.0554.:  95%|█████████▌| 3804/3999 [03:59<00:13, 14.43it/s]

{'mse': 0.061551407, 'raw_mse': 4114.205, 'mae': 0.21164845, 'raw_mae': 33.554863, 'r2': 0.9258819222450256, 'raw_r2': 0.9458448886871338}
Best loss: 0.003
Mean loss: 0.280



Ep: 2/  5. LR: 8.548e-05. Loss: 0.0682.: : 4065it [04:16, 14.48it/s]                        

{'mse': 0.04020236, 'raw_mse': 7372.8174, 'mae': 0.17000198, 'raw_mae': 37.245594, 'r2': 0.9516142010688782, 'raw_r2': 0.9025403261184692}
Best loss: 0.003
Mean loss: 0.269



Ep: 2/  5. LR: 8.526e-05. Loss: 0.0723.: 100%|██████████| 3999/3999 [04:20<00:00, 15.35it/s]
Ep: 3/  5. LR: 8.456e-05. Loss: 0.0678.:   6%|▌         | 224/3999 [00:12<03:11, 19.67it/s]

{'mse': 0.21460022, 'raw_mse': 51619.44, 'mae': 0.40220085, 'raw_mae': 102.01278, 'r2': 0.741657555103302, 'raw_r2': 0.32036930322647095}
Best loss: 0.003
Mean loss: 0.256



Ep: 3/  5. LR: 8.360e-05. Loss: 0.0628.:  12%|█▏        | 494/3999 [00:30<03:50, 15.18it/s]

{'mse': 0.10137351, 'raw_mse': 20527.746, 'mae': 0.2746128, 'raw_mae': 63.913715, 'r2': 0.8779724836349487, 'raw_r2': 0.729253888130188}
Best loss: 0.003
Mean loss: 0.260



Ep: 3/  5. LR: 8.263e-05. Loss: 0.0682.:  19%|█▉        | 760/3999 [00:48<03:54, 13.81it/s]

{'mse': 0.022902614, 'raw_mse': 1464.0039, 'mae': 0.12440506, 'raw_mae': 20.289396, 'r2': 0.9724212884902954, 'raw_r2': 0.9807667136192322}
Best loss: 0.003
Mean loss: 0.261



Ep: 3/  5. LR: 8.162e-05. Loss: 0.0730.:  26%|██▌       | 1025/3999 [01:05<03:24, 14.53it/s]

{'mse': 0.2376158, 'raw_mse': 13512.017, 'mae': 0.42291054, 'raw_mae': 60.179214, 'r2': 0.713886022567749, 'raw_r2': 0.8216636180877686}
Best loss: 0.003
Mean loss: 0.269



Ep: 3/  5. LR: 8.059e-05. Loss: 0.0601.:  32%|███▏      | 1290/3999 [01:22<03:04, 14.71it/s]

{'mse': 0.18391758, 'raw_mse': 10996.616, 'mae': 0.37141955, 'raw_mae': 54.319645, 'r2': 0.7785419821739197, 'raw_r2': 0.8549363017082214}
Best loss: 0.003
Mean loss: 0.235



Ep: 3/  5. LR: 7.953e-05. Loss: 0.0662.:  39%|███▉      | 1553/3999 [01:39<02:40, 15.23it/s]

{'mse': 0.0028520168, 'raw_mse': 370.61688, 'mae': 0.038367473, 'raw_mae': 7.4981456, 'r2': 0.9965720772743225, 'raw_r2': 0.9950330257415771}
Best loss: 0.003
Mean loss: 0.209



Ep: 3/  5. LR: 7.845e-05. Loss: 0.0721.:  45%|████▌     | 1815/3999 [01:56<02:53, 12.61it/s]

{'mse': 0.16770004, 'raw_mse': 37688.797, 'mae': 0.35495514, 'raw_mae': 87.115036, 'r2': 0.7981215715408325, 'raw_r2': 0.5034906268119812}
Best loss: 0.003
Mean loss: 0.197



Ep: 3/  5. LR: 7.736e-05. Loss: 0.0698.:  52%|█████▏    | 2076/3999 [02:15<02:30, 12.79it/s]

{'mse': 0.19610913, 'raw_mse': 45949.87, 'mae': 0.38426802, 'raw_mae': 96.2423, 'r2': 0.7639192938804626, 'raw_r2': 0.3948844373226166}
Best loss: 0.003
Mean loss: 0.206



Ep: 3/  5. LR: 7.622e-05. Loss: 0.0604.:  58%|█████▊    | 2339/3999 [02:31<01:50, 15.06it/s]

{'mse': 0.005116391, 'raw_mse': 897.514, 'mae': 0.05451274, 'raw_mae': 11.416514, 'r2': 0.9938481450080872, 'raw_r2': 0.9880391359329224}
Best loss: 0.003
Mean loss: 0.171



Ep: 3/  5. LR: 7.507e-05. Loss: 0.0677.:  65%|██████▌   | 2600/3999 [02:48<01:27, 16.06it/s]

{'mse': 0.148735, 'raw_mse': 9205.501, 'mae': 0.33340234, 'raw_mae': 49.74813, 'r2': 0.8209043741226196, 'raw_r2': 0.878615140914917}
Best loss: 0.003
Mean loss: 0.169



Ep: 3/  5. LR: 7.390e-05. Loss: 0.0760.:  72%|███████▏  | 2861/3999 [03:04<01:15, 15.07it/s]

{'mse': 0.27249724, 'raw_mse': 15026.519, 'mae': 0.45323643, 'raw_mae': 63.461422, 'r2': 0.6718865633010864, 'raw_r2': 0.8016204237937927}
Best loss: 0.003
Mean loss: 0.138



Ep: 3/  5. LR: 7.270e-05. Loss: 0.0636.:  78%|███████▊  | 3122/3999 [03:20<00:57, 15.13it/s]

{'mse': 0.052552916, 'raw_mse': 3521.4033, 'mae': 0.19481987, 'raw_mae': 31.11367, 'r2': 0.9367173314094543, 'raw_r2': 0.9536627531051636}
Best loss: 0.003
Mean loss: 0.139



Ep: 3/  5. LR: 7.148e-05. Loss: 0.0509.:  85%|████████▍ | 3381/3999 [03:37<00:39, 15.49it/s]

{'mse': 0.05259118, 'raw_mse': 9825.899, 'mae': 0.19565597, 'raw_mae': 43.449226, 'r2': 0.9367001056671143, 'raw_r2': 0.8702022433280945}
Best loss: 0.003
Mean loss: 0.129



Ep: 3/  5. LR: 7.024e-05. Loss: 0.0750.:  91%|█████████ | 3643/3999 [03:53<00:22, 15.91it/s]

{'mse': 0.22157818, 'raw_mse': 53822.027, 'mae': 0.40876368, 'raw_mae': 104.16102, 'r2': 0.7332566380500793, 'raw_r2': 0.2914249002933502}
Best loss: 0.003
Mean loss: 0.123



Ep: 3/  5. LR: 6.898e-05. Loss: 0.0704.:  98%|█████████▊| 3904/3999 [04:09<00:06, 15.57it/s]

{'mse': 0.11108928, 'raw_mse': 22856.31, 'mae': 0.28778535, 'raw_mae': 67.54434, 'r2': 0.8662755489349365, 'raw_r2': 0.6986052989959717}
Best loss: 0.003
Mean loss: 0.123



Ep: 3/  5. LR: 6.788e-05. Loss: 0.0584.: 100%|██████████| 3999/3999 [04:23<00:00, 15.19it/s]
Ep: 4/  5. LR: 6.769e-05. Loss: 0.0583.:   1%|▏         | 51/3999 [00:02<03:36, 18.20it/s]

{'mse': 0.009976147, 'raw_mse': 548.832, 'mae': 0.076575994, 'raw_mae': 12.352418, 'r2': 0.9879884719848633, 'raw_r2': 0.9927908182144165}
Best loss: 0.003
Mean loss: 0.123



Ep: 4/  5. LR: 6.639e-05. Loss: 0.0499.:   8%|▊         | 329/3999 [00:18<03:20, 18.34it/s]

{'mse': 0.18328391, 'raw_mse': 10965.4375, 'mae': 0.3707693, 'raw_mae': 54.243244, 'r2': 0.7793050408363342, 'raw_r2': 0.8553487658500671}
Best loss: 0.003
Mean loss: 0.126



Ep: 4/  5. LR: 6.506e-05. Loss: 0.0773.:  15%|█▍        | 597/3999 [00:35<03:42, 15.30it/s]

{'mse': 0.23048958, 'raw_mse': 13191.693, 'mae': 0.41644362, 'raw_mae': 59.463497, 'r2': 0.7224662899971008, 'raw_r2': 0.8259019255638123}
Best loss: 0.003
Mean loss: 0.126



Ep: 4/  5. LR: 6.371e-05. Loss: 0.0730.:  22%|██▏       | 862/3999 [00:51<03:26, 15.17it/s]

{'mse': 0.03534584, 'raw_mse': 2345.9214, 'mae': 0.15777346, 'raw_mae': 25.54586, 'r2': 0.9574372172355652, 'raw_r2': 0.9691554307937622}
Best loss: 0.003
Mean loss: 0.125



Ep: 4/  5. LR: 6.235e-05. Loss: 0.0472.:  28%|██▊       | 1126/3999 [01:07<03:28, 13.79it/s]

{'mse': 0.04574977, 'raw_mse': 8457.967, 'mae': 0.18192026, 'raw_mae': 40.101936, 'r2': 0.9449360966682434, 'raw_r2': 0.8882331848144531}
Best loss: 0.003
Mean loss: 0.125



Ep: 4/  5. LR: 6.096e-05. Loss: 0.0570.:  35%|███▍      | 1391/3999 [01:23<02:48, 15.47it/s]

{'mse': 0.18866144, 'raw_mse': 43731.805, 'mae': 0.37680373, 'raw_mae': 93.88298, 'r2': 0.7728857398033142, 'raw_r2': 0.42404112219810486}
Best loss: 0.003
Mean loss: 0.124



Ep: 4/  5. LR: 5.956e-05. Loss: 0.0740.:  41%|████▏     | 1653/3999 [01:40<02:37, 14.88it/s]

{'mse': 0.12763916, 'raw_mse': 26968.818, 'mae': 0.30893344, 'raw_mae': 73.512215, 'r2': 0.8463512659072876, 'raw_r2': 0.6444894075393677}
Best loss: 0.003
Mean loss: 0.125



Ep: 4/  5. LR: 5.814e-05. Loss: 0.0625.:  48%|████▊     | 1914/3999 [01:56<02:19, 14.95it/s]

{'mse': 0.002974103, 'raw_mse': 413.44498, 'mae': 0.039508563, 'raw_mae': 7.8189526, 'r2': 0.9964253902435303, 'raw_r2': 0.994462788105011}
Best loss: 0.003
Mean loss: 0.124



Ep: 4/  5. LR: 5.670e-05. Loss: 0.0467.:  54%|█████▍    | 2176/3999 [02:12<02:03, 14.77it/s]

{'mse': 0.10184745, 'raw_mse': 6600.755, 'mae': 0.27469075, 'raw_mae': 42.24879, 'r2': 0.8773611187934875, 'raw_r2': 0.9130266308784485}
Best loss: 0.003
Mean loss: 0.117



Ep: 4/  5. LR: 5.524e-05. Loss: 0.0517.:  61%|██████    | 2439/3999 [02:28<01:43, 15.11it/s]

{'mse': 0.21968834, 'raw_mse': 12698.776, 'mae': 0.4064455, 'raw_mae': 58.345726, 'r2': 0.7354716658592224, 'raw_r2': 0.8324232697486877}
Best loss: 0.003
Mean loss: 0.119



Ep: 4/  5. LR: 5.376e-05. Loss: 0.0739.:  67%|██████▋   | 2698/3999 [02:45<01:28, 14.68it/s]

{'mse': 0.11907162, 'raw_mse': 7589.833, 'mae': 0.29760396, 'raw_mae': 45.241352, 'r2': 0.8566215634346008, 'raw_r2': 0.8999632596969604}
Best loss: 0.003
Mean loss: 0.125



Ep: 4/  5. LR: 5.227e-05. Loss: 0.0664.:  74%|███████▍  | 2960/3999 [03:01<01:09, 14.88it/s]

{'mse': 0.0032028467, 'raw_mse': 203.00021, 'mae': 0.03859107, 'raw_mae': 6.60028, 'r2': 0.9961474537849426, 'raw_r2': 0.9972837567329407}
Best loss: 0.003
Mean loss: 0.116



Ep: 4/  5. LR: 5.076e-05. Loss: 0.0433.:  81%|████████  | 3220/3999 [03:17<00:49, 15.60it/s]

{'mse': 0.0734714, 'raw_mse': 14204.085, 'mae': 0.2327077, 'raw_mae': 52.785587, 'r2': 0.9115630984306335, 'raw_r2': 0.8125182390213013}
Best loss: 0.003
Mean loss: 0.110



Ep: 4/  5. LR: 4.924e-05. Loss: 0.0481.:  87%|████████▋ | 3479/3999 [03:33<00:35, 14.67it/s]

{'mse': 0.1613063, 'raw_mse': 35905.805, 'mae': 0.3480179, 'raw_mae': 85.00879, 'r2': 0.8058190941810608, 'raw_r2': 0.5269365310668945}
Best loss: 0.003
Mean loss: 0.118



Ep: 4/  5. LR: 4.770e-05. Loss: 0.0629.:  94%|█████████▎| 3741/3999 [03:50<00:16, 15.72it/s]

{'mse': 0.09357888, 'raw_mse': 18708.385, 'mae': 0.26357204, 'raw_mae': 60.92076, 'r2': 0.8873562216758728, 'raw_r2': 0.7532044649124146}
Best loss: 0.003
Mean loss: 0.115



Ep: 4/  5. LR: 4.614e-05. Loss: 0.0604.: : 4002it [04:06, 14.54it/s]                        

{'mse': 0.0035735278, 'raw_mse': 572.0342, 'mae': 0.04421712, 'raw_mae': 9.025156, 'r2': 0.9957044720649719, 'raw_r2': 0.9923551678657532}
Best loss: 0.003
Mean loss: 0.102



Ep: 4/  5. LR: 4.540e-05. Loss: 0.0526.: 100%|██████████| 3999/3999 [04:14<00:00, 15.72it/s]
Ep: 5/  5. LR: 4.458e-05. Loss: 0.0460.:   4%|▍         | 157/3999 [00:08<04:13, 15.15it/s]

{'mse': 0.05997242, 'raw_mse': 4011.18, 'mae': 0.20879261, 'raw_mae': 33.144344, 'r2': 0.927783191204071, 'raw_r2': 0.9472036957740784}
Best loss: 0.003
Mean loss: 0.102



Ep: 5/  5. LR: 4.299e-05. Loss: 0.0479.:  11%|█         | 429/3999 [00:24<04:00, 14.85it/s]

{'mse': 0.15937683, 'raw_mse': 9760.5, 'mae': 0.3453431, 'raw_mae': 51.207207, 'r2': 0.80809086561203, 'raw_r2': 0.8712790608406067}
Best loss: 0.003
Mean loss: 0.108



Ep: 5/  5. LR: 4.140e-05. Loss: 0.0573.:  17%|█▋        | 697/3999 [00:40<03:30, 15.67it/s]

{'mse': 0.14732806, 'raw_mse': 9131.181, 'mae': 0.33179152, 'raw_mae': 49.549633, 'r2': 0.8225985765457153, 'raw_r2': 0.8795972466468811}
Best loss: 0.003
Mean loss: 0.104



Ep: 5/  5. LR: 3.980e-05. Loss: 0.0534.:  24%|██▍       | 961/3999 [00:56<03:24, 14.83it/s]

{'mse': 0.0380865, 'raw_mse': 2536.8162, 'mae': 0.16422804, 'raw_mae': 26.535639, 'r2': 0.9541370868682861, 'raw_r2': 0.9666405916213989}
Best loss: 0.003
Mean loss: 0.100



Ep: 5/  5. LR: 3.817e-05. Loss: 0.0519.:  31%|███       | 1225/3999 [01:13<02:56, 15.73it/s]

{'mse': 0.006129879, 'raw_mse': 1092.9849, 'mae': 0.060495656, 'raw_mae': 12.729605, 'r2': 0.9926286339759827, 'raw_r2': 0.9854502081871033}
Best loss: 0.003
Mean loss: 0.100



Ep: 5/  5. LR: 3.654e-05. Loss: 0.0505.:  37%|███▋      | 1489/3999 [01:29<02:50, 14.68it/s]

{'mse': 0.06546619, 'raw_mse': 12489.746, 'mae': 0.21922724, 'raw_mae': 49.33696, 'r2': 0.9212004542350769, 'raw_r2': 0.8351010680198669}
Best loss: 0.003
Mean loss: 0.094



Ep: 5/  5. LR: 3.489e-05. Loss: 0.0475.:  44%|████▍     | 1752/3999 [01:46<02:31, 14.82it/s]

{'mse': 0.11700288, 'raw_mse': 24303.768, 'mae': 0.295515, 'raw_mae': 69.70561, 'r2': 0.8591562509536743, 'raw_r2': 0.6795567274093628}
Best loss: 0.003
Mean loss: 0.088



Ep: 5/  5. LR: 3.325e-05. Loss: 0.0443.:  50%|█████     | 2012/3999 [02:02<02:21, 14.03it/s]

{'mse': 0.085047364, 'raw_mse': 16762.88, 'mae': 0.25093687, 'raw_mae': 57.549187, 'r2': 0.8976271748542786, 'raw_r2': 0.7788199186325073}
Best loss: 0.003
Mean loss: 0.091



Ep: 5/  5. LR: 3.158e-05. Loss: 0.0450.:  57%|█████▋    | 2275/3999 [02:18<02:02, 14.03it/s]

{'mse': 0.016160266, 'raw_mse': 2905.638, 'mae': 0.10441278, 'raw_mae': 22.2507, 'r2': 0.9805556535720825, 'raw_r2': 0.961485743522644}
Best loss: 0.003
Mean loss: 0.089



Ep: 5/  5. LR: 2.990e-05. Loss: 0.0465.:  63%|██████▎   | 2538/3999 [02:34<01:38, 14.80it/s]

{'mse': 0.009054039, 'raw_mse': 487.2837, 'mae': 0.07212295, 'raw_mae': 11.599024, 'r2': 0.9890990853309631, 'raw_r2': 0.9935970306396484}
Best loss: 0.003
Mean loss: 0.080



Ep: 5/  5. LR: 2.822e-05. Loss: 0.0483.:  70%|██████▉   | 2798/3999 [02:51<01:16, 15.60it/s]

{'mse': 0.057416994, 'raw_mse': 3843.7905, 'mae': 0.20408681, 'raw_mae': 32.464684, 'r2': 0.9308602809906006, 'raw_r2': 0.9494113922119141}
Best loss: 0.003
Mean loss: 0.077



Ep: 5/  5. LR: 2.652e-05. Loss: 0.0463.:  76%|███████▋  | 3059/3999 [03:07<01:04, 14.56it/s]

{'mse': 0.10663202, 'raw_mse': 6879.452, 'mae': 0.28124252, 'raw_mae': 43.11324, 'r2': 0.8716000318527222, 'raw_r2': 0.9093460440635681}
Best loss: 0.003
Mean loss: 0.082



Ep: 5/  5. LR: 2.482e-05. Loss: 0.0456.:  83%|████████▎ | 3319/3999 [03:23<00:44, 15.34it/s]

{'mse': 0.10930622, 'raw_mse': 7034.0186, 'mae': 0.28483903, 'raw_mae': 43.58491, 'r2': 0.8683801889419556, 'raw_r2': 0.9073047637939453}
Best loss: 0.003
Mean loss: 0.083



Ep: 5/  5. LR: 2.312e-05. Loss: 0.0434.:  89%|████████▉ | 3579/3999 [03:39<00:27, 15.52it/s]

{'mse': 0.060602076, 'raw_mse': 4052.387, 'mae': 0.2099361, 'raw_mae': 33.30898, 'r2': 0.9270250797271729, 'raw_r2': 0.9466602802276611}
Best loss: 0.003
Mean loss: 0.075



Ep: 5/  5. LR: 2.141e-05. Loss: 0.0392.:  96%|█████████▌| 3838/3999 [03:56<00:10, 15.44it/s]

{'mse': 0.013779279, 'raw_mse': 812.1551, 'mae': 0.092964314, 'raw_mae': 15.121838, 'r2': 0.9834080934524536, 'raw_r2': 0.9893359541893005}
Best loss: 0.003
Mean loss: 0.069



Ep: 5/  5. LR: 1.968e-05. Loss: 0.0380.: : 4100it [04:12, 14.31it/s]                        

{'mse': 0.00332369, 'raw_mse': 511.16443, 'mae': 0.04234987, 'raw_mae': 8.560571, 'r2': 0.996005117893219, 'raw_r2': 0.993163526058197}
Best loss: 0.003
Mean loss: 0.069



Ep: 5/  5. LR: 1.951e-05. Loss: 0.0387.: 100%|██████████| 3999/3999 [04:14<00:00, 15.73it/s]
