# Model Size Checker

This notebook builds up the various models in the standard EEG classification environments, measure their GPU usage, and check the proper minibatch sizes for each model.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import glob
import hydra
from omegaconf import OmegaConf
import pprint
import torch
from torchsummaryX import summary

# custom package
from run_train import check_device_env
from datasets.caueeg_script import build_dataset_for_train
from models.utils import count_parameters

In [9]:
pprint.pprint(torch.cuda.mem_get_info())

(24470421504, 25769148416)


---

## Default settings and all models

In [None]:
train = 'no_wandb'
data_cfg_file = 'task2'
device = 'cuda:0'

In [None]:
model_names = [os.path.basename(full_path)[:-5] for full_path in glob.glob('./config/model/*.yaml')]
model_names = [m for m in model_names if 'base' not in m.lower()]
pprint.pprint(model_names)

---

## Checking Models

In [None]:
model_size_dict = {}

for model_name in model_names:
    with hydra.initialize(config_path='../config'):
        add_configs = [f"data={data_cfg_file}",
                       f"data.input_norm=no",
                       f"+data.age_mean=1.0",
                       f"+data.age_std=1.0",
                       f"train={train}",
                       f"+train.device={device}", 
                       f"model={model_name}"]

        cfg = hydra.compose(config_name='default', overrides=add_configs)

    config = {**OmegaConf.to_container(cfg.data), 
              **OmegaConf.to_container(cfg.train),
              **OmegaConf.to_container(cfg.model)}

    # pprint.pprint(config)
    check_device_env(config)
    
    train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config)    
    model = hydra.utils.instantiate(config).to(config['device'])
    model_size_dict[model_name] = count_parameters(model)
    
    # print('\n\n')
    # x = config['preprocess_train'](next(iter(train_loader)))
    # summary(model, x['signal'], x['age'])
    
pprint.pprint(model_size_dict)