In [1]:
## environ

import os
import logging
import colorlog
from logging import handlers
import yaml

'''
LOG_CONFIG = {
    'name' : 'default_log',
    'handlers': ['console' , 'file'],
    'level': 'DEBUG',
    'datefmt' : '%y-%m-%d %H:%M:%S',
    # 处理器集合
    'console': {
        'level': 'INFO',  # 输出信息的最低级别
        'class': 'logging.StreamHandler',
        'param' : {},
        'formatter_class' : '_LevelColorFormatter', 
        # 'colorlog.ColoredFormatter' , '_LevelFormatter' , 'logging.Formatter'
        'formatter': 'levelcolor',  # 'color' , 'level' , 'standard'
    },
    # 输出到文件
    'file': {
        'level': 'DEBUG',
        'class': 'logging.handlers.TimedRotatingFileHandler',
        'param' : {
            'filename' : './logs/nn_fac_log.log',
            'when' : 'D',
            'backupCount': 5,  # 备份份数
            'encoding': 'utf-8',  # 文件编码
        },
        'formatter_class' : '_LevelFormatter',
        'formatter': 'level', 
    },
    # 日志格式集合
    'formatters': {
        # 标准输出格式 , omit part : 'TRD:%(threadName)-10s|LVL:%(levelno)s|'
        'standard': {
            'fmt': '%(asctime)s|MOD:%(module)-12s|: %(message)s',
        },
        'level' : {
            'fmt': '%(asctime)s|MOD:%(module)-12s|: %(message)s',
            'level_fmts' : {
                'DEBUG' : '%(message)s',
                'INFO' : '%(message)s',
            },
        },
        'color' : {
            'fmt': '%(log_color)s%(asctime)s|MOD:%(module)-12s|%(reset_log_color)s: %(message_log_color)s%(message)s',
            'log_colors' : {
                'DEBUG':'bold,white,bg_cyan',
                'INFO':'bold,white,bg_green',
                'WARNING':'bold,white,bg_blue',
                'ERROR':'bold,white,bg_purple',
                'CRITICAL':'bold,white,bg_red',
            },
            'secondary_log_colors' : {
                'reset': {
                    'DEBUG':'reset',
                    'INFO':'reset',
                    'WARNING':'reset',
                    'ERROR':'reset',
                    'CRITICAL':'reset',
                },
                'message': {
                    'DEBUG':'cyan',
                    'INFO':'green',
                    'WARNING':'bold,blue',
                    'ERROR':'bold,purple',
                    'CRITICAL':'bold,red',
                },
            },
        },
        'levelcolor' : {
            #'fmt': '%(log_color)s%(asctime)s|MOD:%(module)-12s|TRD:%(threadName)-12s|LVL:%(levelno)s|%(reset_log_color)s: %(message_log_color)s%(message)s',
            'fmt': '%(log_color)s%(asctime)s|MOD:%(module)-12s|%(reset_log_color)s: %(message_log_color)s%(message)s',
            'level_fmts' : {
                'DEBUG' : '%(message_log_color)s%(message)s',
                'INFO' : '%(message_log_color)s%(message)s',
            },
            'log_colors' : {
                'DEBUG':'bold,white,bg_cyan',
                'INFO':'bold,white,bg_green',
                'WARNING':'bold,white,bg_blue',
                'ERROR':'bold,white,bg_purple',
                'CRITICAL':'bold,white,bg_red',
            },
            'secondary_log_colors' : {
                'reset': {
                    'DEBUG':'reset',
                    'INFO':'reset',
                    'WARNING':'reset',
                    'ERROR':'reset',
                    'CRITICAL':'reset',
                },
                'message': {
                    'DEBUG':'cyan',
                    'INFO':'green',
                    'WARNING':'bold,blue',
                    'ERROR':'bold,purple',
                    'CRITICAL':'bold,red',
                },
            },
        },
    },
}
'''

class _LevelFormatter(logging.Formatter):
    def __init__(self, fmt=None, datefmt=None, level_fmts={}):
        self._level_formatters = {}
        for level, format in level_fmts.items():
            # Could optionally support level names too
            self._level_formatters[getattr(logging , level)] = logging.Formatter(fmt=format, datefmt=datefmt)
        # self._fmt will be the default format
        super(_LevelFormatter, self).__init__(fmt=fmt, datefmt=datefmt)

    def format(self, record):
        if record.levelno in self._level_formatters:
            return self._level_formatters[record.levelno].format(record)
        return super(_LevelFormatter, self).format(record)
    
class _LevelColorFormatter(colorlog.ColoredFormatter):
    def __init__(self, fmt=None, datefmt=None, log_colors={},level_fmts={},secondary_log_colors={}):
        self._level_formatters = {}
        for level, format in level_fmts.items():
            # Could optionally support level names too
            self._level_formatters[getattr(logging , level)] = colorlog.ColoredFormatter(fmt=format, datefmt=datefmt , log_colors=log_colors , secondary_log_colors=secondary_log_colors)
        # self._fmt will be the default format
        super(_LevelColorFormatter, self).__init__(fmt=fmt, datefmt=datefmt,log_colors=log_colors,secondary_log_colors=secondary_log_colors)

    def format(self, record):
        if record.levelno in self._level_formatters:
            return self._level_formatters[record.levelno].format(record)
        return super(_LevelColorFormatter, self).format(record)

    
def get_logger(test_output = False):
    config_logger = get_config('logger')
    os.makedirs(os.path.dirname(config_logger['file']['param']['filename']), exist_ok = True)
    log = logging.getLogger(config_logger['name'])
    exec("log.setLevel(logging."+config_logger['level']+")")

    while log.handlers:
        log.removeHandler(log.handlers[-1])

    for hdname in config_logger['handlers']:
        exec(hdname+"_hdargs=config_logger[hdname]['param']")
        exec(hdname+"_handler="+config_logger[hdname]['class']+"(**"+hdname+"_hdargs)")
        exec(hdname+"_fmtargs=config_logger['formatters'][config_logger[hdname]['formatter']]")
        exec(hdname+"_formatter="+config_logger[hdname]['formatter_class']+"(datefmt=config_logger['datefmt'],**"+hdname+"_fmtargs)")
        exec(hdname+"_handler.setLevel(logging."+config_logger[hdname]['level']+")")
        exec(hdname+"_handler.setFormatter("+hdname+"_formatter)")
        exec("log.addHandler("+hdname+"_handler)")
    
    if test_output:
        log.debug('This is the DEBUG    message...')
        log.info('This is the INFO     message...')
        log.warning('This is the WARNING  message...')
        log.error('This is the ERROR    message...')
        log.critical('This is the CRITICAL message...')
    return log

def get_config(config_files = ['data_type' , 'train']):
    config_dict = dict()
    if isinstance(config_files , str): config_files = [config_files]
    for cfg_name in config_files:
        with open(f'./configs/config_{cfg_name}.yaml' ,'r') as f:
            cfg = yaml.load(f , Loader = yaml.FullLoader)
        if cfg_name == 'train':
            if 'SPECIAL_CONFIG' in cfg.keys() and 'SHORTTEST' in cfg['SPECIAL_CONFIG'].keys(): 
                if cfg['SHORTTEST']: cfg.update(cfg['SPECIAL_CONFIG']['SHORTTEST'])
                del cfg['SPECIAL_CONFIG']['SHORTTEST']
            if 'SPECIAL_CONFIG' in cfg.keys() and 'TRANSFORMER' in cfg['SPECIAL_CONFIG'].keys():
                if cfg['MODEL_MODULE'] == 'Transformer' or (cfg['MODEL_MODULE'] in ['GeneralRNN'] and 'transformer' in cfg['MODEL_PARAM']['type_rnn']):
                    cfg['TRAIN_PARAM']['trainer'].update(cfg['SPECIAL_CONFIG']['TRANSFORMER']['trainer'])
                del cfg['SPECIAL_CONFIG']['TRANSFORMER']
        config_dict.update(cfg)
    return config_dict
## function

import numpy as np
import pandas as pd
import torch
import time , os , shutil , pprint , psutil
from scipy import stats
from pytimedinput import timedInput

def emphasize_header(header=''):
    print('{: ^100}'.format(''))
    print('{:*^100}'.format(''))
    print('{:*^100}'.format('    '+header+'    '))
    print('{:*^100}'.format(''))
    print('{: ^100}'.format(''))
        
def tensor_nancount(x, dim=None, keepdim=False):  
    return (1-x.isnan().int()).sum(dim = dim , keepdim = keepdim)

def tensor_nanmean(x, dim=None, keepdim=False):  
    try:
        return x.nanmean(dim = dim , keepdim = keepdim)
    except:
        return x.nansum(dim = dim , keepdim = keepdim) / tensor_nancount(x , dim = dim , keepdim = keepdim)

def tensor_nanstd(x, dim=None, correction=1 , keepdim=False):
    if dim is None:
        return torch.tensor(np.nanstd(x.flatten()))
    nancount = tensor_nancount(x , dim = dim , keepdim = True) - correction
    return ((x - tensor_nanmean(x , dim = dim , keepdim = True)).square() / nancount).nansum(dim = dim , keepdim = keepdim).sqrt()

def tensor_standardize_and_weight(x, dim=None):
    if x.isnan().all().item():
        return x , x       
    x = (x - tensor_nanmean(x,dim=dim,keepdim=True)) / (tensor_nanstd(x,dim=dim,correction=0,keepdim=True) + 1e-4)
    w = torch.ones_like(x)
    try: 
        w[x >= x.nanmedian(dim = dim , keepdim = True)[0]] = 2
    except:    
        w[x >= x.nanmedian()] = 2
    return x, w

def standardize_x(x , dim=None):
    if np.all(np.isnan(x)):
        pass
    elif dim is None or len(x.shape) == 1:
        x = (x - np.nanmean(x)) / (np.nanstd(x) + 1e-4)
    else:
        tran_dim = np.arange(len(x.shape))
        tran_dim[0],tran_dim[dim] = dim,0
        y = x.transpose(*tran_dim).reshape(x.shape[dim],-1) * 1.
        for i in range(y.shape[-1]):
            y[:,i] = standardize_x(y[:,i])
        x = y.reshape(*[x.shape[j] for j in tran_dim]).transpose(*tran_dim)
    return x

def standardize_and_weight(x , dim=None):
    if np.all(np.isnan(x)):
        pass
    elif dim is None or len(x.shape) == 1:
        x = (x - np.nanmean(x)) / (np.nanstd(x) + 1e-4)
        w = np.ones_like(x)
        w[x >= np.nanmedian(x)] = 2.
    else:
        tran_dim = np.arange(len(x.shape))
        tran_dim[0],tran_dim[dim] = dim,0
        y = x.transpose(*tran_dim).reshape(x.shape[dim],-1) * 1.
        w = np.ones_like(y)
        for i in range(y.shape[-1]):
            _x , _w = standardize_and_weight(y[:,i])
            y[:,i] , w[:,i] = _x , _w
        x = y.reshape(*[x.shape[j] for j in tran_dim]).transpose(*tran_dim)
        w = w.reshape(*[x.shape[j] for j in tran_dim]).transpose(*tran_dim)
    return x , w

def multi_bin_label(x , n = 10):
    y , w = np.zeros_like(x) , np.zeros_like(x)
    for i in range(n):
        low , high = np.quantile(x, i/n) , np.quantile(x, (i+1)/n)
        if i == n-1:
            y[(x >= low)] = 2 * i - n + 1
        elif i == 0:
            y[(x < high)] = 2 * i - n + 1
        else:
            y[(x >= low) & (x < high)] = 2 * i - n + 1
    w[:] = np.abs(y)
    return y, w


def bin_label(x):
    y , w = np.zeros_like(x) , np.zeros_like(x)
    y[x >= np.nanmedian(x)] = 1
    w[:] = y + 1
    return y, w

def tensor_rank(x):    
    assert x.dim() == 1 , x.dim()
    return torch.zeros_like(x).index_copy_(0,x.argsort(),torch.arange(0.,len(x)))
def rank_weight(x):    
    r = tensor_rank(x)
    w = torch.pow(0.5,((r.numel() - 1 - r) * 2 / (r.numel() - 1)))
    return w / w.sum()
def nd_rank(x , dim = None):
    if dim is None:
        w = tensor_rank(x.flatten()).reshape(x.shape)
    else:
        w = torch.zeros_like(x).copy_(x).transpose(-1 , dim)
        new_shape = w.shape
        w = w.reshape(-1 , new_shape[-1])
        for i in range(len(w)):
            w[i] = tensor_rank(w[i])
        w = w.reshape(*new_shape).transpose(-1 , dim)   
    return w
def nd_rank_weight(x , dim = None):
    if dim is None:
        w = rank_weight(x.flatten()).reshape(x.shape)
    else:
        w = torch.zeros_like(x).copy_(x).transpose(-1 , dim)
        new_shape = w.shape
        w = w.reshape(-1 , new_shape[-1])
        for i in range(len(w)):
            w[i] = rank_weight(w[i])
        w = w.reshape(*new_shape).transpose(-1 , dim)   
    return w 
def nd_minus_mean(x , w , dim = None):
    return x - (w * x).mean(dim=dim,keepdim=True)

def pearson(x, y , w = None, dim = None , **kwargs):
    w = 1. if w is None else w / w.sum(dim=dim,keepdim=True) * (w.numel() if dim is None else w.size(dim=dim))
    x1 , y1 = nd_minus_mean(x , w , dim) , nd_minus_mean(y , w , dim)
    return (w * x1 * y1).mean(dim = dim) / ((w * x1.square()).mean(dim=dim).sqrt() + 1e-4) / ((w * y1.square()).mean(dim=dim).sqrt() + 1e-4)
    
def ccc(x , y , w = None, dim = None , **kwargs):
    w = 1. if w is None else w / w.sum(dim=dim,keepdim=True) * (w.numel() if dim is None else w.size(dim=dim))
    x1 , y1 = nd_minus_mean(x , w , dim) , nd_minus_mean(y , w , dim)
    cov_xy = (w * x1 * y1).mean(dim=dim)
    mse_xy = (w * (x1 - y1).square()).mean(dim=dim)
    return (2 * cov_xy) / (mse_xy + 2 * cov_xy + 1e-4)

def mse(x , y , w = None, dim = None , reduction='mean' , **kwargs):
    w = 1. if w is None else w / w.sum(dim=dim,keepdim=True) * (w.numel() if dim is None else w.size(dim=dim))
    f = torch.mean if reduction == 'mean' else torch.sum
    return f(w * (x - y).square() , dim=dim)

def spearman(x , y , w = None , dim = None , **kwargs):
    x , y = nd_rank(x , dim = dim) , nd_rank(y , dim = dim)
    return pearson(x , y , w , dim , **kwargs)

def wpearson(x, y , dim = None , **kwargs):
    w = nd_rank_weight(y , dim = dim)
    return pearson(x,y,w,dim)

def wccc(x , y , dim = None , **kwargs):
    w = nd_rank_weight(y , dim = dim)
    return ccc(x,y,w,dim)

def wmse(x , y , dim = None , reduction='mean' , **kwargs):
    w = nd_rank_weight(y , dim = dim)
    return mse(x,y,w,dim,reduction)

def wspearman(x , y , dim = None , **kwargs):
    w = nd_rank_weight(y , dim = dim)
    return spearman(x,y,w,dim)

def np_rankic(x , y , w = None , dim = None):
    return stats.spearmanr(x,y)[0]

def transpose_qkv(X,num_heads):
    X = X.reshape(X.shape[0],X.shape[1],num_heads,-1)
    X = X.permute(0,2,1,3)
    return X.reshape(-1,X.shape[2],X.shape[3])

def transpose_output(X,num_heads):
    X = X.reshape(-1,num_heads,X.shape[1],X.shape[2])
    X = X.permute(0,2,1,3)
    return X.reshape(X.shape[0],X.shape[1],-1)

def np_nanrankic(x , y):
    assert len(x) == len(y)
    pairwise_nonnan = (np.isnan(x)*1.0 + np.isnan(y) * 1.0 == 0)
    try:
        return np_rankic(x[pairwise_nonnan],y[pairwise_nonnan])
    except:
        return np.nan

def np_nanrankic_2d(x , y , dim = 0):
    assert type(x) == type(y)
    assert x.shape == y.shape
    if dim == 0:
        return [np_nanrankic(x[:,i],y[:,i]) for i in range(x.shape[1])]
    else:
        return [np_nanrankic(x[i,:],y[i,:]) for i in range(x.shape[0])]
    
def ask_for_confirmation(prompt ='' , timeout = 10 , recurrent = 1 , proceed_condition = lambda x:True , print_function = print):
    assert isinstance(prompt , str)
    userText_list , userText_cond = [] , []
    for t in range(recurrent):
        if t == 0:
            _prompt = prompt 
        elif t == 1:
            _prompt = 'Really?'
        else:
            _prompt = 'Really again?'
            
        userText, timedOut = None , None
        if timeout > 0:
            try:
                userText, timedOut = timedInput(f'{_prompt} (in {timeout} seconds): ' , timeout = timeout)
            except:
                pass
        if userText is None : 
            userText, timedOut = input(f'{_prompt} : ') , False
        (_timeout , _sofar) = ('Time Out! ' , 'so far') if timedOut else ('' , '')
        print_function(f'{_timeout}User-input {_sofar} is : [{userText}].')
        userText_list.append(userText)
        userText_cond.append(proceed_condition(userText))
        if not userText_cond[-1]: 
            break
    return userText_list , userText_cond

def total_memory(unit = 1e9):
    return psutil.Process(os.getpid()).memory_info().rss / unit

def match_values(arr , values , ambiguous = 0):
    sorter = np.argsort(arr)
    index = np.tile(len(arr) , values.shape)
    if ambiguous == 0:
        index[np.isin(values , arr)] = sorter[np.searchsorted(arr, values[np.isin(values , arr)], sorter=sorter)]
    else:
        index[values <= max(arr)] = sorter[np.searchsorted(arr, values[values <= max(arr)], sorter=sorter)]
    return index

def merge_data_2d(data_tuple , row_tuple , col_tuple , row_all = None , col_all = None):
    if all([not isinstance(inp,tuple) for inp in (data_tuple , row_tuple , col_tuple)]):
        return data_tuple , row_tuple , col_tuple
    elif not all([isinstance(inp,tuple) for inp in (data_tuple , row_tuple , col_tuple)]):
        raise Exception(f'Not All of data_tuple , row_tuple , col_tuple are tuple instance!')
    
    assert len(data_tuple) == len(row_tuple) == len(col_tuple)
    for i in range(len(data_tuple)):
        #print(i , data_tuple[i].shape , (len(row_tuple[i]) , len(col_tuple[i])))
        assert data_tuple[i].shape == (len(row_tuple[i]) , len(col_tuple[i]))
    
    row_all = sorted(list(set().union(*row_tuple))) if row_all is None else row_all
    row_index = [[list(row_all).index(r) for r in row_i] for row_i in row_tuple]
    
    col_all = sorted(list(set().union(*col_tuple))) if col_all is None else col_all
    col_index = [[list(col_all).index(c) for c in col_i] for col_i in col_tuple]
    
    data_all = np.full((len(row_all) , len(col_all)) , np.nan)
    for i , data in enumerate(data_tuple):
        data_all[np.repeat(row_index[i],len(col_index[i])),np.tile(col_index[i],len(row_index[i]))] = data[:].flatten()
    return data_all , row_all , col_all

def rmdir(d , remake_dir = False):
    """
    Remove list/instance of dirs , and remake the dir if remake_dir = True
    """
    if isinstance(d , (list,tuple)):
        [shutil.rmtree(x) for x in d if os.path.exists(x)]
        if remake_dir : [os.makedirs(x , exist_ok = True) for x in d]
    elif isinstance(d , str):
        if os.path.exists(d): shutil.rmtree(d)
        if remake_dir : os.mkdir(d)
    else:
        raise Exception(f'KeyError : {str(d)}')
        
def list_converge(l , n = None , eps = None):
    """
    Last n element of l has range smaller than eps
    """
    n = len(l) if n is None else n
    eps = 0 if eps is None else eps
    return len(l) >= n and (max(l[-n:]) - min(l[-n:])) < eps   

def pretty_print_dict(dictionary , width = 140 , sort_dicts = False):
    pprint.pprint(dictionary, indent = 1, width = width , sort_dicts = sort_dicts)
## my_utils
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
import gc
from torch.utils.data.dataset import IterableDataset , Dataset
from mpl_toolkits import mplot3d
from copy import deepcopy 

class lr_cosine_scheduler:
    def __init__(self , optimizer , warmup_stage = 10 , anneal_stage = 40 , initial_lr_div = 10 , final_lr_div = 1e4):
        self.warmup_stage= warmup_stage
        self.anneal_stage= anneal_stage
        self.base_lrs = [x['lr'] for x in optimizer.param_groups]
        self.initial_lr= [x / initial_lr_div for x in self.base_lrs]
        self.final_lr= [x / final_lr_div for x in self.base_lrs]
        self.last_epoch = 0
        self._step_count= 1
        self._linear_phase = self._step_count / self.warmup_stage
        self._cos_phase = math.pi / 2 * (self._step_count - self.warmup_stage) / self.anneal_stage
        self._last_lr= self.initial_lr
        
    def get_last_lr(self):
        #Return last computed learning rate by current scheduler.
        return self._last_lr

    def state_dict(self):
        #Returns the state of the scheduler as a dict.
        return self.__dict__
    
    def step(self):
        self.last_epoch += 1
        if self._step_count <= self.warmup_stage:
            self._last_lr = [y+(x-y)*self._linear_phase for x,y in zip(self.base_lrs,self.initial_lr)]
        elif self._step_count <= self.warmup_stage + self.anneal_stage:
            self._last_lr = [y+(x-y)*math.cos(self._cos_phase) for x,y in zip(self.base_lrs,self.final_lr)]
        else:
            self._last_lr = self.final_lr
        for x , param_group in zip(self._last_lr,self.optimizer.param_groups):
            param_group['lr'] = x
        self._step_count += 1
        self._linear_phase = self._step_count / self.warmup_stage
        self._cos_phase = math.pi / 2 * (self._step_count - self.warmup_stage) / self.anneal_stage
        
class Mydataset(Dataset):
    def __init__(self, data1 , label) -> None:
            super().__init__()
            self.data1 = data1
            self.label = label
    def __len__(self):
        return len(self.data1)
    def __getitem__(self , ii):
        return self.data1[ii], self.label[ii]

class MyIterdataset(IterableDataset):
    def __init__(self, data1 , label) -> None:
            super().__init__()
            self.data1 = data1
            self.label = label
    def __len__(self):
        return len(self.data1)
    def __iter__(self):
        for ii in range(len(self.data1)):
            yield self.data1[ii], self.label[ii]
            
class Mydataloader_basic:
    def __init__(self, x_set , y_set , batch_size = 1, num_worker = 0, set_name = '', batch_num = None):
        self.dataset = Mydataset(x_set, y_set)
        self.batch_size = batch_size
        self.num_worker = num_worker
        self.dataloader = torch.utils.data.DataLoader(self.dataset , batch_size = batch_size , num_workers = num_worker)
        self.set_name = set_name
        self.batch_num = math.ceil(len(y_set)/batch_size)
    def __iter__(self):
        for d in self.dataloader: 
            yield d

class Mydataloader_saved:
    def __init__(self, set_name , batch_num , batch_folder):
        self.set_name = set_name
        self.batch_num = batch_num
        self.batch_folder = batch_folder
        self.batch_path = [f'{self.batch_folder}/{self.set_name}.{ii}.pt' for ii in range(self.batch_num)]
    def __iter__(self):
        for ii in range(self.batch_num): 
            yield torch.load(self.batch_path[ii])
                
class DesireBatchSampler(torch.utils.data.Sampler):
    def __init__(self, sampler , batch_size_list , drop_res = True):
        self.sampler = sampler
        self.batch_size_list = np.array(batch_size_list).astype(int)
        assert (self.batch_size_list >= 0).all()
        self.drop_res = drop_res
        
    def __iter__(self):
        if (not self.drop_res) and (sum(self.batch_size_list) < len(self.sampler)):
            new_list = np.append(self.batch_size_list , len(self.sampler) - sum(self.batch_size_list))
        else:
            new_list = self.batch_size_list
        
        batch_count , sample_idx = 0 , 0
        while batch_count < len(new_list):
            if new_list[batch_count] > 0:
                batch = [0] * new_list[batch_count]
                idx_in_batch = 0
                while True:
                    batch[idx_in_batch] = self.sampler[sample_idx]
                    idx_in_batch += 1
                    sample_idx +=1
                    if idx_in_batch == new_list[batch_count]:
                        yield batch
                        break
            batch_count += 1
        if idx_in_batch > 0:
            yield batch[:idx_in_batch]

    def __len__(self):
        if self.batch_size_list.sum() < len(self.sampler):
            return len(self.batch_size_list) + 1 - self.drop_res
        else:
            return np.where(self.batch_size_list.cumsum() >= len(self.sampler))[0][0] + 1
        
class multiloss_calculator:
    def __init__(self , multi_type = None):
        """
        example:
            import torch
            import numpy as np
            import matplotlib.pyplot as plt
            
            ml = multiloss(2)
            ml.view_plot(2 , 'dwa')
            ml.view_plot(2 , 'ruw')
            ml.view_plot(2 , 'gls')
            ml.view_plot(2 , 'rws')
        """
        self.multi_type = multi_type
        
    def reset_multi_type(self, num_task , **kwargs):
        self.num_task   = num_task
        self.multi_class = self.multi_class_dict()[self.multi_type](num_task , **kwargs)
    
    def calculate_multi_loss(self , losses , mt_param , **kwargs):
        return self.multi_class.forward(losses , mt_param)
    
    """
    def reset_loss_function(self, loss_type):
        self.loss_type = loss_type
        if isinstance(self.loss_type , (list,tuple)):
            # various loss function version, tasks can be of the same output but different loss function
            self.loss_functions = [loss_function(k) for k in self.loss_type]
        else:
            self.loss_functions = [loss_function(self.loss_type) for _ in range(self.num_task)]
            
    def losses(self , y , x , **kwargs):
        return torch.tensor([f(y[i] , x[i] , **kwargs) for i,f in enumerate(self.loss_functions)])
    
    def calculate_multi_loss(self , y , x , **kwargs):
        sub_losses = self.losses(y , x , **kwargs)
        multi_loss = self.multi_class.total_loss(sub_losses)
        return multi_loss , sub_losses
    """
    def multi_class_dict(self):
        return {
            'ewa':self.EWA,
            'hybrid':self.Hybrid,
            'dwa':self.DWA,
            'ruw':self.RUW,
            'gls':self.GLS,
            'rws':self.RWS,
        }
    
    class _base_class():
        """
        base class of multi_class class
        """
        def __init__(self , num_task , **kwargs):
            self.num_task = num_task
            self.record_num = 0 
            self.record_losses = []
            self.record_weight = []
            self.record_penalty = []
            self.kwargs = kwargs
            self.reset(**kwargs)
        def reset(self , **kwargs):
            pass
        def record(self , losses , weight , penalty):
            self.record_num += 1
            self.record_losses.append(losses.detach() if isinstance(losses,torch.Tensor) else losses)
            self.record_weight.append(weight.detach() if isinstance(weight,torch.Tensor) else weight)
            self.record_penalty.append(penalty.detach() if isinstance(penalty,torch.Tensor) else penalty)
        def forward(self , losses , mt_param , **kwargs):
            weight , penalty = self.weight(losses , mt_param) , self.penalty(losses , mt_param)
            self.record(losses , weight , penalty)
            return self.total_loss(losses , weight , penalty)
        def weight(self , losses , mt_param):
            return torch.ones_like(losses)
        def penalty(self , losses , mt_param): 
            return 0.
        def total_loss(self , losses , weight , penalty):
            return (losses * weight).sum() + penalty
    
    class EWA(_base_class):
        """
        Equal weight average
        """
        def __init__(self , num_task , **kwargs):
            super().__init__(num_task , **kwargs)
    
    class Hybrid(_base_class):
        """
        Hybrid of DWA and RUW
        """
        def __init__(self , num_task , **kwargs):
            super().__init__(num_task , **kwargs)
        def reset(self , **kwargs):
            self.tau = kwargs['tau']
            self.phi = kwargs['phi']
        def weight(self , losses , mt_param):
            if self.record_num < 2:
                weight = torch.ones_like(losses)
            else:
                weight = (self.record_losses[-1] / self.record_losses[-2] / self.tau).exp()
                weight = weight / weight.sum() * weight.numel()
            return weight + 1 / mt_param['alpha'].square()
        def penalty(self , losses , mt_param): 
            penalty = (mt_param['alpha'].log().square()+1).log().sum()
            if self.phi is not None: 
                penalty = penalty + (self.phi - mt_param['alpha'].log().abs().sum()).abs()
            return penalty
    
    class DWA(_base_class):
        """
        dynamic weight average
        https://arxiv.org/pdf/1803.10704.pdf
        https://github.com/lorenmt/mtan/tree/master/im2im_pred
        """
        def __init__(self , num_task , **kwargs):
            super().__init__(num_task , **kwargs)
        def reset(self , **kwargs):
            self.tau = kwargs['tau']
        def weight(self , losses , mt_param):
            if self.record_num < 2:
                weight = torch.ones_like(losses)
            else:
                weight = (self.record_losses[-1] / self.record_losses[-2] / self.tau).exp()
                weight = weight / weight.sum() * weight.numel()
            return weight
        
    class RUW(_base_class):
        """
        Revised Uncertainty Weighting (RUW) Loss
        https://arxiv.org/pdf/2206.11049v2.pdf (RUW + DWA)
        """
        def __init__(self , num_task , **kwargs):
            super().__init__(num_task , **kwargs)
        def reset(self , **kwargs):
            self.phi = kwargs['phi']
        def weight(self , losses , mt_param):
            return 1 / mt_param['alpha'].square()
        def penalty(self , losses , mt_param): 
            penalty = (mt_param['alpha'].log().square()+1).log().sum()
            if self.phi is not None: 
                penalty = penalty + (self.phi - mt_param['alpha'].log().abs().sum()).abs()
            return penalty

    class GLS(_base_class):
        """
        geometric loss strategy , Chennupati etc.(2019)
        """
        def __init__(self , num_task , **kwargs):
            super().__init__(num_task , **kwargs)
        def total_loss(self , losses , weight , penalty):
            return losses.pow(weight).prod().pow(1/weight.sum()) + penalty
    
    class RWS(_base_class):
        """
        random weight loss, RW , Lin etc.(2021)
        https://arxiv.org/pdf/2111.10603.pdf
        """
        def __init__(self , num_task , **kwargs):
            super().__init__(num_task , **kwargs)
        def weight(self , losses , mt_param): 
            return torch.nn.functional.softmax(torch.rand_like(losses),-1)

    def view_plot(self , multi_type = 'ruw'):
        num_task = 2
        if multi_type == 'ruw':
            if num_task > 2 : num_task = 2
            x,y = torch.rand(100,num_task),torch.rand(100,1)
            alpha = torch.tensor(np.repeat(np.linspace(0.2, 10, 40),num_task).reshape(-1,num_task))
            fig,ax = plt.figure(),plt.axes(projection='3d')
            s1, s2 = np.meshgrid(alpha[:,0].numpy(), alpha[:,1].numpy())
            l = torch.stack([torch.stack([self.RUW(y,x,[s1[i,j],s2[i,j]])[0] for j in range(s1.shape[1])]) for i in range(s1.shape[0])]).numpy()
            ax.plot_surface(s1, s2, l, cmap='viridis')
            ax.set_xlabel('alpha-1')
            ax.set_ylabel('alpha-2')
            ax.set_zlabel('loss')
            ax.set_title(f'RUW Loss vs alpha ({num_task}-D)')
        elif multi_type == 'gls':
            ls = torch.tensor(np.repeat(np.linspace(0.2, 10, 40),num_task).reshape(-1,num_task))
            fig,ax = plt.figure(),plt.axes(projection='3d')
            s1, s2 = np.meshgrid(ls[:,0].numpy(), ls[:,1].numpy())
            l = torch.stack([torch.stack([torch.tensor([s1[i,j],s2[i,j]]).prod().sqrt() for j in range(s1.shape[1])]) for i in range(s1.shape[0])]).numpy()
            ax.plot_surface(s1, s2, l, cmap='viridis')
            ax.set_xlabel('loss-1')
            ax.set_ylabel('loss-2')
            ax.set_zlabel('gls_loss')
            ax.set_title(f'GLS Loss vs sub-Loss ({num_task}-D)')
        elif multi_type == 'rws':
            ls = torch.tensor(np.repeat(np.linspace(0.2, 10, 40),num_task).reshape(-1,num_task))
            fig,ax = plt.figure(),plt.axes(projection='3d')
            s1, s2 = np.meshgrid(ls[:,0].numpy(), ls[:,1].numpy())
            l = torch.stack([torch.stack([(torch.tensor([s1[i,j],s2[i,j]])*torch.nn.functional.softmax(torch.rand(ntask),-1)).sum() for j in range(s1.shape[1])]) 
                             for i in range(s1.shape[0])]).numpy()
            ax.plot_surface(s1, s2, l, cmap='viridis')
            ax.set_xlabel('loss-1')
            ax.set_ylabel('loss-2')
            ax.set_zlabel('rws_loss')
            ax.set_title(f'RWS Loss vs sub-Loss ({num_task}-D)')
        elif multi_type == 'dwa':
            ntask = 2
            nepoch = 100
            s = np.arange(nepoch)
            l1 = 1 / (4+s) + 0.1 + np.random.rand(nepoch) *0.05
            l2 = 1 / (4+2*s) + 0.15 + np.random.rand(nepoch) *0.03
            tau = 2
            w1 = np.exp(np.concatenate((np.array([1,1]),l1[2:]/l1[1:-1]))/tau)
            w2 = np.exp(np.concatenate((np.array([1,1]),l2[2:]/l1[1:-1]))/tau)
            w1 , w2 = ntask * w1 / (w1+w2) , ntask * w2 / (w1+w2)
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
            ax1.plot(s, l1, color='blue', label='task1')
            ax1.plot(s, l2, color='red', label='task2')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Loss for Epoch')
            ax1.legend()
            ax2.plot(s, w1, color='blue', label='task1')
            ax2.plot(s, w2, color='red', label='task2')
            ax1.set_xlabel('Epoch')
            ax2.set_ylabel('Weight')
            ax2.set_title('Weight for Epoch')
            ax2.legend()
        else:
            print(f'Unknow multi_type : {multi_type}')
            
        plt.show()
        
class versatile_storage():
    def __init__(self , default = 'disk'):
        assert default in ['disk' , 'mem']
        self.default = default
        self.mem_disk = dict()
        self.file_record = list()
        self.file_group = dict()
    
    def save(self , obj , paths , to_disk = False , group = 'default'):
        for p in self._pathlist(paths): 
            self._saveone(obj , p , self.default == 'disk' or to_disk)
            self._addrecord(p , group)
            
    def load(self , path , from_disk = False):
        return torch.load(path) if self.default == 'disk' or from_disk else self.mem_disk[path]

    def _pathlist(self , p):
        if p is None: return []
        return [p] if isinstance(p , str) else p
    
    def _saveone(self , obj , p , to_disk = False):
        if to_disk:
            torch.save(obj , p)
        else:
            self.mem_disk[p] = deepcopy(obj)
    
    def _addrecord(self , p , group):
        self.file_record = np.union1d(self.file_record , p)
        if group not in self.file_group.keys(): 
            self.file_group[group] = [p]
        else:
            self.file_group[group] = np.union1d(self.file_group[group] , [p])
    
    def save_model_state(self , model , paths , to_disk = False , group = 'default'):
        sd = model.state_dict() if (self.default == 'disk' or to_disk) else deepcopy(model).cpu().state_dict()
        self.save(sd , paths , to_disk , group)
        
    def load_model_state(self , model , path , from_disk = False):
        sd = self.load(path , from_disk)
        model.load_state_dict(sd)
        return model
            
    def valid_paths(self , paths):
        return np.intersect1d(self._pathlist(paths) ,  self.file_record).tolist()
    
    def del_path(self , *args):
        for paths in args:
            if self.default == 'disk':
                [os.remove(p) for p in self._pathlist(paths) if os.path.exists(p)]
            else:
                [self.mem_disk.__delitem__(p) for p in np.intersect1d(self._pathlist(paths) , list(self.mem_disk.keys()))]
            self.file_record = np.setdiff1d(self.file_record , paths)
        gc.collect()
        
    def del_group(self , clear_groups = []):
        for g in self._pathlist(clear_groups):
            paths = self.file_group.get(g)
            if paths is not None:
                self.del_path(paths)
                del self.file_group[g]
## mymodel
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
from copy import deepcopy

class mod_tcn_block(nn.Module):
    def __init__(self, input_dim , output_dim , dilation, dropout=0.0 , kernel_size=3):
        super().__init__()
        padding = (kernel_size-1) * dilation
        self.conv1 = weight_norm(nn.Conv1d(input_dim , output_dim, kernel_size, padding=padding, dilation=dilation))
        self.conv2 = weight_norm(nn.Conv1d(output_dim, output_dim, kernel_size, padding=padding, dilation=dilation))
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        
        self.net = nn.Sequential(self.conv1, self._chomp(padding), nn.ReLU(), nn.Dropout(dropout), 
                                 self.conv2, self._chomp(padding), nn.ReLU(), nn.Dropout(dropout))
        
        if input_dim != output_dim:
            self.residual = nn.Conv1d(input_dim , output_dim, 1)
            self.residual.weight.data.normal_(0, 0.01)
        else:
            self.residual = nn.Sequential()
        self.relu = nn.ReLU()

    def forward(self, inputs):
        output = self.net(inputs)
        output = self.relu(output + self.residual(inputs))
        return output
    
    class _chomp(nn.Module):
        def __init__(self, padding):
            super().__init__()
            self.padding = padding
        def forward(self, x):
            return x[:, :, :-self.padding] # .contiguous()

class mod_tcn(nn.Module):
    def __init__(self, input_dim , output_dim , dropout=0.0 , num_layers = 2 , kernel_size = 3):
        super().__init__()
        num_layers = max(2 , num_layers)
        layers = []
        for i in range(num_layers):
            dilation = 2 ** i
            inp_d , out_dim = (input_dim , output_dim) if i == 0 else (output_dim , output_dim)
            layers += [mod_tcn_block(inp_d, out_dim, dilation=dilation, dropout=dropout , kernel_size = kernel_size)]
        self.net = nn.Sequential(*layers)

    def forward(self, inputs):
        output = self.net(inputs.permute(0,2,1)).permute(0,2,1)
        return output
    
class mod_transformer(nn.Module):
    def __init__(self , input_dim , output_dim , dropout=0.0 , num_layers = 2):
        super().__init__()
        num_heads , ffn_dim = 8 , 4 * output_dim
        assert output_dim % num_heads == 0
        num_layers = max(2,num_layers)
        self.fc_in = nn.Sequential(nn.Linear(input_dim, output_dim),nn.Tanh())
        self.pos_enc = PositionalEncoding(output_dim,dropout=dropout)
        enc_layer = nn.TransformerEncoderLayer(output_dim , num_heads, dim_feedforward=ffn_dim , dropout=dropout , batch_first=True)
        self.trans = nn.TransformerEncoder(enc_layer , num_layers)
    def forward(self, inputs):
        hidden = self.fc_in(inputs)
        hidden = self.pos_enc(hidden)
        return self.trans(hidden)

class mod_lstm(nn.Module):
    def __init__(self , input_dim , output_dim , dropout=0.0 , num_layers = 2):
        super().__init__()
        num_layers = min(3,num_layers)
        self.lstm = nn.LSTM(input_dim , output_dim , num_layers = num_layers , dropout = dropout , batch_first = True)
    def forward(self, inputs):
        return self.lstm(inputs)[0]

class mod_gru(nn.Module):
    def __init__(self , input_dim , output_dim , dropout=0.0 , num_layers = 2):
        super().__init__()
        num_layers = min(3,num_layers)
        self.gru = nn.GRU(input_dim , output_dim , num_layers = num_layers , dropout = dropout , batch_first = True)
    def forward(self, inputs):
        return self.gru(inputs)[0]
    
class mod_ewlinear(nn.Module):
    def __init__(self, dim = -1 , keepdim = True):
        super().__init__()
        self.dim , self.keepdim = dim , keepdim
    def forward(self, inputs):
        return inputs.mean(dim = self.dim , keepdim = self.keepdim)
    
class mod_parallel(nn.Module):
    def __init__(self, sub_mod , num_mod , feedforward = True , concat_output = False):
        super().__init__()
        self.mod_list = nn.ModuleList([deepcopy(sub_mod) for _ in range(num_mod)])
        self.feedforward = feedforward
        self.concat_output = concat_output
    def forward(self, inputs):
        output = tuple([mod(inputs[i] if self.feedforward else inputs) for i,mod in enumerate(self.mod_list)])
        if self.concat_output:
            if isinstance(output[0] , (list,tuple)):
                output = tuple([torch.cat([out[i] for out in output] , dim = -1) for i in range(len(output[0]))])  
            else:
                output = torch.cat(output , dim = -1)
        return output

class rnn_univariate(nn.Module):
    def __init__(
        self,
        input_dim ,
        hidden_dim: int = 2**5,
        rnn_layers: int = 2,
        mlp_layers: int = 2,
        dropout:  float = 0.1,
        fc_att:    bool = False,
        fc_in:     bool = False,
        type_rnn:   str = 'gru',
        type_act:   str = 'LeakyReLU',
        num_output: int = 1 ,
        dec_mlp_dim:int = None,
        output_as_factors: bool = True,
        hidden_as_factors: bool = False,
        **kwargs
    ):
        super().__init__()
        self.num_output = num_output
        self.kwargs = kwargs
        self.kwargs.update({'input_dim':input_dim,'hidden_dim':hidden_dim,'rnn_layers':rnn_layers,'mlp_layers':mlp_layers,'dropout':dropout,
                            'fc_att':fc_att,'fc_in':fc_in,'type_rnn':type_rnn,'type_act':type_act,'num_output':num_output,'dec_mlp_dim':dec_mlp_dim,
                            'output_as_factors':output_as_factors,'hidden_as_factors':hidden_as_factors, 
                           })
        self.encoder = mod_parallel(uni_rnn_encoder(**self.kwargs) , num_mod = 1 , feedforward = False , concat_output = True)
        self.decoder = mod_parallel(uni_rnn_decoder(**self.kwargs) , num_mod = num_output , feedforward = False , concat_output = False)
        self.mapping = mod_parallel(uni_rnn_mapping(**self.kwargs) , num_mod = num_output , feedforward = True , concat_output = True)
        self.set_multiloss_params()

    def forward(self, inputs):
        # inputs.shape : (bat_size, seq, input_dim)
        hidden = self.encoder(inputs) # hidden.shape : (bat_size, hidden_dim)
        hidden = self.decoder(hidden) # hidden.shape : tuple of (bat_size, hidden_dim) , len is num_output
        output = self.mapping(hidden) # output.shape : (bat_size, num_output)   
        return output , hidden[0]
        
    def set_multiloss_params(self):
        self.multiloss_alpha = torch.nn.Parameter((torch.rand(self.num_output) + 1e-4).requires_grad_())
        
    def get_multiloss_params(self):
        return {'alpha':self.multiloss_alpha}
    
class rnn_multivariate(nn.Module):
    def __init__(
        self,
        input_dim ,
        hidden_dim: int = 2**5,
        rnn_layers: int = 2,
        mlp_layers: int = 2,
        dropout:  float = 0.1,
        fc_att:    bool = False,
        fc_in:     bool = False,
        type_rnn:   str = 'gru',
        type_act:   str = 'LeakyReLU',
        num_output: int = 1 ,
        rnn_att:   bool = False,
        num_heads:  int = None,
        dec_mlp_dim:int = None,
        ordered_param_group: bool = False,
        output_as_factors:   bool = True,
        hidden_as_factors:   bool = False,
        **kwargs,
    ):
        super().__init__()
        self.num_output = num_output
        self.num_rnn = len(input_dim) if isinstance(input_dim , (list,tuple)) else 1
        self.ordered_param_group = ordered_param_group
        self.kwargs = kwargs
        self.kwargs.update({'input_dim':input_dim,'hidden_dim':hidden_dim,'rnn_layers':rnn_layers,'mlp_layers':mlp_layers,'dropout':dropout,
                            'fc_att':fc_att,'fc_in':fc_in,'type_rnn':type_rnn,'type_act':type_act,'num_output':num_output,'dec_mlp_dim':dec_mlp_dim,
                            'rnn_att':rnn_att,'num_heads':num_heads,'num_rnn':self.num_rnn,
                            'ordered_param_group':ordered_param_group,'output_as_factors':output_as_factors,'hidden_as_factors':hidden_as_factors,
                           })
        mod_encoder = multi_rnn_encoder if self.num_rnn > 1 else uni_rnn_encoder
        mod_decoder = multi_rnn_decoder if self.num_rnn > 1 else uni_rnn_decoder
        mod_mapping = multi_rnn_mapping if self.num_rnn > 1 else uni_rnn_mapping

        self.encoder = mod_parallel(mod_encoder(**self.kwargs) , num_mod = 1 , feedforward = False , concat_output = True)
        self.decoder = mod_parallel(mod_decoder(**self.kwargs) , num_mod = num_output , feedforward = False , concat_output = False)
        self.mapping = mod_parallel(mod_mapping(**self.kwargs) , num_mod = num_output , feedforward = True , concat_output = True)

        self.set_multiloss_params()
        self.set_param_groups()
    
    def forward(self, inputs):
        # inputs.shape : tuple of (bat_size, seq , input_dim[i_rnn]) , len is num_rnn
        hidden = self.encoder(inputs) # hidden.shape : tuple of (bat_size, hidden_dim) , len is num_rnn
        hidden = self.decoder(hidden) # hidden.shape : tuple of (bat_size, num_rnn * hidden_dim) , len is num_output
        output = self.mapping(hidden) # output.shape : (bat_size, 1)      
        return output , hidden[0]
    
    def max_round(self):
        return len(self.param_groups)
    
    def set_param_groups(self):
        self.param_groups = []
        if self.ordered_param_group and self.num_rnn > 1:
            for i in range(self.num_rnn):
                _exclude_strings = np.array([[f'enc_list.{j}.',f'dec_list.{j}.'] for j in range(self.num_rnn) if j!=i]).flatten()
                self.param_groups.append([param for k,param in self.named_parameters() if all([k.find(_str) < 0 for _str in _exclude_strings])]) 
                assert len(self.param_groups[-1]) > 0
        else:
            self.param_groups.append(list(self.parameters())) 
    
    def training_round(self , round_num):
        [par.requires_grad_(round_num >= self.max_round()) for par in self.parameters()]
        [par.requires_grad_(True) for par in self.param_groups[round_num]]
        
    def set_multiloss_params(self):
        self.multiloss_alpha = torch.nn.Parameter((torch.rand(self.num_output) + 1e-4).requires_grad_())
        
    def get_multiloss_params(self):
        return {'alpha':self.multiloss_alpha}
    
class MyGRU(rnn_univariate):
    def __init__(self , input_dim , type_rnn = 'gru' , num_output = 1 , **kwargs):
        super().__init__(input_dim , type_rnn = 'gru' , num_output = 1 , **kwargs)
        
class MyLSTM(rnn_univariate):
    def __init__(self , input_dim , type_rnn = 'lstm' , num_output = 1 , **kwargs):
        super().__init__(input_dim , type_rnn = 'lstm' , num_output = 1 , **kwargs)
        
class MyTransformer(rnn_univariate):
    def __init__(self , input_dim , type_rnn = 'transformer' , num_output = 1 , **kwargs):
        super().__init__(input_dim , type_rnn = 'transformer' , num_output = 1 , **kwargs)
        
class MyTCN(rnn_univariate):
    def __init__(self , input_dim , type_rnn = 'tcn' , num_output = 1 , **kwargs):
        super().__init__(input_dim , type_rnn = 'tcn' , num_output = 1 , **kwargs)
        
class MynTaskRNN(rnn_univariate):
    def __init__(self , input_dim , num_output = 1 , **kwargs):
        super().__init__(input_dim , num_output = num_output , **kwargs)

class MyGeneralRNN(rnn_multivariate):
    def __init__(self , input_dim , **kwargs):
        super().__init__(input_dim , **kwargs)

class uni_rnn_encoder(nn.Module):
    def __init__(self,input_dim,hidden_dim,rnn_layers,dropout,fc_att,fc_in,type_rnn,**kwargs):
        super().__init__()
        self.mod_rnn = {'transformer':mod_transformer,'lstm':mod_lstm,'gru':mod_gru,'tcn':mod_tcn,}[type_rnn]
        if type_rnn == 'transformer': fc_in , fc_att = False , False
        
        self.rnn_kwargs = {'input_dim':hidden_dim if fc_in else input_dim, 'output_dim':hidden_dim,'num_layers':rnn_layers, 'dropout':dropout}
        if 'kernel_size' in kwargs.keys() and type_rnn == 'tcn': self.rnn_kwargs['kernel_size'] = kwargs['kernel_size']
        
        self.fc_in = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.Tanh()) if fc_in else nn.Sequential()
        self.fc_rnn = self.mod_rnn(**self.rnn_kwargs)
        self.fc_enc_att = TimeWiseAttention(hidden_dim,hidden_dim,dropout=dropout) if fc_att else None
    def forward(self, inputs):
        # inputs.shape : (bat_size, seq, input_dim)
        # output.shape : (bat_size, hidden_dim)
        output = self.fc_in(inputs)
        output = self.fc_rnn(output)
        output = self.fc_enc_att(output) if self.fc_enc_att else output[:,-1]
        return output
    
class uni_rnn_decoder(nn.Module):
    def __init__(self,hidden_dim,dec_mlp_dim,mlp_layers,dropout,type_act,hidden_as_factors,map_to_one=False,**kwargs):
        super().__init__()
        assert type_act in ['LeakyReLU' , 'ReLU']
        self.mod_act = getattr(nn , type_act)
        self.fc_dec_mlp = nn.Sequential()
        mlp_dim = dec_mlp_dim if dec_mlp_dim else hidden_dim
        for i in range(mlp_layers): 
            self.fc_dec_mlp.append(nn.Sequential(nn.Linear(hidden_dim if i == 0 else mlp_dim , mlp_dim), self.mod_act(), nn.Dropout(dropout)))
        if hidden_as_factors:
            self.fc_hid_out = nn.Sequential(nn.Linear(mlp_dim , 1 if map_to_one else hidden_dim) , nn.BatchNorm1d(1 if map_to_one else hidden_dim)) 
        else:
            self.fc_hid_out = nn.Linear(mlp_dim , 1 if map_to_one else hidden_dim)
    def forward(self, inputs):
        # inputs.shape : (bat_size, hidden_dim)
        # output.shape : (bat_size, out_dim/hidden_dim)
        output = self.fc_dec_mlp(inputs)
        output = self.fc_hid_out(output)
        return output
    
class uni_rnn_mapping(nn.Module):
    def __init__(self,hidden_dim,output_as_factors,hidden_as_factors,**kwargs):
        super().__init__()
        self.fc_map_out = nn.Sequential(mod_ewlinear()) if hidden_as_factors else nn.Sequential(nn.Linear(hidden_dim, 1))
        if output_as_factors: self.fc_map_out.append(nn.BatchNorm1d(1))
    def forward(self, inputs):
        # inputs.shape : (bat_size, hidden_dim)
        # output.shape : (bat_size, 1)
        return self.fc_map_out(inputs)

class multi_rnn_encoder(nn.Module):
    def __init__(self,input_dim,hidden_dim,**kwargs):
        super().__init__()
        self.enc_list = nn.ModuleList([uni_rnn_encoder(d_inp,hidden_dim,**kwargs) for d_inp in input_dim])
    def forward(self, inputs):
        # inputs.shape : tuple of (bat_size, seq , input_dim[i_rnn]) , seq can be different 
        # output.shape : tuple of (bat_size, hidden_dim) or tuple of (bat_size, 1) if ordered_param_group
        output = [mod(inp) for inp , mod in zip(inputs , self.enc_list)]
        return output
    
class multi_rnn_decoder(nn.Module):
    def __init__(self, hidden_dim,num_rnn,rnn_att,ordered_param_group,hidden_as_factors,**kwargs):
        super().__init__()
        self.dec_list = nn.ModuleList([uni_rnn_decoder(hidden_dim , hidden_as_factors = False , map_to_one = ordered_param_group , **kwargs) for _ in range(num_rnn)])
        self.fc_mod_att = nn.Sequential()
        if ordered_param_group:
            self.fc_hid_out =  nn.BatchNorm1d(num_rnn)
        else:
            if rnn_att: 
                self.fc_mod_att = ModuleWiseAttention(hidden_dim,num_rnn , num_heads=kwargs['num_heads'] , dropout=kwargs['dropout'] , seperate_output=True)
            if hidden_as_factors:
                self.fc_hid_out = nn.Sequential(nn.Linear(num_rnn*hidden_dim , hidden_dim) , nn.BatchNorm1d(hidden_dim))
            else:
                self.fc_hid_out = nn.Linear(num_rnn*hidden_dim , hidden_dim)
    def forward(self, inputs):
        # inputs.shape : tuple of (bat_size, hidden_dim) , len is num_rnn
        # output.shape : (bat_size, hidden_dim) or (bat_size, num_rnn) if ordered_param_group
        output = [mod(inp) for inp , mod in zip(inputs , self.dec_list)]
        output = torch.cat(self.fc_mod_att(output) , dim = -1)
        output = self.fc_hid_out(output)
        return output
    
class multi_rnn_mapping(nn.Module):
    def __init__(self,hidden_dim,num_rnn,ordered_param_group,output_as_factors, hidden_as_factors,**kwargs):
        super().__init__()
        if ordered_param_group or hidden_as_factors: 
            self.fc_map_out = nn.Sequential(mod_ewlinear())
        else:
            self.fc_map_out = nn.Sequential(nn.Linear(hidden_dim, 1))
        if output_as_factors:  self.fc_map_out.append(nn.BatchNorm1d(1))
    def forward(self, inputs):
        # inputs.shape : (bat_size, hidden_dim) or (bat_size, num_rnn) if ordered_param_group
        # output.shape : (bat_size, 1)
        return self.fc_map_out(inputs)
    
class TimeWiseAttention(nn.Module):
    def __init__(self , input_dim, output_dim=None, att_dim = None, dropout = 0.0):
        super().__init__()
        if output_dim is None: output_dim = input_dim
        if att_dim is None: att_dim = output_dim
        self.fc_in = nn.Linear(input_dim, att_dim)
        self.att_net = nn.Sequential(nn.Dropout(dropout),nn.Tanh(),nn.Linear(att_dim,1,bias=False),nn.Softmax(dim=0))
        self.fc_out = nn.Linear(2*att_dim,output_dim)

    def forward(self, inputs):
        inputs = self.fc_in(inputs)
        att_score = self.att_net(inputs)  # [batch, seq_len, 1]
        output = torch.mul(inputs, att_score).sum(dim=1)
        output = torch.cat((inputs[:, -1], output), dim=1)
        return self.fc_out(output)
    
class ModuleWiseAttention(nn.Module):
    def __init__(self , input_dim , mod_num = None , att_dim = None , num_heads = None , dropout=0.0 , seperate_output = True):
        super().__init__()
        if isinstance(input_dim , (list,tuple)):
            assert mod_num == len(input_dim)
        else:
            input_dim = [input_dim for _ in range(mod_num)]
        
        att_dim = max(input_dim) if att_dim is None else att_dim
        num_heads = att_dim // 8 if num_heads is None else num_heads
        
        self.in_fc = nn.ModuleList([nn.Linear(inp_d , att_dim) for inp_d in input_dim])
        self.task_mha = nn.MultiheadAttention(att_dim, num_heads = num_heads, batch_first=True , dropout = dropout)
        self.seperate_output = seperate_output
    def forward(self, inputs):
        hidden = torch.stack([f(x) for x,f in zip(inputs,self.in_fc)],dim=-2)
        hidden = self.task_mha(hidden , hidden , hidden)[0] + hidden
        if self.seperate_output:
            return tuple([hidden.select(-2,i) for i in range(hidden.shape[-2])])
        else:
            return hidden
        
class PositionalEncoding(nn.Module):
    def __init__(self, input_dim, dropout=0.0, max_len=1000,**kwargs):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.seq_len = max_len
        self.P = torch.zeros(1 , self.seq_len, input_dim)
        X = torch.arange(self.seq_len, dtype=torch.float).reshape(-1,1) / torch.pow(10000,torch.arange(0, input_dim, 2 ,dtype=torch.float) / input_dim)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X[:,:input_dim//2])
    def forward(self, inputs):
        return self.dropout(inputs + self.P[:,:inputs.shape[1],:].to(inputs.device))

class SampleWiseTranformer(nn.Module):
    def __init__(self , hidden_dim , ffn_dim = None , num_heads = 8 , encoder_layers = 2 , dropout=0.0):
        super().__init__()
        assert hidden_dim % num_heads == 0
        ffn_dim = 4 * hidden_dim if ffn_dim is None else ffn_dim
        self.fc_att = TimeWiseAttention(hidden_dim,hidden_dim)
        enc_layer = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ffn_dim , dropout=dropout , batch_first=True)
        self.trans = nn.TransformerEncoder(enc_layer , encoder_layers)
    def forward(self, inputs , pad_mask = None):
        if inputs.isnan().any():
            pad_mask = self.pad_mask_nan(inputs) if pad_mask is None else (self.pad_mask_nan(inputs) + pad_mask) > 0
            inputs = inputs.nan_to_num()
        hidden = hidden.unsqueeze(0) if hidden.dim() == 2 else self.fc_att(inputs).unsqueeze(0)
        return self.trans(hidden , src_key_padding_mask = pad_mask).squeeze(0)
    def pad_mask_rand(self , inputs , mask_ratio = 0.1):
        return (torch.rand(1,inputs.shape[0]) < mask_ratio).to(inputs.device)
    def pad_mask_nan(self , inputs):
        return inputs.sum(dim = tuple(torch.arange(inputs.dim())[1:])).isnan().unsqueeze(0)    

class TimeWiseTranformer(nn.Module):
    def __init__(self , input_dim , hidden_dim , ffn_dim = None , num_heads = 8 , encoder_layers = 2 , dropout=0.0):
        super().__init__()
        assert hidden_dim % num_heads == 0
        ffn_dim = 4 * hidden_dim if ffn_dim is None else ffn_dim
        self.pos_enc = PositionalEncoding(hidden_dim,dropout=dropout)
        enc_layer = nn.TransformerEncoderLayer(hidden_dim , num_heads, dim_feedforward=ffn_dim , dropout=dropout , batch_first=True)
        self.trans = nn.TransformerEncoder(enc_layer , encoder_layers)
    def forward(self, inputs):
        hidden = self.pos_enc(hidden)
        return self.trans(hidden)
    
## gen_data
import torch , h5py
import numpy as np
import pandas as pd
import os, shutil , gc , copy , time
import yaml
# from globalvars import *

NBARS      = {'day' : 1 , '15m' : 16 ,}
BEFORE_DAY = 20170101
STEP_DAY   = 5
DATATYPE   = get_config('data_type')['DATATYPE']

update_files = ['day_trading_data' , 'day_ylabels_data' , '15m_trading_data']
data_index_dict = {'day' : ('SecID' , 'TradeDate') , '15m' : ('SecID' , 'TradeDateTime') , '30m' : ('SecID' , 'TradeDateTime') ,
                   'gp' : ('SecID' , 'TradeDate') ,}

dir_nas    = None # f'/root/autodl-nas'
dir_data   = f'./data'
dir_update = f'{dir_data}/update_data'

path_ydata = f'{dir_data}/Ys.npz'
path_xdata = lambda x:f'{dir_data}/Xs_{x}.npz'
path_norm_param = f'{dir_data}/norm_param.pt'

logger = get_logger()

def fetch_update_from_nas():
    source_dir = dir_nas
    if source_dir is None: return
    target_dir = dir_update
    os.makedirs(target_dir , exist_ok=True)
    fetch_list = []
    for file_starter in update_files:
        f_list = [f for f in os.listdir(source_dir) if f.startswith(file_starter)]
        for f in f_list:
            shutil.copy(f'{source_dir}/{f}', f'{target_dir}/{f}')
            os.remove(f'{source_dir}/{f}')
            fetch_list.append(f)
    if len(fetch_list) > 0 : print('{:s} copy file finished!'.format(', '.join(fetch_list)))
    return

def update_trading_data(remove_update_file = True):
    fetch_update_from_nas()
    target_dir = dir_data
    source_dir = dir_update
    for file_starter in update_files:
        row_var , col_var = data_index_dict[file_starter.split('_')[0]]
        
        target_path = f'{target_dir}/{file_starter}.h5'
        source_path = sorted([f'{source_dir}/{f}' for f in os.listdir(source_dir) if f.startswith(file_starter)])
        if len(source_path) > 0 and os.path.exists(target_path) == 0:
            shutil.copy(source_path[0] , target_path)
            source_path = source_path[1:]
        if len(source_path) == 0: continue
        target_file = h5py.File(target_path , mode='r+')
        source_file = [h5py.File(f , mode='r') for f in source_path]

        row_tuple = tuple([f.get(row_var)[:] for f in [target_file] + source_file])
        col_tuple = tuple([f.get(col_var)[:] for f in [target_file] + source_file])
        row_all , col_all = None , None
        
        for k in sorted(list(target_file.keys() - [row_var , col_var])):
            t0 = time.time()
            data = tuple([f.get(k)[:] for f in [target_file] + source_file])
            data , row_all , col_all = merge_data_2d(data , row_tuple , col_tuple , row_all , col_all)
            row_all , col_all = np.array(row_all).astype(int) , np.array(col_all).astype(int)
            
            del target_file[k]
            target_file.create_dataset(k , data = data , compression="gzip")
            print(f'{file_starter} -> {k} cost {(time.time() - t0):.2f}')
        
        del target_file[row_var]
        target_file.create_dataset(row_var , data = row_all , compression="gzip")

        del target_file[col_var]
        target_file.create_dataset(col_var , data = col_all , compression="gzip")

        [f.close() for f in source_file]
        target_file.close()
        if remove_update_file: [os.remove(f'{source_dir}/{f}') for f in os.listdir(source_dir) if f.startswith(file_starter)]
        print(f'Update {file_starter} Finished! From {min(col_all)} to {max(col_all)} , of {len(row_all)} stocks')
    return

def prepare_model_data():
    source_dir = dir_data
    target_dir = dir_data

    for file_starter in update_files:
        print(f'Preparing {file_starter} data...')
        model_data_type , feature_type = file_starter.split('_')[0] , file_starter.split('_')[1]
        row_var , col_var = data_index_dict[model_data_type]
        
        source_file = h5py.File(f'{source_dir}/{file_starter}.h5' , mode = 'r')
        row , col = source_file.get(row_var)[:] , source_file.get(col_var)[:]
        
        if feature_type == 'ylabels':
            feat = ['Y10Delay' , 'Y5Delay']
            file_path = path_ydata
        elif feature_type == 'trading':
            feat = ['OpenPrice','HighPrice','LowPrice','ClosePrice','TradeVolume','VWPrice']
            file_path = path_xdata(model_data_type)
        else:
            raise Exception(f'KeyError : {feature_type}')
            
        arr = np.array([source_file.get(k)[:] for k in feat]).transpose(1,2,0)
        
        if col_var == 'TradeDateTime':
            col = col // 100
            assert sum([j != NBARS[model_data_type] for j in [list(col).count(i) for i in set(col)]]) == 0
            col = col[::NBARS[model_data_type]]
            arr = arr.reshape(arr.shape[0] , -1 , NBARS[model_data_type] , arr.shape[2])
        else:
            arr = arr.reshape(arr.shape[0] , -1 , 1 , arr.shape[2])
        
        assert(arr.shape[0] , arr.shape[1]) == (len(row) , len(col))
        
        save_data_file(file_path , row , col , feat , arr)
        source_file.close()
        print(f'arr shape : {arr.shape} , row shape : {row.shape} , col shape : {col.shape}')
    return

def save_data_file(file_path , row , col , feat , arr):
    if len(arr.shape) == 3:
        arr = arr.reshape(arr.shape[0],arr.shape[1],1,arr.shape[3])
    elif len(arr.shape) == 2:
        arr = arr.reshape(arr.shape[0],arr.shape[1],1,1)
    elif len(arr.shape) == 1:
        raise Exception(f'DimError: shape is {str(arr.shape)}')
    assert (arr.shape[0] , arr.shape[1] , arr.shape[-1]) == (len(row) , len(col) , len(feat))
    np.savez_compressed(file_path , row = row , col = col , feat = feat , arr = arr)

def cal_norm_param(maxday = 60 , before_day = BEFORE_DAY , step_day = STEP_DAY):
    norm_param = {}
    for model_data_type in DATATYPE['trade']:
        if not os.path.exists(path_xdata(model_data_type)): continue
        logger.error(f'[{model_data_type}] Data avg and std generation start!')
        t0 = time.time()
        x_dict = np.load(path_xdata(model_data_type))
        
        row_data , col_data = np.array(x_dict['row'] , dtype = int) , np.array(x_dict['col'] , dtype = int)
        beg_col_id = (col_data < before_day).sum()
        x = torch.tensor(np.array(x_dict['arr'])[:, :beg_col_id, :]).to(dtype = torch.float)
        
        print(f'Loading {model_data_type} trading data finished, cost {time.time() - t0:.2f} Secs')
        stock_n , day_len , _ , feat_dim = x.shape
        step_len = day_len // step_day
        bars_len = maxday * NBARS[model_data_type]
        padd_len = (0,0,0,0,0,max(0 , maxday - step_day),0,0)
        
        x = torch.nn.functional.pad(x,padd_len,value=np.nan)
        avg_x = torch.zeros(bars_len , feat_dim)
        std_x = torch.zeros(bars_len , feat_dim)
        
        x_div = torch.ones(stock_n , step_len , 1 , feat_dim)
        x_div.copy_(x[:,(maxday - 1):(maxday - 1 + day_len):step_day,-1:])
        print(x_div.shape)
        
        nan_sample = (x_div == 0).sum(dim = (-2,-1)) > 0
        for i in range(maxday):
            nan_sample += x[:,i:(i+day_len):step_day,:,:].reshape(stock_n,step_len,-1).isnan().any(dim = -1)

        for i in range(maxday):
            # (stock_n , step_len)(nonnan_sample) , day_bars , feat_dim
            vijs = (x[:,i:(i+day_len):step_day,:,:] / x_div)[nan_sample == 0]
            avg_x[i*NBARS[model_data_type]:(i+1)*NBARS[model_data_type]] = vijs.mean(dim = 0)
            std_x[i*NBARS[model_data_type]:(i+1)*NBARS[model_data_type]] = vijs.std(dim = 0)
        assert avg_x.isnan().sum() + std_x.isnan().sum() == 0

        norm_param.update({model_data_type : {'avg' : avg_x , 'std' : std_x}})
        del x
        gc.collect()
        
    torch.save(norm_param , path_norm_param)

def load_trading_data(model_data_type , precision = 'float'):
    
    t0 = time.time()
    tensor_precision = getattr(torch , precision)
    def set_precision(data):
        if isinstance(data , dict):
            return {k:set_precision(v) for k,v in data.items()}
        elif isinstance(data , (list,tuple)):
            return type(data)(map(set_precision , data))
        else:
            return data.to(tensor_precision)
    
    read_index = lambda x:(np.array(x['row'],dtype=int),np.array(x['col'],dtype=int))
    read_data  = lambda x:torch.tensor(x['arr']).detach()
    i_exact  = lambda x,y:np.intersect1d(x , y , assume_unique=True , return_indices = True)[1]
    i_latest = lambda x,y:np.array([np.where(x<=i)[0][-1] for i in y])
    
    data_type_list = model_data_type.split('+')
    y_file = np.load(path_ydata)
    x_file = {mdt:np.load(path_xdata(mdt)) for mdt in data_type_list}
    
    # aligned row,col
    yr , yc = read_index(y_file)
    x_index = {mdt:read_index(f) for mdt,f in x_file.items()}
    
    row , xc_trade , xc_factor = yr , None , None
    for mdt , (xr , xc) in x_index.items():
        row = np.intersect1d(row , xr)
        if mdt in DATATYPE['trade']:
            xc_trade = xc if xc_trade is None else np.intersect1d(xc_trade , xc)
        else:
            xc_factor = xc if xc_factor is None else np.union1d(xc_factor , xc)

    col = xc_factor if xc_trade is None else xc_trade
    if xc_factor: col = col[col >= xc_factor.min()]
    col , xc_tail = np.intersect1d(col , yc) , col[col > yc.max()]

    y_data = read_data(y_file)[i_exact(yr,row),:][:,i_exact(yc,col)]
    y_data = set_precision(torch.nn.functional.pad(y_data , (0,0,0,0,0,len(xc_tail),0,0) , value=np.nan))
    col = np.concatenate((col , xc_tail))
    
    x_data = {}
    for mdt,(xr , xc) in x_index.items():
        i0 , i1 = i_exact(xr,row) , i_exact(xc,col) if mdt in DATATYPE['trade'] else i_latest(xc,col)
        x_data.update({mdt:set_precision(read_data(x_file[mdt])[i0,:][:,i1])})
    
    # norm_param
    norm_param = {k:set_precision(v) for k,v in torch.load(path_norm_param).items()}

    # check
    assert all([d.shape[0] == y_data.shape[0] == len(row) for mdt,d in x_data.items()])
    assert all([d.shape[1] == y_data.shape[1] == len(col) for mdt,d in x_data.items()])
    
    return x_data , y_data , norm_param , (row , col)

"""
if __name__ == '__main__':
    t1 = time.time()
    logger.critical('Data loading start!')
        
    update_trading_data()
    prepare_model_data()
    cal_norm_param()
    
    t2 = time.time()
    logger.critical('Data loading Finished! Cost {:.2f} Seconds'.format(t2-t1))
"""


"\nif __name__ == '__main__':\n    t1 = time.time()\n    logger.critical('Data loading start!')\n        \n    update_trading_data()\n    prepare_model_data()\n    cal_norm_param()\n    \n    t2 = time.time()\n    logger.critical('Data loading Finished! Cost {:.2f} Seconds'.format(t2-t1))\n"

In [2]:
from scripts.data_util.ModelData import ModelData2
from scripts.data_util.ModelData import ModelData as ModelData3
from scripts.util.trainer import trainer_parser , train_config


class ModelData():
    """
    A class to store relavant training data , includes:
    1. Parameters: train_params , compt_params , model_data_type
    2. Datas: x_data , y_data , norm_param , index_stock , index_date
    3. Dataloader : yield x , y of training samples , create new ones if necessary
    """
    def __init__(self):     
        args = {'process':1,'rawname':1,'resume':0,'anchoring':0}
        try:
            parser = trainer_parser(args).parse_args()
        except:
            parser = trainer_parser(args).parse_args(args=[])

        config2 = train_config(parser = parser , do_process=True , config_files=['train2'])
        self.data2 = ModelData2(['day'] , config2 , debug_type=debug_type)
        self.x_data = self.data2.x_data
        self.y_data , self.norm_param , (self.index_stock , self.index_date) = self.data2.y_data , self.data2.norms , self.data2.index 
        self.stock_n , self.all_day_len = self.y_data.shape[:2]
        self.labels_n = self.y_data.shape[-1] if any([smp['num_output'] > 1 for smp in ShareNames.model_params]) else 1
        self.feat_dims = {mdt:v.shape[-1] for mdt,v in self.x_data.items()}
        if len(ShareNames.data_type_list) > 1: 
            [smp.update({'input_dim':tuple([self.feat_dims[mdt] for mdt in ShareNames.data_type_list])}) for smp in ShareNames.model_params]
        else:
            [smp.update({'input_dim':self.feat_dims[ShareNames.data_type_list[0]]}) for smp in ShareNames.model_params]
        #self.x_feat_dim = {mdt:v.shape[-1] for mdt,v in self.x_data.items()}
        #self.input_dim = [self.x_feat_dim[mdt] for mdt in ShareNames.data_type_list if mdt in config['DATATYPE']['trade']]
        #self.factor_dim = [self.x_feat_dim[mdt] for mdt in ShareNames.data_type_list if mdt in config['DATATYPE']['factor']]
        #if len(self.input_dim) > 0: [smp.update({'input_dim':self.input_dim[0]}) for smp in ShareNames.model_params]
        #if len(self.factor_dim) > 0: [smp.update({'factor_dim':self.factor_dim[0]}) for smp in ShareNames.model_params]
        
        self.input_step = config['INPUT_STEP_DAY']
        self.test_step  = config['TEST_STEP_DAY']

        ShareNames.model_date_list = self.index_date[(self.index_date >= config['BEG_DATE']) & (self.index_date <= config['END_DATE'])][::config['INTERVAL']]
        ShareNames.test_full_dates = self.index_date[(self.index_date > config['BEG_DATE']) & (self.index_date <= config['END_DATE'])]
        self.reset_dataloaders()
    
    def reset_dataloaders(self):        
        """
        Reset dataloaders and dataloader_param
        """
        self.dataloaders = {}
        self.dataloader_param = ()
        gc.collect() , torch.cuda.empty_cache()
    
    def new_train_dataloader(self , model_date , seqlens):
        """
        Create train/valid dataloaders
        """
        assert ShareNames.process_name in ['train' , 'instance']
        self.dataloader_param = (model_date , seqlens)

        if debug_method1 == 0:
            if debug_type == 0:
                self.data2.new_train_dataloader(model_date , seqlens)
            else:
                self.data2.create_dataloader(ShareNames.process_name , 'train' , model_date , seqlens)
            self.dataloaders['train'] = self.data2.dataloaders['train']
            self.dataloaders['valid'] = self.data2.dataloaders['valid']
        else:
            self.data2.new_train_dataloader(model_date , seqlens)

            if starting_point <= 0:
                self.i_train , self.i_valid , self.ii_train , self.ii_valid = None , None , None , None
                self.y_train , self.y_valid , self.train_nonnan_sample = None , None , None
                gc.collect() , torch.cuda.empty_cache()
                
                seqlens = {mdt:(seqlens[mdt] if mdt in seqlens.keys() else 1) for mdt in ShareNames.data_type_list}
                self.seq0 = max(seqlens.values())
                self.seq = {mdt:self.seq0 + seqlens[mdt] if seqlens[mdt] <= 0 else seqlens[mdt] for mdt in ShareNames.data_type_list}
                model_date_col = (self.index_date < model_date).sum()    
                d0 , d1 = max(0 , model_date_col - 15 - config['INPUT_SPAN']) , max(0 , model_date_col - 15)
                self.day_len  = d1 - d0
                self.step_len = self.day_len // self.input_step
                self.lstepped = np.arange(0 , self.day_len , self.input_step)[:self.step_len]
                
                data_func = lambda x:torch.nn.functional.pad(x[:,d0:d1] , (0,0,0,0,0,self.seq0-self.input_step,0,0) , value=np.nan)
                x = {k:data_func(v) for k,v in self.x_data.items()}
                y = data_func(self.y_data).squeeze(2)[:,:,:self.labels_n]
            else:
                self.seq0 = self.data2.seq0
                self.seq = self.data2.seq
                self.day_len  = self.data2.day_len
                self.step_len = self.data2.step_len
                self.lstepped = self.data2.lstepped
                x , y = self.data2.tmp_x , self.data2.tmp_y
            
            if starting_point <=1:
                self._train_nonnan_sample(x , y)
            else:
                self.train_nonnan_sample = self.data2.tmp_train_nonnan_sample
            
            if starting_point <=2:
                self._train_tv_split()
            else:
                self.ii_train , self.ii_valid = self.data2.tmp_ii_train , self.data2.tmp_ii_valid

            if starting_point <=3:
                self._train_y_data(y)
            else:
                self.i_train , self.i_valid = self.data2.tmp_i_train , self.data2.tmp_i_valid
                self.y_train , self.y_valid = self.data2.tmp_y_train , self.data2.tmp_y_valid

            if starting_point <=4:
                self._train_dataloader(x)
            else:
                self.dataloaders['train'] = self.data2.dataloaders['train']
                self.dataloaders['valid'] = self.data2.dataloaders['valid']
                
            x , y = None , None
            self.i_train , self.i_valid , self.ii_train , self.ii_valid = None , None , None , None
            self.y_train , self.y_valid , self.train_nonnan_sample = None , None , None
            gc.collect() , torch.cuda.empty_cache()
        
    def new_test_dataloader(self , model_date , seqlens):
        """
        Create test dataloaders
        """
        assert ShareNames.process_name in ['test' , 'instance']
        self.dataloader_param = (model_date , seqlens)
        
        self.x_test , self.y_test = None , None
        gc.collect() , torch.cuda.empty_cache()
        
        seqlens = {mdt:(seqlens[mdt] if mdt in seqlens.keys() else 1) for mdt in ShareNames.data_type_list}
        self.seq0 = max(seqlens.values())
        self.seq = {mdt:self.seq0 + seqlens[mdt] if seqlens[mdt] <= 0 else seqlens[mdt] for mdt in ShareNames.data_type_list}
        
        if model_date == ShareNames.model_date_list[-1]:
            next_model_date = config['END_DATE'] + 1
        else:
            next_model_date = ShareNames.model_date_list[ShareNames.model_date_list > model_date][0]
        _step = (1 if ShareNames.process_name == 'instance' else self.test_step)
        _dates_list = ShareNames.test_full_dates[::_step]
        self.model_test_dates = _dates_list[(_dates_list > model_date) * (_dates_list <= next_model_date)]
        d0 , d1 = np.where(self.index_date == self.model_test_dates[0])[0][0] , np.where(self.index_date == self.model_test_dates[-1])[0][0] + 1
        self.day_len  = d1 - d0
        self.step_len = (self.day_len // _step) + (0 if self.day_len % _step == 0 else 1)
        self.lstepped = np.arange(0 , self.day_len , _step)[:self.step_len]
        
        data_func = lambda x:x[:,d0 - self.seq0 + 1:d1]
        x = {k:data_func(v) for k,v in self.x_data.items()}
        y = data_func(self.y_data).squeeze(2)[:,:,:self.labels_n]
        
        self._test_nonnan_sample(x , y)
        self._test_y_data(y)
        self._test_dataloader(x)
        x , y = None , None
        self.x_test = None
        gc.collect() , torch.cuda.empty_cache()
        
    def _train_nonnan_sample(self , x , y):
        """
        return non-nan sample position (with shape of stock_n * step_len)
        """
        nansamp = y[:,self.lstepped + self.seq0 - 1].isnan().sum(-1)
        for mdt in ShareNames.data_type_list:
            for i in range(self.seq[mdt]): nansamp += x[mdt][:,(self.seq0 - self.seq[mdt] + i):][:,self.lstepped].isnan().sum((2,3))
            if mdt in config['DATATYPE']['trade']: nansamp += (x[mdt][:,self.lstepped + self.seq0 - 1][:,:,-1] == 0).sum(-1)
        self.train_nonnan_sample = (nansamp == 0)
            
    def _train_tv_split(self):
        """
        update index of train/valid sub-samples of flattened all-samples(with in 0:stock_n * step_len - 1)
        """
        ii_stock_wise = np.arange(self.stock_n * self.step_len)[self.train_nonnan_sample.flatten()]
        ii_time_wise  = np.arange(self.stock_n * self.step_len).reshape(self.step_len , self.stock_n).transpose().flatten()[ii_stock_wise]
        train_samples = int(len(ii_stock_wise) * ShareNames.train_params['dataloader']['train_ratio'])
        random.seed(ShareNames.train_params['dataloader']['random_seed'])
        if ShareNames.train_params['dataloader']['random_tv_split']:
            random.shuffle(ii_stock_wise)
            ii_train , ii_valid = ii_stock_wise[:train_samples] , ii_stock_wise[train_samples:]
        else:
            early_samples = ii_time_wise < sorted(ii_time_wise)[train_samples]
            ii_train , ii_valid = ii_stock_wise[early_samples] , ii_stock_wise[early_samples == 0]
        random.shuffle(ii_train) , random.shuffle(ii_valid)
        self.ii_train , self.ii_valid = ii_train , ii_valid
    
    def _train_y_data(self , y):
        """
        update position (stock_i , date_i) of and normalized (maybe include w) train/valid ydata
        """
        # init i (row , col position) and y (labels) matrix
        i_tv = torch.zeros(self.stock_n , self.step_len , 2 , dtype = int) # i_row (sec) , i_col_x (end)
        i_tv[:,:,0] = torch.tensor(np.arange(self.stock_n , dtype = int)).reshape(-1,1) 
        i_tv[:,:,1] = torch.tensor(self.lstepped + self.seq0 - 1)
        i_tv = i_tv.reshape(-1,i_tv.shape[-1])
        self.i_train , self.i_valid = (i_tv[self.ii_train] , i_tv[self.ii_valid])
        
        y_tv = torch.zeros(self.stock_n , self.step_len , self.labels_n)
        y_tv[:] = y[:,self.lstepped + self.seq0 - 1].nan_to_num(0)
        y_tv[self.train_nonnan_sample == 0] = np.nan
        y_tv , w_tv = tensor_standardize_and_weight(y_tv , dim = 0)
        y_tv , w_tv = y_tv.reshape(-1,y_tv.shape[-1]) , w_tv.reshape(-1,w_tv.shape[-1]) 
        self.y_train , self.y_valid = (y_tv[self.ii_train] , y_tv[self.ii_valid])
        # self.w_train , self.w_valid = (w_tv[self.ii_train] , w_tv[self.ii_valid])
        
    def _train_dataloader(self , x):
        """
        1. if model_data_type == 'day' , update dataloaders dict(dict.key = ['train' , 'valid']), by using a oneshot method
        2. update dataloaders dict(set_name = ['train' , 'valid']), save batch_data to './model/{model_name}/{set_name}_batch_data' and later load them
        """
        if ShareNames.model_data_type == 'day' and False:
            mdt = 'day'
            x_tv = self._norm_x(torch.cat([x[mdt][:,self.lstepped + i] for i in range(self.seq[mdt])] , dim=2) , mdt)
            x_tv = x_tv.reshape(-1 , self.seq[mdt] , self.feat_dims[mdt])
            x_train , x_valid = x_tv[self.ii_train] , x_tv[self.ii_valid]
            num_worker = min(os.cpu_count() , ShareNames.compt_params['num_worker'])
            self.dataloaders['train'] = self.dataloader_oneshot((x_train , self.y_train) , num_worker , ShareNames.compt_params['cuda_first'])
            self.dataloaders['valid'] = self.dataloader_oneshot((x_valid , self.y_valid) , num_worker , ShareNames.compt_params['cuda_first'])
        else:
            storage_loader.del_group('train')
            set_iter = [('train' , self.i_train , self.y_train) , ('valid' , self.i_valid , self.y_valid)]
            for set_name , set_i , set_y in set_iter:
                batch_sampler = torch.utils.data.BatchSampler(range(len(set_i)) , ShareNames.batch_size , drop_last = False)
                batch_file_list = []
                for batch_num , batch_pos in enumerate(batch_sampler):
                    batch_file_list.append(ShareNames.batch_dir[set_name] + f'/{set_name}.{batch_num}.pt')
                    i0 , i1 , batch_y , batch_x = set_i[batch_pos , 0] , set_i[batch_pos , 1] , set_y[batch_pos] , []
                    for mdt in ShareNames.data_type_list:
                        batch_x.append(self._norm_x(torch.cat([x[mdt][i0,i1+i+1-self.seq[mdt]] for i in range(self.seq[mdt])],dim=1),mdt))
                    batch_x = batch_x[0] if len(batch_x) == 1 else tuple(batch_x)
                    storage_loader.save((batch_x, batch_y), batch_file_list[-1] , group = 'train')
                self.dataloaders[set_name] = self.dataloader_saved(batch_file_list)

    def _test_nonnan_sample(self , x , y):
        """
        return non-nan sample position (with shape of stock_n * day_len)
        """
        nansamp = y[:,self.lstepped + self.seq0 - 1].isnan().sum(-1)
        for mdt in ShareNames.data_type_list:
            for i in range(self.seq[mdt]): nansamp += x[mdt][:,(self.seq0 - self.seq[mdt] + i):][:,self.lstepped].isnan().sum((2,3))
            if mdt in config['DATATYPE']['trade']: nansamp += (x[mdt][:,self.lstepped + self.seq0 - 1][:,:,-1] == 0).sum(-1)
        self.test_nonnan_sample = (nansamp == 0)
    
    def _test_y_data(self , y):
        """
        update normalized (maybe include w) test ydata
        """
        y_test = torch.zeros(self.stock_n , self.step_len , self.labels_n)
        y_test[:] = y[:,self.lstepped + self.seq0 - 1].nan_to_num(0)
        y_test[self.test_nonnan_sample == 0] = np.nan
        self.y_test , _ = tensor_standardize_and_weight(y_test , dim = 0)
    
    def _test_dataloader(self , x):
        """
        1. if model_data_type == 'day' , update dataloaders dict(dict.key = ['test']), by using a oneshot method (seperate dealing by TEST_INTERVAL days)
        2. update dataloaders dict(set_name = ['test']), save batch_data to './model/{model_name}/{set_name}_batch_data' and later load them
        """
        if ShareNames.model_data_type == 'day' and False:
            mdt = 'day'
            x_test = self._norm_x(torch.cat([x[mdt][:,i+self.lstepped] for i in range(self.seq[mdt])],dim=2) , mdt)
            self.dataloaders['test'] = self.dataloader_oneshot((x_test , self.y_test) , 0 , ShareNames.compt_params['cuda_first'] , 1) # iter over col(date)
        else:
            storage_loader.del_group('test')
            batch_sampler = [(np.where(self.test_nonnan_sample[:,i])[0] , self.lstepped[i]) for i in range(self.step_len)] # self.test_nonnan_sample.permute(1,0)
            batch_file_list = []
            for batch_num , batch_pos in enumerate(batch_sampler):
                batch_file_list.append(ShareNames.batch_dir['test'] + f'/test.{batch_num}.pt')
                i0 , i1 , batch_y , batch_x = batch_pos[0] , batch_pos[1] + self.seq0 - 1 , self.y_test[batch_pos[0] , batch_num] , []
                for mdt in ShareNames.data_type_list:
                    batch_x.append(self._norm_x(torch.cat([x[mdt][i0,i1+i+1-self.seq[mdt]] for i in range(self.seq[mdt])],dim=1),mdt))
                batch_x = batch_x[0] if len(batch_x) == 1 else tuple(batch_x)
                storage_loader.save((batch_x, batch_y), batch_file_list[-1] , group = 'test')
            self.dataloaders['test'] = self.dataloader_saved(batch_file_list)
        
    def _norm_x(self , x , key):
        """
        return panel_normalized x
        1.for ts-cols , divide by the last value, get seq-mormalized x
        2.for seq-mormalized x , normalized by history avg and std
        """
        if key in config['DATATYPE']['trade']:
            x /= x.select(-2,-1).unsqueeze(-2) + 1e-6
            x -= self.norm_param[key]['avg'][-x.shape[-2]:]
            x /= self.norm_param[key]['std'][-x.shape[-2]:] + 1e-6
        else:
            pass
        return x
    
    class dataloader_oneshot:
        """
        class of oneshot dataloader
        """
        def __init__(self, data , num_worker = 0 , cuda_first = True , batch_by_axis = None):
            if cuda_first: data = cuda(data)
            self.batch_by_axis = batch_by_axis
            if self.batch_by_axis is None:
                self.dataset = Mydataset(*data)  
                self.dataloader = torch.utils.data.DataLoader(self.dataset , batch_size = ShareNames.batch_size , num_workers = (1 - cuda_first)*num_worker)
            else:
                self.x , self.y = data
                
        def __iter__(self):
            if self.batch_by_axis is None:
                for batch_data in self.dataloader: 
                    yield cuda(batch_data)
            else:
                for batch_i in range(self.y.shape[self.batch_by_axis]):
                    x , y = self.x.select(self.batch_by_axis , batch_i) , self.y.select(self.batch_by_axis , batch_i)
                    if y.dim() == 1:
                        valid_row = y.isnan() == 0
                    elif y.dim() == 2:
                        valid_row = y.isnan().sum(-1) == 0
                    else:
                        valid_row = y.isnan().sum(list(range(y.dim()))[1:]) == 0
                    batch_data = (x[valid_row] , y[valid_row])
                    yield cuda(batch_data)
                
    class dataloader_saved:
        """
        class of saved dataloader , retrieve batch_data from './model/{model_name}/{set_name}_batch_data'
        """
        def __init__(self, batch_file_list):
            self.batch_file_list = batch_file_list
        def __iter__(self):
            for batch_file in self.batch_file_list: 
                yield cuda(storage_loader.load(batch_file))

In [3]:
debug_type = 0

debug_method1 = 1
starting_point = 5
# 0: good steady
# 1:
# 2:
# 3: good steady , stand alone
# 4: good !!
# 5: no good

In [4]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : ${2023-6-27} ${21:05}
# @Author : Mathew Jin
# @File : ${run_model.py}
# chmod +x run_model.py
# ./run_model.py --process=0 --rawname=1 --resume=0 --anchoring=0
'''
1.TRA
https://arxiv.org/pdf/2106.12950.pdf
https://github.com/microsoft/qlib/blob/main/examples/benchmarks/TRA/src/model.py
1.1 HIST
https://arxiv.org/pdf/2110.13716.pdf
https://github.com/Wentao-Xu/HIST
2.Lightgbm
https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/plot_example.py
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.plot_tree.html
3.other factors
'''
import argparse
import torch
import torch.nn as nn
import numpy as np
import itertools , random , os, shutil , gc , time ,h5py

from torch.optim.swa_utils import AveragedModel , update_bn
from tqdm import tqdm
from scipy import stats
from copy import deepcopy

# from globalvars import *

# from audtorch.metrics.functional import *

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
TIME_RECODER = False
logger = get_logger()
config = get_config()
torch.autograd.set_detect_anomaly(True)
storage_model  = versatile_storage(config['STORAGE_TYPE'])
storage_loader = versatile_storage(config['STORAGE_TYPE'])

class ShareNames_conctroller():
    """
    1. Assign variables into shared namespace.
    2. Ask what process would anyone want to run : 0 : train & test(default) , 1 : train only , 2 : test only , 3 : copy to instance only
    3. Ask if model_name and model_base_path should be changed if old dir exists
    """
    def __init__(self):
        self.assign_variables(if_process = True , if_rawname = True)
        
    def assign_variables(self , if_process = False , if_rawname = False):
        ShareNames.max_epoch       = config['MAX_EPOCH']
        ShareNames.batch_size      = config['BATCH_SIZE']
        ShareNames.precision       = config['PRECISION']
        
        ShareNames.model_module    = config['MODEL_MODULE']
        ShareNames.model_data_type = config['MODEL_DATATYPE'][ShareNames.model_module]
        ShareNames.model_nickname  = config['MODEL_NICKNAME']
        
        ShareNames.model_num_list  = list(range(config['MODEL_NUM']))
        ShareNames.data_type_list  = ShareNames.model_data_type.split('+')
        
        ShareNames.model_name      = self._model_name()
        ShareNames.model_base_path = f'./model/{ShareNames.model_name}'
        ShareNames.instance_path   = f'./instance/{ShareNames.model_name}'
        ShareNames.batch_dir       = {k:f'./data/{k}_batch_path' for k in ['train' , 'valid' , 'test']}
        
        if if_process  : self._process_confirmation()
        if if_rawname  : self._rawname_confirmation()
        
        ShareNames.train_params = deepcopy(config['TRAIN_PARAM'])
        ShareNames.compt_params = deepcopy(config['COMPT_PARAM'])
        ShareNames.raw_model_params = deepcopy(config['MODEL_PARAM'])
        ShareNames.model_params = self._load_model_param()
        ShareNames.output_types = ShareNames.train_params['output_types']

    def _model_name(self):
        name_element = [
            ShareNames.model_module ,
            ShareNames.model_data_type , 
            ShareNames.model_nickname
        ]
        return '_'.join([x for x in name_element if x is not None])
                          
    def _load_model_param(self):
        """
        Load and return model_params of each model_num , or save one for later use
        """
        try:
            model_params = torch.load(f'{ShareNames.model_base_path}/model_params.pt')
        except:
            model_params = []
            for mm in ShareNames.model_num_list:
                dict_mm = {'path':f'{ShareNames.model_base_path}/{mm}'}
                dict_mm.update({k:(v[mm % len(v)] if isinstance(v,list) else v) for k,v in ShareNames.raw_model_params.items()})
                model_params.append(dict_mm)
        return model_params
        
    def _process_confirmation(self):
        if ShareNames.process < 0:
            logger.critical(f'What process would you want to run? 0: all (default), 1: train only, 2: test only , 3: copy to instance')
            promt_text = f'[0,all] , [1,train] , [2,test] , [3,instance]: '
            _text , _cond = ask_for_confirmation(promt_text , proceed_condition = lambda x:False)
            key = _text[0]
        else:
            key = str(ShareNames.process)

        if key in ['' , '0' , 'all']:
            ShareNames.process_queue = ['data' , 'train' , 'test' , 'instance']
        elif key in ['1' , 'train']:
            ShareNames.process_queue = ['data' , 'train']
        elif key in ['2' , 'test']:
            ShareNames.process_queue = ['data' , 'test' , 'instance']
        elif key in ['3' , 'instance']:
            ShareNames.process_queue = ['data' , 'instance']
        else:
            raise Exception(f'Error input : {key}')
        logger.critical('Process Queue : {:s}'.format(' + '.join(map(lambda x:(x[0].upper() + x[1:]),ShareNames.process_queue))))
                
    def _rawname_confirmation(self , recurrent = 1):
        """
        Confirm the model_name and model_base_path if multifple model_name dirs exists.
        If include train: check if dir of model_name exists, if so ask to remove the old ones or continue with a sequential one
        If test only :    check if model_name exists multiple dirs, if so ask to use the raw one or the last one(default)
        Also ask if resume training, since unexpected end of training may happen
        """
        if_rawname = None if (ShareNames.rawname < 0) else (ShareNames.rawname > 0)
        if_resume  = None if (ShareNames.resume < 0)  else (ShareNames.resume > 0)
        
        if 'train' in ShareNames.process_queue:
            if os.path.exists(ShareNames.model_base_path) == False:
                if_rawname = True
                if_resume = False
              
            if if_resume is None:
                logger.critical(f'[{ShareNames.model_base_path}] exists, input [yes] to resume training, or start a new one!')
                promt_text = f'Confirm resume training [{ShareNames.model_name}]? [yes/no] : '
                _text , _cond = ask_for_confirmation(promt_text ,  recurrent = recurrent)
                if_resume = all([_t.lower() in ['' , 'yes' , 'y'] for _t in _text])
            
            if if_resume:
                logger.critical(f'Resume training {ShareNames.model_name}!') 
                file_appendix = sorted([int(x.split('.')[-1]) for x in os.listdir(f'./model') if x.startswith(ShareNames.model_name + '.')])
                if if_rawname is None and len(file_appendix) > 0:
                    logger.critical(f'Multiple model path of {ShareNames.model_name} exists, input [yes] to confirm using the raw one, or [no] the latest!')
                    promt_text = f'Use the raw one? [yes/no] : '
                    _text , _cond = ask_for_confirmation(promt_text ,  recurrent = recurrent)
                    if_rawname = all([_t.lower() in ['' , 'yes' , 'y'] for _t in _text])
                    
                if if_rawname or len(file_appendix) == 0:
                    logger.critical(f'model_name is still {ShareNames.model_name}!') 
                else:
                    ShareNames.model_name = f'{ShareNames.model_name}.{file_appendix[-1]}'
                    ShareNames.model_base_path = f'./model/{ShareNames.model_name}'
                    logger.critical(f'model_name is now {ShareNames.model_name}!')
            else:
                if if_rawname is None:
                    logger.critical(f'[{ShareNames.model_base_path}] exists, input [yes] to confirm deletion, or a new directory will be made!')
                    promt_text = f'Confirm Deletion of all old directories with model name [{ShareNames.model_name}]? [yes/no] : '
                    _text , _cond = ask_for_confirmation(promt_text ,  recurrent = recurrent)
                    if_rawname = all([_t.lower() in ['' , 'yes' , 'y'] for _t in _text])

                if if_rawname:
                    rmdir([f'./model/{d}' for d in os.listdir(f'./model') if d.startswith(ShareNames.model_name)])
                    logger.critical(f'Directories of [{ShareNames.model_name}] deletion Confirmed!')
                else:
                    ShareNames.model_name += '.'+str(max([1]+[int(d.split('.')[-1])+1 for d in os.listdir(f'./model') if d.startswith(ShareNames.model_name+'.')]))
                    ShareNames.model_base_path = f'./model/{ShareNames.model_name}'
                    logger.critical(f'A new directory [{ShareNames.model_name}] will be made!')

                os.makedirs(ShareNames.model_base_path, exist_ok = True)
                [os.makedirs(f'{ShareNames.model_base_path}/{mm}' , exist_ok = True) for mm in ShareNames.model_num_list]
                for copy_filename in ['configs/config_train.yaml']:
                    shutil.copyfile(f'./{copy_filename}', f'{ShareNames.model_base_path}/{os.path.basename(copy_filename)}')
                    
        elif 'test' in ShareNames.process_queue:
            file_appendix = sorted([int(x.split('.')[-1]) for x in os.listdir(f'./model') if x.startswith(ShareNames.model_name + '.')])
            if if_rawname is None and len(file_appendix) > 0:
                logger.critical(f'Multiple model path of {ShareNames.model_name} exists, input [yes] to confirm using the raw one, or [no] the latest!')
                promt_text = f'Use the raw one? [yes/no] : '
                _text , _cond = ask_for_confirmation(promt_text ,  recurrent = recurrent)
                if_rawname = all([_t.lower() in ['' , 'yes' , 'y'] for _t in _text])

            if if_rawname or len(file_appendix) == 0:
                logger.critical(f'model_name is still {ShareNames.model_name}!') 
            else:
                ShareNames.model_name = f'{ShareNames.model_name}.{file_appendix[-1]}'
                ShareNames.model_base_path = f'./model/{ShareNames.model_name}'
                logger.critical(f'model_name is now {ShareNames.model_name}!')
                
        ShareNames.resume_training = if_resume
                
class model_controller():
    """
    A class to control the whole process of training , includes:
    1. Display controls: tqdm , once , step
    2. Parameters: train_params , compt_params , model_data_type
    3. Data : class of train_data
    3. loop status: model , round , attempt , epoch
    4. file path: model , lastround , transfer(last model date)
    5. text: model , round , attempt , epoch , exit , stat , time , trainer
    """
    def __init__(self , **kwargs):
        self.init_time = time.time()
        self.display = {
            'tqdm' : True if config['VERBOSITY'] >= 10 else False ,
            'once' : True if config['VERBOSITY'] <=  2 else False ,
            'step' : [10,5,5,3,3,1][min(config['VERBOSITY'] // 2 , 5)],
        }
        self.process_time = {}
        self.shared_ctrl = ShareNames_conctroller()
        
    def main_process(self):
        """
        Main process of load_data + train + test + instance
        """

        for process_name in ShareNames.process_queue:
            self.SetProcessName(process_name)
            self.__getattribute__(f'model_process_{process_name.lower()}')()
            rmdir([v for v in ShareNames.batch_dir.values()] , remake_dir = True)
    
    def SetProcessName(self , key = 'data'):
        ShareNames.process_name = key.lower()
        self.model_count = 0
        if 'data' in vars(self) : self.data.reset_dataloaders()
        if ShareNames.process_name == 'data': 
            pass
        elif ShareNames.process_name == 'train': 
            self.f_loss    = loss_function(ShareNames.train_params['criterion']['loss'])
            self.f_metric  = metric_function(ShareNames.train_params['criterion']['metric'])
            self.f_penalty = {k:[penalty_function(k),v] for k,v in ShareNames.train_params['criterion']['penalty'].items() if v > 0.}
        elif ShareNames.process_name == 'test':
            self.f_metric  = metric_function(ShareNames.train_params['criterion']['metric'])
            self.ic_by_date , self.ic_by_model = None , None
        elif ShareNames.process_name == 'instance':
            self.ic_by_date , self.ic_by_model = None , None
        else:
            raise Exception(f'KeyError : {key}')
        
    def model_process_data(self):
        """
        Main process of loading basic data
        """
        self.data_time = time.time()
        logger.critical(f'Start Process [Load Data]!')
        self.data = ModelData()
        logger.critical('Finish Process [Load Data]! Cost {:.1f}Secs'.format(time.time() - self.data_time))
        
    def model_process_train(self):
        """
        Main process of training
        1. loop over model(model_date , model_num)
        2. loop over round(if necessary) , attempt(if converge too soon) , epoch(most prevailing loops)
        """
        self.train_time = time.time()
        logger.critical(f'Start Process [Train Model]!')
        # self.printer('model_specifics')
        logger.error(f'Start Training Models!')
        torch.save(ShareNames.model_params , f'{ShareNames.model_base_path}/model_params.pt')    
        for model_date , model_num in self.ModelIter():
            self.model_date , self.model_num = model_date , model_num
            self.ModelPreparation('train')
            self.TrainModel()
        total_time = time.time() - self.train_time
        logger.critical('Finish Process [Train Model]! Cost {:.1f} Hours, {:.1f} Min/Training'.format(total_time / 3600 , total_time / 60 / max(self.model_count , 1)))
    
    def model_process_test(self):
        self.test_time = time.time()
        logger.critical(f'Start Process [Test Model]!')        
        logger.warning('Each Model Date Testing Mean Rank_ic:')
        self.test_result_model_num = np.repeat(ShareNames.model_num_list,len(ShareNames.output_types))
        self.test_result_output_type = np.tile(ShareNames.output_types,len(ShareNames.model_num_list))
        logger.info('{: <11s}'.format('Models') + ('{: >8d}'*len(self.test_result_model_num)).format(*self.test_result_model_num))
        logger.info('{: <11s}'.format('Output') + ('{: >8s}'*len(self.test_result_model_num)).format(*self.test_result_output_type))
        for model_date , model_num in self.ModelIter():
            self.model_date , self.model_num = model_date , model_num
            self.ModelPreparation('test')
            self.TestModel()
        self.ModelResult()
        logger.critical('Finish Process [Test Model]! Cost {:.1f} Secs'.format(time.time() - self.test_time))
        
    def model_process_instance(self):
        if ShareNames.anchoring < 0:
            logger.critical(f'Do you want to copy the model to instance?')
            promt_text = f'[yes/else no]: '
            _text , _cond = ask_for_confirmation(promt_text , timeout = -1)
            anchoring = all([_t.lower() in ['yes','y'] for _t in _text])
        else:
            anchoring = ShareNames.anchoring > 0
        if anchoring == 0:
            logger.critical(f'Will not copy to instance!')
            return
        else:
            self.instance_time = time.time()
            logger.critical(f'Start Process [Copy to Instance]!')        
            if os.path.exists(ShareNames.instance_path): 
                logger.critical(f'Old instance {ShareNames.instance_path} exists , remove manually first to override!')
                logger.critical(f'The command can be "rm -r {ShareNames.instance_path}"')
                return
            else:
                shutil.copytree(ShareNames.model_base_path , ShareNames.instance_path)
                
        logger.warning('Copy from model to instance finished , Start going forward')
        self.InstanceStart()
        for model_date , model_num in self.ModelIter():
            self.model_date , self.model_num = model_date , model_num
            self.ModelPreparation('instance')
            self.TestModel()
            self.StorePreds()
        self.ModelResult()
        logger.critical('Finish Process [Copy to Instance]! Cost {:.1f} Secs'.format(time.time() - self.instance_time))  
        
    def print_vars(self):
        print(vars(self))

    def ModelIter(self):
        model_iter = itertools.product(ShareNames.model_date_list , ShareNames.model_num_list)
        if ShareNames.resume_training and (ShareNames.process_name == 'train'):
            models_trained = [os.path.exists(f'{ShareNames.model_base_path}/{mn}/{md}.pt') for md,mn in model_iter]
            print(models_trained)
            if models_trained[0] == 0: 
                models_trained[:] = False
            else:
                resume_point = -1 if all(models_trained) else (np.where(np.array(models_trained) == 0)[0][0] - 1)
                models_trained[resume_point:] = False
            model_iter = FilteredIterator(itertools.product(ShareNames.model_date_list , ShareNames.model_num_list), iter(models_trained == 0))
        return model_iter
    
    def ModelPreparation(self , process , last_n = 30 , best_n = 5):
        assert process in ['train' , 'test' , 'instance']
        _start_time = time.time()
        param = ShareNames.model_params[self.model_num]
        
        # variable updates for train_params
        if process in ['train' , 'instance']:
            if 'hidden_orthogonality' in self.f_penalty.keys(): self.f_penalty['hidden_orthogonality'][1] = 1 * (param.get('hidden_as_factors') == True)
        
        path_prefix = '{}/{}'.format(param.get('path') , self.model_date)
        path = {k:f'{path_prefix}.{k}.pt' for k in ShareNames.output_types} #['best','swalast','swabest']
        path.update({f'src_model.{k}':[] for k in ShareNames.output_types})
        if 'swalast' in ShareNames.output_types: 
            path['lastn'] = [f'{path_prefix}.lastn.{i}.pt' for i in range(last_n)]
        if 'swabest' in ShareNames.output_types: 
            path['bestn'] = [f'{path_prefix}.bestn.{i}.pt' for i in range(best_n)]
            path['bestn_ic'] = [-10000. for i in range(best_n)]
        
        if ShareNames.train_params['transfer'] and self.model_date > ShareNames.model_date_list[0]:
            path['transfer'] = '{}/{}.best.pt'.format(param.get('path') , max([d for d in ShareNames.model_date_list if d < self.model_date])) 
            
        self.Param = param
        self.path = path
        self.time_recoder(_start_time , ['ModelPreparation' , process])
    
    def TrainModel(self):
        self.TrainModelStart()
        while self.cond.get('loop_status') != 'model':
            self.NewLoop()
            self.TrainerInit()
            self.TrainEpoch()
            self.LoopCondition()
        self.TrainModelEnd()
        gc.collect() , torch.cuda.empty_cache()
    
    def TestModel(self):
        self.TestModelStart()
        self.Forecast()
        self.TestModelEnd()
        gc.collect() , torch.cuda.empty_cache()
        
    def TrainModelStart(self):
        """
        Reset model specific variables
        """
        _start_time = time.time()
        
        self._init_variables('model')
        self.nanloss_life = ShareNames.train_params['trainer']['nanloss']['retry']
        
        self.text['model'] = '{:s} #{:d} @{:4d}'.format(ShareNames.model_name , self.model_num , self.model_date)

        if (self.data.dataloader_param != (self.model_date , self.Param['seqlens'])):
            self.data.new_train_dataloader(self.model_date , self.Param['seqlens']) 
            self.time[1] = time.time()
            self.printer('train_dataloader')
            
        self.time_recoder(_start_time , ['TrainModelStart'])
            
    def TrainModelEnd(self):
        """
        Do necessary things of ending a model(model_data , model_num)
        """
        _start_time = time.time()
        
        storage_model.del_path(self.path.get('rounds') , self.path.get('lastn') , self.path.get('bestn'))
        if ShareNames.process_name == 'train' : self.model_count += 1
        self.time[2] = time.time()
        self.printer('model_end')
        
        self.time_recoder(_start_time , ['TrainModelEnd'])
        
    def NewLoop(self):
        """
        Reset and loop variables giving loop_status
        """
        _start_time = time.time()
        
        self._init_variables(self.cond.get('loop_status'))
        self.epoch_i += 1
        self.epoch_all += 1
        if self.cond.get('loop_status') in ['attempt' , 'round']:
            self.attempt_i += 1
            self.text['attempt'] = f'FirstBite' if self.attempt_i == 0 else f'Retrain#{self.attempt_i}'
        if self.cond.get('loop_status') in ['round']:
            self.round_i += 1
            self.text['round'] = 'Round{:2d}'.format(self.round_i)
            
        self.time_recoder(_start_time , ['NewLoop'])
        
    def TrainerInit(self):
        """
        Initialize net , optimizer , scheduler if loop_status in ['round' , 'attempt']
        net : 1. Create an instance of f'My{ShareNames.model_module}' or inherit from 'lastround'/'transfer'
              2. In transfer mode , p_late and p_early with be trained with different lr's. If not net.parameters are trained by same lr
        optimizer : Adam or SGD
        scheduler : Cosine or StepLR
        """
        _start_time = time.time()

        if self.cond.get('loop_status') == 'epoch': return
        self.net       = self.load_model('train')
        self.max_round = self.net.max_round() if 'max_round' in self.net.__dir__() else 1
        self.optimizer = self.load_optimizer()
        self.scheduler = self.load_scheduler() 
        self.multiloss = self.load_multiloss()

        self.time_recoder(_start_time , ['TrainerInit'])
        
    def TrainEpoch(self):
        """
        Iterate train and valid dataset, calculate loss/metrics , update values
        If nan loss occurs, turn to _deal_nanloss
        """
        _start_time = time.time()
        loss_train , loss_valid , ic_train , ic_valid = [] , [] , [] , []
        clip_value = ShareNames.train_params['trainer']['gradient'].get('clip_value')
        
        if self.display.get('tqdm'):
            iter_train , iter_valid = tqdm(self.data.dataloaders['train']) , tqdm(self.data.dataloaders['valid'])
            disp_train = lambda x:iter_train.set_description(f'Ep#{self.epoch_i:3d} train loss:{np.mean(x):.5f}')
            disp_valid = lambda x:iter_valid.set_description(f'Ep#{self.epoch_i:3d} valid ic:{np.mean(x):.5f}')
        else:
            iter_train , iter_valid = self.data.dataloaders['train'] , self.data.dataloaders['valid']
            disp_train = disp_valid = lambda x:0

        self.time_recoder(_start_time , ['TrainEpoch' , 'assign_loader'])
        _start_time = time.time()

        self.net.train()
        _start_time_1 = time.time()
        for i , (x , y) in enumerate(iter_train):
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'train' , 'fetch'])
            self.optimizer.zero_grad()
            _start_time_1 = time.time()
            pred , hidden = self.net(x)
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'train' , 'forward'])
            _start_time_1 = time.time()
            loss , metric = self._loss_and_metric(y , pred , 'train' , hidden = hidden)
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'train' , 'loss'])
            _start_time_1 = time.time()
            loss.backward()
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'train' , 'backward'])
            _start_time_1 = time.time()
            if clip_value is not None : nn.utils.clip_grad_value_(self.net.parameters(), clip_value = clip_value)
            self.optimizer.step()

            loss_train.append(loss.item()) , ic_train.append(metric)
            disp_train(loss_train)
            _start_time_1 = time.time()
        if np.isnan(sum(loss_train)): return self._deal_nanloss()
        self.loss_list['train'].append(np.mean(loss_train)) , self.ic_list['train'].append(np.mean(ic_train))
        
        self.time_recoder(_start_time , ['TrainEpoch' , 'train_epochs'])
        _start_time = time.time()

        self.net.eval()     
        _start_time_1 = time.time()  
        for i , (x , y) in enumerate(iter_valid):
            # print(torch.cuda.memory_allocated(DEVICE) / 1024**3 , torch.cuda.memory_reserved(DEVICE) / 1024**3)
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'valid' , 'fetch'])
            _start_time_1 = time.time()
            pred , _ = self.net(x)
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'valid' , 'forward'])
            _start_time_1 = time.time()
            loss , metric = self._loss_and_metric(y , pred , 'valid')
            self.time_recoder(_start_time_1 , ['TrainEpoch' , 'valid' , 'loss'])
            _start_time_1 = time.time()
            loss_valid.append(loss) , ic_valid.append(metric)
            disp_valid(ic_valid)
            _start_time_1 = time.time()
        self.loss_list['valid'].append(np.mean(loss_valid)) , self.ic_list['valid'].append(np.mean(ic_valid))
        self.lr_list.append(self.scheduler.get_last_lr()[0])
        self.scheduler.step()
        self.reset_scheduler()

        self.time_recoder(_start_time , ['TrainEpoch' , 'valid_epochs'])

    def LoopCondition(self):
        """
        Update condition of continuing training epochs , restart attempt if early exit , proceed to next round if convergence , reset round if nan loss
        """
        _start_time = time.time()

        if self.cond['nan_loss']:
            logger.error(f'Initialize a new model to retrain! Lives remaining {self.nanloss_life}')
            self._init_variables('model')
            self.cond['loop_status'] = 'round'
            return
            
        valid_ic = self.ic_list['valid'][-1]
        
        save_targets = [] 
        if valid_ic > self.ic_attempt_best: 
            self.epoch_attempt_best  = self.epoch_i 
            self.ic_attempt_best = valid_ic
            
        if valid_ic > self.ic_round_best:
            self.ic_round_best = valid_ic
            self.path['src_model.best']  = [self.path['best']]
            save_targets.append(self.path['best'])

        if 'swalast' in ShareNames.output_types:
            self.path['lastn'] = self.path['lastn'][1:] + self.path['lastn'][:1]
            save_targets.append(self.path['lastn'][-1])
            
            p_valid = self.path['lastn'][-len(self.ic_list['valid']):]
            arg_max = np.argmax(self.ic_list['valid'][-len(p_valid):])
            arg_swa = (lambda x:x[(x>=0) & (x<len(p_valid))])(min(5,len(p_valid)//3)*np.arange(-5,3)+arg_max)[-5:]
            self.path['src_model.swalast'] = [p_valid[i] for i in arg_swa]
            
        if 'swabest' in ShareNames.output_types:
            arg_min = np.argmin(self.path['bestn_ic'])
            if valid_ic > self.path['bestn_ic'][arg_min]:
                self.path['bestn_ic'][arg_min] = valid_ic
                save_targets.append(self.path['bestn'][arg_min])
                if self.path['bestn'][arg_min] not in self.path['src_model.swabest']: self.path['src_model.swabest'].append(self.path['bestn'][arg_min])
            
        storage_model.save_model_state(self.net , save_targets)
        self.printer('epoch_step')
        self.time_recoder(_start_time , ['LoopCondition' , 'assess'])
        _start_time = time.time()
        
        self.cond['terminate'] = {k:self._terminate_cond(k,v) for k , v in ShareNames.train_params['terminate'].get('overall' if self.max_round <= 1 else 'round').items()}
        if any(self.cond.get('terminate').values()):
            self.text['exit'] = {
                'max_epoch'      : 'Max Epoch' , 
                'early_stop'     : 'EarlyStop' ,
                'tv_converge'    : 'T&V Convg' , 
                'train_converge' : 'Tra Convg' , 
                'valid_converge' : 'Val Convg' ,
            }[[k for k,v in self.cond.get('terminate').items() if v][0]] 
            if (self.epoch_i < ShareNames.train_params['trainer']['retrain'].get('min_epoch' if self.max_round <= 1 else 'min_epoch_round') - 1 and 
                self.attempt_i < ShareNames.train_params['trainer']['retrain']['attempts'] - 1):
                self.cond['loop_status'] = 'attempt'
                self.printer('new_attempt')
            elif self.round_i < self.max_round - 1:
                self.cond['loop_status'] = 'round'
                self.save_model('best')
                self.printer('new_round')
            else:
                self.cond['loop_status'] = 'model'
                self.save_model(ShareNames.output_types)
        else:
            self.cond['loop_status'] = 'epoch'

        _start_time = time.time()
        self.time_recoder(_start_time , ['LoopCondition' , 'confirm_status'])
        
            
    def TestModelStart(self):
        """
        Reset model specific variables
        """
        self._init_variables('model')        
        if (self.data.dataloader_param != (self.model_date , self.Param['seqlens'])):
            self.data.new_test_dataloader(self.model_date , self.Param['seqlens'])
            
        if self.model_num == 0:
            ic_date_0 = np.zeros((len(self.data.model_test_dates) , len(self.test_result_model_num)))
            ic_model_0 =  np.zeros((1 , len(self.test_result_model_num)))
            self.ic_by_date = ic_date_0 if self.ic_by_date is None else np.concatenate([self.ic_by_date , ic_date_0])
            self.ic_by_model = ic_model_0 if self.ic_by_model is None else np.concatenate([self.ic_by_model , ic_model_0])
                
    def Forecast(self):
        if not os.path.exists(self.path['best']): self.TrainModel()
        
        #self.y_pred = cuda(torch.zeros(self.data.stock_n,len(self.data.model_test_dates),self.data.labels_n,len(ShareNames.output_types)).fill_(np.nan))
        self.y_pred = cuda(torch.zeros(self.data.stock_n,len(self.data.model_test_dates),len(ShareNames.output_types)).fill_(np.nan))
        for oi , okey in enumerate(ShareNames.output_types):
            self.net = self.load_model('test' , okey)
            self.net.eval()

            if self.display.get('tqdm'):
                iter_test = tqdm(self.data.dataloaders['test'])
                disp_test = lambda x:iter_test.set_description(f'Date#{x[0]:3d} :{np.mean(x[1]):.5f}')
            else:
                iter_test = self.data.dataloaders['test']
                disp_test = lambda x:0

            m_test = []         
            with torch.no_grad():
                for i , (x , y) in enumerate(iter_test):
                    stock_pos = np.where(self.data.test_nonnan_sample[:,i])[0]
                    for batch_j in torch.utils.data.DataLoader(np.arange(len(y)) , batch_size = ShareNames.batch_size):
                        x_j = tuple([xx[batch_j] for xx in x]) if isinstance(x , tuple) else x[batch_j]
                        output , _ = self.net(x_j)
                        self.y_pred[stock_pos[batch_j],i,oi] = output.select(-1,0).detach()
                    metric = self.f_metric(y.select(-1,0) , self.y_pred[stock_pos,i,oi]).item()
                    if (i + 1) % 20 == 0 : torch.cuda.empty_cache()
                    m_test.append(metric) 
                    disp_test((i , m_test))
                    
            self.ic_by_date[-len(self.data.model_test_dates):,self.model_num*len(ShareNames.output_types) + oi] = torch.tensor(m_test).nan_to_num(0).cpu().numpy()   
        self.y_pred = self.y_pred.cpu().numpy()
        
    def TestModelEnd(self):
        """
        Do necessary things of ending a model(model_data , model_num)
        """
        if self.model_num == ShareNames.model_num_list[-1]:
            self.ic_by_model[-1,:] = np.nanmean(self.ic_by_date[-len(self.data.model_test_dates):,],axis = 0)
            logger.info('{: <11d}'.format(self.model_date)+('{:>8.4f}'*len(self.test_result_model_num)).format(*self.ic_by_model[-1,:]))
        #if False:
        #    df = pd.DataFrame(self.y_pred.T, index = self.data.model_test_dates, columns = self.data.index_stock.astype(str))
        #    with open(f'{ShareNames.instance_path}/{ShareNames.model_name}_fac{self.model_num}.csv', 'a') as f:
        #        df.to_csv(f , mode = 'a', header = f.tell()==0, index = True)
        
    def StorePreds(self):
        assert ShareNames.process_name == 'instance'
        if self.model_num == 0:
            self.y_pred_models = []
            gc.collect()
        self.y_pred_models.append(self.y_pred)
        if self.model_num == ShareNames.model_num_list[-1]:
            self.y_pred_models = np.concatenate(self.y_pred_models,axis=-1).transpose(1,0,2)
            # idx = np.array(np.meshgrid(self.data.model_test_dates , self.data.index_stock)).T.reshape(-1,2)
            mode = 'r+' if os.path.exists(f'{ShareNames.instance_path}/{ShareNames.model_name}.h5') else 'w'
            with h5py.File(f'{ShareNames.instance_path}/{ShareNames.model_name}.h5' , mode = mode) as f:
                for di in range(len(self.data.model_test_dates)):
                    arr , row = self.y_pred_models[di] , self.data.index_stock 
                    arr , row = arr[np.isnan(arr).all(axis=1) == 0] , row[np.isnan(arr).all(axis=1) == 0]
                    col = [f'{mn}.{o}' for mn,o in zip(self.test_result_model_num,self.test_result_output_type)]
                    if str(self.data.model_test_dates[di]) in f.keys():
                        del f[str(self.data.model_test_dates[di])]
                    g = f.create_group(str(self.data.model_test_dates[di]))
                    g.create_dataset('arr' , data=arr , compression='gzip')
                    g.create_dataset('row' , data=row , compression='gzip')
                    g.create_dataset('col' , data=col , compression='gzip')     
  
    def ModelResult(self):
        # date ic writed down
        _step = (1 if ShareNames.process_name == 'instance' else self.data.test_step)
        _dates_list = ShareNames.test_full_dates[::_step]
        for model_num in ShareNames.model_num_list:
            df = {'dates' : _dates_list}
            for oi , okey in enumerate(ShareNames.output_types):
                df.update({f'rank_ic.{okey}' : self.ic_by_date[:,model_num*len(ShareNames.output_types) + oi], 
                           f'cum_ic.{okey}' : np.nancumsum(self.ic_by_date[:,model_num*len(ShareNames.output_types) + oi])})
            df = pd.DataFrame(df , index = map(lambda x:f'{x[:4]}-{x[4:6]}-{x[6:]}' , _dates_list.astype(str)))
            df.to_csv(ShareNames.model_params[model_num]['path'] + f'/{ShareNames.model_name}_ic_by_date_{model_num}.csv')

        # model ic presentation
        add_row_key   = ['AllTimeAvg' , 'AllTimeSum' , 'Std'      , 'TValue'   , 'AnnIR']
        add_row_fmt   = ['{:>8.4f}'   , '{:>8.2f}'   , '{:>8.4f}' , '{:>8.2f}' , '{:>8.4f}']
        ic_mean   = np.nanmean(self.ic_by_date , axis = 0)
        ic_sum    = np.nansum(self.ic_by_date , axis = 0) 
        ic_std    = np.nanstd(self.ic_by_date , axis = 0)
        ic_tvalue = ic_mean / ic_std * (len(self.ic_by_date)**0.5) # 10 days return predicted
        ic_annir  = ic_mean / ic_std * ((240 / 10)**0.5) # 10 days return predicted
        add_row_value = (ic_mean , ic_sum , ic_std , ic_tvalue , ic_annir)
        df = pd.DataFrame(np.concatenate([self.ic_by_model , np.stack(add_row_value)]) , 
                          index = [str(d) for d in ShareNames.model_date_list] + add_row_key , 
                          columns = [f'{mn}.{o}' for mn,o in zip(self.test_result_model_num,self.test_result_output_type)])
        df.to_csv(f'{ShareNames.model_base_path}/{ShareNames.model_name}_ic_by_model.csv')
        for i in range(len(add_row_key)):
            logger.info('{: <11s}'.format(add_row_key[i]) + (add_row_fmt[i]*len(self.test_result_model_num)).format(*add_row_value[i]))
    
    def InstanceStart(self):
        exec(open(f'{ShareNames.instance_path}/globalvars.py').read())
        self.shared_ctrl.assign_variables()
        for mm in range(len(ShareNames.model_params)): ShareNames.model_params[mm].update({'path':f'{ShareNames.instance_path}/{mm}'})
    
    def printer(self , key):
        """
        Print out status giving display conditions and looping conditions
        """
        _detail_print = (self.display.get('once') == 0 or self.model_count <= max(ShareNames.model_num_list))
        if key == 'model_specifics':
            logger.warning('Model Parameters:')
            logger.info(f'Basic Parameters : ')
            print(f'STORAGE [{config["STORAGE_TYPE"]}] | DEVICE [{DEVICE}] | PRECISION [{ShareNames.precision}] | BATCH_SIZE [{ShareNames.batch_size}].') 
            print(f'NAME [{ShareNames.model_name}] | MODULE [{ShareNames.model_module}] | DATATYPE [{ShareNames.model_data_type}] | MODEL_NUM [{len(ShareNames.model_num_list)}].')
            print(f'BEG_DATE [{config["BEG_DATE"]}] | END_DATE [{ShareNames.test_full_dates[-1]}] | ' +
                  f'INTERVAL [{config["INTERVAL"]}] | INPUT_STEP_DAY [{config["INPUT_STEP_DAY"]}] | TEST_STEP_DAY [{config["TEST_STEP_DAY"]}].') 
            logger.info(f'MODEL_PARAM : ')
            pretty_print_dict(ShareNames.raw_model_params)
            logger.info(f'TRAIN_PARAM : ')
            pretty_print_dict(ShareNames.train_params)
            logger.info(f'COMPT_PARAM : ')
            pretty_print_dict(ShareNames.compt_params)
        elif key == 'model_end':
            self.text['epoch'] = 'Ep#{:3d}'.format(self.epoch_all)
            self.text['stat']  = 'Train{: .4f} Valid{: .4f} BestVal{: .4f}'.format(self.ic_list['train'][-1],self.ic_list['valid'][-1],self.ic_round_best)
            self.text['time']  = 'Cost{:5.1f}Min,{:5.1f}Sec/Ep'.format((self.time[2]-self.time[0])/60 , (self.time[2]-self.time[1])/(self.epoch_all+1))
            sdout = self.text['model'] + '|' + self.text['round'] + ' ' + self.text['attempt'] + ' ' +\
            self.text['epoch'] + ' ' + self.text['exit'] + '|' + self.text['stat'] + '|' + self.text['time']
            logger.warning(sdout)
        elif key == 'epoch_step':
            self.text['trainer'] = 'loss {: .5f}, train{: .5f}, valid{: .5f}, max{: .4f}, best{: .4f}, lr{:.1e}'.format(
                self.loss_list['train'][-1] , self.ic_list['train'][-1] , self.ic_list['valid'][-1] , self.ic_attempt_best , self.ic_round_best , self.lr_list[-1])
            if self.epoch_i % self.display.get('step') == 0:
                sdout = ' '.join([self.text['attempt'],'Ep#{:3d}'.format(self.epoch_i),':', self.text['trainer']])
                logger.info(sdout) if _detail_print else logger.debug(sdout) 
        elif key == 'reset_learn_rate':
            speedup = ShareNames.train_params['trainer']['learn_rate']['reset']['speedup2x']
            sdout = 'Reset learn rate and scheduler at the end of epoch {} , effective at epoch {}'.format(self.epoch_i , self.epoch_i+1 , ', and will speedup2x' * speedup)
            logger.info(sdout) if _detail_print else logger.debug(sdout) 
        elif key == 'new_attempt':
            sdout = ' '.join([self.text['attempt'],'Epoch #{:3d}'.format(self.epoch_i),':',self.text['trainer'],', Next attempt goes!'])
            logger.info(sdout) if _detail_print else logger.debug(sdout) 
        elif key == 'new_round':
            sdout = self.text['round'] + ' ' + self.text['exit'] + ': ' + self.text['trainer'] + ', Next round goes!'
            logger.info(sdout) if _detail_print else logger.debug(sdout)
        elif key == 'train_dataloader':
            sdout = ' '.join([self.text['model'],'LoadData Cost {:>6.1f}Secs'.format(self.time[1]-self.time[0])])  
            logger.info(sdout) if _detail_print else logger.debug(sdout)
        else:
            raise Exception(f'KeyError : {key}')        
            
    def _init_variables(self , key = 'model'):
        """
        Reset variables of 'model' , 'round' , 'attempt' start
        """
        if key == 'epoch' : return
        assert key in ['model' , 'round' , 'attempt'] , f'KeyError : {key}'

        self.epoch_i = -1
        self.epoch_attempt_best = -1
        self.ic_attempt_best = -10000.
        self.loss_list = {'train' : [] , 'valid' : []}
        self.ic_list   = {'train' : [] , 'valid' : []}
        self.lr_list   = []
        
        if key in ['model' , 'round']:
            self.attempt_i = -1
            self.ic_round_best = -10000.
        
        if key in ['model']:
            self.round_i = -1
            self.epoch_all = -1
            self.time = np.ones(10) * time.time()
            self.text = {k : '' for k in ['model','round','attempt','epoch','exit','stat','time','trainer']}
            self.cond = {'terminate' : {} , 'nan_loss' : False , 'loop_status' : 'round'}
            
    def _loss_and_metric(self, labels , pred , key , **kwargs):
        """
        Calculate loss(with gradient), metric
        Inputs : 
            cal_options : 'l'for loss , 'm' as metric , 'p' for penalty (add to l) , (1,1,1) as default
            kwargs : other inputs used in calculating loss , penalty and metric
        Possible Methods :
        loss:    pearsonr , mse , ccc
        penalty: none , hidden_orthogonality
        metric:  pearsonr , rankic , mse , ccc
        """
        assert key in ['train' , 'valid'] , key
        if labels.shape != pred.shape:
            # if more labels than output
            assert labels.shape[:-1] == pred.shape[:-1] , (labels.shape , pred.shape)
            labels = labels.transpose(0,-1)[:pred.shape[-1]].transpose(0,-1)
            
        if key == 'train':
            if self.Param['num_output'] > 1:
                loss = self.f_loss(labels , pred , dim = 0)[:self.Param['num_output']]
                loss = self.multiloss.calculate_multi_loss(loss , self.net.get_multiloss_params())
            else:
                loss    = self.f_loss(labels.select(-1,0) , pred.select(-1,0))
            metric  = self.f_metric(labels.select(-1,0) , pred.select(-1,0)).item()
            penalty = sum([w * f(**kwargs) for k,(f,w) in self.f_penalty.items()])
            loss = loss + penalty  
        else:
            metric  = self.f_metric(labels.select(-1,0) , pred.select(-1,0)).item()
            loss    = 0.
        return loss , metric
    
    def _deal_nanloss(self):
        """
        Deal with nan loss, life -1 and change nan_loss condition to True
        """
        logger.error(f'{self.text["model"]} Attempt{self.attempt_i}, epoch{self.epoch_i} got nan loss!')
        if self.nanloss_life > 0:
            self.nanloss_life -= 1
            self.cond['nan_loss'] = True
        else:
            raise Exception('Nan loss life exhausted, possible gradient explosion/vanish!')
    
    def _terminate_cond(self , key , arg):
        """
        Whether terminate condition meets
        """
        if key == 'early_stop':
            return self.epoch_i - self.epoch_attempt_best >= arg
        elif key == 'train_converge':
            return list_converge(self.loss_list['train'] , arg.get('min_epoch') , arg.get('eps'))
        elif key == 'valid_converge':
            return list_converge(self.ic_list['valid'] , arg.get('min_epoch') , arg.get('eps'))
        elif key == 'tv_converge':
            return (list_converge(self.loss_list['train'] , arg.get('min_epoch') , arg.get('eps')) and
                    list_converge(self.ic_list['valid'] , arg.get('min_epoch') , arg.get('eps')))
        elif key == 'max_epoch':
            return self.epoch_i >= min(arg , ShareNames.max_epoch) - 1
        else:
            raise Exception(f'KeyError : {key}')
    
    def save_model(self , key = 'best'):
        assert isinstance(key , (list,tuple,str))
        _start_time = time.time()
        if isinstance(key , (list,tuple)):
            [self.save_model(k) for k in key]
        else:
            assert key in ['best' , 'swalast' , 'swabest']
            if key == 'best':
                model_state = storage_model.load(self.path['best'])
                if self.round_i < self.max_round - 1:
                    if 'rounds' not in self.path.keys():
                        self.path['rounds'] = ['{}/{}.round.{}.pt'.format(self.Param.get('path') , self.model_date , r) for r in range(self.max_round - 1)]
                    # self.path[f'round.{self.round_i}'] = '{}/{}.round.{}.pt'.format(self.Param.get('path') , self.model_date , self.round_i)
                    storage_model.save(model_state , self.path['rounds'][self.round_i])
                storage_model.save(model_state , self.path['best'] , to_disk = True)
            else:
                p_exists = storage_model.valid_paths(self.path[f'src_model.{key}'])
                if len(p_exists) == 0:
                    print(key , self.path[f'bestn'] , self.path[f'bestn_ic'] , self.path[f'src_model.{key}'])
                    raise Exception(f'Model Error')
                else:
                    model = self.swa_model(p_exists)
                    storage_model.save_model_state(model , self.path[key] , to_disk = True) 
        self.time_recoder(_start_time , ['save_model'])
    
    def load_model(self , process , key = 'best'):
        assert process in ['train' , 'test']
        _start_time = time.time()
        net = globals()[f'My{ShareNames.model_module}'](**self.Param)
        if process == 'train':           
            if self.round_i > 0:
                model_path = self.path['rounds'][self.round_i-1]
            elif 'transfer' in self.path.keys():
                model_path = self.path['transfer']
            else:
                model_path = -1
            if os.path.exists(model_path): net = storage_model.load_model_state(net , model_path , from_disk = True)
            if 'training_round' in net.__dir__(): net.training_round(self.round_i)
        else:
            net = storage_model.load_model_state(net , self.path[key] , from_disk = True)
        net = cuda(net)
        self.time_recoder(_start_time , ['load_model'])
        return net
    
    def swa_model(self , model_path_list = []):
        net = globals()[f'My{ShareNames.model_module}'](**self.Param)
        swa_net = AveragedModel(net)
        for p in model_path_list:
            swa_net.update_parameters(storage_model.load_model_state(net , p))
        swa_net = cuda(swa_net)
        update_bn(self.data.dataloaders['train'] , swa_net)
        return swa_net.module
    
    def load_optimizer(self , new_opt_kwargs = None , new_lr_kwargs = None):
        if new_opt_kwargs is None:
            opt_kwargs = ShareNames.train_params['trainer']['optimizer']
        else:
            opt_kwargs = deepcopy(ShareNames.train_params['trainer']['optimizer'])
            opt_kwargs.update(new_opt_kwargs)
        
        if new_lr_kwargs is None:
            lr_kwargs = ShareNames.train_params['trainer']['learn_rate']
        else:
            lr_kwargs = deepcopy(ShareNames.train_params['trainer']['learn_rate'])
            lr_kwargs.update(new_lr_kwargs)

        base_lr = lr_kwargs['base'] * lr_kwargs['ratio']['attempt'][:self.attempt_i+1][-1] * lr_kwargs['ratio']['round'][:self.round_i+1][-1]
        if 'transfer' in self.path.keys():
            # define param list to train with different learn rate
            p_enc = [(p if p.dim()<=1 else nn.init.xavier_uniform_(p)) for x,p in self.net.named_parameters() if 'encoder' in x.split('.')[:3]]
            p_dec = [p for x,p in self.net.named_parameters() if 'encoder' not in x.split('.')[:3]]
            self.net_param_gourps = [{'params': p_dec , 'lr': base_lr , 'lr_param' : base_lr},
                                     {'params': p_enc , 'lr': base_lr * lr_kwargs['ratio']['transfer'] , 'lr_param': base_lr * lr_kwargs['ratio']['transfer']}]
        else:
            self.net_param_gourps = [{'params': [p for p in self.net.parameters()] , 'lr' : base_lr , 'lr_param' : base_lr} ]

        optimizer = {
            'Adam': torch.optim.Adam ,
            'SGD' : torch.optim.SGD ,
        }[opt_kwargs['name']](self.net_param_gourps , **opt_kwargs['param'])
        return optimizer
    
    def load_scheduler(self , new_shd_kwargs = None):
        if new_shd_kwargs is None:
            shd_kwargs = ShareNames.train_params['trainer']['scheduler']
        else:
            shd_kwargs = deepcopy(ShareNames.train_params['trainer']['scheduler'])
            shd_kwargs.update(new_shd_kwargs)

        if shd_kwargs['name'] == 'cos':
            scheduler = lr_cosine_scheduler(self.optimizer, **shd_kwargs['param'])
        elif shd_kwargs['name'] == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, **shd_kwargs['param'])
        elif shd_kwargs['name'] == 'cycle':
            scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, max_lr=[pg['lr_param'] for pg in self.optimizer.param_groups],cycle_momentum=False,mode='triangular2',**shd_kwargs['param'])

        return scheduler
    
    def reset_scheduler(self):
        rst_kwargs = ShareNames.train_params['trainer']['learn_rate']['reset']
        if rst_kwargs['num_reset'] <= 0 or (self.epoch_i + 1) < rst_kwargs['trigger']: return

        trigger_intvl = rst_kwargs['trigger'] // 2 if rst_kwargs['speedup2x'] else rst_kwargs['trigger']
        if (self.epoch_i + 1 - rst_kwargs['trigger']) % trigger_intvl != 0: return
        
        trigger_times = ((self.epoch_i + 1 - rst_kwargs['trigger']) // trigger_intvl) + 1
        if trigger_times > rst_kwargs['num_reset']: return
        
        # confirm reset : change back optimizor learn rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = param_group['lr_param']  * rst_kwargs['recover_level']
        
        # confirm reset : reassign scheduler
        if rst_kwargs['speedup2x']:
            shd_kwargs = deepcopy(ShareNames.train_params['trainer']['scheduler'])
            for k in np.intersect1d(list(shd_kwargs['param'].keys()),['step_size' , 'warmup_stage' , 'anneal_stage' , 'step_size_up' , 'step_size_down']):
                shd_kwargs['param'][k] //= 2
        else:
            shd_kwargs = None
        self.scheduler = self.load_scheduler(shd_kwargs)
        self.printer('reset_learn_rate')
        
    def load_multiloss(self):
        multiloss = None
        if self.Param['num_output'] > 1:
            multiloss = multiloss_calculator(multi_type = ShareNames.train_params['multitask']['type'])
            multiloss.reset_multi_type(self.Param['num_output'] , **ShareNames.train_params['multitask']['param_dict'][multiloss.multi_type])
        return multiloss

    def time_recoder(self , start_time , keys , init_length = 100):
        if TIME_RECODER:
            if isinstance(keys , (list , tuple)): k = '/'.join(keys)
            if self.process_time.get(k) is None: 
                self.process_time[k] = {
                    'value' : np.zeros(init_length) , 
                    'index' : -1 , 
                    'length' : init_length , 
                }
            d = self.process_time[k]
            d['index'] += 1
            if d['length'] <= d['index']: 
                d['value'] = np.append(d['value'] , np.zeros(init_length))
                d['length'] += init_length
            d['value'][d['index']] = time.time() - start_time
    
    def print_time_recorder(self):
        if TIME_RECODER:
            keys = list(self.process_time.keys())
            num_calls = [self.process_time[k]['index']+1 for k in keys]
            total_time = [self.process_time[k]['value'].sum() for k in keys]
            tb = pd.DataFrame({'keys':keys , 'num_calls': num_calls, 'total_time': total_time})
            tb['avg_time'] = tb['total_time'] / tb['num_calls']
            print(tb.sort_values(by=['total_time'],ascending=False))
                
def cuda(x):
    if isinstance(x , (list,tuple)):
        return type(x)(map(cuda , x))
    else:
        return x.to(DEVICE)
    
class FilteredIterator:
    def __init__(self, iterable, condition):
        self.iterable = iterable
        self.condition = condition

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            item = next(self.iterable)
            cond = self.condition(item) if callable(self.condition) else next(self.condition)
            if cond: return item

def loss_function(key):
    """
    loss function , metric should * -1.
    """
    assert key in ('mse' , 'pearson' , 'ccc')
    def decorator(func , key):
        def wrapper(*args, **kwargs):
            v = func(*args, **kwargs)
            if key != 'mse':  
                v = torch.exp(-v)
            return v
        return wrapper
    func = globals()[key]
    return decorator(func , key)

def metric_function(key):
    assert key in ('mse' , 'pearson' , 'ccc' , 'spearman')
    def decorator(func , key , item_only = False):
        def wrapper(*args, **kwargs):
            with torch.no_grad():
                v = func(*args, **kwargs)
            if key == 'mse' : v = -v
            return v
        return wrapper
    func = globals()[key]
    return decorator(globals()[key] , key)
    
def penalty_function(key):
    _cat_tensor = lambda x:(torch.cat(x,dim=-1) if isinstance(x,(tuple,list)) else x)
    def _none(**kwargs):
        return 0.
    def _hidden_orthogonality(**kwargs):
        _cat_tensor = lambda x:(torch.cat(x,dim=-1) if isinstance(x,(tuple,list)) else x)
        return _cat_tensor(kwargs.get('hidden')).T.corrcoef().triu(1).nan_to_num().square().sum()
    return locals()[f'_{key}']

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='manual to this script')
    parser.add_argument("--process",     type=int, default=1)
    parser.add_argument("--rawname",     type=int, default=1)
    parser.add_argument("--resume",      type=int, default=0)
    parser.add_argument("--anchoring",   type=int, default=0)
    ShareNames = parser.parse_args([])

    Controller = model_controller()
    Controller.main_process()
    Controller.print_time_recorder()


[1m[37m[41m24-02-03 23:33:34|MOD:30757165    |[0m: [1m[31mProcess Queue : Data + Train[0m
[1m[37m[41m24-02-03 23:33:34|MOD:30757165    |[0m: [1m[31mDirectories of [GeneralRNN_day_SHORTTEST] deletion Confirmed![0m
[1m[37m[41m24-02-03 23:33:34|MOD:30757165    |[0m: [1m[31mStart Process [Load Data]![0m
usage: ipykernel_launcher.py [-h] [--process PROCESS] [--rawname RAWNAME]
                             [--resume RESUME] [--anchoring ANCHORING]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/mengkjin/.local/share/jupyter/runtime/kernel-v2-7625epp6ZNGIKf8Q.json


--Process Queue : Data + Train
--Start Training New!
--Model_name is set to LSTM_day_SHORTTEST!


[1m[37m[41m24-02-03 23:33:50|MOD:30757165    |[0m: [1m[31mFinish Process [Load Data]! Cost 15.5Secs[0m
[1m[37m[41m24-02-03 23:33:50|MOD:30757165    |[0m: [1m[31mStart Process [Train Model]![0m
[1m[37m[45m24-02-03 23:33:50|MOD:30757165    |[0m: [1m[35mStart Training Models![0m
[32mGeneralRNN_day_SHORTTEST #0 @20170103 LoadData Cost    1.4Secs[0m
[32mFirstBite Ep#  0 : loss  0.99814, train 0.00190, valid 0.01278, max 0.0128, best 0.0128, lr1.0e-07[0m
[32mFirstBite Ep#  5 : loss  0.86911, train 0.14036, valid 0.16530, max 0.1653, best 0.1653, lr3.8e-03[0m
[1m[37m[44m24-02-03 23:34:10|MOD:30757165    |[0m: [1m[34mGeneralRNN_day_SHORTTEST #0 @20170103|Round 0 FirstBite Ep#  9 Max Epoch|Train 0.1749 Valid 0.1996 BestVal 0.1996|Cost  0.3Min,  1.8Sec/Ep[0m
[32mGeneralRNN_day_SHORTTEST #0 @20170704 LoadData Cost    1.4Secs[0m
[32mFirstBite Ep#  0 : loss  1.00161, train-0.00156, valid-0.01343, max-0.0134, best-0.0134, lr1.0e-07[0m
[32mFirstBite Ep#  5 : los