# Finetune Masked-AutoEncoder

- Finetune the deep network after pretraining the self-supervised learning framework.

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import os
import gc
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
import wandb
import pprint
import numpy as np
import torch
from tqdm.auto import tqdm
from collections import OrderedDict

import matplotlib
import matplotlib.pyplot as plt
import scienceplots
from matplotlib.colors import Normalize
import matplotlib.transforms as mtransforms
from matplotlib.patches import FancyBboxPatch
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredDirectionArrows

# custom package
from run_train import check_device_env
from run_train import set_seed
from run_train import compose_dataset
from run_train import generate_model
from train.ssl_train_script import ssl_train_script
from train.train_script import train_script
from models.utils import count_parameters

---

## Specify the dataset, model, and train setting

In [None]:
pre_model_path = 'local/checkpoint/'
pre_model_name = '2023-1110-2304'

use_wandb = False
project = 'caueeg-mae'
device = 'cuda:0'

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device(device if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

---

## Load and modify the pretrained network

In [None]:
# load pretrained configurations
path = os.path.join(pre_model_path, pre_model_name.split(',')[-1], 'checkpoint.pt')
try:
    ckpt = torch.load(path, map_location=device)
    config = ckpt['config']
except Exception as e:
    print(e)
    print(f'- checkpoint cannot be opened: {path}')
pprint.pprint(config)

In [None]:
# generate the model
model = generate_model(config).to(device)
model.load_state_dict(model.state_dict())

---
## Visualize

In [None]:
# check the workstation environment and update some configurations
check_device_env(config)

# compose dataset
train_loader, val_loader, test_loader, multicrop_test_loader = compose_dataset(config)

pprint.pprint(config)

In [None]:
signal_header = [channel.split('-')[0] for i, channel in enumerate(config["signal_header"])]
fps = config.get('resample', 200)

In [None]:
def from_as_real_to_complex(signal):
    N, _, H, W = signal.shape
    C = signal.shape[1] // 2

    sig_out = torch.zeros((N, C, H, W, 2))
    sig_out[..., 0] = signal[:, :C]
    sig_out[..., 1] = signal[:, C:]

    sig_out = torch.view_as_complex(sig_out)
    return sig_out

In [None]:
def draw_stft(img, pred, mask, config, index=0, log_scale=False, save_fig=None):
    img = from_as_real_to_complex(img)[index].abs().cpu().numpy()
    pred = from_as_real_to_complex(pred)[index].abs().cpu().numpy()
    
    # always do not consider EKG and Photic channels
    C, H, W = img.shape
    p = config["patch_size"]
    h = H // p
    w = W // p
    mask = mask[index].reshape(h, w)

    signal_f = np.zeros_like(img)
    for hh in range(h):
        for ww in range(w):
            if mask[hh, ww] > 0.5:
                signal_f[:, hh*p:(hh + 1)*p, ww*p:(ww + 1)*p] = img[:, hh*p:(hh + 1)*p, ww*p:(ww + 1)*p]
            else:
                signal_f[:, hh*p:(hh + 1)*p, ww*p:(ww + 1)*p] = pred[:, hh*p:(hh + 1)*p, ww*p:(ww + 1)*p]
    
    columns = 7
    rows = round(np.ceil(C / columns))
    fig, ax = plt.subplots(rows, columns, 
                           figsize=(22.0, 9.5), constrained_layout=True)
    normalizer = Normalize()
    
    for k in range(columns * rows):
        r = k // columns
        c = k % columns

        if k < C:
            im = ax[r, c].imshow(np.log(signal_f[k] + 1e-8) if log_scale else signal_f[k],
                                 interpolation='nearest',
                                 extent=[0, config['seq_length']/fps, 0, fps/2.0], 
                                 aspect=(config['seq_length']/fps) / (fps/2.0))
            ax[r, c].set_title(config['signal_header'][k].split('-')[0], 
                               fontsize=18, fontweight='bold', color='darkred')
            ax[r, c].set_xlabel('Time (s)', fontsize=13)
            ax[r, c].set_ylabel('Frequency (Hz)', fontsize=13)
            # ax[r, c].invert_yaxis()
        else:
            axins = ax[r, c]
            ax[r, c].axis('off')
        
    fig.suptitle('Time-Frequency Representation', fontsize=20, fontweight='semibold')
    # cax = inset_axes(axins,
    #                  width="10%",  # width = 10% of parent_bbox width
    #                  height="80%",  # height : 50%
    #                  loc='center',
    #                  bbox_to_anchor=(0., 0., 1, 1),
    #                  bbox_transform=axins.transAxes,
    #                  borderpad=0,
    #                  )
    # cbar = fig.colorbar(im, ax=ax.ravel().tolist(), cax=cax)
    # cbar.ax.set_xlabel('Magnitude in log-scale' if log_scale else 'Magnitude', fontsize=13) 

    plt.show()
    fig.clear()
    plt.close(fig)

In [None]:
with torch.no_grad():
    for target_dataset in tqdm(["val"], desc="Dataset", leave=False):
        if target_dataset == 'train':
            loader = train_loader
        elif target_dataset == 'val':
            loader = val_loader
        elif target_dataset == 'test':
            loader = test_loader
        else:
            raise ValueError('')
                
        for sample_batched in tqdm(loader, total=len(loader), desc='Batch', leave=False):
            config["preprocess_test"](sample_batched)
            img = sample_batched["signal"]

            pred, mask = model.mask_and_reconstruct(img, sample_batched["age"], config["mask_ratio"])
            pred_img = model.unpatchify(pred)

            draw_stft(img, pred_img, mask, config, index=0, log_scale=False, save_fig=None)            
            break
        break

## Finetuning

In [None]:
pre_model = deepcopy(model)
pre_model_state = pre_model.state_dict()

In [None]:
# training configuration
config['project'] = project
config['use_wandb'] = use_wandb
config['pre_model'] = pre_model_name
config['device'] = device

config['total_samples'] = 5.0e+5
config['search_lr'] = False
config['base_lr'] = 1e-3
config['lr_scheduler_type'] = 'cosine_decay_with_warmup_half'

config["warmup_min"] = 200   

# model
config["tuning_type"] = "finetune"  # "finetune", "fc_stage"
config["layer_wise_lr"] = True

config["out_dims"] = 3
config["task"] = "dementia"
config["use_age"] = 'fc'
config["fc_stages"] = 3
config["global_pool"] = True
config["dropout"] = 0.3

In [None]:
# check the workstation environment and update some configurations
check_device_env(config)

# compose dataset
train_loader, val_loader, test_loader, multicrop_test_loader = compose_dataset(config)

pprint.pprint(config)

In [None]:
# generate the model
config["_target_"] = config["_target_"].replace('.ssl', '').replace('_pre', '')
model = generate_model(config).to(device)

# load the model
model_state = model.state_dict()
for k, v in model_state.items():
    if not k.startswith('fc') and not k.endswith("pos_embed"):
        model_state[k] = pre_model_state[k]

model.load_state_dict(model_state)

In [None]:
model.finetune_mode(config["tuning_type"])
config["num_params"] = count_parameters(model)

for name, param in model.named_parameters():
    print(f"{name:100}\t|\t{param.requires_grad}")

In [None]:
# collect some garbage
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

# fix the seed for reproducibility (a negative seed value means not fixing)
set_seed(config, rank=None)

# train
train_script(
    config,
    model,
    train_loader,
    val_loader,
    test_loader,
    multicrop_test_loader,
    config["preprocess_train"],
    config["preprocess_test"],
)