# This is code for pre-training SGNN with ABCD data

In [1]:
#GPU number
GPU = 0

num_workers = 4

#Random Seed
seed = 42 

#Target set to 'p' since we're predicting the p-factor. 
target_name = 'p' 

In [2]:
import sklearn, torch
sklearn.__version__, torch.__version__

('0.24.1', '1.10.1+cu102')

In [3]:
#Index for hyperparameter selection.
#Hyperparameter sets were selected with Paramgrid, then one specific hyperparameter was chosen based on this index.
temp_sel_idx = 0 

#Number of cross-validation. 5 as default
n_cv = 5

## 1. Parameter setting

In [4]:
import numpy as np

outer_cv_part = np.arange(0, n_cv)
print("Selected Fold: {}".format(outer_cv_part))

select_fold = [0]
print("Selected Fold: {}".format(select_fold))

Selected Fold: [0 1 2 3 4]
Selected Fold: [0]


In [17]:
from sklearn.model_selection import ParameterGrid

#Number of nodes in SGNN. 1024 as default
ext_cand = [1024]
prd_cand = [1024]
dsc_cand = [1024]

#Dropout rate of SGNN
dropout_ext_cand = [0.9]
dropout_prd_cand = [0.9]
dropout_dsc_cand = [0.0]

#Activation function / Optimizer of SGNN
act_func_name = "elu"
optimizer_name = "nag"

#Batch size / Learning rate / LR patience for scheduling / LR scaling ratio / Epochs / Freezing epochs
batch_size_cand = [32]
lr_cand = [5e-05]
lr_patience_cand = [5]
lr_factor_cand = [0.5]
epochs_cand = [10]
pretrain_epoch_cand = [20]

#Hoyer's sparsity candidates
hsp_ext_cand = [0.95] #[0.975, 0.95, 0.9, 0.8]
hsp_prd_cand = [0.3]
hsp_dsc_cand = [0.3]

#Lambda for gradient reversal layer
lambda_cand = [0.01]

#L2 norm
l2_param_cand = [5e-02]


param_cand = {
    "ext": ext_cand, "prd": prd_cand,"dsc":dsc_cand,  
    "dropout_ext": dropout_ext_cand, "dropout_prd": dropout_prd_cand,
    "dropout_dsc": dropout_dsc_cand,"batch_size": batch_size_cand,
    "lr": lr_cand, "epochs": epochs_cand, "pretrain_epoch": pretrain_epoch_cand,
    "lr_patience": lr_patience_cand, "lr_factor": lr_factor_cand,
    "hsp_ext": hsp_ext_cand, "hsp_prd": hsp_prd_cand,"hsp_dsc": hsp_dsc_cand,
    "l2_param": l2_param_cand,"lambda_": lambda_cand
}

In [6]:
import os
import gc
import time
import math
import pickle
import random
import itertools
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image
import imageio
import seaborn as sns
import scipy.stats as stats
from decimal import Decimal
from datetime import datetime as dt
from pytz import timezone
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler, MinMaxScaler,QuantileTransformer,RobustScaler,PowerTransformer
from sklearn.preprocessing import OneHotEncoder

In [7]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import torch.optim as optim
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import ReduceLROnPlateau 
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [8]:
import warnings
warnings.filterwarnings('ignore')

In [9]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)

In [10]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [11]:
def seed_everything(seed=seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(seed)

In [12]:
nowtime = dt.now(timezone("Asia/Seoul")); year = str(nowtime.year)[2:]
month = '0{}'.format(nowtime.month) if nowtime.month < 10 else str(nowtime.month)
day = '0{}'.format(nowtime.day) if nowtime.day < 10 else str(nowtime.day)
hour = '0{}'.format(nowtime.hour) if nowtime.hour < 10 else str(nowtime.hour)
minute = '0{}'.format(nowtime.minute) if nowtime.minute < 10 else str(nowtime.minute)
sec = '0{}'.format(nowtime.second) if nowtime.second < 10 else str(nowtime.second)
msec = str(nowtime.microsecond)[:2]

## 2. Input / Target setting & Fold division

In [13]:
#Input data: vectorized RSFC
data_path = "/users/hjw/data/ABCD/npz_files/rsfc_p_site_scanner_si_ge.npz"
data = np.load(data_path, allow_pickle=True)
X = data["X"]

#Target data: factor score / scanner type
targets_all = np.load("/data4/SNU/data/ABCD_CFA_5factor.npz",allow_pickle=True) 
target_fs = targets_all[target_name]
target_scn = targets_all['scn']
y = np.hstack([target_fs.reshape(-1,1),target_scn.reshape(-1,1)])
print(X.max(),X.min())
print(X.shape, y.shape)

7.254328727722168 -7.254328727722168
(6905, 61776) (6905, 2)


In [14]:
# X_wfm = make_wfm(stats.zscore(X.mean(0)))
# sns.set(style="white", font_scale=1.5)
# plt.figure(figsize=(20,15))
# sns.heatmap(X_wfm,vmax=1.96,vmin=-1.96,cmap="RdBu_r")

In [15]:
p_factor_idx = 0
scanner_idx = 1

In [16]:
seed_everything(seed)

from sklearn.model_selection import KFold, ShuffleSplit

outer_n_splits = n_cv

outer_train_folds_idx = []
outer_val_folds_idx = []
outer_test_folds_idx = []

outer_skf = ShuffleSplit(
    n_splits=outer_n_splits, test_size=0.20, random_state=seed)


if 'Full' in select_fold:
    n_cv = 1
    outer_train_folds_idx.append(np.arange(len(X)))
    outer_val_folds_idx.append(np.array([0,1]))
    outer_test_folds_idx.append(np.array([0,1]))

else:
    for n_outer, (outer_train_idx, outer_test1_idx) in enumerate(outer_skf.split(X, y)):
        outer_train_folds_idx.append(outer_train_idx)
#         outer_test_folds_idx.append(outer_test_idx)
        outer_val_idx = np.random.choice(outer_test1_idx,size=len(outer_test1_idx)//2,replace=False)
        outer_test_idx = np.array([i for i in outer_test1_idx if i not in outer_val_idx])
        outer_val_folds_idx.append(outer_val_idx)
        outer_test_folds_idx.append(outer_test_idx)

len(outer_train_folds_idx),len(outer_train_folds_idx[0]), len(outer_val_folds_idx[0]), len(outer_test_folds_idx[0])

(5, 5524, 690, 691)

## 3. Real-time visualization code

In [18]:
num_ROIs = tot_rois = 352
threshold = 1.96

# Preparing draw feature map
parcels = pd.read_excel("/users/hjw/data/ABCD/Parcels/Parcels.xlsx", engine="openpyxl")
networks = list(parcels["Community"]) + 19 * ["Subcortex"]

networks_df = pd.DataFrame(networks, columns=["network"])
networks_df[networks_df["network"] == "Auditory"] = "AUD"
networks_df[networks_df["network"] == "Visual"] = "VIS"
networks_df[networks_df["network"] == "VentralAttn"] = "VAN"
networks_df[networks_df["network"] == "Subcortex"] = "SCN"
networks_df[networks_df["network"] == "Salience"] = "SAL"
networks_df[networks_df["network"] == "SMmouth"] = "SMM"
networks_df[networks_df["network"] == "SMhand"] = "SMH"
networks_df[networks_df["network"] == "RetrosplenialTemporal"] = "RSP"
networks_df[networks_df["network"] == "None"] = "NONE"
networks_df[networks_df["network"] == "FrontoParietal"] = "FPN"
networks_df[networks_df["network"] == "DorsalAttn"] = "DAN"
networks_df[networks_df["network"] == "Default"] = "DMN"
networks_df[networks_df["network"] == "CinguloParietal"] = "CPAR"
networks_df[networks_df["network"] == "CinguloOperc"] = "CON"

networks = np.array(networks_df["network"]).astype("str")
network_label = np.unique(networks, return_index=True)[0].astype("str")
orig_network_path = "/users/hjw/data/ABCD/npz_files/Gordon_network_labels.npz"
orig_networks = np.load(orig_network_path)["networks"]
print(orig_networks.shape)

# Set order of network label 
new_orig_network_idx_order = [
    'AUD', 'DAN', 'CPAR', 'NONE', 
    'SAL', 'VIS', 'RSP', 'DMN',
    'SMM', 'CON', 'SCN', 
    'SMH', 'VAN', 'FPN'
]

sorted_order = sorted(new_orig_network_idx_order)
sorted_order.remove("NONE")
sorted_order.append("NONE")
new_orig_network_idx_order = sorted_order
print(new_orig_network_idx_order)

new_orig_network_order = {
    key:value for (key, value) in 
    zip(new_orig_network_idx_order, (np.arange(len(new_orig_network_idx_order))))
}


ref_orig_wfm = np.zeros((tot_rois, tot_rois))
print(ref_orig_wfm.shape)
ref_orig_wfm = pd.DataFrame(ref_orig_wfm, index=orig_networks, columns=orig_networks)
ref_orig_wfm = ref_orig_wfm.sort_index(key=lambda x: x.map(new_orig_network_order), 
                                       axis=0)
ref_orig_wfm = ref_orig_wfm.sort_index(key=lambda x: x.map(new_orig_network_order), 
                                       axis=1)    

network_unq = new_orig_network_idx_order
sorted_networks = np.array(ref_orig_wfm.columns, dtype=np.str)


sorted_networks_df = pd.DataFrame(np.unique(sorted_networks, return_index=True)).T
sorted_networks_df.columns = ["networks", "n"]
sorted_networks_df = sorted_networks_df.sort_values(
    by="networks", key=lambda x: x.map(new_orig_network_order)
)

start_network_idx = np.array(sorted_networks_df.n)
next_network_idx = np.hstack((start_network_idx[1:], 352))

network_mid_idx = np.array((start_network_idx + next_network_idx) / 2, dtype=np.int)

(352,)
['AUD', 'CON', 'CPAR', 'DAN', 'DMN', 'FPN', 'RSP', 'SAL', 'SCN', 'SMH', 'SMM', 'VAN', 'VIS', 'NONE']
(352, 352)


In [19]:
def make_wfm(vec):
    wfm = np.zeros((tot_rois, tot_rois))
    iu_non_di_idx = np.mask_indices(tot_rois, np.triu, 1)
    wfm[iu_non_di_idx] = vec
    il_idx = np.tril_indices(tot_rois, -1)
    wfm[il_idx] = wfm.T[il_idx]
    wfm_df = pd.DataFrame(wfm, index=orig_networks, columns=orig_networks)
    wfm_df = wfm_df.sort_index(key=lambda x: x.map(new_orig_network_order), axis=0)
    wfm_df = wfm_df.sort_index(key=lambda x: x.map(new_orig_network_order), axis=1)
    
    return wfm_df

In [20]:
#Model, Epoch as input, save weight feature map

def visualize_wfm(trained_model,epoch,epoch_gap = 10,threshold=False,mode = 'sum'):
    
    if epoch%epoch_gap !=0:
        return
    
    wfm_save_dir = outer_save_dir+'/WFM/'
    if not os.path.isdir(wfm_save_dir):
        os.mkdir(wfm_save_dir)
    
    w_ext = []
    w_reg = []
    w_dsc = []

    ext_hidden1 = trained_model.ext_1.weight.shape[0]
    prd_hidden1 = trained_model.prd_1.weight.shape[0]

    w = []

    for name,params in trained_model.named_parameters():
        if "weight" in name and "bn" not in name:
            w.append(params)

    w_ext_1 = w[0].detach().cpu().numpy().T
    w_reg_1 = w[1].detach().cpu().numpy().T
    w_reg_2 = w[2].detach().cpu().numpy().T

    temp_w_ext = w_ext_1
    temp_w_reg = np.matmul(np.matmul(temp_w_ext, w_reg_1), w_reg_2)

    w_ext.append(temp_w_ext) #ext x pred1 (352x1024)
    w_reg.append(temp_w_reg) #ext x pred1 x pred2 (352x1)

#     np.savez(wfm_save_dir+"/wfm_reg_epoch{}".format(epoch), X=w_reg)
    ### ROI level interpretation - HJD ###
    wfm_rois = w_reg[0].reshape(-1,)
    wfm_rois[np.where(np.abs(stats.zscore(wfm_rois))<3.091)[0]] = 0
    wfm_rois_df = make_wfm(wfm_rois)
    
    sns.set(style="white", font_scale=3)
    fig, ax = plt.subplots(figsize=(32, 32))
    cbar_kws = dict(use_gridspec=False, shrink=0.85, location="right")

    sns.heatmap(
        wfm_rois_df, square=True, cmap="RdBu_r", center=0,  
        ax=ax, cbar_kws=cbar_kws
    )
    ax.set_title("99.9% significant ROIs")
    ax.set_xticks(network_mid_idx)
    ax.set_yticks(network_mid_idx)

    ax.set_xticklabels(sorted_order, rotation=90, fontsize=45, ha="center")
    ax.set_yticklabels(sorted_order, rotation='horizontal', fontsize=45)

    for network_pos in next_network_idx:
        plt.axvline(network_pos, linewidth=1.5, color="black", ymin=0, ymax=network_pos)
        plt.axhline(network_pos, linewidth=1.5, color="black", xmin=0, xmax=network_pos)

    cbar = ax.collections[0].colorbar
    cbar.set_label("$WF$", fontsize=60, labelpad=50)    
    plt.savefig(wfm_save_dir+f"/WFM_ROIs_epoch{epoch}.jpg")
    plt.close(fig)
    ######
    
    wfm_df = make_wfm(stats.zscore(w_reg[0].reshape(-1,)))

    cols = rows = wfm_df.columns.values
#     new_orig_network_idx_order = np.unique(cols)
    avg_df_1 = pd.DataFrame(columns=cols, index=rows)
    avg_df_2 = pd.DataFrame(columns=new_orig_network_idx_order, index=new_orig_network_idx_order)
    avg_df_3 = pd.DataFrame(columns=new_orig_network_idx_order, index=new_orig_network_idx_order)

    for i, temp_row in enumerate(new_orig_network_idx_order):
        for j, temp_col in enumerate(new_orig_network_idx_order):
            temp_row_ids = np.where(rows == temp_row)[0]
            temp_col_ids = np.where(cols == temp_col)[0]
            temp_mat = wfm_df.iloc[temp_row_ids, temp_col_ids]
            temp_vec = temp_mat.values.ravel()
            n_temp_vec = np.sum(np.absolute(temp_mat.values) > 1.96)
            if mode == 'sum':
                sum_val = temp_vec.sum()
                n_val = n_temp_vec
            elif mode == 'avg':
                if temp_row == temp_col:
                    sum_val = temp_vec.sum() / (len(temp_row_ids) * (len(temp_row_ids) - 1))
                    n_val = n_temp_vec / (len(temp_row_ids) * (len(temp_row_ids) - 1))
                else:
                    sum_val = temp_vec.sum() / (len(temp_row_ids) * (len(temp_col_ids)))
                    n_val = n_temp_vec / (len(temp_row_ids) * (len(temp_col_ids)))
            avg_df_1.iloc[temp_row_ids, temp_col_ids] = sum_val
            avg_df_2.iloc[i, j] = sum_val
            avg_df_3.iloc[i, j] = n_val

    if threshold==False:
        avg_mat = np.array(avg_df_2.values, dtype=np.float)
        avg_mat = (avg_mat - avg_mat.mean()) / avg_mat.std()
        avg_mat_df = pd.DataFrame(avg_mat, columns=new_orig_network_idx_order, index=new_orig_network_idx_order)
        avg_mat_df = avg_mat_df.sort_index(key=lambda x: x.map(new_orig_network_order), axis=0)
        avg_mat_df = avg_mat_df.sort_index(key=lambda x: x.map(new_orig_network_order), axis=1)
        thr_zero = 0
        avg_mat_df[(avg_mat_df < thr_zero) & (avg_mat_df > -thr_zero)] = 0

        sns.set(style="white", font_scale=4.5)
        fig, ax = plt.subplots(figsize=(32, 32))
        threshold = np.abs(avg_mat_df.values).max()
        threshold = 5.5
        cbar_kws = dict(
            use_gridspec=False, shrink=0.85, location="right", label="Correlation ($r$)")
        sns.heatmap(
            avg_mat_df, square=True, cmap="vlag", ax=ax, 
            mask=np.triu(np.ones(avg_mat_df.shape), 1).astype(np.bool),
            vmax=threshold, vmin=-threshold, 
            linewidths=5, 
            cbar=True, cbar_kws=cbar_kws,
            fmt=".2f", annot=True, annot_kws={"fontsize": 32}
        )
        cbar = ax.collections[0].colorbar
        cbar.set_label("$WF$", fontsize=60, labelpad=50)
        ax.set_xticklabels(new_orig_network_order, rotation=90, fontsize=60)
        ax.set_yticklabels(new_orig_network_order, rotation=0, fontsize=60)
        plt.title(f"Epoch {epoch}/{epochs}")
        plt.savefig(wfm_save_dir+f"/{mode}_WFM_epoch{epoch}.jpg")
        plt.close(fig)
    #     plt.show()

    if threshold==True:
        avg_mat = np.array(avg_df_2.values, dtype=np.float)
        avg_mat = (avg_mat - avg_mat.mean()) / avg_mat.std()
        avg_mat_df = pd.DataFrame(avg_mat, columns=new_orig_network_idx_order, index=new_orig_network_idx_order)
        avg_mat_df = avg_mat_df.sort_index(key=lambda x: x.map(new_orig_network_order), axis=0)
        avg_mat_df = avg_mat_df.sort_index(key=lambda x: x.map(new_orig_network_order), axis=1)

        annot = np.vectorize(lambda x: '' if np.absolute(x) < 3.091 else str(round(x, 2)))(avg_mat_df.to_numpy())
        # annot = np.vectorize(lambda x: '' if np.absolute(x) < 2.58 else str(round(x, 2)))(avg_mat_df.to_numpy())
        # annot = np.vectorize(lambda x: '' if np.absolute(x) < 1.96 else str(round(x, 2)))(avg_mat_df.to_numpy())

        thr_zero = 0
        avg_mat_df[(avg_mat_df < thr_zero) & (avg_mat_df > -thr_zero)] = 0

        sns.set(style="white", font_scale=5)
        plt.rcParams['mathtext.fontset'] = 'custom'
        plt.rcParams['mathtext.it'] = 'Arial:italic'
        plt.rcParams['mathtext.rm'] = 'Arial'

        fig, ax = plt.subplots(figsize=(32, 32))
        threshold = np.abs(avg_mat_df.values).max()
        threshold = 7
        cbar_kws = dict(
            use_gridspec=False, shrink=0.85, location="right", label="Correlation ($r$)")
        sns.heatmap(
            avg_mat_df, square=True, cmap="vlag",
            ax=ax,
            mask=np.triu(np.ones(avg_mat_df.shape), 1).astype(np.bool),
            vmax=threshold, vmin=-threshold, 
            linewidths=5, 
            cbar=True, cbar_kws=cbar_kws,
            fmt="", annot=annot, annot_kws={"fontsize": 36}
        )
        cbar = ax.collections[0].colorbar
        cbar.set_label("$WF$", fontsize=60, labelpad=50)
        ax.set_xticklabels(new_orig_network_order, rotation=90, fontsize=60)
        ax.set_yticklabels(new_orig_network_order, rotation=0, fontsize=60)
        plt.title(f"Epoch {epoch}/{epochs}")
        plt.savefig(wfm_save_dir+f"/{mode}_WFM_epoch{epoch}.jpg")
        plt.close(fig)
    #     plt.show()
    
    return

In [21]:
def plot_prediction_result(epoch,pearsonr,valid_prediction,valid_true):
    
    pred_save_dir = outer_save_dir+'/Prediction/'
    if not os.path.isdir(pred_save_dir):
        os.mkdir(pred_save_dir)
    
    sns.set(style="darkgrid", font_scale=2)
    plt.figure(figsize=(10,7))
    fig = sns.scatterplot(valid_true.flatten(), valid_prediction.flatten())
    plt.title(f"Epoch {epoch}/{epochs}, Pearson's r : %.3f" % pearsonr)
    plt.xlabel("True $p$-factor")
    plt.ylabel("Predicted $p$-factor")
    plt.tight_layout()
    plt.ylim(valid_prediction.min(),valid_prediction.max())
    plt.savefig(pred_save_dir+f"/Prediction_epoch{epoch}.png")
    plt.close()
    return 

## 4. Make dataset & Build model

In [22]:
#Hyperparameters for training
mode = "max"
min_lr = 1e-08
lr_alpha = -1.5
lr_beta = 1.7

momentum = 0.90
l1_param = 0
early_stopping_patience = 20

input_dim = X.shape[1]
n_classes = len(np.unique(y[:, scanner_idx]))
output_prd_dim = 1
output_dsc_dim = n_classes

#Hyperparameter for sparsity control
wsc_flag = [0, 0, 0]
beta_lr = [1e-3, 2e-3, 2e-3] #1e-04, 1e-03, 1e-03
max_beta = [1e-2, 2e-02, 2e-02] #1e-2, 5e-02, 5e-02
n_wsc = wsc_flag.count(1)

In [23]:
# Make dataset
class MakeDataset(Dataset): 
    def __init__(self, X_data, y_data):
        self.X_data = X_data
        self.y_data = y_data
        
    def __len__(self):
        return len(self.X_data)
    
    def __getitem__(self, idx): 
        X_data = torch.from_numpy(self.X_data[idx]).type(torch.FloatTensor)
        y_data = torch.from_numpy(self.y_data[idx]).type(torch.FloatTensor)

        return X_data, y_data

In [24]:
#Gradient reversal function (Ganin,2015)
class GradRevFunc(Function):

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = lambda_ * grads.neg()
        return dx, None
    
class GradRev(torch.nn.Module):
    def __init__(self, lambda_=0.0):
        super(GradRev, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradRevFunc.apply(x, self.lambda_)

In [25]:
#Build SGNN model
class SGNN(nn.Module):
    def __init__(self, ext_hidden, prd_hidden, dsc_hidden,
                 dropout_ext, dropout_prd, dropout_dsc, act_func_name, lambda_):
        super(SGNN, self).__init__()
        self.ext_1 = nn.Linear(input_dim, ext_hidden)
        self.ext_bn_1 = nn.BatchNorm1d(ext_hidden)
        
        self.prd_1 = nn.Linear(ext_hidden, prd_hidden)
        self.prd_bn_1 = nn.BatchNorm1d(prd_hidden)
        self.prd_2 = nn.Linear(prd_hidden, output_prd_dim)
        
        self.dsc_1 = nn.Linear(ext_hidden, dsc_hidden)
        self.dsc_bn_1 = nn.BatchNorm1d(dsc_hidden)
        self.dsc_2 = nn.Linear(dsc_hidden, output_dsc_dim)

        self.dropout_ext = nn.Dropout(p=dropout_ext)
        self.dropout_prd = nn.Dropout(p=dropout_prd)
        self.dropout_dsc = nn.Dropout(p=dropout_dsc)
        
        self.act_func = get_act_func(act_func_name)
        self.GradRev = GradRev(lambda_)
        self.weights_init()
    
    def forward(self, x):
        x_ftr = self.ext_1(x)
        x_ftr = self.ext_bn_1(x_ftr)
        x_ftr = self.act_func(x_ftr)
        x_ftr = self.dropout_ext(x_ftr)
        
        x_prd = self.prd_1(x_ftr)
        x_prd = self.prd_bn_1(x_prd)
        x_prd = self.act_func(x_prd)
        x_prd = self.dropout_prd(x_prd)
        x_prd = self.prd_2(x_prd)
        
        x_rev = self.GradRev(x_ftr)
        x_dsc = self.dsc_1(x_rev)
        x_dsc = self.dsc_bn_1(x_dsc)
        x_dsc = self.act_func(x_dsc)
        x_dsc = self.dropout_dsc(x_dsc)
        x_dsc = self.dsc_2(x_dsc)
        
        return x_prd, x_dsc
    
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                nn.init.normal_(m.bias, std=0.01)

In [26]:
def get_optimizer(model, opt_name, learning_rate=None, l2_param=None):
    lower_opt_name = opt_name.lower()
    if lower_opt_name == 'momentum':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param)
    elif lower_opt_name == 'nag':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param, nesterov=True)
    elif lower_opt_name == 'adam':
        return optim.Adam(model.parameters(), lr=learning_rate, 
                          weight_decay=l2_param)
    elif lower_opt_name == 'sparseadam':
        return optim.SparseAdam(model.parameters(), lr=learning_rate,
                       betas=(0.9, 0.999), eps=1e-08, maximize=False)
    elif lower_opt_name == 'radam':
        return optim.RAdam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999),
                           eps=1e-08, weight_decay=l2_param)
    elif lower_opt_name == 'nadam':
        return optim.NAdam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08,
                  weight_decay=l2_param, momentum_decay=0.004)
    elif lower_opt_name == 'adamax':
        return optim.Adamax(model.parameters(), lr=learning_rate, betas=(0.9, 0.999),
                            eps=1e-08, weight_decay=l2_param)
    else:
        sys.exit("Illegal arguement for optimizer type")

In [27]:
def get_act_func(act_func_name):
    act_func_name = act_func_name.lower()
    if act_func_name == 'relu':
        return nn.ReLU()
    elif act_func_name == 'prelu':
        return nn.PReLU()
    elif act_func_name == 'elu':
        return nn.ELU()
    elif act_func_name == 'silu':
        return nn.SiLU()
    elif act_func_name == 'leakyrelu':
        return nn.LeakyReLU()
    elif act_func_name == 'tanh':
        return nn.Tanh()
    elif act_func_name == 'selu':
        return nn.SELU()
    elif act_func_name == 'gelu':
        return nn.GELU()
    else:
        sys.exit("Illegal arguement for activation function type")

In [28]:
def init_hsp(n_wsc, epochs):
    hsp_val = torch.zeros(n_wsc)
    beta_val = torch.clone(hsp_val)
    hsp_list = torch.zeros((n_wsc, epochs))
    beta_list = torch.zeros((n_wsc, epochs))
    
    return hsp_val, beta_val, hsp_list, beta_list

In [29]:
# Weight sparsity control with Hoyer's sparsness (Layer wise)
def calc_hsp(w, beta, max_beta, beta_lr, tg_hsp):
    
    # Get value of weight
    [dim, n_nodes] = w.shape
    num_elements = dim * n_nodes
    norm_ratio = torch.norm(w.detach(), 1) / torch.norm(w.detach(), 2)

    # Calculate hoyer's sparsity level
    num = math.sqrt(num_elements) - norm_ratio
    den = math.sqrt(num_elements) - 1
    hsp = torch.tensor(num / den).to(device)

    # Update beta
    beta = beta.clone() + beta_lr * torch.sign(torch.tensor(tg_hsp).to(device) - hsp)
    
    # Trim value
    beta = 0 if beta < 0 else beta
    beta = max_beta if beta > max_beta else beta

    return [hsp, beta]

In [30]:
def calc_l1(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp):
    l1_reg = None
    layer_idx = 0
    wsc_idx = 0

    for name, param in model.named_parameters():
        if "weight" in name and "bn" not in name:
            if "ext" in name or "prd_1" in name or "dsc_1" in name:
                temp_w = param
                
                if wsc_flag[layer_idx] != 0:
                    hsp_val[wsc_idx], beta_val[wsc_idx] = calc_hsp(
                        temp_w, beta_val[wsc_idx], max_beta[wsc_idx], 
                        beta_lr[wsc_idx], tg_hsp[wsc_idx]
                    )
                    hsp_list[wsc_idx, epoch - 1] = hsp_val[wsc_idx]
                    beta_list[wsc_idx, epoch - 1] = beta_val[wsc_idx]
                    layer_reg = torch.norm(temp_w, 1) * beta_val[wsc_idx].clone()
                    wsc_idx += 1
                else:
                    layer_reg = torch.norm(temp_w, 1) * l1_param

                if l1_reg is None:
                    l1_reg = layer_reg
                else:
                    l1_reg = l1_reg + layer_reg
                layer_idx += 1
        
    return l1_reg

In [31]:
def calc_pearsonr(x, y):
    x_mean = torch.mean(x.detach())
    y_mean = torch.mean(y.detach())
    xx = x.sub(x_mean)
    yy = y.sub(y_mean)
    num = xx.dot(yy)
    den = torch.norm(xx, 2) * torch.norm(yy, 2)
    corr = num / den
    return corr.item()

In [32]:
def calc_mae(x, y):
    return torch.abs(x - y).mean().data

In [33]:
def train(model, epoch, train_loader, optimizer, criterion_prd, criterion_dsc, 
          hsp_val, beta_val, hsp_list, beta_list, tg_hsp, lambda_, X_train, y_train):
    seed_everything(seed)
    model.train()
    prd_loss = 0
    dsc_loss = 0
    dsc_acc = 0
    cost = 0
    total = 0
    correct = 0
    y_train_true = []
    y_train_pred = []
    
    for batch_idx, (input, target) in enumerate(train_loader):
        optimizer.zero_grad(set_to_none=True)
        input, target = input.to(device), target.to(device)
        output_prd, output_dsc = model(input)
        
        # 1. p-factor predictor loss
        target_prd = target[:, p_factor_idx].view(-1, 1)
        running_prd_loss = criterion_prd(output_prd, target_prd)
        l1_norm = calc_l1(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp)

        if epoch > pretrain_epoch:
            # Undersampling for scanner-generalization
            target_dsc = target[:, scanner_idx].long().view(-1)
            scnr_smp = target[:, scanner_idx].detach().cpu().numpy()
            n_minor = (scnr_smp == 0).sum()
            n_major = len(scnr_smp) - n_minor
            if n_minor != 0 and n_major != 0:
                minor_idx = np.where(scnr_smp == 0)[0]
                major_idx = np.where(scnr_smp != 0)[0]
                major_smp_idx = np.random.choice(major_idx, size=n_minor, replace=True)
                smp_idx = np.concatenate((minor_idx.astype(np.int), major_smp_idx.astype(np.int)))
                running_dsc_loss = criterion_dsc(output_dsc[smp_idx], target_dsc[smp_idx])
                dsc_loss += running_dsc_loss.detach()
            else:
                running_dsc_loss = 0
                dsc_loss += 0

            # Total Loss
            running_loss = running_dsc_loss + running_prd_loss + l1_norm.clone()
            
        else:
            running_loss = running_prd_loss + l1_norm.clone()

        cost = running_loss
        cost.backward()
        optimizer.step()
        
        prd_loss += running_prd_loss.detach()
        total += output_prd.size(0)
        true_batch = torch.flatten(target_prd.detach())
        pred_batch = torch.flatten(output_prd.detach())
        y_train_true.append(true_batch)
        y_train_pred.append(pred_batch)
        
    X_train = torch.from_numpy(X_train).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_train)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_train[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)
    
    prd_loss /= total
    dsc_loss /= total
    y_train_true = torch.flatten(torch.stack(y_train_true))
    y_train_pred = torch.flatten(torch.stack(y_train_pred))
    train_corr = calc_pearsonr(y_train_true, y_train_pred)
    train_mae = calc_mae(y_train_true, y_train_pred).detach().cpu().numpy()
    torch.cuda.empty_cache()
    
#     plot_prediction_result(epoch,train_corr,y_train_pred.detach().cpu().numpy(),
#                            y_train_true.detach().cpu().numpy())

    return prd_loss, dsc_loss, dsc_acc, train_corr, train_mae

In [34]:
def valid(model, epoch, val_loader, criterion_prd, criterion_dsc, X_val, y_val):
    seed_everything(seed)
    model.eval()
    prd_loss = 0
    total = 0
    y_val_true = []
    y_val_pred = []
    
    with torch.no_grad():
        for input, target in val_loader:
            input, target = input.to(device), target.to(device)
            output_prd, output_clf = model(input)
            target_prd = target[:, p_factor_idx].view(-1, 1)
            running_prd_loss = criterion_prd(output_prd, target_prd)
            prd_loss += running_prd_loss.detach()
            total += output_prd.size(0)
            true_batch = torch.flatten(target_prd.detach())
            pred_batch = torch.flatten(output_prd.detach())
            y_val_true.append(true_batch)
            y_val_pred.append(pred_batch)

    X_val = torch.from_numpy(X_val).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_val)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_val[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)

    y_val_true = torch.flatten(torch.stack(y_val_true))
    y_val_pred = torch.flatten(torch.stack(y_val_pred))
    val_corr = calc_pearsonr(y_val_true, y_val_pred)
    val_mae = calc_mae(y_val_true, y_val_pred).detach().cpu().numpy()
    torch.cuda.empty_cache()
    
#     plot_prediction_result(epoch,val_corr,y_val_pred.detach().cpu().numpy(),
#                            y_val_true.detach().cpu().numpy())

    return prd_loss, val_corr, val_mae, dsc_acc

In [35]:
def test(model, epoch, test_loader, criterion_prd, criterion_dsc, X_test, y_test):
    seed_everything(seed)
    model.eval()
    prd_loss = 0
    total = 0
    y_test_true = []
    y_test_pred = []
    
    with torch.no_grad():
        for input, target in test_loader:
            input, target = input.to(device), target.to(device)
            output_prd, output_clf = model(input)
            target_prd = target[:, p_factor_idx].view(-1, 1)
            running_prd_loss = criterion_prd(output_prd, target_prd)
            prd_loss += running_prd_loss.detach()
            total += output_prd.size(0)
            true_batch = torch.flatten(target_prd.detach())
            pred_batch = torch.flatten(output_prd.detach())
            y_test_true.append(true_batch)
            y_test_pred.append(pred_batch)

    X_test = torch.from_numpy(X_test).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_test)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_test[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)

    y_test_true = torch.flatten(torch.stack(y_test_true))
    y_test_pred = torch.flatten(torch.stack(y_test_pred))
    test_corr = calc_pearsonr(y_test_true, y_test_pred)
    test_mae = calc_mae(y_test_true, y_test_pred).detach().cpu().numpy()
    torch.cuda.empty_cache()
    
    plot_prediction_result(epoch,test_corr,y_test_pred.detach().cpu().numpy(),
                           y_test_true.detach().cpu().numpy())

    return prd_loss, test_corr, test_mae, dsc_acc

In [36]:
class EarlyStopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''validation loss가 감소하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [37]:
def plot_learning_curves(
    save_dir, epochs, train_loss, val_loss, test_loss,
    train_corr, val_corr, test_corr,train_acc_scd, val_acc_scd,test_acc_scd,
    lr, plot_hsp_list, plot_beta_list, tg_hsp):
    
    sns.set(style="darkgrid", font_scale=2)
    fig, ax = plt.subplots(2, 3, figsize=(20, 10))
    ax = ax.flat
    lw = 4
    last_epoch = epochs
    
    #Loss plot
    ax[0].plot(train_loss[:last_epoch], label='train loss', lw=lw, color="g")
    ax[0].tick_params(axis='y',labelcolor='g')
    ax_01 = ax[0].twinx()

    ax_01.plot(val_loss[:last_epoch], label='val loss', lw=lw, color="orange")
    ax_01.tick_params(axis='y',labelcolor='orange')
    ax_01.plot(test_loss[:last_epoch], label='test loss', lw=lw, color="r")
    ax_01.tick_params(axis='y',labelcolor='r')
    
    ax[0].legend()
    ax[0].set_title("Loss Plot", pad=10)

    #Corr plot
    ax[1].plot(train_corr[:last_epoch], label='train corr', lw=lw, color="g")
    ax[1].tick_params(axis='y',labelcolor='g')
    ax_02 = ax[1].twinx()

    ax_02.plot(val_corr[:last_epoch], label='val corr', lw=lw, color="orange")
    ax_02.tick_params(axis='y',labelcolor='orange')
    ax_02.plot(test_corr[:last_epoch], label='test corr', lw=lw, color="r")
    ax_02.tick_params(axis='y',labelcolor='r')
    ax[1].legend()
    ax[1].set_title("Correlation Plot", pad=10)

    plot_hsp_list, plot_beta_list = np.array(plot_hsp_list).T, np.array(plot_beta_list).T
    
    #hsp plot
    for idx, n_layer in enumerate(indices):
        ax[3].plot(plot_hsp_list[idx], label='layer{}'.format(n_layer), lw=lw)
        ax[4].plot(plot_beta_list[idx], 
                   label='layer{}'.format(n_layer), lw=lw)
        ax[3].legend(); ax[4].legend()
        ax[3].set_title("HSP plot [{:.3f}/{:.3f}]"
                        .format(plot_hsp_list[0, -1], tg_hsp[0][0]), pad=10)
        ax[4].set_title("Beta plot", pad=10)
    
    #Scanner acc plot
    ax[2].plot(train_acc_scd[:last_epoch], label='train acc', lw=lw, color="g")
    ax[2].plot(val_acc_scd[:last_epoch], label='val acc', lw=lw, color="orange")
    ax[2].plot(test_acc_scd[:last_epoch], label='test acc', lw=lw, color="r")
    ax[2].set_title("Acc plot - scanner")
    ax[2].legend()
    
    #Learning rate plot
    ax[5].plot(lr[:last_epoch],label='Learning rate',lw=lw,color='m')
    ax[5].set_title("Learning rate")
    
    fig.tight_layout()
    fig.savefig("{}/Learning_curves.png".format(save_dir))
    
    plt.close(fig)

In [38]:
def run_outer_fold(n_outer_cv=0, outer_save_dir=None, sel_tg_hsp=None,sel_lambda=None):
    seed_everything(seed)
    outer_cv_list = []
    
    # Outer fold
    print("\n===================================", end=" ")
    print("Outer Fold [{}/{}]".format(n_outer_cv + 1, outer_n_splits), end=" ")
    print("===================================")
    
    outer_start_fold_time = time.time()
    outer_train_idx = outer_train_folds_idx[n_outer_cv]
    outer_val_idx = outer_val_folds_idx[n_outer_cv]
    outer_test_idx = outer_test_folds_idx[n_outer_cv]

    X_train, y_train = X[outer_train_idx], y[outer_train_idx]
    X_val, y_val = X[outer_val_idx], y[outer_val_idx]
    X_test, y_test = X[outer_test_idx], y[outer_test_idx]
    
    X_train = stats.zscore(X_train, axis=1)
    X_val = stats.zscore(X_val, axis=1)
    X_test = stats.zscore(X_test, axis=1)
        
    outer_train_dataset = MakeDataset(X_train, y_train)
    outer_val_dataset = MakeDataset(X_val, y_val)
    outer_test_dataset = MakeDataset(X_test, y_test)
    
    outer_train_loader = DataLoader(
        outer_train_dataset, batch_size=batch_size, pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True)
    outer_val_loader = DataLoader(
        outer_val_dataset, batch_size=len(y_val), pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True)
    outer_test_loader = DataLoader(
        outer_test_dataset, batch_size=len(y_test), pin_memory=True,
        shuffle=True, num_workers=num_workers, drop_last=True)
        
    # Assign model 
    model = SGNN(
        ext_hidden, dsc_hidden, prd_hidden, 
        dropout_ext, dropout_prd, dropout_dsc, act_func_name, sel_lambda
    ).to(device)
    optimizer = get_optimizer(model, optimizer_name, learning_rate, l2_param)
    lr_factor = lr_alpha * sel_tg_hsp[0][0] + lr_beta
    scheduler = ReduceLROnPlateau(
        optimizer, mode=mode, patience=lr_patience, min_lr=min_lr, factor=lr_factor
    )
    cosine_scheduler = CosineAnnealingWarmRestarts(optimizer,20,eta_min = min_lr)
    early_stopping = EarlyStopping(patience=early_stopping_patience,
                                   path=outer_save_dir + "/model_fold_" + str(n_outer_cv + 1) + ".pt")
    es_switch_count = 0
    es_switch = False
    criterion_prd = nn.MSELoss()
    criterion_dsc = nn.CrossEntropyLoss() # nn.BCELoss()
              
    # list to save learning parameters
    outer_train_loss = []
    outer_val_loss=[]
    outer_test_loss = []
    outer_train_corr = []
    outer_val_corr = []
    outer_test_corr = []
    outer_train_acc = []
    outer_val_acc = []
    outer_test_acc = []
    outer_train_mae = []
    outer_val_mae = []
    outer_test_mae = []
    outer_lr = []
    outer_hsp_list = []
    outer_beta_list = []

    hsp_val, beta_val, hsp_list, beta_list = init_hsp(n_wsc, epochs)
        
    for epoch in range(1, epochs + 1):
        train_prd_loss, train_dsc_loss, train_acc, train_corr, train_mae = train(
            model, epoch, outer_train_loader, 
            optimizer, criterion_prd, criterion_dsc, 
            hsp_val, beta_val, hsp_list, beta_list, sel_tg_hsp, sel_lambda,
            X_train, y_train
        )
        val_prd_loss, val_corr, val_mae, val_acc = valid(
            model, epoch, outer_val_loader, criterion_prd, criterion_dsc, X_val, y_val
        )
        test_prd_loss, test_corr, test_mae, test_acc = test(
            model, epoch, outer_test_loader, criterion_prd, criterion_dsc, X_test, y_test
        )

        lr = optimizer.param_groups[0]['lr']
        
        outer_train_loss.append([train_prd_loss, train_dsc_loss])
        outer_train_mae.append(train_mae)
        outer_train_corr.append(train_corr)
        outer_train_acc.append(train_acc)
        outer_val_loss.append([val_prd_loss, []])
        outer_val_mae.append(val_mae)
        outer_val_corr.append(val_corr)
        outer_val_acc.append(val_acc)
        outer_test_loss.append([test_prd_loss, []])
        outer_test_mae.append(test_mae)
        outer_test_corr.append(test_corr)
        outer_test_acc.append(test_acc)
        outer_lr.append(lr)
        outer_hsp_list.append(list(hsp_val.clone()))
        outer_beta_list.append(list(beta_val.clone()))

        if epoch % print_epoch == 0:
            print("\nEpoch [{:d}/{:d}]".format(epoch, epochs), end=" ")
            print("Train corr: {:.4f}, Test corr: {:.4f}, Train loss: {:.4f}, Test loss: {:.4f}"
                  .format(train_corr, test_corr, train_prd_loss, test_prd_loss))
            for i in range(len(wsc_flag)):
                if wsc_flag[i] != 0:
                    print("Layer {:d}: [{:.4f}/{:.4f}]".
                          format( i + 1, hsp_val[i], sel_tg_hsp[i][0]), end=" ")
            print("Train acc: {:.2f}".format(train_acc))


            plot_learning_curves(
                outer_save_dir, epochs,
                np.array(outer_train_mae), np.array(outer_val_mae), np.array(outer_test_mae),  
                outer_train_corr,outer_val_corr, outer_test_corr,
                outer_train_acc,outer_val_acc,outer_test_acc,
                outer_lr, outer_hsp_list, outer_beta_list, sel_tg_hsp
            )

            
        ###Debugging code_Tracing WFM during training -HJD###
        if epoch % 5 ==0:
            visualize_wfm(model,epoch,epoch_gap = 5,threshold=False,mode='sum')
            visualize_wfm(model,epoch,epoch_gap = 5,threshold=False,mode='avg')
        ######
        if 'Full' in select_fold:
            if len(hsp_val)>0:
                if hsp_val[0] < temp_param['hsp_ext']:
                    scheduler.step(hsp_val[0])
        else:
            if len(hsp_val)>0:
                if hsp_val[0] < temp_param['hsp_ext']:
                    scheduler.step(hsp_val[0])
                if hsp_val[0] >= temp_param['hsp_ext'] and hsp_val[1] >= temp_param['hsp_prd'] and hsp_val[2] >= temp_param['hsp_dsc'] and epoch > pretrain_epoch:
                    es_switch_count +=1
                    if es_switch_count>5:
                        es_switch = True
                        cosine_scheduler.step()
                if es_switch:
                    early_stopping(val_prd_loss,model)
                    if early_stopping.early_stop:
                        print("Early stopping")
                        break
            
            
#     train_prd_loss, train_corr, train_mae, train_acc = test(
#         model, epoch, outer_train_loader, criterion_prd, criterion_dsc, X_train, y_train
#     )

    if 'Full' in select_fold:
        torch.save(model.state_dict(), 
                   outer_save_dir + "/model_fold_" + str(n_outer_cv + 1) + ".pt")
    
    torch.cuda.empty_cache()
    gc.collect()


    outer_cv_list.append([train_corr, val_corr, test_corr, train_mae, val_mae, test_mae,
                          train_acc, val_acc, test_acc])
            
    outer_cv_df = pd.DataFrame(
        np.array(outer_cv_list), 
        columns=["train_corr", "valid_corr","test_corr",
                 "train_mae", "valid_mae", "test_mae","train_acc", "valid_acc","test_acc"]
    )
    outer_cv_df.to_csv("{}/outer_cv.csv".format(outer_save_dir))

    outer_tot_time = time.time() - outer_start_fold_time
    print("\nExecution Time for Fold: {:.2f} mins".format(outer_tot_time / 60))
    
    return train_corr, test_corr, train_mae, test_mae, outer_train_acc, outer_train_corr, outer_hsp_list

## 5. Settle parameters, output path & Train model

In [39]:
param_grid = list(ParameterGrid(param_cand))

temp_param = param_grid[temp_sel_idx]

ext_hidden = temp_param["ext"]
prd_hidden = temp_param["prd"]
dsc_hidden = temp_param["dsc"]

dropout_ext = temp_param["dropout_ext"]
dropout_prd = temp_param["dropout_prd"]
dropout_dsc = temp_param["dropout_dsc"]


batch_size = temp_param["batch_size"]
learning_rate = temp_param["lr"]
epochs = temp_param["epochs"]
l2_param = temp_param["l2_param"]
pretrain_epoch = temp_param["pretrain_epoch"]

lr_patience = temp_param["lr_patience"]
# lr_factor = temp_param["lr_factor"]

In [41]:
save_folder = f"{target_name}/Hsp:[{temp_param['hsp_ext']*wsc_flag[0]},{temp_param['hsp_prd']*wsc_flag[1]},{temp_param['hsp_dsc']*wsc_flag[2]}]_Maxb:{max_beta}_Betalr:{beta_lr}_LR:[{temp_param['lr']}]_Act:[{act_func_name}]_Opt:[{optimizer_name}]_DO:[{temp_param['dropout_ext']},{temp_param['dropout_prd']},{temp_param['dropout_dsc']}]_lambda:[{temp_param['lambda_']}]_seed[{seed}]"
save_path = "/users/hjd/IG_my_study/SNUH/data/temp/"
output_folder = os.path.join(save_path, save_folder)
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

elif "Full" in select_fold:
    if "Outer_fold_Full" in os.listdir(output_folder):
        print(output_folder)
        raise Exception("Result already exist. Check directory!")
        
elif len(os.listdir(output_folder))>0:
    for fold in [f'Outer_fold_{i+1}' for i in select_fold]:
        if fold in os.listdir(output_folder):
            print(output_folder)
            raise Exception("Result already exist. Check directory!")
    
print(output_folder)

/users/hjd/IG_my_study/SNUH/data/temp/p/Hsp:[0.0,0.0,0.0]_Maxb:[0.01, 0.02, 0.02]_Betalr:[0.001, 0.002, 0.002]_LR:[5e-05]_Act:[elu]_Opt:[nag]_DO:[0.9,0.9,0.0]_lambda:[0.01]_seed[42]


In [42]:
print_epoch = 5

In [43]:
seed_everything(seed)
code_start_time = time.time()

print(output_folder)

outer_cv = []

for n_outer_cv in outer_cv_part:
    
    if "Full" in select_fold:
        pass
    elif n_outer_cv not in select_fold:
        continue
        
    print("\n===================================", end=" ")
    print("Outer Fold [{}/{}]".format(n_outer_cv + 1, outer_n_splits), end=" ")
    print("===================================")

    outer_save_dir = "{}/Outer_fold_{}".format(output_folder, n_outer_cv + 1)
    if "Full" in select_fold:
        print("Training Full subjects!")
        outer_save_dir = "{}/Outer_fold_Full".format(output_folder)
    os.makedirs(outer_save_dir, exist_ok=True)

    sel_idx = temp_sel_idx
    sel_param = param_grid[sel_idx]
    sel_hsp = []
    sel_lambda = sel_param["lambda_"]
    print("Selected param:", end=" ")
    for x in sel_param:
        if "hsp" in x: 
            print("{}: {}".format(x, sel_param[x]), end=" ")
            sel_hsp.append(sel_param[x])
    
    # Outer Fold
    hsp_cand_1 = [sel_param["hsp_ext"]]
    hsp_cand_2 = [sel_param["hsp_prd"]]
    hsp_cand_3 = [sel_param["hsp_dsc"]]

    indices = [i + 1 for i, x in enumerate(wsc_flag) if x == 1]
    hsp_cand_list = list(itertools.product(hsp_cand_1, hsp_cand_2, hsp_cand_3))
    hsp_cand_list = [list(i) for i in hsp_cand_list]
    hsp_cand = [hsp_cand_1, hsp_cand_2, hsp_cand_3]
    sel_tg_hsp = hsp_cand

    (outer_train_corr, outer_test_corr, outer_train_mae, 
    outer_test_mae, outer_train_acc, outer_test_acc, outer_hsp_list) = run_outer_fold(
        n_outer_cv, outer_save_dir, sel_tg_hsp, sel_lambda
    )
    
    outer_cv.append([sel_hsp, outer_train_corr, outer_test_corr, outer_train_mae, outer_test_mae])
    
    print("\nOuter Fold [{}/{}]: train corr: {:.4f}, valid corr: {:.4f}"
          .format(n_outer_cv + 1, outer_n_splits, outer_train_corr, outer_test_corr))
    
    ###Debugging code_Save WFM - HJD###f
    path_avg = [i for i in os.listdir(outer_save_dir+'/WFM') if 'jpg' in i and 'avg' in i]
    path_sum = [i for i in os.listdir(outer_save_dir+'/WFM') if 'jpg' in i and 'sum' in i]
#     path2 = [i for i in path if int(i.split('epoch')[1].split('.jpg')[0])<=100]
    imgs_avg = [Image.open(outer_save_dir+'/WFM/'+i) for i in path_avg]
    imgs_sum = [Image.open(outer_save_dir+'/WFM/'+i) for i in path_sum]
    imageio.mimsave(outer_save_dir+'/WM_changes_avg.gif', imgs_avg, fps=10)
    imageio.mimsave(outer_save_dir+'/WM_changes_sum.gif', imgs_sum, fps=10)

    ###Debugging code_Save scatterplot - HJD###
    path = [i for i in os.listdir(outer_save_dir+'/Prediction') if 'png' in i]
    imgs = [Image.open(outer_save_dir+'/Prediction/'+i) for i in path]
    imageio.mimsave(outer_save_dir+'/Prediction_changes.gif', imgs, fps=10)
    ######

/users/hjd/IG_my_study/SNUH/data/temp/p/Hsp:[0.0,0.0,0.0]_Maxb:[0.01, 0.02, 0.02]_Betalr:[0.001, 0.002, 0.002]_LR:[5e-05]_Act:[elu]_Opt:[nag]_DO:[0.9,0.9,0.0]_lambda:[0.01]_seed[42]

Selected param: hsp_dsc: 0.3 hsp_ext: 0.95 hsp_prd: 0.3 

Epoch [5/10] Train corr: 0.5830, Test corr: 0.1610, Train loss: 0.0065, Test loss: 0.3013
Train acc: 0.51

Epoch [10/10] Train corr: 0.8657, Test corr: 0.1504, Train loss: 0.0025, Test loss: 0.3017
Train acc: 0.51

Execution Time for Fold: 2.77 mins

Outer Fold [1/5]: train corr: 0.8657, valid corr: 0.1504


In [44]:
code_tot_time = time.time() - code_start_time 
print("Execution Time for the training: {:.2f} hours".format(code_tot_time / 60 / 60))

Execution Time for the training: 0.05 hours


In [45]:
df = pd.DataFrame(outer_cv, columns=["hsp", "train_corr", "test_corr","train_mae","test_mae"])
train_avg = np.array([x for x in df["train_corr"].values]).mean()
test_avg = np.array([x for x in df["test_corr"].values]).mean()
print("Train: {:.4f}, Test: {:.4f}".format(train_avg, test_avg))

Train: 0.8657, Test: 0.1504


In [46]:
tmp_param_grid = list(ParameterGrid(param_cand))[temp_sel_idx]
for name in tmp_param_grid:
    print(name, tmp_param_grid[name])
    
import json
with open(outer_save_dir+'/params.json', 'w') as fp:
    json.dump(tmp_param_grid, fp)

batch_size 32
dropout_dsc 0.0
dropout_ext 0.9
dropout_prd 0.9
dsc 1024
epochs 10
ext 1024
hsp_dsc 0.3
hsp_ext 0.95
hsp_prd 0.3
l2_param 0.05
lambda_ 0.01
lr 5e-05
lr_factor 0.5
lr_patience 5
prd 1024
pretrain_epoch 20


In [47]:
df['lr'] = learning_rate
df['beta_lr'] = [beta_lr]*len(df)
df['max_beta'] = [max_beta]*len(df)
df['l2_param'] = l2_param
df['batch_size'] = batch_size
df['act_func'] = act_func_name
df['optimizer'] = optimizer_name
df['momentum'] = momentum
if 'Full' in select_fold:
    df['train_idx'] = [outer_train_folds_idx[0]]
    df['test_idx'] = [outer_test_folds_idx[0]]
    df.to_csv(outer_save_dir+f"/result_df_{select_fold}.csv", sep='\t',index=None)
else:
    df['train_idx'] = [outer_train_folds_idx[i] for i in select_fold]
    df['test_idx'] = [outer_test_folds_idx[i] for i in select_fold]
    df.to_csv(outer_save_dir+f"/result_df_{np.array(select_fold)+1}.csv", sep='\t',index=None)