<a href="https://colab.research.google.com/github/nhnain/eegchallenge/blob/main/Challenge1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#1. Preprocessing

In [1]:
!pip install braindecode
!pip install eegdash

Collecting braindecode
  Downloading braindecode-1.2.0-py3-none-any.whl.metadata (7.1 kB)
Collecting mne>=1.10.0 (from braindecode)
  Downloading mne-1.10.1-py3-none-any.whl.metadata (20 kB)
Collecting mne_bids>=0.16 (from braindecode)
  Downloading mne_bids-0.17.0-py3-none-any.whl.metadata (7.3 kB)
Collecting skorch>=1.2.0 (from braindecode)
  Downloading skorch-1.2.0-py3-none-any.whl.metadata (11 kB)
Collecting torchinfo~=1.8 (from braindecode)
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting wfdb (from braindecode)
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting linear_attention_transformer (from braindecode)
  Downloading linear_attention_transformer-0.19.1-py3-none-any.whl.metadata (787 bytes)
Collecting docstring_inheritance (from braindecode)
  Downloading docstring_inheritance-2.2.2-py3-none-any.whl.metadata (11 kB)
Collecting axial-positional-embedding (from linear_attention_transformer->braindecode)
  Downloading axial_position

In [8]:
from pathlib import Path
import math
import os
import random
from joblib import Parallel, delayed

import torch
from torch.utils.data import DataLoader
from torch import optim
from torch.nn.functional import l1_loss
from braindecode.preprocessing import preprocess, Preprocessor, create_fixed_length_windows, create_windows_from_events
from braindecode.datasets.base import EEGWindowsDataset, BaseConcatDataset, BaseDataset
from braindecode.models import EEGNeX
from eegdash import EEGChallengeDataset
from typing import Optional
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import annotate_trials_with_target, add_aux_anchors, add_extras_columns, keep_only_recordings_with

In [9]:
# Identify whether a CUDA-enabled GPU is available
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    msg ='CUDA-enabled GPU found. Training should be faster.'
else:
    msg = (
        "No GPU found. Training will be carried out on CPU, which might be "
        "slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
        " clicking\n`Runtime/Change runtime type` in the top bar menu, then "
        "selecting \'T4 GPU\'\nunder \'Hardware accelerator\'."
    )
print(msg)

CUDA-enabled GPU found. Training should be faster.


In [10]:
release_list = ['R1', 'R2', 'R3', 'R4', 'R5', 'R6',
                'R7', 'R8', 'R9', 'R10', 'R11']
train_set = []
valid_set = []
test_set = []

DATA_DIR = Path('data')
DATA_DIR.mkdir(parents=True, exist_ok=True)


In [12]:
for release in tqdm(release_list):
  dataset_ccd = EEGChallengeDataset(
      task = 'contrastChangeDetection',
      release = release,
      cache_dir = DATA_DIR,
      mini = True
  )
  raws = Parallel(n_jobs=-1)(
      delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets
  )

  EPOCH_LEN_S = 2.0
  SFREQ = 100

  transformation_offline = [
      Preprocessor(
          annotate_trials_with_target,
          target_field = 'rt_from_stimulus', epoch_length =  EPOCH_LEN_S,
          require_stimulus = True,
          require_response = True,
          apply_on_array = False,
      ),
      Preprocessor(add_aux_anchors, apply_on_array=False),
  ]
  preprocess (dataset_ccd, transformation_offline, n_jobs=1)

  ANCHOR = 'stimulus_anchor'
  SHIFT_AFTER_STIM = 0.5
  WINDOW_LEN = 2.0

  dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

  single_windows = create_windows_from_events(
      dataset,
      mapping={ANCHOR: 0},
      trial_start_offset_samples = int(SHIFT_AFTER_STIM * SFREQ),
      trial_stop_offset_samples = int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),
      window_size_samples = int(EPOCH_LEN_S * SFREQ),
      window_stride_samples = SFREQ,
      preload = True,
  )

  single_windows = add_extras_columns(
      single_windows,
      dataset,
      desc = ANCHOR, keys = ("target", "rt_from_stimulus", "rt_from_trialstart",
            "stimulus_onset", "response_onset", "correct", "response_type"))

  meta_information = single_windows.get_metadata()

  subjects = meta_information['subject'].unique()
  sub_rm = ["NDARWV769JM7", "NDARME789TD2", "NDARUA442ZVF", "NDARJP304NK1",
            "NDARTY128YLU", "NDARDW550GU6", "NDARLD243KRE", "NDARUJ292JXV", "NDARBA381JGH"]
  subjects = [s for s in subjects if s not in sub_rm]

  subject_split = single_windows.split('subject')

  if release != 'R5':
    for s in subject_split:
      train_set.append(subject_split[s])
  else:
    for s in subject_split:
      valid_set.append(subject_split[s])


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

100%|██████████| 11/11 [34:50<00:00, 190.07s/it]


In [13]:
train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)

print(f'Train set: {len(train_set)}')
print(f'Valid set: {len(valid_set)}')

Train set: 12805
Valid set: 1214


#Build model

In [21]:
batch_size = 128
num_workers = 1

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [22]:
from braindecode.models.util import models_dict

names = sorted(models_dict)
w = max(len(n) for n in names)

for i in range(0, len(names), 3):
    row = names[i:i+3]
    print("  ".join(f"{n:<{w}}" for n in row))

ATCNet                  AttentionBaseNet        AttnSleep             
BDTCN                   BIOT                    CTNet                 
ContraWR                Deep4Net                DeepSleepNet          
EEGConformer            EEGITNet                EEGInceptionERP       
EEGInceptionMI          EEGMiner                EEGNeX                
EEGNet                  EEGSimpleConv           EEGTCNet              
FBCNet                  FBLightConvNet          FBMSNet               
IFNet                   Labram                  MSVTNet               
SCCNet                  SPARCNet                ShallowFBCSPNet       
SignalJEPA              SignalJEPA_Contextual   SignalJEPA_PostLocal  
SignalJEPA_PreLocal     SincShallowNet          SleepStagerBlanco2020 
SleepStagerChambon2018  SyncNet                 TIDNet                
TSception               USleep                


In [24]:
model = EEGNeX(n_chans=129,
               n_outputs=1,
               n_times=200,
               sfreq=100)

#Model training

In [25]:
lr = 1e-3
weight_decay = 1e-5
n_epochs = 100
early_stopping_patience = 50

In [34]:
def train_one_epoch(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    optimizer,
    scheduler: Optional[LRScheduler],
    epoch: int,
    device,
    print_batch_stats: bool = True,):
  model.train()

  total_loss = 0.0
  sum_sq_err = 0.0
  n_samples = 0

  progress_bar = tqdm(
      enumerate(dataloader),
      total=len(dataloader),
      disable=not print_batch_stats)
  for batch_idx, batch in progress_bar:
    X, y = batch[0], batch[1]
    X, y = X.to(device).float(), y.to(device).float()

    optimizer.zero_grad(set_to_none=True)
    preds = model(X)
    loss = loss_fn(preds, y)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    preds_flat = preds.detach().view(-1)
    y_flat = y.detach().view(-1)
    sum_sq_err += torch.sum((preds_flat - y_flat)**2).item()
    n_samples += y_flat.numel()

    if print_batch_stats:
      running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
      progress_bar.set_description(
          f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)},"
          f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}"
      )
  if scheduler is not None:
    scheduler.step()

  avg_loss = total_loss/len(dataloader)
  rmse = (sum_sq_err/max(n_samples, 1))** 0.5
  return avg_loss, rmse

In [35]:
with torch.inference_mode():
  def valid_model(
      dataloader: DataLoader,
      model: Module,
      loss_fn,
      device,
      print_batch_stats: bool = True):

    model.eval()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_batches = len(dataloader)
    n_samples = 0.0

    iterator = tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        disable=not print_batch_stats)

    for batch_idx, batch in iterator:
      X, y = batch[0], batch[1]
      X, y = X.to(device).float(), y.to(device).float()

      preds = model(X)
      batch_loss = loss_fn(preds, y).item()
      total_loss += batch_loss

      preds_flat = preds.detach().view(-1)
      y_flat = y.detach().view(-1)
      sum_sq_err += torch.sum((preds_flat - y_flat)**2).item()
      n_samples += y_flat.numel()

      if print_batch_stats:
        running_rmse = (sum_sq_err/max(n_samples,1)) ** 0.5
        iterator.set_description(
            f"Val Batch {batch_idx + 1}/{len(dataloader)},"
            f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
        )

    avg_loss = total_loss / n_batches if n_batches else float("nan")
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5

    print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}")
    return avg_loss, rmse

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = n_epochs -1)
loss_fn = torch.nn.MSELoss()

patience = 5
min_delta = 1e-4
best_rmse = float('inf')
epochs_no_improve = 0
best_state, best_epoch = None, None

model.to(device)
for epoch in range(1, n_epochs+1):
  print(f'Epoch {epoch}/{n_epochs}: ', end='')

  train_loss, train_rmse = train_one_epoch(
      train_loader, model, loss_fn, optimizer, scheduler, epoch, device
  )
  valid_loss, valid_rmse = valid_model(
      valid_loader, model, loss_fn, device
  )
  print(
      f'Train RMSE: {train_rmse:.6f}, '
      f'Average Train Loss: {train_loss:.6f}, '
      f'Val RMSE: {valid_rmse:.6f}, '
      f'Average Valid Loss: {valid_loss:.6f}\n'
  )
if best_state is not None:
  model.load_state_dict(best_state)

Epoch 1/100: 

Epoch 1, Batch 101/101,Loss: 0.184482, RMSE: 0.455785: 100%|██████████| 101/101 [00:49<00:00,  2.02it/s]
Val Batch 10/10,Loss: 0.267905, RMSE: 0.405653: 100%|██████████| 10/10 [00:03<00:00,  2.98it/s]

Val RMSE: 0.405653, Val Loss: 0.169883

Train RMSE: 0.455785, Average Train Loss: 0.207519, Val RMSE: 0.405653, Average Valid Loss: 0.169883

Epoch 2/100: 


Epoch 2, Batch 101/101,Loss: 0.158314, RMSE: 0.448486: 100%|██████████| 101/101 [00:49<00:00,  2.03it/s]
Val Batch 10/10,Loss: 0.268793, RMSE: 0.413763: 100%|██████████| 10/10 [00:03<00:00,  2.96it/s]

Val RMSE: 0.413763, Val Loss: 0.176232

Train RMSE: 0.448486, Average Train Loss: 0.200732, Val RMSE: 0.413763, Average Valid Loss: 0.176232

Epoch 3/100: 


Epoch 3, Batch 101/101,Loss: 0.268585, RMSE: 0.445947: 100%|██████████| 101/101 [00:49<00:00,  2.02it/s]
Val Batch 10/10,Loss: 0.247082, RMSE: 0.394169: 100%|██████████| 10/10 [00:03<00:00,  2.52it/s]

Val RMSE: 0.394169, Val Loss: 0.160098

Train RMSE: 0.445947, Average Train Loss: 0.199532, Val RMSE: 0.394169, Average Valid Loss: 0.160098

Epoch 4/100: 


Epoch 4, Batch 101/101,Loss: 0.203967, RMSE: 0.440531: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.283080, RMSE: 0.416578: 100%|██████████| 10/10 [00:04<00:00,  2.39it/s]

Val RMSE: 0.416578, Val Loss: 0.179185

Train RMSE: 0.440531, Average Train Loss: 0.194162, Val RMSE: 0.416578, Average Valid Loss: 0.179185

Epoch 5/100: 


Epoch 5, Batch 101/101,Loss: 0.096945, RMSE: 0.439175: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.327549, RMSE: 0.445448: 100%|██████████| 10/10 [00:03<00:00,  3.00it/s]

Val RMSE: 0.445448, Val Loss: 0.205082

Train RMSE: 0.439175, Average Train Loss: 0.191962, Val RMSE: 0.445448, Average Valid Loss: 0.205082

Epoch 6/100: 


Epoch 6, Batch 101/101,Loss: 0.096096, RMSE: 0.432810: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.173509, RMSE: 0.410534: 100%|██████████| 10/10 [00:03<00:00,  3.02it/s]

Val RMSE: 0.410534, Val Loss: 0.168794

Train RMSE: 0.432810, Average Train Loss: 0.186457, Val RMSE: 0.410534, Average Valid Loss: 0.168794

Epoch 7/100: 


Epoch 7, Batch 101/101,Loss: 0.308443, RMSE: 0.434887: 100%|██████████| 101/101 [00:49<00:00,  2.03it/s]
Val Batch 10/10,Loss: 0.311770, RMSE: 0.439666: 100%|██████████| 10/10 [00:03<00:00,  2.98it/s]

Val RMSE: 0.439666, Val Loss: 0.199415

Train RMSE: 0.434887, Average Train Loss: 0.190262, Val RMSE: 0.439666, Average Valid Loss: 0.199415

Epoch 8/100: 


Epoch 8, Batch 101/101,Loss: 0.101653, RMSE: 0.431008: 100%|██████████| 101/101 [00:49<00:00,  2.03it/s]
Val Batch 10/10,Loss: 0.238786, RMSE: 0.416638: 100%|██████████| 10/10 [00:03<00:00,  2.95it/s]

Val RMSE: 0.416638, Val Loss: 0.176949

Train RMSE: 0.431008, Average Train Loss: 0.184967, Val RMSE: 0.416638, Average Valid Loss: 0.176949

Epoch 9/100: 


Epoch 9, Batch 101/101,Loss: 0.174943, RMSE: 0.430016: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.322919, RMSE: 0.437641: 100%|██████████| 10/10 [00:04<00:00,  2.43it/s]

Val RMSE: 0.437641, Val Loss: 0.198305

Train RMSE: 0.430016, Average Train Loss: 0.184819, Val RMSE: 0.437641, Average Valid Loss: 0.198305

Epoch 10/100: 


Epoch 10, Batch 101/101,Loss: 0.005644, RMSE: 0.425000: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.263331, RMSE: 0.418479: 100%|██████████| 10/10 [00:04<00:00,  2.38it/s]

Val RMSE: 0.418479, Val Loss: 0.179673

Train RMSE: 0.425000, Average Train Loss: 0.178960, Val RMSE: 0.418479, Average Valid Loss: 0.179673

Epoch 11/100: 


Epoch 11, Batch 101/101,Loss: 0.298239, RMSE: 0.426737: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.335156, RMSE: 0.451187: 100%|██████████| 10/10 [00:03<00:00,  3.01it/s]

Val RMSE: 0.451187, Val Loss: 0.210355

Train RMSE: 0.426737, Average Train Loss: 0.183209, Val RMSE: 0.451187, Average Valid Loss: 0.210355

Epoch 12/100: 


Epoch 12, Batch 101/101,Loss: 0.186346, RMSE: 0.423823: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.183383, RMSE: 0.393433: 100%|██████████| 10/10 [00:03<00:00,  2.93it/s]

Val RMSE: 0.393433, Val Loss: 0.156264

Train RMSE: 0.423823, Average Train Loss: 0.179690, Val RMSE: 0.393433, Average Valid Loss: 0.156264

Epoch 13/100: 


Epoch 13, Batch 101/101,Loss: 0.726453, RMSE: 0.422737: 100%|██████████| 101/101 [00:49<00:00,  2.03it/s]
Val Batch 10/10,Loss: 0.224362, RMSE: 0.418283: 100%|██████████| 10/10 [00:03<00:00,  2.95it/s]

Val RMSE: 0.418283, Val Loss: 0.177508

Train RMSE: 0.422737, Average Train Loss: 0.183918, Val RMSE: 0.418283, Average Valid Loss: 0.177508

Epoch 14/100: 


Epoch 14, Batch 101/101,Loss: 0.124685, RMSE: 0.420769: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.332202, RMSE: 0.446549: 100%|██████████| 10/10 [00:03<00:00,  3.00it/s]

Val RMSE: 0.446549, Val Loss: 0.206254

Train RMSE: 0.420769, Average Train Loss: 0.176548, Val RMSE: 0.446549, Average Valid Loss: 0.206254

Epoch 15/100: 


Epoch 15, Batch 101/101,Loss: 0.590182, RMSE: 0.419737: 100%|██████████| 101/101 [00:49<00:00,  2.03it/s]
Val Batch 10/10,Loss: 0.271161, RMSE: 0.416932: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s]

Val RMSE: 0.416932, Val Loss: 0.178850

Train RMSE: 0.419737, Average Train Loss: 0.180118, Val RMSE: 0.416932, Average Valid Loss: 0.178850

Epoch 16/100: 


Epoch 16, Batch 101/101,Loss: 0.103534, RMSE: 0.417469: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.230421, RMSE: 0.434488: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.434488, Val Loss: 0.190927

Train RMSE: 0.417469, Average Train Loss: 0.173607, Val RMSE: 0.434488, Average Valid Loss: 0.190927

Epoch 17/100: 


Epoch 17, Batch 101/101,Loss: 0.367536, RMSE: 0.415456: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.333714, RMSE: 0.453358: 100%|██████████| 10/10 [00:03<00:00,  2.91it/s]

Val RMSE: 0.453358, Val Loss: 0.212143

Train RMSE: 0.415456, Average Train Loss: 0.174459, Val RMSE: 0.453358, Average Valid Loss: 0.212143

Epoch 18/100: 


Epoch 18, Batch 101/101,Loss: 0.196691, RMSE: 0.414274: 100%|██████████| 101/101 [00:49<00:00,  2.03it/s]
Val Batch 10/10,Loss: 0.217745, RMSE: 0.420542: 100%|██████████| 10/10 [00:03<00:00,  2.96it/s]

Val RMSE: 0.420542, Val Loss: 0.178964

Train RMSE: 0.414274, Average Train Loss: 0.171862, Val RMSE: 0.420542, Average Valid Loss: 0.178964

Epoch 19/100: 


Epoch 19, Batch 101/101,Loss: 0.254773, RMSE: 0.413639: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.308400, RMSE: 0.432493: 100%|██████████| 10/10 [00:03<00:00,  3.01it/s]

Val RMSE: 0.432493, Val Loss: 0.193307

Train RMSE: 0.413639, Average Train Loss: 0.171893, Val RMSE: 0.432493, Average Valid Loss: 0.193307

Epoch 20/100: 


Epoch 20, Batch 101/101,Loss: 0.091229, RMSE: 0.413669: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.256197, RMSE: 0.414064: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.414064, Val Loss: 0.175819

Train RMSE: 0.413669, Average Train Loss: 0.170362, Val RMSE: 0.414064, Average Valid Loss: 0.175819

Epoch 21/100: 


Epoch 21, Batch 101/101,Loss: 0.094020, RMSE: 0.413666: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.269285, RMSE: 0.412989: 100%|██████████| 10/10 [00:04<00:00,  2.44it/s]

Val RMSE: 0.412989, Val Loss: 0.175650

Train RMSE: 0.413666, Average Train Loss: 0.170386, Val RMSE: 0.412989, Average Valid Loss: 0.175650

Epoch 22/100: 


Epoch 22, Batch 101/101,Loss: 0.060953, RMSE: 0.408646: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.244124, RMSE: 0.410281: 100%|██████████| 10/10 [00:03<00:00,  3.01it/s]

Val RMSE: 0.410281, Val Loss: 0.172238

Train RMSE: 0.408646, Average Train Loss: 0.165982, Val RMSE: 0.410281, Average Valid Loss: 0.172238

Epoch 23/100: 


Epoch 23, Batch 101/101,Loss: 0.046396, RMSE: 0.407486: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.271846, RMSE: 0.415987: 100%|██████████| 10/10 [00:03<00:00,  3.01it/s]

Val RMSE: 0.415987, Val Loss: 0.178140

Train RMSE: 0.407486, Average Train Loss: 0.164907, Val RMSE: 0.415987, Average Valid Loss: 0.178140

Epoch 24/100: 


Epoch 24, Batch 101/101,Loss: 0.134998, RMSE: 0.407936: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.270400, RMSE: 0.428374: 100%|██████████| 10/10 [00:03<00:00,  2.99it/s]

Val RMSE: 0.428374, Val Loss: 0.187985

Train RMSE: 0.407936, Average Train Loss: 0.166113, Val RMSE: 0.428374, Average Valid Loss: 0.187985

Epoch 25/100: 


Epoch 25, Batch 101/101,Loss: 0.025919, RMSE: 0.409091: 100%|██████████| 101/101 [00:48<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.178835, RMSE: 0.413517: 100%|██████████| 10/10 [00:03<00:00,  3.08it/s]

Val RMSE: 0.413517, Val Loss: 0.171401

Train RMSE: 0.409091, Average Train Loss: 0.166010, Val RMSE: 0.413517, Average Valid Loss: 0.171401

Epoch 26/100: 


Epoch 26, Batch 101/101,Loss: 0.211720, RMSE: 0.405718: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.282438, RMSE: 0.422061: 100%|██████████| 10/10 [00:03<00:00,  2.78it/s]

Val RMSE: 0.422061, Val Loss: 0.183513

Train RMSE: 0.405718, Average Train Loss: 0.165055, Val RMSE: 0.422061, Average Valid Loss: 0.183513

Epoch 27/100: 


Epoch 27, Batch 101/101,Loss: 0.138450, RMSE: 0.404643: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.159856, RMSE: 0.423167: 100%|██████████| 10/10 [00:04<00:00,  2.43it/s]

Val RMSE: 0.423167, Val Loss: 0.178079

Train RMSE: 0.404643, Average Train Loss: 0.163496, Val RMSE: 0.423167, Average Valid Loss: 0.178079

Epoch 28/100: 


Epoch 28, Batch 101/101,Loss: 0.071854, RMSE: 0.403819: 100%|██████████| 101/101 [00:48<00:00,  2.10it/s]
Val Batch 10/10,Loss: 0.276083, RMSE: 0.415649: 100%|██████████| 10/10 [00:03<00:00,  2.60it/s]

Val RMSE: 0.415649, Val Loss: 0.178092

Train RMSE: 0.403819, Average Train Loss: 0.162202, Val RMSE: 0.415649, Average Valid Loss: 0.178092

Epoch 29/100: 


Epoch 29, Batch 101/101,Loss: 0.169458, RMSE: 0.402997: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.195000, RMSE: 0.426086: 100%|██████████| 10/10 [00:03<00:00,  3.01it/s]

Val RMSE: 0.426086, Val Loss: 0.182243

Train RMSE: 0.402997, Average Train Loss: 0.162474, Val RMSE: 0.426086, Average Valid Loss: 0.182243

Epoch 30/100: 


Epoch 30, Batch 101/101,Loss: 0.090492, RMSE: 0.402474: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.228581, RMSE: 0.405917: 100%|██████████| 10/10 [00:03<00:00,  3.04it/s]

Val RMSE: 0.405917, Val Loss: 0.168059

Train RMSE: 0.402474, Average Train Loss: 0.161305, Val RMSE: 0.405917, Average Valid Loss: 0.168059

Epoch 31/100: 


Epoch 31, Batch 101/101,Loss: 0.074509, RMSE: 0.400151: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.152430, RMSE: 0.399589: 100%|██████████| 10/10 [00:03<00:00,  3.03it/s]

Val RMSE: 0.399589, Val Loss: 0.159298

Train RMSE: 0.400151, Average Train Loss: 0.159306, Val RMSE: 0.399589, Average Valid Loss: 0.159298

Epoch 32/100: 


Epoch 32, Batch 101/101,Loss: 0.085964, RMSE: 0.400535: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.165261, RMSE: 0.399607: 100%|██████████| 10/10 [00:03<00:00,  3.08it/s]

Val RMSE: 0.399607, Val Loss: 0.159973

Train RMSE: 0.400535, Average Train Loss: 0.159720, Val RMSE: 0.399607, Average Valid Loss: 0.159973

Epoch 33/100: 


Epoch 33, Batch 101/101,Loss: 0.475071, RMSE: 0.400609: 100%|██████████| 101/101 [00:48<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.263850, RMSE: 0.417393: 100%|██████████| 10/10 [00:03<00:00,  3.05it/s]

Val RMSE: 0.417393, Val Loss: 0.178838

Train RMSE: 0.400609, Average Train Loss: 0.163481, Val RMSE: 0.417393, Average Valid Loss: 0.178838

Epoch 34/100: 


Epoch 34, Batch 101/101,Loss: 0.712397, RMSE: 0.400117: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.253400, RMSE: 0.416147: 100%|██████████| 10/10 [00:04<00:00,  2.39it/s]

Val RMSE: 0.416147, Val Loss: 0.177315

Train RMSE: 0.400117, Average Train Loss: 0.165349, Val RMSE: 0.416147, Average Valid Loss: 0.177315

Epoch 35/100: 


Epoch 35, Batch 101/101,Loss: 0.101023, RMSE: 0.400386: 100%|██████████| 101/101 [00:49<00:00,  2.02it/s]
Val Batch 10/10,Loss: 0.222864, RMSE: 0.403369: 100%|██████████| 10/10 [00:03<00:00,  2.64it/s]

Val RMSE: 0.403369, Val Loss: 0.165808

Train RMSE: 0.400386, Average Train Loss: 0.159745, Val RMSE: 0.403369, Average Valid Loss: 0.165808

Epoch 36/100: 


Epoch 36, Batch 101/101,Loss: 0.050287, RMSE: 0.396174: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.249914, RMSE: 0.404315: 100%|██████████| 10/10 [00:03<00:00,  3.00it/s]

Val RMSE: 0.404315, Val Loss: 0.167928

Train RMSE: 0.396174, Average Train Loss: 0.155939, Val RMSE: 0.404315, Average Valid Loss: 0.167928

Epoch 37/100: 


Epoch 37, Batch 101/101,Loss: 0.152162, RMSE: 0.397608: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.249674, RMSE: 0.406021: 100%|██████████| 10/10 [00:03<00:00,  2.99it/s]

Val RMSE: 0.406021, Val Loss: 0.169227

Train RMSE: 0.397608, Average Train Loss: 0.158036, Val RMSE: 0.406021, Average Valid Loss: 0.169227

Epoch 38/100: 


Epoch 38, Batch 101/101,Loss: 0.089539, RMSE: 0.395505: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.185420, RMSE: 0.406679: 100%|██████████| 10/10 [00:03<00:00,  3.08it/s]

Val RMSE: 0.406679, Val Loss: 0.166421

Train RMSE: 0.395505, Average Train Loss: 0.155788, Val RMSE: 0.406679, Average Valid Loss: 0.166421

Epoch 39/100: 


Epoch 39, Batch 101/101,Loss: 0.060698, RMSE: 0.394682: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.182011, RMSE: 0.438348: 100%|██████████| 10/10 [00:03<00:00,  3.04it/s]

Val RMSE: 0.438348, Val Loss: 0.191626

Train RMSE: 0.394682, Average Train Loss: 0.154869, Val RMSE: 0.438348, Average Valid Loss: 0.191626

Epoch 40/100: 


Epoch 40, Batch 101/101,Loss: 0.494765, RMSE: 0.395331: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.174618, RMSE: 0.407994: 100%|██████████| 10/10 [00:03<00:00,  2.78it/s]

Val RMSE: 0.407994, Val Loss: 0.166880

Train RMSE: 0.395331, Average Train Loss: 0.159507, Val RMSE: 0.407994, Average Valid Loss: 0.166880

Epoch 41/100: 


Epoch 41, Batch 101/101,Loss: 0.212578, RMSE: 0.394299: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.183645, RMSE: 0.415621: 100%|██████████| 10/10 [00:04<00:00,  2.36it/s]

Val RMSE: 0.415621, Val Loss: 0.173303

Train RMSE: 0.394299, Average Train Loss: 0.156015, Val RMSE: 0.415621, Average Valid Loss: 0.173303

Epoch 42/100: 


Epoch 42, Batch 101/101,Loss: 0.048815, RMSE: 0.393788: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.219703, RMSE: 0.430442: 100%|██████████| 10/10 [00:03<00:00,  2.92it/s]

Val RMSE: 0.430442, Val Loss: 0.187055

Train RMSE: 0.393788, Average Train Loss: 0.154058, Val RMSE: 0.430442, Average Valid Loss: 0.187055

Epoch 43/100: 


Epoch 43, Batch 101/101,Loss: 0.155854, RMSE: 0.393819: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.167734, RMSE: 0.419380: 100%|██████████| 10/10 [00:03<00:00,  3.06it/s]

Val RMSE: 0.419380, Val Loss: 0.175459

Train RMSE: 0.393819, Average Train Loss: 0.155101, Val RMSE: 0.419380, Average Valid Loss: 0.175459

Epoch 44/100: 


Epoch 44, Batch 101/101,Loss: 0.385067, RMSE: 0.392559: 100%|██████████| 101/101 [00:48<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.223836, RMSE: 0.394885: 100%|██████████| 10/10 [00:03<00:00,  3.03it/s]

Val RMSE: 0.394885, Val Loss: 0.159436

Train RMSE: 0.392559, Average Train Loss: 0.156300, Val RMSE: 0.394885, Average Valid Loss: 0.159436

Epoch 45/100: 


Epoch 45, Batch 101/101,Loss: 0.207232, RMSE: 0.392481: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.156314, RMSE: 0.410527: 100%|██████████| 10/10 [00:03<00:00,  3.05it/s]

Val RMSE: 0.410527, Val Loss: 0.167902

Train RMSE: 0.392481, Average Train Loss: 0.154548, Val RMSE: 0.410527, Average Valid Loss: 0.167902

Epoch 46/100: 


Epoch 46, Batch 101/101,Loss: 0.186849, RMSE: 0.393486: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.168790, RMSE: 0.384271: 100%|██████████| 10/10 [00:03<00:00,  3.09it/s]

Val RMSE: 0.384271, Val Loss: 0.148754

Train RMSE: 0.393486, Average Train Loss: 0.155135, Val RMSE: 0.384271, Average Valid Loss: 0.148754

Epoch 47/100: 


Epoch 47, Batch 101/101,Loss: 0.051473, RMSE: 0.390415: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.220744, RMSE: 0.400029: 100%|██████████| 10/10 [00:03<00:00,  3.03it/s]

Val RMSE: 0.400029, Val Loss: 0.163154

Train RMSE: 0.390415, Average Train Loss: 0.151464, Val RMSE: 0.400029, Average Valid Loss: 0.163154

Epoch 48/100: 


Epoch 48, Batch 101/101,Loss: 0.061833, RMSE: 0.391968: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.159157, RMSE: 0.397358: 100%|██████████| 10/10 [00:04<00:00,  2.47it/s]

Val RMSE: 0.397358, Val Loss: 0.157959

Train RMSE: 0.391968, Average Train Loss: 0.152766, Val RMSE: 0.397358, Average Valid Loss: 0.157959

Epoch 49/100: 


Epoch 49, Batch 101/101,Loss: 0.127211, RMSE: 0.391979: 100%|██████████| 101/101 [00:47<00:00,  2.11it/s]
Val Batch 10/10,Loss: 0.158706, RMSE: 0.402503: 100%|██████████| 10/10 [00:04<00:00,  2.42it/s]

Val RMSE: 0.402503, Val Loss: 0.161838

Train RMSE: 0.391979, Average Train Loss: 0.153396, Val RMSE: 0.402503, Average Valid Loss: 0.161838

Epoch 50/100: 


Epoch 50, Batch 101/101,Loss: 0.074350, RMSE: 0.391530: 100%|██████████| 101/101 [00:48<00:00,  2.09it/s]
Val Batch 10/10,Loss: 0.180871, RMSE: 0.397879: 100%|██████████| 10/10 [00:03<00:00,  2.87it/s]

Val RMSE: 0.397879, Val Loss: 0.159471

Train RMSE: 0.391530, Average Train Loss: 0.152545, Val RMSE: 0.397879, Average Valid Loss: 0.159471

Epoch 51/100: 


Epoch 51, Batch 101/101,Loss: 0.030474, RMSE: 0.389853: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.175392, RMSE: 0.405949: 100%|██████████| 10/10 [00:03<00:00,  3.11it/s]

Val RMSE: 0.405949, Val Loss: 0.165341

Train RMSE: 0.389853, Average Train Loss: 0.150829, Val RMSE: 0.405949, Average Valid Loss: 0.165341

Epoch 52/100: 


Epoch 52, Batch 101/101,Loss: 0.546034, RMSE: 0.390220: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.199513, RMSE: 0.385575: 100%|██████████| 10/10 [00:03<00:00,  3.06it/s]

Val RMSE: 0.385575, Val Loss: 0.151290

Train RMSE: 0.390220, Average Train Loss: 0.156018, Val RMSE: 0.385575, Average Valid Loss: 0.151290

Epoch 53/100: 


Epoch 53, Batch 101/101,Loss: 0.134419, RMSE: 0.390792: 100%|██████████| 101/101 [00:48<00:00,  2.07it/s]
Val Batch 10/10,Loss: 0.210520, RMSE: 0.389785: 100%|██████████| 10/10 [00:03<00:00,  3.10it/s]

Val RMSE: 0.389785, Val Loss: 0.154953

Train RMSE: 0.390792, Average Train Loss: 0.152544, Val RMSE: 0.389785, Average Valid Loss: 0.154953

Epoch 54/100: 


Epoch 54, Batch 101/101,Loss: 0.351615, RMSE: 0.389260: 100%|██████████| 101/101 [00:48<00:00,  2.08it/s]
Val Batch 10/10,Loss: 0.194611, RMSE: 0.382998: 100%|██████████| 10/10 [00:03<00:00,  3.03it/s]

Val RMSE: 0.382998, Val Loss: 0.149159

Train RMSE: 0.389260, Average Train Loss: 0.153427, Val RMSE: 0.382998, Average Valid Loss: 0.149159

Epoch 55/100: 


Epoch 55, Batch 101/101,Loss: 0.133474, RMSE: 0.388871: 100%|██████████| 101/101 [00:50<00:00,  2.01it/s]
Val Batch 10/10,Loss: 0.146856, RMSE: 0.404064: 100%|██████████| 10/10 [00:03<00:00,  2.94it/s]

Val RMSE: 0.404064, Val Loss: 0.162422

Train RMSE: 0.388871, Average Train Loss: 0.151052, Val RMSE: 0.404064, Average Valid Loss: 0.162422

Epoch 56/100: 


Epoch 56, Batch 101/101,Loss: 0.081448, RMSE: 0.388360: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.186636, RMSE: 0.390181: 100%|██████████| 10/10 [00:04<00:00,  2.41it/s]

Val RMSE: 0.390181, Val Loss: 0.154015

Train RMSE: 0.388360, Average Train Loss: 0.150163, Val RMSE: 0.390181, Average Valid Loss: 0.154015

Epoch 57/100: 


Epoch 57, Batch 101/101,Loss: 0.080013, RMSE: 0.388301: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.215254, RMSE: 0.394359: 100%|██████████| 10/10 [00:03<00:00,  2.55it/s]

Val RMSE: 0.394359, Val Loss: 0.158599

Train RMSE: 0.388301, Average Train Loss: 0.150104, Val RMSE: 0.394359, Average Valid Loss: 0.158599

Epoch 58/100: 


Epoch 58, Batch 101/101,Loss: 0.094067, RMSE: 0.388106: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.167926, RMSE: 0.405164: 100%|██████████| 10/10 [00:03<00:00,  3.04it/s]

Val RMSE: 0.405164, Val Loss: 0.164352

Train RMSE: 0.388106, Average Train Loss: 0.150088, Val RMSE: 0.405164, Average Valid Loss: 0.164352

Epoch 59/100: 


Epoch 59, Batch 101/101,Loss: 0.317447, RMSE: 0.386852: 100%|██████████| 101/101 [00:49<00:00,  2.02it/s]
Val Batch 10/10,Loss: 0.178087, RMSE: 0.383056: 100%|██████████| 10/10 [00:03<00:00,  3.06it/s]

Val RMSE: 0.383056, Val Loss: 0.148349

Train RMSE: 0.386852, Average Train Loss: 0.151251, Val RMSE: 0.383056, Average Valid Loss: 0.148349

Epoch 60/100: 


Epoch 60, Batch 101/101,Loss: 0.049300, RMSE: 0.388810: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.153590, RMSE: 0.406688: 100%|██████████| 10/10 [00:03<00:00,  3.07it/s]

Val RMSE: 0.406688, Val Loss: 0.164786

Train RMSE: 0.388810, Average Train Loss: 0.150204, Val RMSE: 0.406688, Average Valid Loss: 0.164786

Epoch 61/100: 


Epoch 61, Batch 101/101,Loss: 0.095338, RMSE: 0.386009: 100%|██████████| 101/101 [00:49<00:00,  2.04it/s]
Val Batch 10/10,Loss: 0.161963, RMSE: 0.425200: 100%|██████████| 10/10 [00:03<00:00,  3.08it/s]

Val RMSE: 0.425200, Val Loss: 0.179824

Train RMSE: 0.386009, Average Train Loss: 0.148492, Val RMSE: 0.425200, Average Valid Loss: 0.179824

Epoch 62/100: 


Epoch 62, Batch 101/101,Loss: 0.709278, RMSE: 0.387186: 100%|██████████| 101/101 [00:49<00:00,  2.06it/s]
Val Batch 10/10,Loss: 0.176945, RMSE: 0.385901: 100%|██████████| 10/10 [00:03<00:00,  2.58it/s]

Val RMSE: 0.385901, Val Loss: 0.150365

Train RMSE: 0.387186, Average Train Loss: 0.155235, Val RMSE: 0.385901, Average Valid Loss: 0.150365

Epoch 63/100: 


Epoch 63, Batch 101/101,Loss: 0.171952, RMSE: 0.388535: 100%|██████████| 101/101 [00:48<00:00,  2.09it/s]
Val Batch 10/10,Loss: 0.210979, RMSE: 0.395534: 100%|██████████| 10/10 [00:04<00:00,  2.44it/s]

Val RMSE: 0.395534, Val Loss: 0.159259

Train RMSE: 0.388535, Average Train Loss: 0.151159, Val RMSE: 0.395534, Average Valid Loss: 0.159259

Epoch 64/100: 


Epoch 64, Batch 101/101,Loss: 0.155732, RMSE: 0.385901: 100%|██████████| 101/101 [00:48<00:00,  2.09it/s]
Val Batch 10/10,Loss: 0.204168, RMSE: 0.393368: 100%|██████████| 10/10 [00:03<00:00,  2.69it/s]

Val RMSE: 0.393368, Val Loss: 0.157287

Train RMSE: 0.385901, Average Train Loss: 0.148984, Val RMSE: 0.393368, Average Valid Loss: 0.157287

Epoch 65/100: 


Epoch 65, Batch 101/101,Loss: 0.141731, RMSE: 0.384779: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.161905, RMSE: 0.390996: 100%|██████████| 10/10 [00:03<00:00,  3.03it/s]

Val RMSE: 0.390996, Val Loss: 0.153343

Train RMSE: 0.384779, Average Train Loss: 0.147994, Val RMSE: 0.390996, Average Valid Loss: 0.153343

Epoch 66/100: 


Epoch 66, Batch 101/101,Loss: 0.086074, RMSE: 0.385922: 100%|██████████| 101/101 [00:49<00:00,  2.05it/s]
Val Batch 10/10,Loss: 0.161922, RMSE: 0.405263: 100%|██████████| 10/10 [00:03<00:00,  3.02it/s]

Val RMSE: 0.405263, Val Loss: 0.164119

Train RMSE: 0.385922, Average Train Loss: 0.148337, Val RMSE: 0.405263, Average Valid Loss: 0.164119

Epoch 67/100: 


Epoch 67, Batch 81/101,Loss: 0.177325, RMSE: 0.385006:  80%|████████  | 81/101 [00:39<00:09,  2.21it/s]

In [None]:
torch.save(model.state_dict(), 'model.pth')