# %% [markdown]

 # EDA

# %% [markdown]

 ## Import Libraries

In [None]:
import copy
from pathlib import Path
from typing import Optional

import torch
from braindecode.datasets import BaseConcatDataset
from braindecode.models import EEGNeX
from braindecode.preprocessing import (
    Preprocessor,
    create_windows_from_events,
    preprocess,
)
from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import (
    add_aux_anchors,
    add_extras_columns,
    annotate_trials_with_target,
    keep_only_recordings_with,
)
from joblib import Parallel, delayed
from matplotlib.pylab import plt
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# %% [markdown]

 ## Constants

In [None]:
MINI_DATASET_ROOT = Path("/media/varun/braininahat/datasets/eeg2025/mini/")
EPOCH_LEN_S = 2.0
SFREQ = 100
ANCHOR = "stimulus_anchor"
SHIFT_AFTER_STIM = 0.5
WINDOW_LEN = 2.0

# Validation and test set fractions
VALID_FRAC = 0.1
TEST_FRAC = 0.1
# Random seed
SEED = 2025

SUBJECTS_TO_REMOVE = [
    "NDARWV769JM7",
    "NDARME789TD2",
    "NDARUA442ZVF",
    "NDARJP304NK1",
    "NDARTY128YLU",
    "NDARDW550GU6",
    "NDARLD243KRE",
    "NDARUJ292JXV",
    "NDARBA381JGH",
]

BATCH_SIZE = 128
NUM_WORKERS = 8

LR = 1e-3
WEIGHT_DECAY = 1e-5
N_EPOCHS = 100
EARLY_STOPPING_PATIENCE = 50

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# %% [markdown]

 ## Load Data

In [None]:
dataset_ccd = EEGChallengeDataset(
    task="contrastChangeDetection",
    release="R1",
    mini=True,
    cache_dir=MINI_DATASET_ROOT,
)


[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


# %% [markdown]

 ## Explore Data

In [None]:
# %%

%matplotlib qt

raw = dataset_ccd.datasets[0].raw

fig = raw.plot()

Using matplotlib as 2D backend.


# %% [markdown]

 ## Download all

In [None]:
# %%

raws = Parallel(n_jobs=-1)(delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets)

# %% [markdown]

 ## Braindecode init

# %% [markdown]

 ## Epoching

In [None]:
# %%

transformation_offline = [
    Preprocessor(
        annotate_trials_with_target,
        target_field="rt_from_stimulus",
        epoch_length=EPOCH_LEN_S,
        require_stimulus=True,
        require_response=True,
        apply_on_array=False,
    ),
    Preprocessor(add_aux_anchors, apply_on_array=False),
]
preprocess(dataset_ccd, transformation_offline, n_jobs=1)

  raw.set_annotations(new_ann, verbose=False)
  raw.set_annotations(ann + aux, verbose=False)


<eegdash.dataset.dataset.EEGChallengeDataset at 0x739c8eb3c710>

Traceback (most recent call last):
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/matplotlib/cbook.py", line 361, in process
    func(*args, **kwargs)
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_mpl_figure.py", line 812, in _buttonpress
    self._redraw(annotations=True)
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_mpl_figure.py", line 2146, in _redraw
    super()._redraw(update_data, annotations)
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_figure.py", line 463, in _redraw
    self._draw_annotations()
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_mpl_figure.py", line 1408, in _draw_annotations
    segment_color = self.mne.annotation_segment_colors[descr]
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
KeyError: np.str_('contrast_trial_start')


# %% [markdown]

 ## Filter for stimulus anchor presence

In [None]:
# %%

dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

# %% [markdown]

 ## Window creation

In [None]:
# %%

single_windows = create_windows_from_events(
    dataset,
    mapping={ANCHOR: 0},
    trial_start_offset_samples=int(SHIFT_AFTER_STIM * SFREQ),
    trial_stop_offset_samples=int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),
    window_size_samples=int(EPOCH_LEN_S * SFREQ),
    window_stride_samples=SFREQ,
    preload=True,
)

Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

# %% [markdown]

 ## Add metadata

In [None]:
# %%

single_windows = add_extras_columns(
    single_windows,
    dataset,
    desc=ANCHOR,
    keys=(
        "target",
        "rt_from_stimulus",
        "rt_from_trialstart",
        "stimulus_onset",
        "response_onset",
        "correct",
        "response_type",
    ),
)

# %% [markdown]

 ## Inspect metadata

In [None]:
# %%

single_windows.get_metadata().head()

Unnamed: 0,i_window_in_trial,i_start_in_trial,i_stop_in_trial,target,rt_from_stimulus,rt_from_trialstart,stimulus_onset,response_onset,correct,response_type,...,thepresent,diaryofawimpykid,contrastchangedetection_1,contrastchangedetection_2,contrastchangedetection_3,surroundsupp_1,surroundsupp_2,seqlearning6target,seqlearning8target,symbolsearch
0,0,4278,4478,2.13,2.13,4.93,42.284,44.414,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
1,0,4798,4998,1.96,1.96,4.76,47.484,49.444,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
2,0,5478,5678,2.02,2.02,6.42,54.284,56.304,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
3,0,6318,6518,1.72,1.72,7.72,62.684,64.404,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
4,0,6838,7038,1.8,1.8,4.6,67.884,69.684,1,left_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available


# %% [markdown]

 ## Target inspection

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
ax = single_windows.get_metadata()["target"].plot.hist(
    bins=30, ax=ax, color="lightblue"
)
ax.set_xlabel("Response Time (s)")
ax.set_ylabel("Frequency")
ax.set_title("Distribution of Response Times")
plt.show()

# %% [markdown]

 ## Train Test Split (Stratified by subject)

In [None]:
# %%

subjects = single_windows.description["subject"].unique()
print(f"Number of subjects: {len(subjects)}")
print(f"Subjects: {subjects}")

Number of subjects: 20
Subjects: ['NDARAC904DMU' 'NDARAM704GKZ' 'NDARAP359UM6' 'NDARBD879MBX'
 'NDARBH024NH2' 'NDARBK082PDD' 'NDARCA153NKE' 'NDARCE721YB5'
 'NDARCJ594BWQ' 'NDARCN669XPR' 'NDARCW094JCG' 'NDARCZ947WU5'
 'NDARDH670PXH' 'NDARDL511UND' 'NDARDU986RBM' 'NDAREM731BYM'
 'NDAREN519BLJ' 'NDARFK610GY5' 'NDARFT581ZW5' 'NDARFW972KFQ']


# %% [markdown]

 ## Remove subjects

In [None]:
# %%

subjects = [s for s in subjects if s not in SUBJECTS_TO_REMOVE]
print(f"Number of subjects: {len(subjects)}")
print(f"Subjects: {subjects}")

Number of subjects: 20
Subjects: ['NDARAC904DMU', 'NDARAM704GKZ', 'NDARAP359UM6', 'NDARBD879MBX', 'NDARBH024NH2', 'NDARBK082PDD', 'NDARCA153NKE', 'NDARCE721YB5', 'NDARCJ594BWQ', 'NDARCN669XPR', 'NDARCW094JCG', 'NDARCZ947WU5', 'NDARDH670PXH', 'NDARDL511UND', 'NDARDU986RBM', 'NDAREM731BYM', 'NDAREN519BLJ', 'NDARFK610GY5', 'NDARFT581ZW5', 'NDARFW972KFQ']


# %% [markdown]

 ## Train Test Split

In [None]:
# %%

train_subj, valid_test_subject = train_test_split(
    subjects,
    test_size=(VALID_FRAC + TEST_FRAC),
    random_state=check_random_state(SEED),
    shuffle=True,
)

valid_subj, test_subj = train_test_split(
    valid_test_subject,
    test_size=TEST_FRAC,
    random_state=check_random_state(SEED + 1),
    shuffle=True,
)

# %% [markdown]

 ## Sanity check

In [None]:
# %%

assert (set(valid_subj) | set(test_subj) | set(train_subj)) == set(subjects)

# %% [markdown]

 ## Create train/valid/test splits for the windows

In [None]:
# %%

subject_split = single_windows.split("subject")
train_set = []
valid_set = []
test_set = []

for s in subject_split:
    if s in train_subj:
        train_set.append(subject_split[s])
    elif s in valid_subj:
        valid_set.append(subject_split[s])
    elif s in test_subj:
        test_set.append(subject_split[s])

train_set = BaseConcatDataset(train_set)
valid_set = BaseConcatDataset(valid_set)
test_set = BaseConcatDataset(test_set)

# %% [markdown]

 ## Create dataloaders

In [None]:
# %%

train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)
valid_loader = DataLoader(
    valid_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)
test_loader = DataLoader(
    test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

# %% [markdown]

 ## Build the model

In [None]:
# %%

model = EEGNeX(n_chans=129, n_outputs=1, n_times=2 * SFREQ, sfreq=SFREQ)

# %% [markdown]

 ## Print model

In [None]:
# %%

print(model)

Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
EEGNeX (EEGNeX)                                              [1, 129, 200]             [1, 1]                    --                        --
├─Sequential (block_1): 1-1                                  [1, 129, 200]             [1, 8, 129, 200]          --                        --
│    └─Rearrange (0): 2-1                                    [1, 129, 200]             [1, 1, 129, 200]          --                        --
│    └─Conv2d (1): 2-2                                       [1, 1, 129, 200]          [1, 8, 129, 200]          512                       [1, 64]
│    └─BatchNorm2d (2): 2-3                                  [1, 8, 129, 200]          [1, 8, 129, 200]          16                        --
├─Sequential (block_2): 1-2                                  [1, 8, 129, 200]          [1, 32, 129, 200]         --                  

  return F.conv2d(


# %% [markdown]

 ## Train the model

# %% [markdown]

 ## Define training functions

In [None]:
# %%

def train_one_epoch(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    optimizer,
    scheduler: Optional[LRScheduler],
    epoch: int,
    device,
    print_batch_stats: bool = True,
):
    model.train()

    total_loss = 0.0
    sum_sq_err = 0.0
    n_samples = 0

    progress_bar = tqdm(
        enumerate(dataloader), total=len(dataloader), disable=not print_batch_stats
    )

    for batch_idx, batch in progress_bar:
        # Support datasets that may return (X, y) or (X, y, ...)
        X, y = batch[0], batch[1]
        X, y = X.to(device).float(), y.to(device).float()

        optimizer.zero_grad(set_to_none=True)
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Flatten to 1D for regression metrics and accumulate squared error
        preds_flat = preds.detach().view(-1)
        y_flat = y.detach().view(-1)
        sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
        n_samples += y_flat.numel()

        if print_batch_stats:
            running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
            progress_bar.set_description(
                f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}"
            )

    if scheduler is not None:
        scheduler.step()

    avg_loss = total_loss / len(dataloader)
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
    return avg_loss, rmse

# %% [markdown]

 ## Define validation function

In [None]:
# %%

def valid_model(
    dataloader: DataLoader,
    model: Module,
    loss_fn,
    device,
    print_batch_stats: bool = True,
):
    model.eval()
    total_loss = 0.0
    sum_sq_err = 0.0
    n_batches = len(dataloader)
    n_samples = 0

    iterator = tqdm(
        enumerate(dataloader), total=n_batches, disable=not print_batch_stats
    )

    for batch_idx, batch in iterator:
        # Supports (X, y) or (X, y, ...)\n",
        X, y = batch[0], batch[1]
        X, y = X.to(device).float(), y.to(device).float()

        preds = model(X)
        batch_loss = loss_fn(preds, y).item()
        total_loss += batch_loss

        preds_flat = preds.detach().view(-1)
        y_flat = y.detach().view(-1)
        sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item()
        n_samples += y_flat.numel()

        if print_batch_stats:
            running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5
            iterator.set_description(
                f"Val Batch {batch_idx + 1}/{n_batches}, "
                f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}"
            )

    avg_loss = total_loss / n_batches if n_batches else float("nan")
    rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5

    print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}\n")
    return avg_loss, rmse

# %% [markdown]

 ## Train the model

In [None]:
# %%

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS - 1)
loss_fn = torch.nn.MSELoss()

patience = EARLY_STOPPING_PATIENCE
min_delta = 1e-4
best_rmse = float("inf")
epochs_no_improve = 0
best_state, best_epoch = None, None

# %% [markdown]

 ## Train the model

In [None]:
# %%

for epoch in range(1, N_EPOCHS + 1):
    print(f"Epoch {epoch}/{N_EPOCHS}: ", end="")

    train_loss, train_rmse = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, DEVICE
    )
    val_loss, val_rmse = valid_model(test_loader, model, loss_fn, DEVICE)

    print(
        f"Train RMSE: {train_rmse:.6f}, Average Train Loss: {train_loss:.6f}, Val RMSE: {val_rmse:.6f}, Average Val Loss: {val_loss:.6f}"
    )

    if val_rmse < best_rmse - min_delta:
        best_rmse = val_rmse
        best_state = copy.deepcopy(model.state_dict())
        best_epoch = epoch
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(
                f"Early stopping at epoch {epoch}. Best Val RMSE: {best_rmse:.6f} (epoch {best_epoch})"
            )
            break

Epoch 1/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 1.535728, Val Loss: 2.358461

Train RMSE: 1.430488, Average Train Loss: 2.037427, Val RMSE: 1.535728, Average Val Loss: 2.358461
Epoch 2/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 1.242760, Val Loss: 1.544452

Train RMSE: 1.199207, Average Train Loss: 1.363612, Val RMSE: 1.242760, Average Val Loss: 1.544452
Epoch 3/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.340037, Val Loss: 0.115625

Train RMSE: 0.668329, Average Train Loss: 0.439798, Val RMSE: 0.340037, Average Val Loss: 0.115625
Epoch 4/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.454136, Val Loss: 0.206240

Train RMSE: 0.539248, Average Train Loss: 0.304155, Val RMSE: 0.454136, Average Val Loss: 0.206240
Epoch 5/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.376810, Val Loss: 0.141986

Train RMSE: 0.512449, Average Train Loss: 0.262026, Val RMSE: 0.376810, Average Val Loss: 0.141986
Epoch 6/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.315413, Val Loss: 0.099485

Train RMSE: 0.501983, Average Train Loss: 0.266773, Val RMSE: 0.315413, Average Val Loss: 0.099485
Epoch 7/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.357039, Val Loss: 0.127477

Train RMSE: 0.501939, Average Train Loss: 0.249162, Val RMSE: 0.357039, Average Val Loss: 0.127477
Epoch 8/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.402764, Val Loss: 0.162219

Train RMSE: 0.493953, Average Train Loss: 0.243844, Val RMSE: 0.402764, Average Val Loss: 0.162219
Epoch 9/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.404349, Val Loss: 0.163498

Train RMSE: 0.486654, Average Train Loss: 0.233706, Val RMSE: 0.404349, Average Val Loss: 0.163498
Epoch 10/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.563467, Val Loss: 0.317496

Train RMSE: 0.490735, Average Train Loss: 0.249844, Val RMSE: 0.563467, Average Val Loss: 0.317496
Epoch 11/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.486134, Val Loss: 0.236326

Train RMSE: 0.490336, Average Train Loss: 0.237721, Val RMSE: 0.486134, Average Val Loss: 0.236326
Epoch 12/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.419649, Val Loss: 0.176105

Train RMSE: 0.492612, Average Train Loss: 0.255073, Val RMSE: 0.419649, Average Val Loss: 0.176105
Epoch 13/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.414988, Val Loss: 0.172215

Train RMSE: 0.485474, Average Train Loss: 0.230343, Val RMSE: 0.414988, Average Val Loss: 0.172215
Epoch 14/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.517923, Val Loss: 0.268245

Train RMSE: 0.492087, Average Train Loss: 0.236028, Val RMSE: 0.517923, Average Val Loss: 0.268245
Epoch 15/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.452176, Val Loss: 0.204463

Train RMSE: 0.474256, Average Train Loss: 0.221496, Val RMSE: 0.452176, Average Val Loss: 0.204463
Epoch 16/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.480172, Val Loss: 0.230565

Train RMSE: 0.467984, Average Train Loss: 0.223925, Val RMSE: 0.480172, Average Val Loss: 0.230565
Epoch 17/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.405453, Val Loss: 0.164392

Train RMSE: 0.480387, Average Train Loss: 0.229904, Val RMSE: 0.405453, Average Val Loss: 0.164392
Epoch 18/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.374446, Val Loss: 0.140210

Train RMSE: 0.477581, Average Train Loss: 0.227174, Val RMSE: 0.374446, Average Val Loss: 0.140210
Epoch 19/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.379359, Val Loss: 0.143913

Train RMSE: 0.475582, Average Train Loss: 0.220451, Val RMSE: 0.379359, Average Val Loss: 0.143913
Epoch 20/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.612212, Val Loss: 0.374803

Train RMSE: 0.472653, Average Train Loss: 0.225118, Val RMSE: 0.612212, Average Val Loss: 0.374803
Epoch 21/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.532236, Val Loss: 0.283275

Train RMSE: 0.480247, Average Train Loss: 0.223527, Val RMSE: 0.532236, Average Val Loss: 0.283275
Epoch 22/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.495780, Val Loss: 0.245797

Train RMSE: 0.470331, Average Train Loss: 0.232250, Val RMSE: 0.495780, Average Val Loss: 0.245797
Epoch 23/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.444948, Val Loss: 0.197979

Train RMSE: 0.460256, Average Train Loss: 0.211831, Val RMSE: 0.444948, Average Val Loss: 0.197979
Epoch 24/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.396772, Val Loss: 0.157428

Train RMSE: 0.468686, Average Train Loss: 0.224316, Val RMSE: 0.396772, Average Val Loss: 0.157428
Epoch 25/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.389421, Val Loss: 0.151649

Train RMSE: 0.458329, Average Train Loss: 0.215976, Val RMSE: 0.389421, Average Val Loss: 0.151649
Epoch 26/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.439770, Val Loss: 0.193398

Train RMSE: 0.472702, Average Train Loss: 0.212609, Val RMSE: 0.439770, Average Val Loss: 0.193398
Epoch 27/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.384590, Val Loss: 0.147909

Train RMSE: 0.464403, Average Train Loss: 0.218105, Val RMSE: 0.384590, Average Val Loss: 0.147909
Epoch 28/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.377542, Val Loss: 0.142538

Train RMSE: 0.461104, Average Train Loss: 0.226245, Val RMSE: 0.377542, Average Val Loss: 0.142538
Epoch 29/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.417108, Val Loss: 0.173979

Train RMSE: 0.462701, Average Train Loss: 0.205919, Val RMSE: 0.417108, Average Val Loss: 0.173979
Epoch 30/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.391320, Val Loss: 0.153131

Train RMSE: 0.455569, Average Train Loss: 0.210343, Val RMSE: 0.391320, Average Val Loss: 0.153131
Epoch 31/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.389624, Val Loss: 0.151807

Train RMSE: 0.460680, Average Train Loss: 0.214305, Val RMSE: 0.389624, Average Val Loss: 0.151807
Epoch 32/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.391469, Val Loss: 0.153248

Train RMSE: 0.452914, Average Train Loss: 0.200772, Val RMSE: 0.391469, Average Val Loss: 0.153248
Epoch 33/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.438515, Val Loss: 0.192296

Train RMSE: 0.452249, Average Train Loss: 0.218056, Val RMSE: 0.438515, Average Val Loss: 0.192296
Epoch 34/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.390416, Val Loss: 0.152425

Train RMSE: 0.471399, Average Train Loss: 0.221444, Val RMSE: 0.390416, Average Val Loss: 0.152425
Epoch 35/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.402305, Val Loss: 0.161850

Train RMSE: 0.451401, Average Train Loss: 0.208694, Val RMSE: 0.402305, Average Val Loss: 0.161850
Epoch 36/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.417157, Val Loss: 0.174020

Train RMSE: 0.457315, Average Train Loss: 0.208796, Val RMSE: 0.417157, Average Val Loss: 0.174020
Epoch 37/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.600167, Val Loss: 0.360200

Train RMSE: 0.457987, Average Train Loss: 0.213342, Val RMSE: 0.600167, Average Val Loss: 0.360200
Epoch 38/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.485234, Val Loss: 0.235453

Train RMSE: 0.455685, Average Train Loss: 0.206515, Val RMSE: 0.485234, Average Val Loss: 0.235453
Epoch 39/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.422324, Val Loss: 0.178358

Train RMSE: 0.444422, Average Train Loss: 0.204939, Val RMSE: 0.422324, Average Val Loss: 0.178358
Epoch 40/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.390537, Val Loss: 0.152519

Train RMSE: 0.444697, Average Train Loss: 0.199111, Val RMSE: 0.390537, Average Val Loss: 0.152519
Epoch 41/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.386936, Val Loss: 0.149719

Train RMSE: 0.455687, Average Train Loss: 0.203622, Val RMSE: 0.386936, Average Val Loss: 0.149719
Epoch 42/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.413519, Val Loss: 0.170998

Train RMSE: 0.445708, Average Train Loss: 0.206041, Val RMSE: 0.413519, Average Val Loss: 0.170998
Epoch 43/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.381284, Val Loss: 0.145377

Train RMSE: 0.444589, Average Train Loss: 0.203956, Val RMSE: 0.381284, Average Val Loss: 0.145377
Epoch 44/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.472875, Val Loss: 0.223611

Train RMSE: 0.437800, Average Train Loss: 0.187552, Val RMSE: 0.472875, Average Val Loss: 0.223611
Epoch 45/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.408507, Val Loss: 0.166878

Train RMSE: 0.445868, Average Train Loss: 0.193641, Val RMSE: 0.408507, Average Val Loss: 0.166878
Epoch 46/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.386841, Val Loss: 0.149646

Train RMSE: 0.446245, Average Train Loss: 0.196856, Val RMSE: 0.386841, Average Val Loss: 0.149646
Epoch 47/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.747620, Val Loss: 0.558936

Train RMSE: 0.439622, Average Train Loss: 0.188054, Val RMSE: 0.747620, Average Val Loss: 0.558936
Epoch 48/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.600675, Val Loss: 0.360811

Train RMSE: 0.450221, Average Train Loss: 0.198015, Val RMSE: 0.600675, Average Val Loss: 0.360811
Epoch 49/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.531241, Val Loss: 0.282217

Train RMSE: 0.439084, Average Train Loss: 0.189045, Val RMSE: 0.531241, Average Val Loss: 0.282217
Epoch 50/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.482627, Val Loss: 0.232929

Train RMSE: 0.448964, Average Train Loss: 0.206186, Val RMSE: 0.482627, Average Val Loss: 0.232929
Epoch 51/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.430889, Val Loss: 0.185665

Train RMSE: 0.444211, Average Train Loss: 0.196483, Val RMSE: 0.430889, Average Val Loss: 0.185665
Epoch 52/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.490224, Val Loss: 0.240319

Train RMSE: 0.439190, Average Train Loss: 0.203600, Val RMSE: 0.490224, Average Val Loss: 0.240319
Epoch 53/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.490220, Val Loss: 0.240315

Train RMSE: 0.435932, Average Train Loss: 0.183051, Val RMSE: 0.490220, Average Val Loss: 0.240315
Epoch 54/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.430103, Val Loss: 0.184989

Train RMSE: 0.444296, Average Train Loss: 0.202130, Val RMSE: 0.430103, Average Val Loss: 0.184989
Epoch 55/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.433123, Val Loss: 0.187595

Train RMSE: 0.439792, Average Train Loss: 0.188758, Val RMSE: 0.433123, Average Val Loss: 0.187595
Epoch 56/100: 

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Val RMSE: 0.632785, Val Loss: 0.400417

Train RMSE: 0.436125, Average Train Loss: 0.211392, Val RMSE: 0.632785, Average Val Loss: 0.400417
Early stopping at epoch 56. Best Val RMSE: 0.315413 (epoch 6)


Traceback (most recent call last):
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/matplotlib/cbook.py", line 361, in process
    func(*args, **kwargs)
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_mpl_figure.py", line 812, in _buttonpress
    self._redraw(annotations=True)
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_mpl_figure.py", line 2146, in _redraw
    super()._redraw(update_data, annotations)
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_figure.py", line 463, in _redraw
    self._draw_annotations()
  File "/home/varun/repos/cerebro/.venv/lib/python3.12/site-packages/mne/viz/_mpl_figure.py", line 1408, in _draw_annotations
    segment_color = self.mne.annotation_segment_colors[descr]
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
KeyError: np.str_('contrast_trial_start')
Traceback (most recent call last):
  File "/home/varun/repos/cerebro/.venv/lib

# %% [markdown]

 ## Save the model

In [None]:
# %%

from pathlib import Path
weights_dir = Path("weights")
weights_dir.mkdir(exist_ok=True)
torch.save(model.state_dict(), weights_dir / "weights_challenge_1.pt")
print("Model saved as 'weights/weights_challenge_1.pt'")

Model saved as 'weights/weights_challenge_1.pt'
