# Training 

Notebooks to train models  
  
Examples is given for the EUV PCNNs of the paper    
  
Results are stored in PATH_RES/CONTINUING_FOLDER   
(or PATH_RES/NEW_FOLDER_NAME when creating a new reslult folder)  

The structure of a NEW_FOLDER_NAME is   
automatically set up by the function utilsTraining.setUpResultFolder,   
called here in the begining of the cell 'Training'  
The content of those result folders is described in the doc config.py, below 'PATH_RES'  

# Config

In [None]:
COLAB = False

if COLAB : 
  configSetup = {
      'COLAB'           : 'True',
      'PATH_ROOT_DRIVE' : '/content/drive/MyDrive/Projects/Forecast',
      'PATH_ROOT_LOCAL' : '/content/session',
      'PATH_SUNDL'      : '/content/sundl',
      'PATH_PROJECT'    : '/content/sundl/notebooks/flare_limits_pcnn'
  }
  !git clone https://github.com/gfrancisco20/sundl.git
  import sys
  import re
  sys.path.append(configSetup['PATH_SUNDL'])
  sys.path.append(configSetup['PATH_PROJECT'])
  configFile = f'{configSetup["PATH_PROJECT"]}/config.py'
  with open(configFile, 'r') as file:
    content = file.read()
  for constant in configSetup.keys():
    content = re.sub(re.compile(f'{constant} = .*'), f'{constant} = \'{configSetup[constant]}\'', content)
  with open(configFile, 'w') as file:
    file.write(content)
   
from config import *
from sundl.utils.colab import mountDrive
if COLAB:
  # mouting drive content in session on colab
  mountDrive()

In [8]:
from sunpy.net import Fido
from sunpy.net import attrs as a
event_type = "FL"
tstart = "2010/04/28"
tend = "2023/04/29"
result = Fido.search(a.Time(tstart, tend),
                     a.hek.EventType(event_type),
                     a.hek.FL.GOESCls > "C1.0",
                     a.hek.OBS.Observatory == "GOES")
# Here we only show two columns due there being over 100 columns returned normally.
print(result.show("hpc_bbox", "refs"))

# It"s also possible to access the HEK results from the
# `~sunpy.net.fido_factory.UnifiedResponse` by name.
hek_results = result["hek"]
filtered_results = hek_results["event_starttime", "event_peaktime",
                               "event_endtime", "fl_goescls", "ar_noaanum"]

filtered_results

# Libraries

In [None]:
from pathlib import Path
import dill as pickle

import time
from tqdm import tqdm
import gc

import numpy as np
import pandas as pd

import tensorflow as tf

# Data import

In [None]:
%%time
from sundl.utils.colab import mountDrive, ressourcesSetAndCheck, drive2local
############################
# SETUP
############################

# overwriting CLEAN_LOCAL :
CLEAN_LOCAL = False

if CLEAN_LOCAL:
  shutil.rmtree(PATH_ROOT_LOCAL)
  os.makedirs(PATH_ROOT_LOCAL)
  
# checking gpu and ram ressources
ressourcesSetAndCheck(MIXED_PREC)

############################
# DATA IMPORT
############################

FILES2TRANSFER = {'images' : (PATH_ROOT_DRIVE_DS/'Images',        # source
                              PATH_IMAGES,                        # dest
                              ['eq_hmi_448', 'eq_193x211x94_448'] # files
                              )
                  }

drive2local(FILES2TRANSFER)

# Hyperparameters

In [None]:
from sundl.metrics.tfmetrics import *


labelCol     = 'mpf' # 'mpf' -> windows's SXR-max-peak-flux , 'toteh' -> (hourly average of sum of flares' SXR-fluence)
windowSizesH = [24]
EPOCHS       = 1 # 25
BATCH_SIZE   = 16
IMG_SIZE     = (224, 448, 3) # (512, 1024, 3) (224, 448, 3)
PTCH_SIZE    = (112, 112, 3) # (256, 256, 3) (112, 112, 3)

NEW_FOLDER_NAME   = None                 # New folder in which to store results
CONTINUING_FOLDER = 'Results_Paper_PCNN' # Existing foler in which to store results

if labelCol=='mpf':
  # aggregation type 
  agg = 'max'
else:
  agg = 'sum'

weightByClass = True

CV_K      = 5 
VAL_SPLIT = None # --> not used if CV_K not none

SAMPLE_TRAIN = None # 0.95 
SAMPLE_VAL   = None

CACHE     = True
verbose   = 0

SAVE_MODEL   = True
save_monitor = 'val_tss3'
save_mode    = 'max'
save_thdS    = {'C': 0.50, 
                'M': 0.25, 
                'X': 0.10} 

RECOMPUTE_DATASET = True

thresholds = [0.5]
metrics = [tf.keras.metrics.BinaryAccuracy(threshold=0.5, name=f'acc')] \
        + [Tss(threshold=thd) for thd in thresholds] \
        + [Hss(threshold=thd) for thd in thresholds] \
        + [Mcc(threshold=thd) for thd in thresholds] \
        + [F1(threshold=thd) for thd in thresholds] \
        + [tf.keras.metrics.Precision(class_id = 1, name = 'precision')] \
        + [tf.keras.metrics.Recall(class_id = 1, name = 'recall')] \
        + [TP(threshold=thd) for thd in thresholds] \
        + [FN(threshold=thd) for thd in thresholds] \
        + [TN(threshold=thd) for thd in thresholds] \
        + [FP(threshold=thd) for thd in thresholds] \
        + [tf.keras.metrics.AUC(curve='ROC', name='auc_roc')] \
        + [tf.keras.metrics.AUC(curve='PR', name='auc_pr')]
   
# different weights and penalisation strategies  
WEIGHT_BY_CLASS = True   
weightCollection = {'EquiC'    : {'quiet': 0.25, 'B':0.25, 'C':0.167, 'M':0.167, 'X': 0.166},
                    'EquiCnat' : {'quiet': 0.46, 'B':0.54, 'C':0.72, 'M':0.26, 'X': 0.03},
                    'EquiM'    : {'quiet': 0.166, 'B':0.167, 'C':0.167, 'M':0.25, 'X': 0.25},
                    'EquiMnat' : {'quiet': 0.28, 'B':0.32, 'C':0.40, 'M':0.91, 'X': 0.09},
                    'ProgPos'  : {'quiet': 0.05, 'B':0.05, 'C':0.10, 'M':0.30, 'X': 0.50},
                    'LowBC'    : {'quiet': 0.4, 'B':0.2, 'C':0.1, 'M':0.8, 'X': 0.8},
                    'LowC'     : {'quiet': 0.4, 'B':0.4, 'C':0.1, 'M':0.8, 'X': 0.8},
                    'LowC2'    : {'quiet': 0.2, 'B':0.2, 'C':0.1, 'M':0.8, 'X': 0.8} 
                    }

# Models Definition

In [None]:
from notebooks.flare_limits_pcnn.utilsTraining import ModelInstantier2
from sundl.models.blueprints import build_pretrained_PatchCNN
from sundl.dataloader.sdocml import builDS_image_feature
from tensorflow.keras.optimizers import Adam, AdamW
from tensorflow.keras.losses import BinaryCrossentropy

num_classes = 2

# Dataset common parameters
ds_params = {'labelCol'    : labelCol,
             'num_classes' : num_classes,
             'img_size'    : IMG_SIZE
             }

# Models common parameters
tfModel = tf.keras.applications.efficientnet_v2.EfficientNetV2S
core_params = {'tfModel'        : tfModel,
               'pretainedWeight': True,
               'unfreeze_top_N' : 'all', 
               'num_classes'    : num_classes, # no use here
               'img_size'       : IMG_SIZE,
               'patches_size'   : PTCH_SIZE,
               'regression'     : False,
               'metrics'        : metrics,
               'includeInterPatches' : False,
               'loss' : BinaryCrossentropy(label_smoothing = 0,
                                           name = 'loss'
                                         )
               } 

# Models definition
# We just give as an example the fina EUV models of the paper
PCNN_C =  ModelInstantier2(
    buildModelFunction = build_pretrained_PatchCNN,
    buildModelParams = dict(**core_params,
                            **{'patche_output_type' : 'pre_pred',
                               'meth_patche_agg'    : agg,
                               'shared_patcher'     : 'all',
                               'optimizer'          : AdamW(learning_rate = 1e-5,#  amsgrad = True,
                                                            weight_decay  = 1e-4)
                              }
                            ),
    buildDsFunction = builDS_image_feature,
    buildDsParams =  ds_params,
    name = f'C+_{labelCol}',
    cls = 'C',
    weightStrategy = 'ProgPos'
)

PCNN_M =  ModelInstantier2(
    buildModelFunction = build_pretrained_PatchCNN,
    buildModelParams = dict(**core_params,
                            **{'patche_output_type' : 'pre_pred',
                               'meth_patche_agg'    : agg,
                               'shared_patcher'     : 'all',
                               'optimizer'          : AdamW(learning_rate = 1e-5,#  amsgrad = True,
                                                            weight_decay  = 1e-4)
                              }
                            ),
    buildDsFunction = builDS_image_feature,
    buildDsParams =  ds_params,
    name = f'M+_{labelCol}',
    cls = 'M',
    weightStrategy = 'ProgPos'
)

# Training

In [None]:
from notebooks.flare_limits_pcnn.utilsTraining import setUpResultFolder, conditionalHyperParameters, trainConstantModel, printTrainingResults, saveTrainingResults
from sundl.utils.data import read_Dataframe_With_Dates, loadMinMaxDates
from sundl.utils.flare.windows import windowHistoryFromFlList

models = [(PCNN_C, ['0193x0211x0094'], h) for h in windowSizesH] + \
         [(PCNN_M, ['0193x0211x0094'], h) for h in windowSizesH] 

log, resDir, modelDir, mtcDict = setUpResultFolder(
    models = models, 
    pathRes = PATH_RES,
    metrics = metrics,
    continuingFolder = CONTINUING_FOLDER, 
    newFolder = NEW_FOLDER_NAME,
    imgSize = IMG_SIZE,
    cv_K = CV_K,
    saveModel = SAVE_MODEL
    )

print('\nINITIAL STATUS : ')
display(log)
print('')
res = {}
best = None
eval = None
bestCVCrossEpoch = None
dsTrain = None
dsVal = None
ct=0
verbose = 1
ct_dsBuilds = -1
minDate, maxDate = loadMinMaxDates(PATH_IMAGES)
flCatalog = read_Dataframe_With_Dates(PATH_FLCATALOG)
print('minDate : ', minDate)
print('maxDate : ', maxDate)

for modelInstantiater, channels, h in tqdm(models):
  ct_dsBuilds+=1
  
  
  save_thd, labelCol, binCls, classWeights, classTresholds, encoder = conditionalHyperParameters(modelInstantiater, h, save_thdS, weightCollection)
  modelInstantiater.buildDsParams['labelEncoder'] = encoder
  modelInstantiater.buildDsParams['classTresholds'] = classTresholds
  
  if F_PATH_WINDOWS('mpf', h).exists():
    dfFlareHistory = read_Dataframe_With_Dates(F_PATH_WINDOWS('mpf', h))
  else:
    dfFlareHistory = windowHistoryFromFlList(flCatalog, window_h = h, timeRes_h = 2, minDate = minDate, maxDate = maxDate)
  
  
  CV_FLD_PTH  = F_PATH_FOLDS(labelCol, h)
  with open(CV_FLD_PTH, 'rb') as f1:
    dfFoldsTrainVal = pickle.load(f1)[0:CV_K]
    
  full_name_comb = modelInstantiater.fullNameFunc(channels,h)
  if log.loc[full_name_comb]['status'] > 0:
    print(f'\n\n-----------------------------\n{full_name_comb} already successfuly trained\n')
  else:
    log.loc[full_name_comb, 'status'] = -1
    log.to_csv(resDir + '/log.csv')
  
    model = None
    if RECOMPUTE_DATASET:
      dsTrain = None
      dsVal = None
  
    #===================================================
    # CROSS VALIDATION LOOP
    #===================================================
    duration = time.time()
    res[full_name_comb] = []
    kf=0
    for df_train, df_val in tqdm(dfFoldsTrainVal,disable = False):#not verbose):
      print(f'\n\n-----------------------------\nModel : {full_name_comb}')
      if Path(resDir+f'/training_folds/training_{full_name_comb}_fd{kf:0>3d}.csv').exists():
        res[full_name_comb].append(pd.read_csv(resDir+f'/training_folds/training_{full_name_comb}_fd{kf:0>3d}.csv').set_index('epoch'))
        print(f'FOLD #{kf} ALREADY TRAINED')
        kf += 1
      else:
        print(f'\n\n-----------------------------\nModel : {full_name_comb}')
        print(f'FOLD #{kf}')
        
        # MEMORY CLEANING
        tf.keras.backend.clear_session()
        tf.compat.v1.reset_default_graph()
        if model is not None: del model
        if RECOMPUTE_DATASET:
          if dsTrain is not None: del dsTrain
          if dsVal is not None: del dsVal
        gc.collect()
        
        # FOLDER FOR MODEL
        if SAVE_MODEL:
          modelDirSub = modelDir + f'/{full_name_comb}'
          if CV_K is not None:
            modelDirSub = modelDirSub + f'_fd{kf:0>3d}'
            
        # DATASETS INSTANTIATION 
        dfSamples_train = df_train.copy()
        dfSamples_val = df_val.copy()
        if SAMPLE_TRAIN is not None:
          dfSamples_train = dfSamples_train.sample(frac = SAMPLE_TRAIN, random_state=49)
        if SAMPLE_VAL is not None:
           dfSamples_val = dfSamples_val.sample(frac = SAMPLE_VAL, random_state=49)
        if ct_dsBuilds==0 or RECOMPUTE_DATASET:
          pathDir = PATH_IMAGES if channels is not None else None
          dsTrain, _, missing_file_regexp, dfSamples_train_corr = modelInstantiater.build_DS(
              pathDir    = pathDir,
              channels   = channels,
              dfTimeseries = dfFlareHistory.copy(), 
              samples    = dfSamples_train.copy(), 
              batch_size = BATCH_SIZE,
              epochs     = EPOCHS,
              cache      = CACHE,
              shuffle    = True,
              weightByClass = WEIGHT_BY_CLASS,
              classWeights = classWeights,
          )
          print('')
          dsVal, _, missing_file_regexp_val, dfSamples_val_corr  = modelInstantiater.build_DS(
              pathDir    = pathDir,
              channels   = channels,
              dfTimeseries = dfFlareHistory.copy(),
              samples    = dfSamples_val.copy(),
              batch_size = BATCH_SIZE,
              epochs     = EPOCHS,
              cache      = CACHE,
              shuffle    = True,
              weightByClass = False,
              classWeights = None,
              typeDs = 'val'
          )
          print(f'{len(missing_file_regexp)} incomplete training dates')
          print(f'{len(missing_file_regexp_val)} incomplete val dates')
          
        # MODEL INSTANTIATION
        model = modelInstantiater()
        try:
          print(f'\nMODEL PARAMETERS #: {model.count_params()/1e6:.2f}M')
          trainable_params = tf.reduce_sum([tf.reduce_prod(p.shape) for p in model.trainable_variables])
          print("of which trainable #:", trainable_params)
        except:
          pass
        
        # CALLBACKS
        callbacks = []
        if SAVE_MODEL:
          callbacks.append(tf.keras.callbacks.ModelCheckpoint(
              modelDirSub,
              save_best_only = True,
              save_weights_only=False,
              monitor = save_monitor,
              verbose = 1,
              mode = save_mode,
              initial_value_threshold = save_thd)
          )
        
        # TRAINING
        if modelInstantiater.savedPredictionModel:
          # models where input = output (e.g. persistant models)
          historyData = trainConstantModel(dsTrain, dsVal, model, modelInstantiater, EPOCHS, weightByClass, SAVE_MODEL, modelDirSub)
        else:
          history = model.fit(dsTrain,#.take(1),
                              epochs=EPOCHS,
                              validation_data = dsVal,#.take(1),
                              callbacks = callbacks,
                              verbose = 1 #verbose
                              )
          historyData = history.history
          if SAVE_MODEL:
            pathConfigModel = modelDirSub + f'_config.pkl'
            modelInstantiater.saveConfig(pathConfigModel)
        # vectorizing metric reesults
        historyData = {m : np.array(historyData[m]) for m in historyData.keys()}
        
        # ADDITIONAL METRIC
        historyData['far'] = 1 - historyData['precision']
        historyData['val_far'] = 1 - historyData['val_precision']
          
        # PRINTING TRAINING HISTORY
        printTrainingResults(historyData)
        kf+=1 # fold index

      # SAVING FOLD RESULTS
      res[full_name_comb].append(pd.DataFrame(historyData)) 
      res[full_name_comb][-1].index.names = ['epoch']
      num_inst = len(dfSamples_train)
      res[full_name_comb][-1]['num_train_inst'] = (res[full_name_comb][-1].index + 1) * num_inst
      res[full_name_comb][-1].to_csv(resDir+f'/training_folds/training_{full_name_comb}_fd{kf:0>3d}.csv',index=True)
       
      
      # END OF CV-LOOP
      #===================================================
      
    # SAVING GENERAL RESULTS
    res, best, bestCVCrossEpoch = saveTrainingResults(resDir, res, best, bestCVCrossEpoch, full_name_comb, CV_K)
    
    duration = time.time() - duration
    log.loc[full_name_comb, 'status'] = 1
    log.loc[full_name_comb, 'duration'] = f'{duration//3600:0>2.0f}h {duration//60%60:0>2.0f}m {duration%60:0>2.0f}s'
    log.to_csv(resDir + '/log.csv')
        