# Finetune Masked-AutoEncoder

- Finetune the deep network after pretraining the self-supervised learning framework.

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import os
import gc
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
import wandb
import pprint
import numpy as np
import torch
from tqdm.auto import tqdm
from collections import OrderedDict

import matplotlib
import matplotlib.pyplot as plt
import scienceplots

# custom package
from run_train import check_device_env
from run_train import set_seed
from run_train import compose_dataset
from run_train import generate_model
from train.ssl_train_script import ssl_train_script
from train.train_script import train_script
from models.utils import count_parameters

---

## Specify the dataset, model, and train setting

In [None]:
pre_model_path = 'local/checkpoint/'
pre_model_name = '2ew55ua4'

use_wandb = False
project = 'caueeg-mae'
device = 'cuda:0'

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device(device if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

---

## Load and modify the pretrained network

In [None]:
# load pretrained configurations
path = os.path.join(pre_model_path, pre_model_name.split(',')[-1], 'checkpoint.pt')
try:
    ckpt = torch.load(path, map_location=device)
    config = ckpt['config']
except Exception as e:
    print(e)
    print(f'- checkpoint cannot be opened: {path}')
pprint.pprint(config)

## Finetuning

In [None]:
pre_model = deepcopy(model)
pre_model_state = pre_model.state_dict()

In [None]:
# training configuration
config['project'] = project
config['use_wandb'] = use_wandb
config['pre_model'] = pre_model_name
config['device'] = device

config['total_samples'] = 5.0e+5
config['search_lr'] = False
config['base_lr'] = 1e-3
config['lr_scheduler_type'] = 'cosine_decay_with_warmup_half'

config["warmup_min"] = 200   

# model
config["tuning_type"] = "finetune"  # "finetune", "fc_stage"
config["layer_wise_lr"] = True

config["out_dims"] = 3
config["task"] = "dementia"
config["use_age"] = 'fc'
config["fc_stages"] = 3
config["global_pool"] = True
config["dropout"] = 0.3

In [None]:
# check the workstation environment and update some configurations
check_device_env(config)

# compose dataset
train_loader, val_loader, test_loader, multicrop_test_loader = compose_dataset(config)

pprint.pprint(config)

In [None]:
# generate the model
config["_target_"] = config["_target_"].replace('.ssl', '').replace('_pre', '')
model = generate_model(config).to(device)

# load the model
model_state = model.state_dict()
for k, v in model_state.items():
    if not k.startswith('fc') and not k.endswith("pos_embed")::
        model_state[k] = pre_model_state[k]

model.load_state_dict(model_state)

In [None]:
model.finetune_mode(config["tuning_type"])
config["num_params"] = count_parameters(model)

for name, param in model.named_parameters():
    print(f"{name:100}\t|\t{param.requires_grad}")

In [None]:
# collect some garbage
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

# fix the seed for reproducibility (a negative seed value means not fixing)
set_seed(config, rank=None)

# train
train_script(
    config,
    model,
    train_loader,
    val_loader,
    test_loader,
    multicrop_test_loader,
    config["preprocess_train"],
    config["preprocess_test"],
)