# Finetune Self-Supervision

- Finetune the deep network after pretraining the self-supervised learning framework.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

/home/imkbsz/workspace/eeg_analysis


In [2]:
# Load some packages
import os
import gc
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
import wandb
import pprint
import torch

# custom package
from run_train import check_device_env
from run_train import set_seed
from run_train import compose_dataset
from train.train_script import train_script
from models.utils import count_parameters

---

## Specify the dataset, model, and train setting

In [3]:
pre_model_path = 'local/checkpoint/'
pre_model_name = 't39h5eja'
finetune = 'whole'

project = 'caueeg-ssl-finetune'
use_wandb = True
device = 'cuda:0'

crop_multiple = 8
total_samples = 1.0e+6
reset_minibatch = False
search_lr = True   ##########
base_lr = 1e-3  #########
warmup_min = 300 # None
lr_scheduler_type = 'cosine_decay_with_warmup_half'  # 'consine_decay_with_warmup_half', 'linear_decay_with_warmup'

mixup = 0.3    ########

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device(device if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 1.11.0
cuda is available.


---

## Load and modify the pretrained network

In [5]:
# load pretrained configurations
path = os.path.join(pre_model_path, pre_model_name.split(',')[-1], 'checkpoint.pt')
try:
    ckpt = torch.load(path, map_location=device)
    config = ckpt['config']
except Exception as e:
    print(e)
    print(f'- checkpoint cannot be opened: {path}')

# initiate the model
model = hydra.utils.instantiate(config).to(device)
    
# initiate SSL model and load model state
ssl_config = deepcopy(config)
ssl_config['_target_'] = ssl_config['_ssl_target_']
ssl_model = hydra.utils.instantiate(ssl_config, model).to(device)

if ckpt["config"]["ddp"] == ssl_config["ddp"]:
    ssl_model.load_state_dict(ckpt["ssl_model_state"])
elif ckpt["config"]["ddp"]:
    ssl_model_state_ddp = deepcopy(ckpt["ssl_model_state"])
    ssl_model_state = OrderedDict()
    for k, v in ssl_model_state_ddp.items():
        name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
        ssl_model_state[name] = v
    ssl_model.load_state_dict(ssl_model_state)
else:
    ssl_model.module.load_state_dict(ckpt["ssl_model_state"])    

model_state = deepcopy(ssl_model.backbone.state_dict())
del ssl_config, ssl_model

# load
model.load_state_dict(model_state)
pprint.pprint(config)

{'EKG': 'O',
 '_ssl_target_': 'models.ssl.byol.BYOL',
 '_target_': 'models.conformer.ConformerClassifier',
 'age_mean': tensor([71.1602], device='cuda:0'),
 'age_std': tensor([9.8829], device='cuda:0'),
 'awgn': 0.05,
 'awgn_age': 0.001,
 'base_lr': 0.1,
 'channel_dropout': 0.2,
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'criterion': 'cross-entropy',
 'crop_multiple': 2,
 'crop_timing_analysis': False,
 'cwd': '/home/imkbsz/workspace/eeg_analysis',
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/caueeg-dataset/',
 'ddp': False,
 'device': device(type='cuda'),
 'draw_result': True,
 'dropout': 0.1,
 'embedding_layer': 3,
 'encoder_dim': 512,
 'fc_stages': 3,
 'file_format': 'memmap',
 'in_channels': 40,
 'input_norm': 'datapoint',
 'iterations': 260417,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_with_warmup_half',
 'mgn': 0.05,
 'minibatch': 384,
 'minibatch_3

In [6]:
# define finetuning range   
if finetune == 'fc_stage':
    model.requires_grad_(False)
    for name, param in model.named_parameters():
        if 'fc_stage' in name:
            param.requires_grad_(True)
        elif 'heads' in name:
            param.requires_grad_(True)
    for name, param in model.named_parameters():
        print(f"{str(param.requires_grad):^15}\t{name}")
elif finetune == 'whole':
    model.requires_grad_(True)
elif finetune == 'reset':
    model = model = hydra.utils.instantiate(config).to(device)
    model.requires_grad_(True)
else:
    raise NotImplementedError('Not implemented!')
    
# TODO: Need to think about the DropOut and Batch/LayerNorms statistics
# eval/train mode

In [7]:
# modify configuration
config['project'] = project
config['use_wandb'] = use_wandb
config['pre_model'] = pre_model_name
config['finetune'] = finetune
config['device'] = device

config['crop_multiple'] = crop_multiple
config['total_samples'] = total_samples
if reset_minibatch: 
    config.pop('minibatch')
config['search_lr'] = search_lr
config['base_lr'] = base_lr
config['lr_scheduler_type'] = lr_scheduler_type

config["output_length"] = model.get_output_length()
config["num_params"] = count_parameters(model)
if warmup_min:
    config["warmup_min"] = warmup_min

config['mixup'] = mixup
    
# remove unused keywords
config.pop('_ssl_target_', None)
config.pop('embedding_layer', None)
config.pop('mlp_hidden_size', None)
config.pop('projection_size', None)
config.pop('warmup_steps', None)
pass

---
## Train

In [8]:
# check the workstation environment and update some configurations
check_device_env(config)

# collect some garbage
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

# fix the seed for reproducibility (a negative seed value means not fixing)
set_seed(config, rank=None)

# compose dataset
train_loader, val_loader, test_loader, multicrop_test_loader = compose_dataset(config)

pprint.pprint(config)

{'EKG': 'O',
 '_target_': 'models.conformer.ConformerClassifier',
 'age_mean': tensor([71.1602], device='cuda:0'),
 'age_std': tensor([9.8829], device='cuda:0'),
 'awgn': 0.05,
 'awgn_age': 0.001,
 'base_lr': 0.001,
 'channel_dropout': 0.2,
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'criterion': 'cross-entropy',
 'crop_multiple': 8,
 'crop_timing_analysis': False,
 'cwd': '/home/imkbsz/workspace/eeg_analysis',
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/caueeg-dataset/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.1,
 'encoder_dim': 512,
 'fc_stages': 3,
 'file_format': 'memmap',
 'finetune': 'whole',
 'in_channels': 40,
 'input_norm': 'datapoint',
 'iterations': 260417,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_with_warmup_half',
 'mgn': 0.05,
 'minibatch': 384,
 'minibatch_3090': 384,
 'mixup': 0.3,
 'mod

In [None]:
# train
train_script(
    config,
    model,
    train_loader,
    val_loader,
    test_loader,
    multicrop_test_loader,
    config["preprocess_train"],
    config["preprocess_test"],
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.



******************************    Configurations for Train    ******************************

{'EKG': 'O',
 '_target_': 'models.conformer.ConformerClassifier',
 'age_mean': tensor([71.1602], device='cuda:0'),
 'age_std': tensor([9.8829], device='cuda:0'),
 'awgn': 0.05,
 'awgn_age': 0.001,
 'base_lr': 0.001,
 'channel_dropout': 0.2,
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'criterion': 'cross-entropy',
 'crop_multiple': 8,
 'crop_timing_analysis': False,
 'cwd': '/home/imkbsz/workspace/eeg_analysis',
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/caueeg-dataset/',
 'ddp': False,
 'device': device(type='cuda', index=0),
 'draw_result': True,
 'dropout': 0.1,
 'encoder_dim': 512,
 'fc_stages': 3,
 'file_format': 'memmap',
 'finetune': 'whole',
 'in_channels': 40,
 'input_norm': 'datapoint',
 'iterations': 260417,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_w

[34m[1mwandb[0m: Currently logged in as: [33mipis-mjkim[0m. Use [1m`wandb login --relogin`[0m to force relogin
