# Finetune Self-Supervision

- Finetune the deep network after pretraining the self-supervised learning framework.

-----

## Load Packages

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [21]:
# Load some packages
import os
import gc
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
import wandb
import pprint
import torch

# custom package
from run_train import check_device_env
from run_train import set_seed
from run_train import compose_dataset
from train.train_script import train_script
import models
from models.utils import count_parameters

---

## Specify the dataset, model, and train setting

In [3]:
model_path = 'local/checkpoint/'
model_a_name = 'l8524nml'
model_b_name = 'xci5svkl'

project = 'noname'
use_wandb = True
device = 'cuda'

crop_multiple = 8
total_samples = 3.0e+6
minibatch = 256
search_lr = True   ##########
base_lr = 1e-3     #########
warmup_min = 300   # None
lr_scheduler_type = 'cosine_decay_with_warmup_half'  # 'consine_decay_with_warmup_half', 'linear_decay_with_warmup'

criterion = 'cross-entropy'
mixup = 0.0
awgn = 0.0
mwgn = 0.0

In [4]:
print('PyTorch version:', torch.__version__)
device = torch.device(device if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

PyTorch version: 2.0.0+cu117
cuda is available.


---

## Load and modify the pretrained network

In [14]:
def initiate_model(config, device):
    model_state = ckpt['model_state']
    # initiate the model
    if '_target_' in config:
        model = hydra.utils.instantiate(config).to(device)
    elif type(config['generator']) is str:
        config['generator'] = getattr(models, config['generator'].split('.')[-1])
        if 'block' in config:
            config['block'] = getattr(models, config['block'].split('.')[-1])
        model = config['generator'](**config).to(device)
    else:
        if 'block' in config:
            if config['block'] == models.resnet_1d.BottleneckBlock1D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_2d.Bottleneck2D:
                config['block'] = 'bottleneck'
            elif config['block'] == models.resnet_1d.BasicBlock1D:
                config['block'] = 'basic'
            elif config['block'] == models.resnet_2d.BasicBlock2D:
                config['block'] = 'basic'
                
        model = config['generator'](**config).to(device)
    
    if config.get('ddp', False):
        model_state_ddp = deepcopy(model_state)
        model_state = OrderedDict()
        for k, v in model_state_ddp.items():
            name = k[7:]  # remove 'module.' of DataParallel/DistributedDataParallel
            model_state[name] = v
    
    model.load_state_dict(model_state)
    return model

In [25]:
# load pretrained configurations
path = os.path.join(model_path, model_a_name.split(',')[-1], 'checkpoint.pt')
try:
    ckpt = torch.load(path, map_location=device)
    config = ckpt['config']
except Exception as e:
    print(e)
    print(f'- checkpoint cannot be opened: {path}')

# initiate the model
config_a = deepcopy(config)
model_a = initiate_model(config_a, device)
pprint.pprint(config_a)

{'EKG': 'O',
 '_target_': 'models.resnet_1d.ResNet1D',
 'activation': 'gelu',
 'age_mean': tensor([71.1417], device='cuda:0'),
 'age_std': tensor([9.7264], device='cuda:0'),
 'awgn': 0.004872735559634612,
 'awgn_age': 0.03583361229344302,
 'base_channels': 64,
 'base_lr': 0.00046936536527944847,
 'block': 'basic',
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'conv_layers': [2, 2, 2, 2],
 'criterion': 'multi-bce',
 'crop_multiple': 4,
 'crop_timing_analysis': False,
 'cwd': '/home/minjae/Desktop/eeg_analysis',
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda'),
 'draw_result': True,
 'dropout': 0.3,
 'fc_stages': 3,
 'file_format': 'memmap',
 'in_channels': 21,
 'input_norm': 'dataset',
 'iterations': 390625,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_with_warmup_half',
 'mgn': 0.09575622309480

In [26]:
# load pretrained configurations
path = os.path.join(model_path, model_b_name.split(',')[-1], 'checkpoint.pt')
try:
    ckpt = torch.load(path, map_location=device)
    config = ckpt['config']
except Exception as e:
    print(e)
    print(f'- checkpoint cannot be opened: {path}')

# initiate the model
config_b = deepcopy(config) 
model_b = initiate_model(config_b, device)
pprint.pprint(config_b)

{'EKG': 'O',
 'LR': 0.0010669676460233542,
 'activation': 'relu',
 'age_mean': tensor([71.2325], device='cuda:0'),
 'age_std': tensor([9.7353], device='cuda:0'),
 'awgn': 0.07461481243334823,
 'awgn_age': 0.027784886639797713,
 'base_channels': 64,
 'block': 'basic',
 'class_label_to_name': ['Normal', 'MCI', 'Dementia'],
 'class_name_to_label': {'Dementia': 2, 'MCI': 1, 'Normal': 0},
 'conv_layers': [2, 2, 2, 2],
 'criterion': 'cross-entropy',
 'crop_multiple': 4,
 'dataset_name': 'CAUEEG dataset',
 'dataset_path': 'local/dataset/02_Curated_Data_220419/',
 'ddp': False,
 'device': device(type='cuda'),
 'draw_result': True,
 'dropout': 0.4969637310464512,
 'fc_stages': 2,
 'file_format': 'memmap',
 'final_pool': 'average',
 'generator': <class 'models.resnet_2d.ResNet2D'>,
 'history_interval': 333,
 'in_channels': 40,
 'input_norm': 'dataset',
 'iterations': 166666,
 'latency': 2000,
 'load_event': False,
 'lr_scheduler_type': 'cosine_decay_with_warmup_one_and_half',
 'mgn': 0.032280199

In [None]:
# define concatnet

class ConCatNet(nn.Module):
    def __init__(model_a, model_b):
        self.model_a = model_a
        self.model_b = model_b

    def compute_feature_embedding(self, x, age, target_from_last: int = 0):
        
    def forward(self, x, age):
        x = self.compute_feature_embedding(x, age)
        # return F.log_softmax(x, dim=1)
        return x

In [None]:
# generate
config = {}

config['project'] = project
config['use_wandb'] = use_wandb
config['model_a_name'] = model_a_name
config['model_a'] = model_a
config['model_b_name'] = model_b_name
config['model_b'] = model_b
config['device'] = device

config['crop_multiple'] = crop_multiple
config['total_samples'] = total_samples
if reset_minibatch: 
    config.pop('minibatch')
config['search_lr'] = search_lr
config['base_lr'] = base_lr
config['lr_scheduler_type'] = lr_scheduler_type

config["num_params"] = count_parameters(model)
if warmup_min:
    config["warmup_min"] = warmup_min

config['mixup'] = mixup
    
# remove unused keywords
config.pop('_ssl_target_', None)
config.pop('embedding_layer', None)
config.pop('mlp_hidden_size', None)
config.pop('projection_size', None)
config.pop('warmup_steps', None)
pass

---
## Train

In [None]:
# check the workstation environment and update some configurations
check_device_env(config)

# collect some garbage
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

# fix the seed for reproducibility (a negative seed value means not fixing)
set_seed(config, rank=None)

# compose dataset
train_loader, val_loader, test_loader, multicrop_test_loader = compose_dataset(config)

pprint.pprint(config)

In [None]:
# train
train_script(
    config,
    model,
    train_loader,
    val_loader,
    test_loader,
    multicrop_test_loader,
    config["preprocess_train"],
    config["preprocess_test"],
)