In [3]:
import argparse
import random

import numpy as np
import torch

import wandb

from datasets import faced_dataset, seedv_dataset, physio_dataset, shu_dataset, isruc_dataset, chb_dataset, \
    speech_dataset, mumtaz_dataset, seedvig_dataset, stress_dataset, tuev_dataset, tuab_dataset, bciciv2a_dataset
from datasets import pearl_kfold_dataset as pearl_dataset
from finetune_trainer_kfold import Trainer
from models import model_for_faced, model_for_seedv, model_for_physio, model_for_shu, model_for_isruc, model_for_chb, \
    model_for_speech, model_for_mumtaz, model_for_seedvig, model_for_stress, model_for_tuev, model_for_tuab, \
    model_for_bciciv2a, model_for_pearl


def main(argv=[]):
    parser = argparse.ArgumentParser(description='Big model downstream')
    parser.add_argument('--project_name', type=str, default='EEGFoundationModel', help='project_name')
    parser.add_argument('--modality_mode', type=str, default='multi_attend', help='modality_mode (mono (only EEG), multi (EEG + other modalities))')
    parser.add_argument('--num_folds', type=int, default=0, help='fold to test, -1 means all folds')
    parser.add_argument('--seed', type=int, default=3407, help='random seed (default: 0)')
    parser.add_argument('--cuda', type=int, default=0, help='cuda number (default: 1)')
    parser.add_argument('--epochs', type=int, default=20, help='number of epochs (default: 5)')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size for training (default: 32)')
    parser.add_argument('--lr', type=float, default=3e-4, help='learning rate (default: 1e-3)')
    parser.add_argument('--weight_decay', type=float, default=5e-2, help='weight decay (default: 1e-2)')
    parser.add_argument('--optimizer', type=str, default='AdamW', help='optimizer (AdamW, SGD)')
    parser.add_argument('--clip_value', type=float, default=1, help='clip_value')
    parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
    parser.add_argument('--classifier', type=str, default='all_patch_reps',
                        help='[all_patch_reps, all_patch_reps_twolayer, '
                             'all_patch_reps_onelayer, avgpooling_patch_reps]')
    # all_patch_reps: use all patch features with a three-layer classifier;
    # all_patch_reps_twolayer: use all patch features with a two-layer classifier;
    # all_patch_reps_onelayer: use all patch features with a one-layer classifier;
    # avgpooling_patch_reps: use average pooling for patch features;

    """############ Downstream dataset settings ############"""
    parser.add_argument('--downstream_dataset', type=str, default='PEARL',
                        help='[FACED, SEED-V, PhysioNet-MI, SHU-MI, ISRUC, CHB-MIT, BCIC2020-3, Mumtaz2016, '
                             'SEED-VIG, MentalArithmetic, TUEV, TUAB, BCIC-IV-2a, PEARL]')
    parser.add_argument('--datasets_dir', type=str,
                        default='/mnt/disk1/aiotlab/namth/EEGFoundationModel/datasets/pearl_30s',
                        help='datasets_dir')
    parser.add_argument('--num_of_classes', type=int, default=2, help='number of classes')
    parser.add_argument('--model_dir', type=str, default='./data/wjq/models_weights/Big/BigFaced', help='model_dir')
    """############ Downstream dataset settings ############"""

    parser.add_argument('--num_workers', type=int, default=16, help='num_workers')
    parser.add_argument('--label_smoothing', type=float, default=0.1, help='label_smoothing')
    parser.add_argument('--multi_lr', type=bool, default=False,
                        help='multi_lr')  # set different learning rates for different modules
    parser.add_argument('--frozen', type=bool,
                        default=False, help='frozen')
    parser.add_argument('--use_pretrained_weights', type=bool,
                        default=True, help='use_pretrained_weights')
    parser.add_argument('--foundation_dir', type=str,
                        default='pretrained_weights/pretrained_108.pth',
                        help='foundation_dir')

    params = parser.parse_args(args=argv)
    
    group = 'Reproduce'
    
    print(params)

    dir = params.datasets_dir
    # torch.cuda.set_device(params.cuda)
    if params.downstream_dataset == 'PEARL':
        acc, pr_auc, roc_auc = 0, 0, 0
        for i in range(5):
            if params.num_folds != -1 and i != params.num_folds:
                continue
            run = wandb.init(project=params.project_name, 
                    group=group,
                    name=f'data_{params.downstream_dataset}_seed_{params.seed}',
                    config=vars(params))
            setup_seed(params.seed)
            params.datasets_dir = f'{dir}/fold_{i}'
            load_dataset = pearl_dataset.LoadDataset(params)
            data_loader = load_dataset.get_data_loader()
            model = model_for_pearl.Model(params).cuda()
            t = Trainer(params, data_loader, model)
            acc_, pr_auc_, roc_auc_ = t.train_for_binaryclass()
            acc += acc_
            pr_auc += pr_auc_
            roc_auc += roc_auc_
            run.finish()

        if params.num_folds == -1:
            acc /= 5
            pr_auc /= 5
            roc_auc /= 5
        print(f'5-fold cross validation results: acc {acc}, pr_auc {pr_auc}, roc_auc {roc_auc}')

    print('Done!!!!!')


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [4]:

# print('dmm')
wandb.login(key='ca0cbcfb28dcbf7cd1b54e0a6d40d8a482fc730f')
main()
wandb.finish()


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/aiotlab/.netrc


Namespace(project_name='EEGFoundationModel', modality_mode='multi_attend', num_folds=0, seed=3407, cuda=0, epochs=20, batch_size=32, lr=0.0003, weight_decay=0.05, optimizer='AdamW', clip_value=1, dropout=0.1, classifier='all_patch_reps', downstream_dataset='PEARL', datasets_dir='/mnt/disk1/aiotlab/namth/EEGFoundationModel/datasets/pearl_30s', num_of_classes=2, model_dir='./data/wjq/models_weights/Big/BigFaced', num_workers=16, label_smoothing=0.1, multi_lr=False, frozen=False, use_pretrained_weights=True, foundation_dir='pretrained_weights/pretrained_108.pth')


503 114 135
752
Model(
  (backbone): CBraMod(
    (patch_embedding): PatchEmbedding(
      (positional_encoding): Sequential(
        (0): Conv2d(200, 200, kernel_size=(19, 7), stride=(1, 1), padding=(9, 3), groups=200)
      )
      (proj_in): Sequential(
        (0): Conv2d(1, 25, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24))
        (1): GroupNorm(5, 25, eps=1e-05, affine=True)
        (2): GELU(approximate='none')
        (3): Conv2d(25, 25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
        (4): GroupNorm(5, 25, eps=1e-05, affine=True)
        (5): GELU(approximate='none')
        (6): Conv2d(25, 25, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
        (7): GroupNorm(5, 25, eps=1e-05, affine=True)
        (8): GELU(approximate='none')
      )
      (spectral_proj): Sequential(
        (0): Linear(in_features=101, out_features=200, bias=True)
        (1): Dropout(p=0.1, inplace=False)
      )
    )
    (encoder): TransformerEncoder(
      (layers): ModuleList(

100%|██████████| 16/16 [00:04<00:00,  3.39it/s]


Epoch 1 : Training Loss: 0.69835, LR: 0.00030, Time elapsed 0.08 mins


100%|██████████| 4/4 [00:00<00:00, 16.07it/s]


val Evaluation: acc: 0.50000, pr_auc: 0.83481, roc_auc: 0.83190
[[56  0]
 [58  0]]
roc_auc increasing....saving weights !! 
Val Evaluation: acc: 0.50000, pr_auc: 0.83481, roc_auc: 0.83190


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch 2 : Training Loss: 0.69649, LR: 0.00029, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 18.22it/s]


val Evaluation: acc: 0.50000, pr_auc: 0.79766, roc_auc: 0.77371
[[56  0]
 [58  0]]


100%|██████████| 16/16 [00:03<00:00,  5.02it/s]


Epoch 3 : Training Loss: 0.58607, LR: 0.00028, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.60it/s]


val Evaluation: acc: 0.66964, pr_auc: 0.69311, roc_auc: 0.62962
[[19 37]
 [ 0 58]]


100%|██████████| 16/16 [00:03<00:00,  5.01it/s]


Epoch 4 : Training Loss: 0.17661, LR: 0.00027, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.47it/s]


val Evaluation: acc: 0.83036, pr_auc: 1.00000, roc_auc: 1.00000
[[37 19]
 [ 0 58]]
roc_auc increasing....saving weights !! 
Val Evaluation: acc: 0.83036, pr_auc: 1.00000, roc_auc: 1.00000


100%|██████████| 16/16 [00:03<00:00,  5.01it/s]


Epoch 5 : Training Loss: 0.15887, LR: 0.00026, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.51it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.79645, roc_auc: 0.79464
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.99it/s]


Epoch 6 : Training Loss: 0.12568, LR: 0.00024, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.48it/s]


val Evaluation: acc: 0.83036, pr_auc: 0.90346, roc_auc: 0.88608
[[37 19]
 [ 0 58]]


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch 7 : Training Loss: 0.06651, LR: 0.00022, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.09it/s]


val Evaluation: acc: 0.83036, pr_auc: 0.95451, roc_auc: 0.95166
[[37 19]
 [ 0 58]]


100%|██████████| 16/16 [00:03<00:00,  4.99it/s]


Epoch 8 : Training Loss: 0.03116, LR: 0.00020, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.50it/s]


val Evaluation: acc: 0.48553, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [40 18]]


100%|██████████| 16/16 [00:03<00:00,  5.00it/s]


Epoch 9 : Training Loss: 0.02336, LR: 0.00017, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.49it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.83402, roc_auc: 0.82266
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.98it/s]


Epoch 10 : Training Loss: 0.02316, LR: 0.00015, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.51it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.90060, roc_auc: 0.88239
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.98it/s]


Epoch 11 : Training Loss: 0.01542, LR: 0.00013, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.47it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.98it/s]


Epoch 12 : Training Loss: 0.00029, LR: 0.00010, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.47it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.97it/s]


Epoch 13 : Training Loss: 0.00111, LR: 0.00008, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.42it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.96it/s]


Epoch 14 : Training Loss: 0.00026, LR: 0.00006, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.34it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.96it/s]


Epoch 15 : Training Loss: 0.00022, LR: 0.00004, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.44it/s]


val Evaluation: acc: 0.63208, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [23 35]]


100%|██████████| 16/16 [00:03<00:00,  4.95it/s]


Epoch 16 : Training Loss: 0.00195, LR: 0.00003, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.36it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.95it/s]


Epoch 17 : Training Loss: 0.00010, LR: 0.00002, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.18it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.93it/s]


Epoch 18 : Training Loss: 0.00603, LR: 0.00001, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.39it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.93it/s]


Epoch 19 : Training Loss: 0.00009, LR: 0.00000, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.25it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]


100%|██████████| 16/16 [00:03<00:00,  4.92it/s]


Epoch 20 : Training Loss: 0.00564, LR: 0.00000, Time elapsed 0.05 mins


100%|██████████| 4/4 [00:00<00:00, 17.27it/s]


val Evaluation: acc: 0.65794, pr_auc: 0.75990, roc_auc: 0.76601
[[37 19]
 [20 38]]
***************************Test************************


100%|██████████| 5/5 [00:00<00:00, 17.44it/s]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


***************************Test results************************
Test Evaluation: acc: 0.74667, pr_auc: 0.77444, roc_auc: 0.79200
[[37 38]
 [ 0 60]]
model save in ./data/wjq/models_weights/Big/BigFaced/epoch20_acc_0.74667_pr_0.77444_roc_0.79200.pth


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,███▇▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁
test/acc,▁
test/pr_auc,▁
test/roc_auc,▁
time_min,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,██▇▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20.0
lr,0.0
test/acc,0.74667
test/pr_auc,0.77444
test/roc_auc,0.792
time_min,0.05424
train/loss,0.00564


5-fold cross validation results: acc 0.7466666666666667, pr_auc 0.7744440317647967, roc_auc 0.792
Done!!!!!
