<a href="https://colab.research.google.com/github/nhnain/eegchallenge/blob/main/Challenge1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#1. Preprocessing

In [1]:
!pip install braindecode
!pip install eegdash

Collecting braindecode
  Downloading braindecode-1.2.0-py3-none-any.whl.metadata (7.1 kB)
Collecting mne>=1.10.0 (from braindecode)
  Downloading mne-1.10.1-py3-none-any.whl.metadata (20 kB)
Collecting mne_bids>=0.16 (from braindecode)
  Downloading mne_bids-0.17.0-py3-none-any.whl.metadata (7.3 kB)
Collecting skorch>=1.2.0 (from braindecode)
  Downloading skorch-1.2.0-py3-none-any.whl.metadata (11 kB)
Collecting torchinfo~=1.8 (from braindecode)
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting wfdb (from braindecode)
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting linear_attention_transformer (from braindecode)
  Downloading linear_attention_transformer-0.19.1-py3-none-any.whl.metadata (787 bytes)
Collecting docstring_inheritance (from braindecode)
  Downloading docstring_inheritance-2.2.2-py3-none-any.whl.metadata (11 kB)
Collecting axial-positional-embedding (from linear_attention_transformer->braindecode)
  Downloading axial_position

In [2]:
from pathlib import Path
import math
import os
import random
from joblib import Parallel, delayed

import torch
from torch.utils.data import DataLoader
from torch import optim
from torch.nn.functional import l1_loss
from braindecode.preprocessing import preprocess, Preprocessor, create_fixed_length_windows, create_windows_from_events
from braindecode.datasets.base import EEGWindowsDataset, BaseConcatDataset, BaseDataset
from braindecode.models import EEGNeX
from eegdash import EEGChallengeDataset
from typing import Optional
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import annotate_trials_with_target, add_aux_anchors, add_extras_columns, keep_only_recordings_with

In [3]:
# Identify whether a CUDA-enabled GPU is available
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    msg ='CUDA-enabled GPU found. Training should be faster.'
else:
    msg = (
        "No GPU found. Training will be carried out on CPU, which might be "
        "slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
        " clicking\n`Runtime/Change runtime type` in the top bar menu, then "
        "selecting \'T4 GPU\'\nunder \'Hardware accelerator\'."
    )
print(msg)

CUDA-enabled GPU found. Training should be faster.


In [4]:
release_list = ['R1', 'R2', 'R3', 'R4', 'R5', 'R6',
                'R7', 'R8', 'R9', 'R10', 'R11']
train_set = []
valid_set = []
test_set = []

DATA_DIR = Path('data')
DATA_DIR.mkdir(parents=True, exist_ok=True)


In [5]:
for release in tqdm(release_list):
  dataset_ccd = EEGChallengeDataset(
      task = 'contrastChangeDetection',
      release = release,
      cache_dir = DATA_DIR,
      mini = True
  )
  raws = Parallel(n_jobs=-1)(
      delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets
  )

  EPOCH_LEN_S = 2.0
  SFREQ = 100

  transformation_offline = [
      Preprocessor(
          annotate_trials_with_target,
          target_field = 'rt_from_stimulus', epoch_length =  EPOCH_LEN_S,
          require_stimulus = True,
          require_response = True,
          apply_on_array = False,
      ),
      Preprocessor(add_aux_anchors, apply_on_array=False),
  ]
  preprocess (dataset_ccd, transformation_offline, n_jobs=1)

  ANCHOR = 'stimulus_anchor'
  SHIFT_AFTER_STIM = 0.5
  WINDOW_LEN = 2.0

  dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

  single_windows = create_windows_from_events(
      dataset,
      mapping={ANCHOR: 0},
      trial_start_offset_samples = int(SHIFT_AFTER_STIM * SFREQ),
      trial_stop_offset_samples = int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),
      window_size_samples = int(EPOCH_LEN_S * SFREQ),
      window_stride_samples = SFREQ,
      preload = True,
  )

  single_windows = add_extras_columns(
      single_windows,
      dataset,
      desc = ANCHOR, keys = ("target", "rt_from_stimulus", "rt_from_trialstart",
            "stimulus_onset", "response_onset", "correct", "response_type"))

  meta_information = single_windows.get_metadata()

  subjects = meta_information['subject'].unique()
  sub_rm = ["NDARWV769JM7", "NDARME789TD2", "NDARUA442ZVF", "NDARJP304NK1",
            "NDARTY128YLU", "NDARDW550GU6", "NDARLD243KRE", "NDARUJ292JXV", "NDARBA381JGH"]
  subjects = [s for s in subjects if s not in sub_rm]

  subject_split = single_windows.split('subject')

  if release != 'R5':
    for s in subject_split:
      train_set.append(subject_split[s])
  else:
    for s in subject_split:
      valid_set.append(subject_split[s])


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

100%|██████████| 11/11 [37:56<00:00, 206.98s/it]


In [6]:
train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)

print(f'Train set: {len(train_set)}')
print(f'Valid set: {len(valid_set)}')

Train set: 12805
Valid set: 1214


#Build model

In [7]:
batch_size = 128
num_workers = 2

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [8]:
from braindecode.models.util import models_dict

names = sorted(models_dict)
w = max(len(n) for n in names)

for i in range(0, len(names), 3):
    row = names[i:i+3]
    print("  ".join(f"{n:<{w}}" for n in row))

ATCNet                  AttentionBaseNet        AttnSleep             
BDTCN                   BIOT                    CTNet                 
ContraWR                Deep4Net                DeepSleepNet          
EEGConformer            EEGITNet                EEGInceptionERP       
EEGInceptionMI          EEGMiner                EEGNeX                
EEGNet                  EEGSimpleConv           EEGTCNet              
FBCNet                  FBLightConvNet          FBMSNet               
IFNet                   Labram                  MSVTNet               
SCCNet                  SPARCNet                ShallowFBCSPNet       
SignalJEPA              SignalJEPA_Contextual   SignalJEPA_PostLocal  
SignalJEPA_PreLocal     SincShallowNet          SleepStagerBlanco2020 
SleepStagerChambon2018  SyncNet                 TIDNet                
TSception               USleep                


In [9]:
model = EEGNeX(n_chans=129,
               n_outputs=1,
               n_times=200,
               sfreq=100)

#Model training

In [10]:
lr = 1e-3
weight_decay = 1e-5
n_epochs = 60
early_stopping_patience = 50

In [11]:
def train_one_epoch(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    optimizer,
    scheduler: Optional[LRScheduler],
    epoch: int,
    device,
    print_batch_stats: bool = True,):
  model.train()

  total_loss = 0.0
  sum_sq_err = 0.0
  n_samples = 0

  progress_bar = tqdm(
      enumerate(dataloader),
      total=len(dataloader),
      disable=not print_batch_stats)
  for batch_idx, batch in progress_bar:
    X, y = batch[0], batch[1]
    X, y = X.to(device).float(), y.to(device).float()

    optimizer.zero_grad(set_to_none=True)
    preds = model(X)
    loss = loss_fn(preds, y)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    preds_flat = preds.detach().view(-1)
    y_flat = y.detach().view(-1)
    sum_sq_err += torch.sum((preds_flat - y_flat)**2).item()
    n_samples += y_flat.numel()

    if print_batch_stats:
      running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
      progress_bar.set_description(
          f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)},"
          f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}"
      )
  if scheduler is not None:
    scheduler.step()

  avg_loss = total_loss/len(dataloader)
  rmse = (sum_sq_err/max(n_samples, 1))** 0.5
  return avg_loss, rmse

In [12]:
with torch.inference_mode():
  def valid_model(
      dataloader: DataLoader,
      model: Module,
      loss_fn,
      device,
      print_batch_stats: bool = True):

    model.eval()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_batches = len(dataloader)
    n_samples = 0.0

    iterator = tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        disable=not print_batch_stats)

    for batch_idx, batch in iterator:
      X, y = batch[0], batch[1]
      X, y = X.to(device).float(), y.to(device).float()

      preds = model(X)
      batch_loss = loss_fn(preds, y).item()
      total_loss += batch_loss

      preds_flat = preds.detach().view(-1)
      y_flat = y.detach().view(-1)
      sum_sq_err += torch.sum((preds_flat - y_flat)**2).item()
      n_samples += y_flat.numel()

      if print_batch_stats:
        running_rmse = (sum_sq_err/max(n_samples,1)) ** 0.5
        iterator.set_description(
            f"Val Batch {batch_idx + 1}/{len(dataloader)},"
            f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
        )

    avg_loss = total_loss / n_batches if n_batches else float("nan")
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5

    print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}")
    return avg_loss, rmse

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = n_epochs -1)
loss_fn = torch.nn.MSELoss()

patience = 5
min_delta = 1e-4
best_rmse = float('inf')
epochs_no_improve = 0
best_state, best_epoch = None, None

model.to(device)
for epoch in range(1, n_epochs+1):
  print(f'Epoch {epoch}/{n_epochs}: ', end='')

  train_loss, train_rmse = train_one_epoch(
      train_loader, model, loss_fn, optimizer, scheduler, epoch, device
  )
  valid_loss, valid_rmse = valid_model(
      valid_loader, model, loss_fn, device
  )
  print(
      f'Train RMSE: {train_rmse:.6f}, '
      f'Average Train Loss: {train_loss:.6f}, '
      f'Val RMSE: {valid_rmse:.6f}, '
      f'Average Valid Loss: {valid_loss:.6f}\n'
  )
if best_state is not None:
  model.load_state_dict(best_state)

Epoch 1/100: 

  return F.conv2d(
Epoch 1, Batch 101/101,Loss: 0.100190, RMSE: 0.835295: 100%|██████████| 101/101 [01:01<00:00,  1.65it/s]
Val Batch 10/10,Loss: 0.386997, RMSE: 0.472473: 100%|██████████| 10/10 [00:04<00:00,  2.46it/s]

Val RMSE: 0.472473, Val Loss: 0.231675
Train RMSE: 0.835295, Average Train Loss: 0.692033, Val RMSE: 0.472473, Average Valid Loss: 0.231675

Epoch 2/100: 


Epoch 2, Batch 101/101,Loss: 0.393672, RMSE: 0.481925: 100%|██████████| 101/101 [00:57<00:00,  1.75it/s]
Val Batch 10/10,Loss: 0.406114, RMSE: 0.565627: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.565627, Val Loss: 0.324377
Train RMSE: 0.481925, Average Train Loss: 0.233787, Val RMSE: 0.565627, Average Valid Loss: 0.324377

Epoch 3/100: 


Epoch 3, Batch 101/101,Loss: 0.261439, RMSE: 0.466932: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.342528, RMSE: 0.450396: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]

Val RMSE: 0.450396, Val Loss: 0.210058
Train RMSE: 0.466932, Average Train Loss: 0.218438, Val RMSE: 0.450396, Average Valid Loss: 0.210058

Epoch 4/100: 


Epoch 4, Batch 101/101,Loss: 0.342430, RMSE: 0.457216: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.288806, RMSE: 0.421368: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.421368, Val Loss: 0.183288
Train RMSE: 0.457216, Average Train Loss: 0.210316, Val RMSE: 0.421368, Average Valid Loss: 0.183288

Epoch 5/100: 


Epoch 5, Batch 101/101,Loss: 0.084066, RMSE: 0.453886: 100%|██████████| 101/101 [00:58<00:00,  1.74it/s]
Val Batch 10/10,Loss: 0.288473, RMSE: 0.421260: 100%|██████████| 10/10 [00:04<00:00,  2.46it/s]

Val RMSE: 0.421260, Val Loss: 0.183184
Train RMSE: 0.453886, Average Train Loss: 0.204853, Val RMSE: 0.421260, Average Valid Loss: 0.183184

Epoch 6/100: 


Epoch 6, Batch 101/101,Loss: 0.070230, RMSE: 0.456445: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.260011, RMSE: 0.402845: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.402845, Val Loss: 0.167323
Train RMSE: 0.456445, Average Train Loss: 0.207028, Val RMSE: 0.402845, Average Valid Loss: 0.167323

Epoch 7/100: 


Epoch 7, Batch 101/101,Loss: 0.382325, RMSE: 0.450983: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.269894, RMSE: 0.410211: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]

Val RMSE: 0.410211, Val Loss: 0.173513
Train RMSE: 0.450983, Average Train Loss: 0.205088, Val RMSE: 0.410211, Average Valid Loss: 0.173513

Epoch 8/100: 


Epoch 8, Batch 101/101,Loss: 0.314686, RMSE: 0.448365: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.277724, RMSE: 0.416971: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]

Val RMSE: 0.416971, Val Loss: 0.179220
Train RMSE: 0.448365, Average Train Loss: 0.202113, Val RMSE: 0.416971, Average Valid Loss: 0.179220

Epoch 9/100: 


Epoch 9, Batch 101/101,Loss: 0.121968, RMSE: 0.449241: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.286668, RMSE: 0.423491: 100%|██████████| 10/10 [00:04<00:00,  2.06it/s]

Val RMSE: 0.423491, Val Loss: 0.184879
Train RMSE: 0.449241, Average Train Loss: 0.201058, Val RMSE: 0.423491, Average Valid Loss: 0.184879

Epoch 10/100: 


Epoch 10, Batch 101/101,Loss: 0.099054, RMSE: 0.448572: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.212647, RMSE: 0.385667: 100%|██████████| 10/10 [00:03<00:00,  2.51it/s]

Val RMSE: 0.385667, Val Loss: 0.152035
Train RMSE: 0.448572, Average Train Loss: 0.200245, Val RMSE: 0.385667, Average Valid Loss: 0.152035

Epoch 11/100: 


Epoch 11, Batch 101/101,Loss: 0.101418, RMSE: 0.444538: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.229645, RMSE: 0.393662: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.393662, Val Loss: 0.158820
Train RMSE: 0.444538, Average Train Loss: 0.196698, Val RMSE: 0.393662, Average Valid Loss: 0.158820

Epoch 12/100: 


Epoch 12, Batch 101/101,Loss: 0.148758, RMSE: 0.444823: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.347513, RMSE: 0.456223: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]

Val RMSE: 0.456223, Val Loss: 0.215325
Train RMSE: 0.444823, Average Train Loss: 0.197400, Val RMSE: 0.456223, Average Valid Loss: 0.215325

Epoch 13/100: 


Epoch 13, Batch 101/101,Loss: 0.109162, RMSE: 0.441627: 100%|██████████| 101/101 [00:56<00:00,  1.77it/s]
Val Batch 10/10,Loss: 0.245746, RMSE: 0.403887: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.403887, Val Loss: 0.167385
Train RMSE: 0.441627, Average Train Loss: 0.194218, Val RMSE: 0.403887, Average Valid Loss: 0.167385

Epoch 14/100: 


Epoch 14, Batch 101/101,Loss: 0.125944, RMSE: 0.443607: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.364153, RMSE: 0.463417: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Val RMSE: 0.463417, Val Loss: 0.222459
Train RMSE: 0.443607, Average Train Loss: 0.196113, Val RMSE: 0.463417, Average Valid Loss: 0.222459

Epoch 15/100: 

Epoch 15, Batch 101/101,Loss: 0.423374, RMSE: 0.440861: 100%|██████████| 101/101 [01:01<00:00,  1.65it/s]
Val Batch 10/10,Loss: 0.291044, RMSE: 0.433685: 100%|██████████| 10/10 [00:04<00:00,  2.14it/s]

Val RMSE: 0.433685, Val Loss: 0.193392
Train RMSE: 0.440861, Average Train Loss: 0.196537, Val RMSE: 0.433685, Average Valid Loss: 0.193392

Epoch 16/100: 


Epoch 16, Batch 101/101,Loss: 0.182535, RMSE: 0.441643: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.353683, RMSE: 0.452055: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]

Val RMSE: 0.452055, Val Loss: 0.212053
Train RMSE: 0.441643, Average Train Loss: 0.194930, Val RMSE: 0.452055, Average Valid Loss: 0.212053

Epoch 17/100: 


Epoch 17, Batch 101/101,Loss: 0.273051, RMSE: 0.438752: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.260981, RMSE: 0.409749: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.409749, Val Loss: 0.172694
Train RMSE: 0.438752, Average Train Loss: 0.193270, Val RMSE: 0.409749, Average Valid Loss: 0.172694

Epoch 18/100: 


Epoch 18, Batch 101/101,Loss: 0.324881, RMSE: 0.438041: 100%|██████████| 101/101 [00:57<00:00,  1.74it/s]
Val Batch 10/10,Loss: 0.256564, RMSE: 0.412481: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]

Val RMSE: 0.412481, Val Loss: 0.174597
Train RMSE: 0.438041, Average Train Loss: 0.193145, Val RMSE: 0.412481, Average Valid Loss: 0.174597

Epoch 19/100: 


Epoch 19, Batch 101/101,Loss: 0.157246, RMSE: 0.437105: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.268211, RMSE: 0.412867: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s]

Val RMSE: 0.412867, Val Loss: 0.175499
Train RMSE: 0.437105, Average Train Loss: 0.190739, Val RMSE: 0.412867, Average Valid Loss: 0.175499

Epoch 20/100: 


Epoch 20, Batch 101/101,Loss: 0.276435, RMSE: 0.434603: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.440093, RMSE: 0.509441: 100%|██████████| 10/10 [00:04<00:00,  2.39it/s]

Val RMSE: 0.509441, Val Loss: 0.268840
Train RMSE: 0.434603, Average Train Loss: 0.189713, Val RMSE: 0.509441, Average Valid Loss: 0.268840

Epoch 21/100: 


Epoch 21, Batch 101/101,Loss: 0.111145, RMSE: 0.433625: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.371257, RMSE: 0.469409: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Val RMSE: 0.469409, Val Loss: 0.228126
Train RMSE: 0.433625, Average Train Loss: 0.187299, Val RMSE: 0.469409, Average Valid Loss: 0.228126

Epoch 22/100: 

Epoch 22, Batch 101/101,Loss: 0.163692, RMSE: 0.430763: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.223247, RMSE: 0.390566: 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]

Val RMSE: 0.390566, Val Loss: 0.156187
Train RMSE: 0.430763, Average Train Loss: 0.185349, Val RMSE: 0.390566, Average Valid Loss: 0.156187

Epoch 23/100: 


Epoch 23, Batch 101/101,Loss: 0.215593, RMSE: 0.432701: 100%|██████████| 101/101 [00:57<00:00,  1.76it/s]
Val Batch 10/10,Loss: 0.393787, RMSE: 0.481112: 100%|██████████| 10/10 [00:04<00:00,  2.13it/s]

Val RMSE: 0.481112, Val Loss: 0.239838
Train RMSE: 0.432701, Average Train Loss: 0.187500, Val RMSE: 0.481112, Average Valid Loss: 0.239838

Epoch 24/100: 


Epoch 24, Batch 101/101,Loss: 0.116605, RMSE: 0.430252: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.268863, RMSE: 0.412591: 100%|██████████| 10/10 [00:03<00:00,  2.50it/s]

Val RMSE: 0.412591, Val Loss: 0.175317
Train RMSE: 0.430252, Average Train Loss: 0.184465, Val RMSE: 0.412591, Average Valid Loss: 0.175317

Epoch 25/100: 


Epoch 25, Batch 101/101,Loss: 0.563872, RMSE: 0.426966: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.350069, RMSE: 0.458831: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.458831, Val Loss: 0.217721
Train RMSE: 0.426966, Average Train Loss: 0.185930, Val RMSE: 0.458831, Average Valid Loss: 0.217721

Epoch 26/100: 


Epoch 26, Batch 101/101,Loss: 0.458809, RMSE: 0.427916: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.352659, RMSE: 0.455296: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Val RMSE: 0.455296, Val Loss: 0.214790
Train RMSE: 0.427916, Average Train Loss: 0.185736, Val RMSE: 0.455296, Average Valid Loss: 0.214790

Epoch 27/100: 

Epoch 27, Batch 101/101,Loss: 0.546035, RMSE: 0.427827: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.310504, RMSE: 0.437184: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.437184, Val Loss: 0.197285
Train RMSE: 0.427827, Average Train Loss: 0.186490, Val RMSE: 0.437184, Average Valid Loss: 0.197285

Epoch 28/100: 


Epoch 28, Batch 101/101,Loss: 0.179219, RMSE: 0.423687: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.379700, RMSE: 0.468707: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]

Val RMSE: 0.468707, Val Loss: 0.227937
Train RMSE: 0.423687, Average Train Loss: 0.179508, Val RMSE: 0.468707, Average Valid Loss: 0.227937

Epoch 29/100: 


Epoch 29, Batch 101/101,Loss: 0.210751, RMSE: 0.421954: 100%|██████████| 101/101 [00:57<00:00,  1.77it/s]
Val Batch 10/10,Loss: 0.189174, RMSE: 0.382269: 100%|██████████| 10/10 [00:04<00:00,  2.06it/s]

Val RMSE: 0.382269, Val Loss: 0.148349
Train RMSE: 0.421954, Average Train Loss: 0.178357, Val RMSE: 0.382269, Average Valid Loss: 0.148349

Epoch 30/100: 


Epoch 30, Batch 101/101,Loss: 0.068024, RMSE: 0.423878: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.228555, RMSE: 0.389884: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]

Val RMSE: 0.389884, Val Loss: 0.155956
Train RMSE: 0.423878, Average Train Loss: 0.178611, Val RMSE: 0.389884, Average Valid Loss: 0.155956

Epoch 31/100: 


Epoch 31, Batch 101/101,Loss: 0.128527, RMSE: 0.419782: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.447855, RMSE: 0.515496: 100%|██████████| 10/10 [00:03<00:00,  2.51it/s]

Val RMSE: 0.515496, Val Loss: 0.275126
Train RMSE: 0.419782, Average Train Loss: 0.175763, Val RMSE: 0.515496, Average Valid Loss: 0.275126

Epoch 32/100: 


Epoch 32, Batch 101/101,Loss: 0.085398, RMSE: 0.420635: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.435587, RMSE: 0.504735: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]

Val RMSE: 0.504735, Val Loss: 0.264081
Train RMSE: 0.420635, Average Train Loss: 0.176063, Val RMSE: 0.504735, Average Valid Loss: 0.264081

Epoch 33/100: 


Epoch 33, Batch 101/101,Loss: 0.622974, RMSE: 0.416674: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.295548, RMSE: 0.431739: 100%|██████████| 10/10 [00:04<00:00,  2.37it/s]

Val RMSE: 0.431739, Val Loss: 0.192027
Train RMSE: 0.416674, Average Train Loss: 0.177892, Val RMSE: 0.431739, Average Valid Loss: 0.192027

Epoch 34/100: 


Epoch 34, Batch 101/101,Loss: 0.117349, RMSE: 0.417024: 100%|██████████| 101/101 [00:58<00:00,  1.74it/s]
Val Batch 10/10,Loss: 0.313773, RMSE: 0.439183: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.439183, Val Loss: 0.199115
Train RMSE: 0.417024, Average Train Loss: 0.173371, Val RMSE: 0.439183, Average Valid Loss: 0.199115

Epoch 35/100: 


Epoch 35, Batch 101/101,Loss: 0.055493, RMSE: 0.415480: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.180787, RMSE: 0.388449: 100%|██████████| 10/10 [00:04<00:00,  2.21it/s]

Val RMSE: 0.388449, Val Loss: 0.152434
Train RMSE: 0.415480, Average Train Loss: 0.171510, Val RMSE: 0.388449, Average Valid Loss: 0.152434

Epoch 36/100: 


Epoch 36, Batch 101/101,Loss: 0.033707, RMSE: 0.413262: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.195160, RMSE: 0.386086: 100%|██████████| 10/10 [00:04<00:00,  2.32it/s]

Val RMSE: 0.386086, Val Loss: 0.151439
Train RMSE: 0.413262, Average Train Loss: 0.169481, Val RMSE: 0.386086, Average Valid Loss: 0.151439

Epoch 37/100: 


Epoch 37, Batch 101/101,Loss: 0.041410, RMSE: 0.413638: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.436170, RMSE: 0.508610: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]


Val RMSE: 0.508610, Val Loss: 0.267836
Train RMSE: 0.413638, Average Train Loss: 0.169862, Val RMSE: 0.508610, Average Valid Loss: 0.267836

Epoch 38/100: 

Epoch 38, Batch 101/101,Loss: 0.161555, RMSE: 0.415231: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.240601, RMSE: 0.400645: 100%|██████████| 10/10 [00:04<00:00,  2.24it/s]

Val RMSE: 0.400645, Val Loss: 0.164646
Train RMSE: 0.415231, Average Train Loss: 0.172314, Val RMSE: 0.400645, Average Valid Loss: 0.164646

Epoch 39/100: 


Epoch 39, Batch 101/101,Loss: 0.348121, RMSE: 0.411009: 100%|██████████| 101/101 [00:57<00:00,  1.76it/s]
Val Batch 10/10,Loss: 0.207497, RMSE: 0.385085: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.385085, Val Loss: 0.151343
Train RMSE: 0.411009, Average Train Loss: 0.170633, Val RMSE: 0.385085, Average Valid Loss: 0.151343

Epoch 40/100: 


Epoch 40, Batch 101/101,Loss: 0.093213, RMSE: 0.414814: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.258955, RMSE: 0.406485: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.406485, Val Loss: 0.170063
Train RMSE: 0.414814, Average Train Loss: 0.171320, Val RMSE: 0.406485, Average Valid Loss: 0.170063

Epoch 41/100: 


Epoch 41, Batch 101/101,Loss: 0.386820, RMSE: 0.412514: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.328005, RMSE: 0.450351: 100%|██████████| 10/10 [00:04<00:00,  2.43it/s]

Val RMSE: 0.450351, Val Loss: 0.209271
Train RMSE: 0.412514, Average Train Loss: 0.172229, Val RMSE: 0.450351, Average Valid Loss: 0.209271

Epoch 42/100: 


Epoch 42, Batch 101/101,Loss: 0.085966, RMSE: 0.408837: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.198991, RMSE: 0.395818: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]

Val RMSE: 0.395818, Val Loss: 0.158854
Train RMSE: 0.408837, Average Train Loss: 0.166376, Val RMSE: 0.395818, Average Valid Loss: 0.158854

Epoch 43/100: 


Epoch 43, Batch 101/101,Loss: 0.351644, RMSE: 0.409057: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.283186, RMSE: 0.422844: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.422844, Val Loss: 0.184180
Train RMSE: 0.409057, Average Train Loss: 0.169081, Val RMSE: 0.422844, Average Valid Loss: 0.184180

Epoch 44/100: 


Epoch 44, Batch 101/101,Loss: 0.015487, RMSE: 0.407295: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.279677, RMSE: 0.428880: 100%|██████████| 10/10 [00:05<00:00,  1.93it/s]

Val RMSE: 0.428880, Val Loss: 0.188875
Train RMSE: 0.407295, Average Train Loss: 0.164459, Val RMSE: 0.428880, Average Valid Loss: 0.188875

Epoch 45/100: 


Epoch 45, Batch 101/101,Loss: 0.355826, RMSE: 0.406647: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.359943, RMSE: 0.472012: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Val RMSE: 0.472012, Val Loss: 0.229867
Train RMSE: 0.406647, Average Train Loss: 0.167174, Val RMSE: 0.472012, Average Valid Loss: 0.229867

Epoch 46/100: 

Epoch 46, Batch 101/101,Loss: 0.358052, RMSE: 0.405777: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.204168, RMSE: 0.395825: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.395825, Val Loss: 0.159126
Train RMSE: 0.405777, Average Train Loss: 0.166495, Val RMSE: 0.395825, Average Valid Loss: 0.159126

Epoch 47/100: 


Epoch 47, Batch 101/101,Loss: 0.102576, RMSE: 0.404621: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.252656, RMSE: 0.413466: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]


Val RMSE: 0.413466, Val Loss: 0.175167
Train RMSE: 0.404621, Average Train Loss: 0.163137, Val RMSE: 0.413466, Average Valid Loss: 0.175167

Epoch 48/100: 

Epoch 48, Batch 101/101,Loss: 0.050206, RMSE: 0.403430: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.262419, RMSE: 0.415678: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.415678, Val Loss: 0.177410
Train RMSE: 0.403430, Average Train Loss: 0.161685, Val RMSE: 0.415678, Average Valid Loss: 0.177410

Epoch 49/100: 


Epoch 49, Batch 101/101,Loss: 0.027055, RMSE: 0.401609: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.182109, RMSE: 0.393175: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.393175, Val Loss: 0.156006
Train RMSE: 0.401609, Average Train Loss: 0.160013, Val RMSE: 0.393175, Average Valid Loss: 0.156006

Epoch 50/100: 


Epoch 50, Batch 101/101,Loss: 0.446765, RMSE: 0.402685: 100%|██████████| 101/101 [00:57<00:00,  1.74it/s]
Val Batch 10/10,Loss: 0.277647, RMSE: 0.425970: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Val RMSE: 0.425970, Val Loss: 0.186411
Train RMSE: 0.402685, Average Train Loss: 0.164863, Val RMSE: 0.425970, Average Valid Loss: 0.186411

Epoch 51/100: 

Epoch 51, Batch 101/101,Loss: 0.130932, RMSE: 0.399925: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.230451, RMSE: 0.396431: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.396431, Val Loss: 0.160936
Train RMSE: 0.399925, Average Train Loss: 0.159664, Val RMSE: 0.396431, Average Valid Loss: 0.160936

Epoch 52/100: 


Epoch 52, Batch 101/101,Loss: 0.072555, RMSE: 0.401081: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.191205, RMSE: 0.377790: 100%|██████████| 10/10 [00:04<00:00,  2.45it/s]


Val RMSE: 0.377790, Val Loss: 0.145225
Train RMSE: 0.401081, Average Train Loss: 0.160025, Val RMSE: 0.377790, Average Valid Loss: 0.145225

Epoch 53/100: 

Epoch 53, Batch 101/101,Loss: 0.126664, RMSE: 0.398691: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.316937, RMSE: 0.435534: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]

Val RMSE: 0.435534, Val Loss: 0.196251
Train RMSE: 0.398691, Average Train Loss: 0.158647, Val RMSE: 0.435534, Average Valid Loss: 0.196251

Epoch 54/100: 


Epoch 54, Batch 101/101,Loss: 0.091607, RMSE: 0.399099: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.211649, RMSE: 0.393407: 100%|██████████| 10/10 [00:04<00:00,  2.18it/s]

Val RMSE: 0.393407, Val Loss: 0.157702
Train RMSE: 0.399099, Average Train Loss: 0.158636, Val RMSE: 0.393407, Average Valid Loss: 0.157702

Epoch 55/100: 


Epoch 55, Batch 101/101,Loss: 0.188939, RMSE: 0.400125: 100%|██████████| 101/101 [00:57<00:00,  1.76it/s]
Val Batch 10/10,Loss: 0.226677, RMSE: 0.396587: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.396587, Val Loss: 0.160859
Train RMSE: 0.400125, Average Train Loss: 0.160374, Val RMSE: 0.396587, Average Valid Loss: 0.160859

Epoch 56/100: 


Epoch 56, Batch 101/101,Loss: 0.102022, RMSE: 0.397662: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.304902, RMSE: 0.436116: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Val RMSE: 0.436116, Val Loss: 0.196112
Train RMSE: 0.397662, Average Train Loss: 0.157601, Val RMSE: 0.436116, Average Valid Loss: 0.196112

Epoch 57/100: 

Epoch 57, Batch 101/101,Loss: 0.080622, RMSE: 0.396163: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.195988, RMSE: 0.391765: 100%|██████████| 10/10 [00:04<00:00,  2.06it/s]

Val RMSE: 0.391765, Val Loss: 0.155672
Train RMSE: 0.396163, Average Train Loss: 0.156219, Val RMSE: 0.391765, Average Valid Loss: 0.155672

Epoch 58/100: 


Epoch 58, Batch 101/101,Loss: 0.208809, RMSE: 0.396675: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.176019, RMSE: 0.406527: 100%|██████████| 10/10 [00:03<00:00,  2.52it/s]

Val RMSE: 0.406527, Val Loss: 0.165819
Train RMSE: 0.396675, Average Train Loss: 0.157841, Val RMSE: 0.406527, Average Valid Loss: 0.165819

Epoch 59/100: 


Epoch 59, Batch 101/101,Loss: 0.191971, RMSE: 0.396191: 100%|██████████| 101/101 [00:55<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.219524, RMSE: 0.414962: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]

Val RMSE: 0.414962, Val Loss: 0.174634
Train RMSE: 0.396191, Average Train Loss: 0.157300, Val RMSE: 0.414962, Average Valid Loss: 0.174634

Epoch 60/100: 


Epoch 60, Batch 101/101,Loss: 0.124531, RMSE: 0.396778: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.284849, RMSE: 0.431199: 100%|██████████| 10/10 [00:04<00:00,  2.10it/s]

Val RMSE: 0.431199, Val Loss: 0.191033
Train RMSE: 0.396778, Average Train Loss: 0.157120, Val RMSE: 0.431199, Average Valid Loss: 0.191033

Epoch 61/100: 


Epoch 61, Batch 101/101,Loss: 0.097632, RMSE: 0.396647: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.206450, RMSE: 0.395542: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.395542, Val Loss: 0.159032
Train RMSE: 0.396647, Average Train Loss: 0.156761, Val RMSE: 0.395542, Average Valid Loss: 0.159032

Epoch 62/100: 


Epoch 62, Batch 101/101,Loss: 0.033207, RMSE: 0.394931: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.176978, RMSE: 0.391315: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.391315, Val Loss: 0.154357
Train RMSE: 0.394931, Average Train Loss: 0.154803, Val RMSE: 0.391315, Average Valid Loss: 0.154357

Epoch 63/100: 


Epoch 63, Batch 101/101,Loss: 0.274138, RMSE: 0.394670: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.251629, RMSE: 0.422771: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.422771, Val Loss: 0.182494
Train RMSE: 0.394670, Average Train Loss: 0.156891, Val RMSE: 0.422771, Average Valid Loss: 0.182494

Epoch 64/100: 


Epoch 64, Batch 101/101,Loss: 0.031601, RMSE: 0.393857: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.169198, RMSE: 0.389829: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.389829, Val Loss: 0.152855
Train RMSE: 0.393857, Average Train Loss: 0.153948, Val RMSE: 0.389829, Average Valid Loss: 0.152855

Epoch 65/100: 


Epoch 65, Batch 101/101,Loss: 0.061762, RMSE: 0.393026: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.271142, RMSE: 0.418636: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Val RMSE: 0.418636, Val Loss: 0.180201
Train RMSE: 0.393026, Average Train Loss: 0.153588, Val RMSE: 0.418636, Average Valid Loss: 0.180201

Epoch 66/100: 

Epoch 66, Batch 101/101,Loss: 0.380559, RMSE: 0.393028: 100%|██████████| 101/101 [00:57<00:00,  1.76it/s]
Val Batch 10/10,Loss: 0.195547, RMSE: 0.404703: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]

Val RMSE: 0.404703, Val Loss: 0.165423
Train RMSE: 0.393028, Average Train Loss: 0.156622, Val RMSE: 0.404703, Average Valid Loss: 0.165423

Epoch 67/100: 


Epoch 67, Batch 101/101,Loss: 0.056741, RMSE: 0.393112: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.168771, RMSE: 0.388891: 100%|██████████| 10/10 [00:03<00:00,  2.50it/s]


Val RMSE: 0.388891, Val Loss: 0.152140
Train RMSE: 0.393112, Average Train Loss: 0.153607, Val RMSE: 0.388891, Average Valid Loss: 0.152140

Epoch 68/100: 

Epoch 68, Batch 101/101,Loss: 0.087547, RMSE: 0.391942: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.273240, RMSE: 0.424678: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]


Val RMSE: 0.424678, Val Loss: 0.185141
Train RMSE: 0.391942, Average Train Loss: 0.152990, Val RMSE: 0.424678, Average Valid Loss: 0.185141

Epoch 69/100: 

Epoch 69, Batch 101/101,Loss: 0.454780, RMSE: 0.391972: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.231470, RMSE: 0.401447: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Val RMSE: 0.401447, Val Loss: 0.164785
Train RMSE: 0.391972, Average Train Loss: 0.156507, Val RMSE: 0.401447, Average Valid Loss: 0.164785

Epoch 70/100: 

Epoch 70, Batch 101/101,Loss: 0.149365, RMSE: 0.391313: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.267865, RMSE: 0.428347: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.428347, Val Loss: 0.187832
Train RMSE: 0.391313, Average Train Loss: 0.153090, Val RMSE: 0.428347, Average Valid Loss: 0.187832

Epoch 71/100: 


Epoch 71, Batch 101/101,Loss: 0.256177, RMSE: 0.392125: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.171135, RMSE: 0.387168: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.387168, Val Loss: 0.150994
Train RMSE: 0.392125, Average Train Loss: 0.154736, Val RMSE: 0.387168, Average Valid Loss: 0.150994

Epoch 72/100: 


Epoch 72, Batch 101/101,Loss: 0.272833, RMSE: 0.390611: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.162235, RMSE: 0.398511: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.398511, Val Loss: 0.158988
Train RMSE: 0.390611, Average Train Loss: 0.153721, Val RMSE: 0.398511, Average Valid Loss: 0.158988

Epoch 73/100: 


Epoch 73, Batch 101/101,Loss: 0.192054, RMSE: 0.392116: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.247733, RMSE: 0.410906: 100%|██████████| 10/10 [00:04<00:00,  2.46it/s]

Val RMSE: 0.410906, Val Loss: 0.172911
Train RMSE: 0.392116, Average Train Loss: 0.154120, Val RMSE: 0.410906, Average Valid Loss: 0.172911

Epoch 74/100: 


Epoch 74, Batch 101/101,Loss: 0.076775, RMSE: 0.389840: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.255465, RMSE: 0.417569: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]


Val RMSE: 0.417569, Val Loss: 0.178545
Train RMSE: 0.389840, Average Train Loss: 0.151260, Val RMSE: 0.417569, Average Valid Loss: 0.178545

Epoch 75/100: 

Epoch 75, Batch 101/101,Loss: 0.225343, RMSE: 0.390248: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.269468, RMSE: 0.419948: 100%|██████████| 10/10 [00:04<00:00,  2.08it/s]


Val RMSE: 0.419948, Val Loss: 0.181158
Train RMSE: 0.390248, Average Train Loss: 0.152988, Val RMSE: 0.419948, Average Valid Loss: 0.181158

Epoch 76/100: 

Epoch 76, Batch 101/101,Loss: 0.281701, RMSE: 0.389601: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.227541, RMSE: 0.410247: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.410247, Val Loss: 0.171357
Train RMSE: 0.389601, Average Train Loss: 0.153025, Val RMSE: 0.410247, Average Valid Loss: 0.171357

Epoch 77/100: 


Epoch 77, Batch 101/101,Loss: 0.125580, RMSE: 0.390788: 100%|██████████| 101/101 [00:58<00:00,  1.74it/s]
Val Batch 10/10,Loss: 0.278047, RMSE: 0.431856: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.431856, Val Loss: 0.191220
Train RMSE: 0.390788, Average Train Loss: 0.152457, Val RMSE: 0.431856, Average Valid Loss: 0.191220

Epoch 78/100: 


Epoch 78, Batch 101/101,Loss: 0.188754, RMSE: 0.389185: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.161752, RMSE: 0.402932: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]


Val RMSE: 0.402932, Val Loss: 0.162323
Train RMSE: 0.389185, Average Train Loss: 0.151820, Val RMSE: 0.402932, Average Valid Loss: 0.162323

Epoch 79/100: 

Epoch 79, Batch 101/101,Loss: 0.096406, RMSE: 0.390406: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.268146, RMSE: 0.426968: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]


Val RMSE: 0.426968, Val Loss: 0.186728
Train RMSE: 0.390406, Average Train Loss: 0.151884, Val RMSE: 0.426968, Average Valid Loss: 0.186728

Epoch 80/100: 

Epoch 80, Batch 101/101,Loss: 0.053354, RMSE: 0.388721: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.161950, RMSE: 0.411668: 100%|██████████| 10/10 [00:04<00:00,  2.43it/s]


Val RMSE: 0.411668, Val Loss: 0.169083
Train RMSE: 0.388721, Average Train Loss: 0.150174, Val RMSE: 0.411668, Average Valid Loss: 0.169083

Epoch 81/100: 

Epoch 81, Batch 101/101,Loss: 0.080630, RMSE: 0.388013: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.271506, RMSE: 0.422878: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]

Val RMSE: 0.422878, Val Loss: 0.183604
Train RMSE: 0.388013, Average Train Loss: 0.149889, Val RMSE: 0.422878, Average Valid Loss: 0.183604

Epoch 82/100: 


Epoch 82, Batch 101/101,Loss: 0.254798, RMSE: 0.388262: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.177430, RMSE: 0.417062: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]


Val RMSE: 0.417062, Val Loss: 0.174120
Train RMSE: 0.388262, Average Train Loss: 0.151738, Val RMSE: 0.417062, Average Valid Loss: 0.174120

Epoch 83/100: 

Epoch 83, Batch 101/101,Loss: 0.133728, RMSE: 0.389080: 100%|██████████| 101/101 [00:56<00:00,  1.78it/s]
Val Batch 10/10,Loss: 0.240091, RMSE: 0.426130: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]


Val RMSE: 0.426130, Val Loss: 0.184604
Train RMSE: 0.389080, Average Train Loss: 0.151215, Val RMSE: 0.426130, Average Valid Loss: 0.184604

Epoch 84/100: 

Epoch 84, Batch 101/101,Loss: 0.087896, RMSE: 0.387515: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.192607, RMSE: 0.408323: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]


Val RMSE: 0.408323, Val Loss: 0.168062
Train RMSE: 0.387515, Average Train Loss: 0.149576, Val RMSE: 0.408323, Average Valid Loss: 0.168062

Epoch 85/100: 

Epoch 85, Batch 101/101,Loss: 0.101545, RMSE: 0.388821: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.258313, RMSE: 0.424682: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]


Val RMSE: 0.424682, Val Loss: 0.184375
Train RMSE: 0.388821, Average Train Loss: 0.150709, Val RMSE: 0.424682, Average Valid Loss: 0.184375

Epoch 86/100: 

Epoch 86, Batch 101/101,Loss: 0.095700, RMSE: 0.387650: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.175613, RMSE: 0.393683: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.393683, Val Loss: 0.156050
Train RMSE: 0.387650, Average Train Loss: 0.149753, Val RMSE: 0.393683, Average Valid Loss: 0.156050

Epoch 87/100: 


Epoch 87, Batch 101/101,Loss: 0.037362, RMSE: 0.389073: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.216987, RMSE: 0.407151: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]

Val RMSE: 0.407151, Val Loss: 0.168413
Train RMSE: 0.389073, Average Train Loss: 0.150293, Val RMSE: 0.407151, Average Valid Loss: 0.168413

Epoch 88/100: 


Epoch 88, Batch 101/101,Loss: 0.140477, RMSE: 0.388243: 100%|██████████| 101/101 [00:57<00:00,  1.76it/s]
Val Batch 10/10,Loss: 0.175577, RMSE: 0.397708: 100%|██████████| 10/10 [00:04<00:00,  2.50it/s]

Val RMSE: 0.397708, Val Loss: 0.159069
Train RMSE: 0.388243, Average Train Loss: 0.150635, Val RMSE: 0.397708, Average Valid Loss: 0.159069

Epoch 89/100: 


Epoch 89, Batch 101/101,Loss: 0.411254, RMSE: 0.387650: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.202800, RMSE: 0.400675: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]


Val RMSE: 0.400675, Val Loss: 0.162719
Train RMSE: 0.387650, Average Train Loss: 0.152755, Val RMSE: 0.400675, Average Valid Loss: 0.162719

Epoch 90/100: 

Epoch 90, Batch 101/101,Loss: 0.113755, RMSE: 0.387115: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.177232, RMSE: 0.398692: 100%|██████████| 10/10 [00:04<00:00,  2.09it/s]

Val RMSE: 0.398692, Val Loss: 0.159898
Train RMSE: 0.387115, Average Train Loss: 0.149514, Val RMSE: 0.398692, Average Valid Loss: 0.159898

Epoch 91/100: 


Epoch 91, Batch 101/101,Loss: 0.033758, RMSE: 0.387164: 100%|██████████| 101/101 [00:55<00:00,  1.83it/s]
Val Batch 10/10,Loss: 0.171922, RMSE: 0.394162: 100%|██████████| 10/10 [00:03<00:00,  2.50it/s]

Val RMSE: 0.394162, Val Loss: 0.156218
Train RMSE: 0.387164, Average Train Loss: 0.148791, Val RMSE: 0.394162, Average Valid Loss: 0.156218

Epoch 92/100: 


Epoch 92, Batch 101/101,Loss: 0.157000, RMSE: 0.387791: 100%|██████████| 101/101 [00:56<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.177936, RMSE: 0.421978: 100%|██████████| 10/10 [00:04<00:00,  2.34it/s]

Val RMSE: 0.421978, Val Loss: 0.178059
Train RMSE: 0.387791, Average Train Loss: 0.150444, Val RMSE: 0.421978, Average Valid Loss: 0.178059

Epoch 93/100: 


Epoch 93, Batch 101/101,Loss: 0.160111, RMSE: 0.387354: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.184518, RMSE: 0.406839: 100%|██████████| 10/10 [00:05<00:00,  1.74it/s]

Val RMSE: 0.406839, Val Loss: 0.166497
Train RMSE: 0.387354, Average Train Loss: 0.150139, Val RMSE: 0.406839, Average Valid Loss: 0.166497

Epoch 94/100: 


Epoch 94, Batch 101/101,Loss: 0.118841, RMSE: 0.386198: 100%|██████████| 101/101 [00:55<00:00,  1.80it/s]
Val Batch 10/10,Loss: 0.204185, RMSE: 0.408579: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]

Val RMSE: 0.408579, Val Loss: 0.168857
Train RMSE: 0.386198, Average Train Loss: 0.148861, Val RMSE: 0.408579, Average Valid Loss: 0.168857

Epoch 95/100: 


Epoch 95, Batch 101/101,Loss: 0.029222, RMSE: 0.387414: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.218749, RMSE: 0.416401: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.416401, Val Loss: 0.175729
Train RMSE: 0.387414, Average Train Loss: 0.148940, Val RMSE: 0.416401, Average Valid Loss: 0.175729

Epoch 96/100: 


Epoch 96, Batch 101/101,Loss: 0.105903, RMSE: 0.386352: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.268720, RMSE: 0.427651: 100%|██████████| 10/10 [00:04<00:00,  2.05it/s]

Val RMSE: 0.427651, Val Loss: 0.187311
Train RMSE: 0.386352, Average Train Loss: 0.148855, Val RMSE: 0.427651, Average Valid Loss: 0.187311

Epoch 97/100: 


Epoch 97, Batch 101/101,Loss: 0.089364, RMSE: 0.387741: 100%|██████████| 101/101 [00:55<00:00,  1.81it/s]
Val Batch 10/10,Loss: 0.254157, RMSE: 0.419233: 100%|██████████| 10/10 [00:04<00:00,  2.48it/s]

Val RMSE: 0.419233, Val Loss: 0.179799
Train RMSE: 0.387741, Average Train Loss: 0.149763, Val RMSE: 0.419233, Average Valid Loss: 0.179799

Epoch 98/100: 


Epoch 98, Batch 101/101,Loss: 0.367050, RMSE: 0.387167: 100%|██████████| 101/101 [00:56<00:00,  1.79it/s]
Val Batch 10/10,Loss: 0.178964, RMSE: 0.402174: 100%|██████████| 10/10 [00:04<00:00,  2.45it/s]

Val RMSE: 0.402174, Val Loss: 0.162632
Train RMSE: 0.387167, Average Train Loss: 0.151965, Val RMSE: 0.402174, Average Valid Loss: 0.162632

Epoch 99/100: 


Epoch 99, Batch 101/101,Loss: 0.082547, RMSE: 0.386822: 100%|██████████| 101/101 [00:57<00:00,  1.76it/s]
Val Batch 10/10,Loss: 0.174456, RMSE: 0.413538: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]

Val RMSE: 0.413538, Val Loss: 0.171191
Train RMSE: 0.386822, Average Train Loss: 0.148993, Val RMSE: 0.413538, Average Valid Loss: 0.171191

Epoch 100/100: 


Epoch 100, Batch 101/101,Loss: 0.427852, RMSE: 0.386252: 100%|██████████| 101/101 [00:55<00:00,  1.82it/s]
Val Batch 10/10,Loss: 0.179654, RMSE: 0.390877: 100%|██████████| 10/10 [00:04<00:00,  2.42it/s]

Val RMSE: 0.390877, Val Loss: 0.154171
Train RMSE: 0.386252, Average Train Loss: 0.151842, Val RMSE: 0.390877, Average Valid Loss: 0.154171






In [15]:
torch.save(model.state_dict(), 'model_weights_challenge_1.pt')

In [17]:
print(model.state_dict())

OrderedDict({'block_1.1.weight': tensor([[[[-7.5011e-02,  1.0098e-01, -1.9283e-02,  1.3473e-01,  1.3163e-01,
            1.3963e-01,  1.4651e-01,  1.0248e-02, -7.8975e-02, -4.8278e-02,
            1.4744e-02,  7.3645e-02, -2.3366e-02,  3.4903e-02,  7.8428e-02,
           -1.3044e-02, -4.8524e-02,  3.8590e-02,  1.8508e-02, -7.1009e-03,
           -1.0633e-01,  5.7422e-02, -3.1136e-02, -5.1034e-02,  4.0973e-02,
           -5.9263e-02,  5.1305e-02, -5.2801e-02,  4.8022e-02,  3.3509e-02,
            4.7258e-02, -5.4056e-02,  2.7368e-02, -1.9561e-02,  8.9956e-02,
            9.6563e-02, -2.4960e-02,  8.8679e-02, -4.7450e-04, -6.8953e-02,
            6.8167e-02,  4.8882e-02,  9.2060e-02, -2.1475e-02,  7.8569e-02,
           -1.0743e-01, -9.2860e-02,  3.9631e-03, -3.3669e-02, -2.3431e-02,
           -1.0935e-01, -5.2011e-02,  2.1382e-02, -5.9367e-03,  7.0558e-02,
           -2.0812e-02,  1.2929e-01,  1.3423e-01,  1.6437e-01,  1.5494e-01,
            9.1586e-02, -2.8003e-02, -2.3535e-02,  1.42