# Train Networks

- Train SoftMax or Multi-BCE classifier for the EEG diagnosis classification
    - CAUEEG-task1 benchmark: Classification of **Normal**, **MCI**, and **Dementia** symptoms
    - CAUEEG-task2 benchmark: Classification of **Normal** and **Abnormal** symptoms
- PyTorch DDP is used for multi-node/GPU environment
- `Weights and Biases` sweep is used for hyperparameter search

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
# Load some packages
import os
import json
from copy import deepcopy
import gc
import time

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

import pprint
import wandb

# custom package
from datasets.caueeg_dataset import *
from datasets.caueeg_script import *
from models import *
from train import *

In [3]:
print('PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): 
    print('cuda is available:', torch.cuda.get_device_name(0))
else: 
    print('cuda is unavailable.') 

PyTorch version: 1.11.0+cu113
cuda is available: NVIDIA GeForce RTX 3090


-----

## Set the default configuration for building datatset

In [4]:
cfg_data = {}
cfg_data['device'] = device
cfg_data['task'] = 'task1'
cfg_data['dataset_path'] = r'local/dataset/02_Curated_Data_220419/'
cfg_data['file_format'] = 'memmap'  # 'feather', 'memmap'
cfg_data['load_event'] = False
cfg_data['latency'] = 200 * 10      # 10 seconds
cfg_data['seq_length'] = 200 * 10  # 10 seconds
cfg_data['crop_multiple'] = 4
cfg_data['test_crop_multiple'] = 8
cfg_data['EKG'] = 'O'
cfg_data['photic'] = 'X'
cfg_data['input_norm'] = 'dataset'  # 'datatset', 'datapoint', 'no'
cfg_data['awgn'] = 5e-2
cfg_data['mgn'] = 1e-4
cfg_data['awgn_age'] = 5e-2

if '3090' in torch.cuda.get_device_name(0):
    cfg_data['minibatch'] = 256
elif '2080' in torch.cuda.get_device_name(0):
    cfg_data['minibatch'] = 128
elif '1070' in torch.cuda.get_device_name(0):
    cfg_data['minibatch'] = 96

In [5]:
cfg_data_temp = deepcopy(cfg_data)
_ = build_dataset_for_train(cfg_data_temp, verbose=True)

in_channels = cfg_data_temp['preprocess_train'](next(iter(_[0])))['signal'].shape[1]
out_dims = len(cfg_data_temp['class_label_to_name'])

del _

transform: Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=4, latency=2000, return_timing=False)
    EegDropChannels(drop_index=[20])
    EegToTensor()
)

----------------------------------------------------------------------------------------------------

transform_multicrop: Compose(
    EegRandomCrop(crop_length=2000, length_limit=10000000, multiple=8, latency=2000, return_timing=False)
    EegDropChannels(drop_index=[20])
    EegToTensor()
)

----------------------------------------------------------------------------------------------------


task config:
{'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'task_description': 'Classification of [Normal], [MCI], and [Dementia] '
                     'symptoms.',
 'task_name': 'CAUEEG-task1 benchmark'}

 ---------------------------------------------------------------------------------------------------- 

train_dataset[0].keys():
dict_ke

In [6]:
pprint.pprint(cfg_data_temp, width=250)

{'EKG': 'O',
 'age_mean': tensor([70.7924]),
 'age_std': tensor([9.7796]),
 'awgn': 0.05,
 'awgn_age': 0.05,
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'crop_multiple': 4,
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'device': device(type='cuda'),
 'file_format': 'memmap',
 'input_norm': 'dataset',
 'latency': 2000,
 'load_event': False,
 'mgn': 0.0001,
 'minibatch': 256,
 'photic': 'X',
 'preprocess_test': Sequential(
  (0): EegToDevice(device=cuda)
  (1): EegNormalizeAge(mean=tensor([70.7924]),std=tensor([9.7796]),eps=1e-08)
  (2): EegNormalizeMeanStd(mean=tensor([-0.0131, -0.0537, -0.0270, -0.0130,  0.0671,  0.1447,  0.0241, -0.0065,
           0.0358,  0.0454,  0.0042, -0.0076,  0.0100, -0.0491,  0.1222, -0.0020,
          -0.0455, -0.0090, -0.0075, -0.0007]),std=tensor([44.3636, 19.8341, 11.1687, 11.5708, 15.1467, 46.4805, 19.3917, 10.3994,
          11.396

-----

## Define Network Models

In [7]:
n_fft, hop_length, seq_len_2d = calculate_stft_params(seq_length=cfg_data['seq_length'])
cfg_temp_model = {
    'in_channels': in_channels, 
    'out_dims': out_dims, 
    'seq_length': cfg_data['seq_length'],
    'stft_params': {'n_fft': n_fft, 'hop_length': hop_length,
                    'complex_mode': 'as_real',  # 'as_real', 'power', 'remove'
                   }
}

cfg_model_pool = []

### 1D Tiny CNN

In [8]:
# cfg_model = {}
# cfg_model['model'] = '1D-Tiny-CNN'
# cfg_model['generator'] = TinyCNN1D
# cfg_model['fc_stages'] = 1
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'average'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

### M5 model

In [9]:
# cfg_model = {}
# cfg_model['model'] = '1D-M5'
# cfg_model['generator'] = M5
# cfg_model['fc_stages'] = 1
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'average'
# cfg_model['base_channels'] = 256
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

### 1D VGG model

In [10]:
cfg_model = {}
cfg_model['model'] = '1D-VGG-19'
cfg_model['generator'] = VGG1D
cfg_model['fc_stages'] = 3
cfg_model['batch_norm'] = True
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

'Model config:'
{'LR': None,
 'activation': 'relu',
 'base_channels': 64,
 'batch_norm': True,
 'fc_stages': 3,
 'final_pool': 'average',
 'generator': <class 'models.vgg_1d.VGG1D'>,
 'model': '1D-VGG-19',
 'use_age': 'fc'}

----------------------------------------------------------------------------------------------------



TypeError: __init__() missing 3 required positional arguments: 'in_channels', 'out_dims', and 'seq_length'

### 1D ResNet variants

In [None]:
cfg_model = {}
cfg_model['model'] = '1D-ResNet-18'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BasicBlock1D
cfg_model['conv_layers'] = [2, 2, 2, 2]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

In [None]:
cfg_model = {}
cfg_model['model'] = '1D-ResNet-50'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

In [None]:
# cfg_model = {}
# cfg_model['model'] = '1D-ResNet-101'
# cfg_model['generator'] = ResNet1D
# cfg_model['block'] = BottleneckBlock1D
# cfg_model['conv_layers'] = [3, 4, 23, 3]
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'average'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')
    
# model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

### 1D ResNeXt variants

In [None]:
cfg_model = {}
cfg_model['model'] = '1D-ResNeXt-50'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['groups'] = 32
cfg_model['width_per_group'] = 4
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

In [None]:
cfg_model = {}
cfg_model['model'] = '1D-ResNeXt-101'
cfg_model['generator'] = ResNet1D
cfg_model['block'] = BottleneckBlock1D
cfg_model['conv_layers'] = [3, 4, 23, 3]
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['groups'] = 32
cfg_model['width_per_group'] = 8
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

### 2D VGG

In [None]:
cfg_model = {}
cfg_model['model'] = '2D-VGG-19'
cfg_model['generator'] = VGG2D
cfg_model['seq_len_2d'] = seq_len_2d
cfg_model['fc_stages'] = 3
cfg_model['batch_norm'] = True
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')
    
model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

### 2D ResNet variants

In [None]:
# cfg_model = {}
# cfg_model['model'] = '2D-ResNet-18'
# cfg_model['generator'] = ResNet2D
# cfg_model['block'] = BasicBlock2D
# cfg_model['conv_layers'] = [2, 2, 2, 2]
# cfg_model['seq_len_2d'] = seq_len_2d
# cfg_model['fc_stages'] = 3
# cfg_model['use_age'] = 'fc'
# cfg_model['final_pool'] = 'average'
# cfg_model['base_channels'] = 64
# cfg_model['LR'] = None
# cfg_model['activation'] = 'relu'

# pprint.pprint('Model config:')
# pprint.pprint(cfg_model)
# print('\n' + '-' * 100 + '\n')

# model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
# print(model)
# print('\n' + '-' * 100 + '\n')

# del model
# cfg_model_pool.append(cfg_model)

In [None]:
cfg_model = {}
cfg_model['model'] = '2D-ResNet-50'
cfg_model['generator'] = ResNet2D
cfg_model['block'] = Bottleneck2D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['seq_len_2d'] = seq_len_2d
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

### 2D ResNeXt variants

In [None]:
cfg_model = {}
cfg_model['model'] = '2D-ResNeXt-50'
cfg_model['generator'] = ResNet2D
cfg_model['block'] = Bottleneck2D
cfg_model['conv_layers'] = [3, 4, 6, 3]
cfg_model['seq_len_2d'] = seq_len_2d
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['groups'] = 32
cfg_model['width_per_group'] = 4
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

In [None]:
cfg_model = {}
cfg_model['model'] = '2D-ResNeXt-101'
cfg_model['generator'] = ResNet2D
cfg_model['block'] = Bottleneck2D
cfg_model['conv_layers'] = [3, 4, 23, 3]
cfg_model['seq_len_2d'] = seq_len_2d
cfg_model['fc_stages'] = 3
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 64
cfg_model['groups'] = 32
cfg_model['width_per_group'] = 8
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

#### CNN-Transformer

In [None]:
cfg_model = {}
cfg_model['model'] = '1D-CNN-Transformer'
cfg_model['generator'] = CNNTransformer
cfg_model['seq_len_2d'] = seq_len_2d
cfg_model['fc_stages'] = 2
cfg_model['use_age'] = 'fc'
cfg_model['final_pool'] = 'average'
cfg_model['base_channels'] = 256  #
cfg_model['n_encoders'] = 8  #
cfg_model['n_heads'] = 8     #
cfg_model['dropout'] = 0.2   #
cfg_model['LR'] = None
cfg_model['activation'] = 'relu'

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

### Vision Transformer (ViT-B-16)

In [None]:
cfg_model = {}
cfg_model['model'] = '2D-ViT-B-16'
cfg_model['generator'] = vit_b_16
cfg_model['seq_len_2d'] = seq_len_2d
cfg_model['fc_stages'] = 2
cfg_model['use_age'] = 'conv'
cfg_model['dropout'] = 0.1
cfg_model['attention_dropout'] = 0.1
cfg_model['LR'] = None
cfg_model['minibatch'] = 64  # ViT requires enormous memory to train

pprint.pprint('Model config:')
pprint.pprint(cfg_model)
print('\n' + '-' * 100 + '\n')

model = cfg_model['generator'](**cfg_model, **cfg_temp_model).to(device, dtype=torch.float32)
print(model)
print('\n' + '-' * 100 + '\n')

del model
cfg_model_pool.append(cfg_model)

#### Summarize the loaded models

In [None]:
for cfg_model in cfg_model_pool:
    pprint.pp(cfg_model, width=150)
    print('\n' + '-' * 100 + '\n')

-----

## Default Configurations for Training

In [None]:
# training configurations
cfg_train = {}
cfg_train['iterations'] = (300000 * 64) // cfg_data['minibatch']
cfg_train['warmup_steps'] = max(round(cfg_train['iterations'] * 0.05), 3000)
cfg_train['num_history'] = 500
cfg_train['lr_scheduler_type'] = 'constant_with_decay'
cfg_train['weight_decay'] = 1e-2
cfg_train['mixup'] = 0.0  # 0.0 for no usage
cfg_train['criterion'] = 'cross-entropy' # 'cross-entropy', 'multi-bce'

cfg_train['device'] = device
cfg_train['save_model'] = True
cfg_train['draw_result'] = True
cfg_train['watch_model'] = True

-----

## Train

In [None]:
sweep_data = {}
sweep_data['seq_length'] = {
    'values': [
        # 200 *  5, #  5 sec
        # 200 * 10, # 10 sec
        200 * 20, # 20 sec
        # 200 * 30, # 30 sec
              ],
}

sweep_data['EKG'] = {
    'values': ['O', 'X'],
}

sweep_data['photic'] = {
    'values': ['O', 'X'],
}

sweep_data['awgn'] = {
    'distribution': 'uniform',
    'min': 0,
    'max': 0.3,
}

sweep_data['mgn'] = {
    'distribution': 'uniform',
    'min': 0,
    'max': 0.1,
}

sweep_data['awgn_age'] = {
    'distribution': 'uniform',
    'min': 0,
    'max': 0.3,
}

In [None]:
sweep_model = {}
sweep_model['model_index'] = { 
    'values': [i for i in range(len(cfg_model_pool))] 
}

sweep_model['fc_stages'] = { 
    'distribution': 'int_uniform',
    'min': 2,
    'max': 4,
}

sweep_model['use_age'] = { 
    'values': ['fc', 'conv']  # 'fc', 'conv', 'no'
}

sweep_model['dropout'] = {
    'values': [0, 0.1, 0.2, 0.3]
}

sweep_model['activation'] = {
    'values': ['relu', 'gelu', 'mish']
}

In [None]:
sweep_train = {}

sweep_train['lr_scheduler_type'] = {
    'values': [
        'constant_with_decay',
        'constant_with_twice_decay',
        'transformer_style',
        'cosine_decay_with_warmup_half',
        'cosine_decay_with_warmup_one_and_half',
#        'cosine_decay_with_warmup_two_and_half',
        'linear_decay_with_warmup',
    ]

}

sweep_train['search_multiplier'] = {
    'values': [1.0, 1.2, 2.0]
}

sweep_train['weight_decay'] = {
    'distribution' : 'log_uniform_values',
    'min': 1e-5,
    'max': 1e-1
}

sweep_train['mixup'] = {
    'values': [0, 0.1, 0.2, 0.3]
}

sweep_train['criterion'] = {
    'values': ['cross-entropy', 'multi-bce']
}

In [None]:
sweep_config = {
    "entity": "ipis-mjkim",
    "name" : "my-sweep",
    "method": "random",
    "parameters": 
    {
        **sweep_data,
        **sweep_model,
        **sweep_train,
    }
}

sweep_id = wandb.sweep(sweep_config, project=f"caueeg-{cfg_data['task']}")

In [None]:
def setup_ddp(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup_ddp():
    dist.destroy_process_group()

In [None]:
def train_ddp(rank, world_size):
    # initialize DDP
    setup(rank, world_size)
    
    # initialize the wandb log
    wandb_run = wandb.init()
    wandb.run.name = wandb.run.id

    with wandb_run:
        # collect some garbages
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
        # load default config
        config = {}
        cfg_model = cfg_model_pool[wandb.config.model_index]
        for k, v in {**cfg_data,  **cfg_model, **cfg_train}.items():
            if k not in wandb.config:
                config[k] = v
        
        # load the selected configurations from wandb sweep with preventing callables from type-conversion to str
        for k, v in wandb.config.items():
            if k not in config:
                config[k] = v

        # build dataset
        train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config)
        
        # train the model
        model = train_with_wandb(config, train_loader, val_loader, test_loader, multicrop_test_loader, 
                                 config['preprocess_train'], config['preprocess_test'])
        
        # release memory
        del model
        del train_loader, val_loader
        del test_loader, multicrop_test_loader
        del config
        
    time.sleep(60)

In [None]:
# def train():
#     # initialize the wandb log
#     wandb_run = wandb.init()
#     wandb.run.name = wandb.run.id

#     with wandb_run:
#         # collect some garbages
#         gc.collect()
#         torch.cuda.empty_cache()
#         torch.cuda.synchronize()
    
#         # load default config
#         config = {}
#         cfg_model = cfg_model_pool[wandb.config.model_index]
#         for k, v in {**cfg_data,  **cfg_model, **cfg_train}.items():
#             if k not in wandb.config:
#                 config[k] = v
                
#         # load the selected configurations from wandb sweep with preventing callables from type-conversion to str
#         for k, v in wandb.config.items():
#             if k not in config:
#                 config[k] = v

#         # build dataset
#         train_loader, val_loader, test_loader, multicrop_test_loader = build_dataset_for_train(config)
        
#         # train the model
#         model = train_with_wandb(config, train_loader, val_loader, test_loader, multicrop_test_loader, 
#                                  config['preprocess_train'], config['preprocess_test'])
        
#         # release memory
#         del model
#         del train_loader, val_loader
#         del test_loader, multicrop_test_loader
#         del config
        
#     time.sleep(60)

In [None]:
# wandb.agent(sweep_id, function=train, count=1)