In [1]:
import torch
import numpy as np
import os
import timm
import yaml
import gc
import pandas as pd
import pytorch_lightning as pl


In [2]:
import sys
sys.path.append('../')
sys.path.append('../hms_pipeline')

from hms_pipeline.trainer import train_model

In [3]:
print(torch.cuda.is_available())

True


In [4]:
CONFIG_PATH = "../configs/hms-configs.yaml"

In [5]:
CONFIG = dict(
    trainer=dict(
        max_epochs=12,
        min_epochs=5,
        enable_progress_bar=True,
        devices=1,
        deterministic=False,
        precision="16",
    ),

    seed =42,
    
    patience=3,
    train_bs=8,
    valid_bs=8,
    workers=8,
    wandb_project="hms_v6_2",
    output_dir="../models-hmsv6-2/hmsv6-convnexts-3imgs-w020-kl1-lrf",
    fold_dir =  "/home/maxc/workspace/kaggle-hms/folds",
    train_folds=[0],
    # train_folds = [1],
    # train_folds = [1, 2, 3, 4],
    # train_folds = [0, 1],
    # train_folds = [2, 3, 4],
    # train_folds = [0, 1, 2, 3, 4],
    # train_folds = [2],

    ##############################
    # Dataset parameters
    ##############################
    # train_dataset_group="eeg_id",
    train_dataset_group="unique_label",
    val_dataset_group="eeg_id",

    raw_eeg_dir = "/home/maxc/workspace/kaggle-hms/data/v6/raw_eegs_sosbandclip_2500",
    eeg_spec_dir = "/home/maxc/workspace/kaggle-hms/data/v6/eeg_specs_h100w250gf5fft1024wl200lc05_norm_const",
    long_spec_dir= "/home/maxc/workspace/kaggle-hms/data/v6/long_specs",
    

    l7_weight = 0.2, # extra inverse weight for votes less than 7


    inverse_kl_weight=True,
    # inverse_kl_weight=False,
    # inverse_kl_mean = [ 0.174031, 0.112700,0.090854, 0.071484,0.136408, 0.414523],
    inverse_kl_mean = [1/6, 1/6, 1/6, 1/6, 1/6, 1/6],
    kl_multiplier = 1.0,
    vote_weight = False,


    # Options to control the presence of each image type
    # img_types=["raw_eeg", "eeg_spec", "long_spec"],    
    # img_types=["full_raw_eeg", "raw_eeg", "long_spec"],
    # img_types = ["long_spec", "raw_eeg"]
    img_types = ["long_spec", "eeg_spec", "full_raw_eeg", "raw_eeg"],


    # The order of each area in the image
    signals=["LL", "RL", "LP", "RP"],
    lrflip_signals=["RL", "LL", "RP", "LP"],

    # # new order
    # signals = ["LL", "LP", "RL", "RP"],
    # lrflip_signals = ["RL", "RP", "LL", "LP"],

    # # old order
    # signals = ["LL", "LP", "RP", "RL"],
    # lrflip_signals = ["RL", "RP", "LP", "LL"],

  
    img_size=[512, 512],

    long_spec_ratio = 1.0, # The percentage of long spectrogram of 50s in the image
    raw_full_gap = 16, # gap between full raw eeg and raw eeg
    center_timespan = 10, # timespan of center on eeg spectrogram

    sub_img_size={
        "eeg_spec": [288, 200],
        "long_spec": [288, 300],
        "raw_eeg": [128, 500],
        "full_raw_eeg": [64, 500],
    },


    # img_size = [448, 448],
    # sub_img_size={
    #         "eeg_spec": [256, 224],
    #         "long_spec": [256, 224],
    #         "raw_eeg": [128, 448],
    #         "full_raw_eeg": [64, 448],
    # },
   
    # Options to control the vertical flip of sub images
    sub_img_vflips={
        "eeg_spec": [False, False, False, False],
        # "eeg_spec": [True, False, True, False],
        "long_spec": [False, False, False, False],
        # "long_spec": [False, True, False, False],
        "raw_eeg": [False, False, False, False],
        "raw_eeg": [False, False, True, True],
    },

    train_aug_probs={
        "lrflip_prob": 0.5, # left right flip
        "fbflip_prob": 0.5,  # front back flip
        "mask_prob": 0.5, # probablity of masking
        "keep_center_ratio": 0.2, # ratio of keeping center on eeg spectrogram
        "hflip_prob": 0.5, # probablity of horizontal flip on spectrogram
        "blur_prob":0.0, # probablity of blurring spectrogram
        "roll_prob": 0.5, # probablity of rolling raw eeg
        "neg_eeg_prob": 0.5, # probablity of negative raw eeg
        "contrast_prob": 0.5, # probablity of contrast adjustment on raw eeg
        "fuse_prob": 0.0, # probablity of fusing spectrogram
        "block_prob": 0.5, # probablity of blocking raw eeg channels
        "noise_prob": 0.0,  # probablity of adding noise to entire image
        "mask_iter": 3, # iteration of masking
        "mask_size_ratio": 0.2, # h/w ratio of each mask size
        "num_block_ch": 4, # number of blocked channels
        "dummy_votes_prob": 0.0, # probablity of adding dummy votes
        "num_dummy_votes": 1, # number of dummy votes
    },  # sofar best aug

    val_aug_probs={
        "lrflip_prob": 0.0,
        "fbflip_prob": 0.0,
        "mask_prob": 0.0,
        "keep_center_ratio": 0.0,
        "hflip_prob": 0.0,
        "blur_prob": 0.0,
        "roll_prob": 0.0,
        "neg_eeg_prob": 0.0,
        "contrast_prob": 0.0,
        "fuse_prob": 0.0,
        "block_prob": 0.0,
        "noise_prob": 0.0,
        "mask_iter": 5,
        "mask_size_ratio": 0.1,
        "num_block_ch": 4,
        "dummy_votes_prob": 0.0,
        "num_dummy_votes": 1,
    },

    ##############################
    # Model parameters
    ##############################
    # model_type = "SpecVitModel",
    model_type="SpecModel",

    dropout=0.5,
    # global_pool = ["max", "avg"],
    global_pool=["avg"],
    # global_pool=["gem"],
    hidden_size = 8,

    # spec_backbone = "efficientnet_b0",
    # spec_backbone=  "efficientnet_b2",
    # spec_backbone = "maxxvitv2_nano_rw_256",
    # spec_backbone="convnext_tiny.fb_in22k",
    # spec_backbone = "convnextv2_tiny.fcmae_ft_in22k_in1k",
    # spec_backbone = "convnext_base.fb_in22k_ft_in1k",
    # spec_backbone="tf_efficientnetv2_m.in21k_ft_in1k",
    spec_backbone= "convnext_small.fb_in22k",
    # spec_backbone = "maxvit_tiny_tf_512",
    # spec_backbone = "maxvit_small_tf_512",
    # spec_backbone = "maxxvitv2_nano_rw_256.sw_in1k",
    # spec_backbone = "maxvit_small_tf_512",
    # spec_backbone = "tf_efficientnetv2_s.in21k",
    # spec_backbone = "tf_efficientnetv2_m.in21k",
    # spec_backbone = "swinv2_tiny_window8_256.ms_in1k",
    
    ## SpecVitModel Params
    vit_model="vit_small_patch16_224",
    # vit_model = "vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
    # vit_model = "vit_base_patch16_224.augreg2_in21k_ft_in1k",
    # vit_model="swinv2_small_window16_256.ms_in1k",
    # vit_model= "swinv2_tiny_window16_256.ms_in1k",
    # vit_model = "swinv2_tiny_window8_256.ms_in1k",
    # vit_model = "maxvit_tiny_tf_224.in1k",
    # vit_model = "maxvit_tiny_tf_224",
    # vit_model = "efficientnet_b0",
    feature_layer=[-1],
    # global_pool = "avg",
    # global_pool = "max",
    # global_pool = "avgmax",
    optimizer_params=dict(
        lr = 1e-4,
        # weight_decay = 0,
    ),
    use_ema = True,
    ema_decay = 0.999,
    scheduler=dict(
        name="CosineAnnealingLR",
        params=dict(
            CosineAnnealingLR=dict(T_max=10, eta_min=1e-6),
        ),
    ),
)

with open(CONFIG_PATH, "w") as f:
    yaml.dump(CONFIG, f, sort_keys=False)

In [6]:
torch.set_float32_matmul_precision("medium")

config_path = CONFIG_PATH

ckpt_path = None

train_model(config_path, seed=CONFIG["seed"] ,ckpt_path=ckpt_path)

Seed set to 42


{'trainer': {'max_epochs': 12, 'min_epochs': 5, 'enable_progress_bar': True, 'devices': 1, 'deterministic': False, 'precision': '16'}, 'seed': 42, 'patience': 3, 'train_bs': 8, 'valid_bs': 8, 'workers': 8, 'wandb_project': 'hms_v6_2', 'output_dir': '../models-hmsv6-2/hmsv6-convnexts-3imgs-w020-kl1-lrf', 'fold_dir': '/home/maxc/workspace/kaggle-hms/folds', 'train_folds': [0], 'train_dataset_group': 'unique_label', 'val_dataset_group': 'eeg_id', 'raw_eeg_dir': '/home/maxc/workspace/kaggle-hms/data/v6/raw_eegs_sosbandclip_2500', 'eeg_spec_dir': '/home/maxc/workspace/kaggle-hms/data/v6/eeg_specs_h100w250gf5fft1024wl200lc05_norm_const', 'long_spec_dir': '/home/maxc/workspace/kaggle-hms/data/v6/long_specs', 'l7_weight': 0.2, 'inverse_kl_weight': True, 'inverse_kl_mean': [0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666], 'kl_multiplier': 1.0, 'vote_weight': False, 'img_types': ['long_spec', 'eeg_spec', 'full_raw_eeg'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaxc303[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/maxc/miniconda3/envs/hms/lib/python3.10/site-packages/lightning_fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Using SpecModel
Using EMA model with decay 0.999


/home/maxc/miniconda3/envs/hms/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/maxc/workspace/kaggle-hms/models-hmsv6-2/hmsv6-convnexts-3imgs-w020-kl1-lrf exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type          | Params
--------------------------------------------
0 | model     | SpecModel     | 49.5 M
1 | kl_loss   | KLDivLoss     | 0     
2 | ema_model | AveragedModel | 49.5 M
--------------------------------------------
98.9 M    Trainable params
0         Non-trainable params
98.9 M    Total params
395.662   Total estimated model params size (MB)


torch.Size([1, 6])


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/maxc/miniconda3/envs/hms/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


(106800, 7) (16, 7) (16, 7)
(106800, 7) (6, 7) (6, 7)
Score:  1.6747184724475055 G10 Score:  1.1537897069518372


Training: |          | 0/? [00:00<?, ?it/s]

/home/maxc/miniconda3/envs/hms/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


VBox(children=(Label(value='0.593 MB of 0.610 MB uploaded\r'), FloatProgress(value=0.9724268124401271, max=1.0…