# U-Trans + ViT  
## California STEAD Example (Event magnitude)

This example demonstrates training and testing an **earthquake magnitude model** using:

- **U-Trans foundation backbone**
- **ViT regression model**
- **California subset of the STEAD dataset**

---

#  Dataset

## STEAD (Stanford Earthquake Dataset)

**STEAD** (Stanford Earthquake Dataset) is a large-scale seismic waveform dataset containing:

- 3-component waveform recordings  
- Event metadata (including magnitude)

In this example:

- Only the **California region subset** is used.
- Waveform length is fixed to **6000 samples**.
- Input shape is **(6000, 3)** representing three-component seismic data.
- Ground-truth event magnitude is used as regression targets.

---

#  Purpose of This Example

This pipeline demonstrates how to:

- Train a magnitude regression model on California STEAD traces  
- Predict earthquake magnitude  
- Evaluate spatial prediction performance  
- Test the trained model on a held-out test set  

This setup reproduces the California STEAD experiment using the **U-Trans + ViT architecture** for event localization.

---


In [None]:
# %%
from __future__ import print_function, division

import os

import tensorflow.compat.v1 as tf1
tf1.disable_v2_behavior()

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[2], True)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
tf.config.experimental.set_visible_devices(physical_devices[0], 'GPU')

os.environ['KERAS_BACKEND'] = 'tensorflow'

In [None]:



# %%
import sys
import time
import math
import json
import pickle
import random
import shutil
import logging
import warnings
import contextlib
import multiprocessing
import datetime
import csv
import numpy as np
import pandas as pd
import h5py

from tqdm import tqdm
from glob import glob
from os.path import join
from datetime import datetime, timedelta
from scipy import signal
from matplotlib.lines import Line2D

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import tensorflow
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers, callbacks
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler

from tensorflow.keras.layers import (
    Input, Dense, Dropout, Flatten, Reshape, Activation,
    Conv1D, MaxPooling1D, UpSampling1D, BatchNormalization,
    Add, concatenate, DepthwiseConv1D
)

warnings.filterwarnings("ignore")
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

import faulthandler
faulthandler.enable()

# External utilities (as in your original code)
from EqT_utils_Mag import DataGenerator, _lr_schedule

def recall(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    return true_positives / (predicted_positives + K.epsilon())

def f1(y_true, y_pred):
    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    return 2 * ((p * r) / (p + r + K.epsilon()))



## Foundation Model

In [None]:
stochastic_depth_rate = 0.1
positional_emb = False
conv_layers = 1
num_classes = 1
projection_dim = 80
num_heads = 4
transformer_units = [projection_dim, projection_dim]
transformer_layers = 4


def convF1(inpt, D1, fil_ord, Dr):

    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    #filters = inpt._keras_shape[channel_axis]
    filters = int(inpt.shape[-1])
    
    #infx = Activation(tf.nn.gelu')(inpt)
    pre = Conv1D(filters,  fil_ord, strides =(1), padding='same',kernel_initializer='he_normal')(inpt)
    pre = BatchNormalization()(pre)    
    pre = Activation('linear')(pre)
    
    #shared_conv = Conv1D(D1,  fil_ord, strides =(1), padding='same')
    
    inf  = Conv1D(filters,  fil_ord, strides =(1), padding='same',kernel_initializer='he_normal')(pre)
    inf = BatchNormalization()(inf)    
    inf = Activation('linear')(inf)
    inf = Add()([inf,inpt])
    
    inf1  = Conv1D(D1,  fil_ord, strides =(1), padding='same',kernel_initializer='he_normal')(inf)
    inf1 = BatchNormalization()(inf1)  
    inf1 = Activation('linear')(inf1)    
    encode = Dropout(Dr)(inf1, training=False)

    return encode

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


class StochasticDepth(layers.Layer):
    def __init__(self, drop_prop, **kwargs):
        super(StochasticDepth, self).__init__(**kwargs)
        self.drop_prob = drop_prop

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x


class CCTTokenizer1(layers.Layer):
    def __init__(
        self,
        kernel_size=3,
        stride=1,
        padding=1,
        pooling_kernel_size=3,
        pooling_stride=(1,1,1,1),
        num_conv_layers=conv_layers,
        num_output_channels=[int(projection_dim)] * 8,
        positional_emb=positional_emb,
        **kwargs,
    ):
        super(CCTTokenizer1, self).__init__(**kwargs)

        self.conv_model = tf.keras.Sequential()
        for i in range(num_conv_layers):
            self.conv_model.add(
                layers.Conv1D(
                    num_output_channels[i],
                    kernel_size,
                    stride,
                    padding="same",
                    use_bias=False,
                    activation="relu",
                    kernel_initializer="he_normal",
                )
            )

        self.positional_emb = positional_emb

    def call(self, images):
        outputs = self.conv_model(images)
        reshaped = tf.reshape(outputs, (-1, tf.shape(outputs)[1], tf.shape(outputs)[2]))
        return outputs

    def positional_embedding(self, image_size):
        if self.positional_emb:
            dummy_inputs = tf.ones((1, image_size, 1))
            dummy_outputs = self.call(dummy_inputs)
            sequence_length = int(dummy_outputs.shape[1])
            projection_dim = int(dummy_outputs.shape[-1])

            embed_layer = layers.Embedding(input_dim=sequence_length, output_dim=projection_dim)
            return embed_layer, sequence_length
        else:
            return None


def create_cct_model1(inputs):
    cct_tokenizer = CCTTokenizer1()
    encoded_patches = cct_tokenizer(inputs)

    if positional_emb:
        pos_embed, seq_length = cct_tokenizer.positional_embedding(image_size)
        positions = tf.range(start=0, limit=seq_length, delta=1)
        position_embeddings = pos_embed(positions)
        encoded_patches += position_embeddings

    dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]

    for i in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)

        attention_output = StochasticDepth(dpr[i])(attention_output)
        x2 = layers.Add()([attention_output, encoded_patches])

        x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)

        x3 = StochasticDepth(dpr[i])(x3)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
    return representation


def UNET(inputs, D1):
    D2 = int(D1 * 2)
    D3 = int(D2 * 2)
    D4 = int(D3 * 2)
    D5 = int(D4 * 2)

    conv1 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling1D(pool_size=(2))(conv1)

    conv2 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling1D(pool_size=(2))(conv2)

    conv3 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling1D(pool_size=(2))(conv3)

    conv4 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    pool4 = MaxPooling1D(pool_size=(2))(conv4)

    conv44 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv44 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv44)
    pool44 = MaxPooling1D(pool_size=(5))(conv44)

    conv5 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool44)
    conv5 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)

    drop5 = create_cct_model1(conv5)

    up66 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(5))(drop5))
    merge66 = concatenate([pool4, up66], axis=-1)
    conv66 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge66)
    conv66 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv66)

    up6 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv66))
    merge6 = concatenate([conv4, up6], axis=-1)
    conv6 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    up7 = Conv1D(D3, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv6))
    merge7 = concatenate([conv3, up7], axis=-1)
    conv7 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    up8 = Conv1D(D2, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv7))
    merge8 = concatenate([conv2, up8], axis=-1)
    conv8 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    conv8 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    up9 = Conv1D(D1, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv8))
    merge9 = concatenate([conv1, up9], axis=-1)
    conv9 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    conv9 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)

    return conv9


## ViT

In [None]:
monte_carlo_sampling = 50
drop_rate = 0.2
input_shape = (6000,3)


num_classes = 1
input_shapeX = (75, 32)
image_sizeX = 75  # We'll resize input images to this size
patch_sizeX = 5  # Size of the patches to be extract from the input images
num_patchesX = (image_sizeX // patch_sizeX)
projection_dimX = 100
num_headsX = 4
transformer_unitsX = [
    projection_dimX * 2,
    projection_dimX,
]  # Size of the transformer layers
transformer_layersX = 4

def mlpX(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        #x = layers.Dense(units, activation='relu')(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


class PatchesX(layers.Layer):
    def __init__(self, patch_sizeX, **kwargs):
        super(PatchesX, self).__init__()
        self.patch_sizeX = patch_sizeX

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'patch_sizeX' : self.patch_sizeX, 
            
        })
        
        return config
        
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_sizeX, 1, 1],
            strides=[1, self.patch_sizeX, 1, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    

class PatchEncoderX(layers.Layer):
    def __init__(self, num_patchesX, projection_dimX, **kwargs):
        super(PatchEncoderX, self).__init__()
        self.num_patchesX = num_patchesX
        self.projection = layers.Dense(units=projection_dimX)
        self.position_embedding = layers.Embedding(
            input_dim=num_patchesX, output_dim=projection_dimX
        )

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_patchesX' : self.num_patchesX, 
            'projection_dimX' : projection_dimX, 
            
        })
        
        return config
    
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patchesX, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        
        return encoded
    
def create_vit_classifier(inputs,inp):
    
    filters = [32, 64, 96, 128, 256] 

    x = convF1(inputs,   80, 13, 0.1)
    x = convF1(x,   80, 13, 0.1)
    x = convF1(x,   80, 13, 0.1)
    x = Flatten()(x)
    x =  Reshape((6000,1))(x)
    x = concatenate([x,inp])
    
     
    e = Conv1D(filters[1], 3, padding = 'same')(x) 
    e = Dropout(drop_rate)(e, training=False)
    e = MaxPooling1D(2, padding='same')(e)

    
    e = Conv1D(filters[1], 3, padding = 'same')(e) 
    e = Dropout(drop_rate)(e, training=False)
    e = MaxPooling1D(2, padding='same')(e)
    
    
    e = Conv1D(filters[0], 3, padding = 'same')(e) 
    e = Dropout(drop_rate)(e, training=False)
    e = MaxPooling1D(2, padding='same')(e)
    
    e = Conv1D(filters[0], 3, padding = 'same')(e) 
    e = Dropout(drop_rate)(e, training=False)
    e = MaxPooling1D(2, padding='same')(e)
    
        
    e = Conv1D(filters[0], 3, padding = 'same')(e) 
    e = Dropout(drop_rate)(e, training=False)
    e = MaxPooling1D(5, padding='same')(e)
    
    
    #print(e)
    inputreshaped = layers.Reshape((75,1,32))(e)
    # Create patches.
    patches = PatchesX(patch_sizeX)(inputreshaped)
    # Encode patches.
    encoded_patches = PatchEncoderX(num_patchesX, projection_dimX)(patches)
    
    
    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layersX):
        #encoded_patches = convF1(encoded_PatchesX, projection_dimX,11, 0.1)
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_headsX, key_dim=projection_dimX, dropout=0.1
        )(x1, x1)

        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # mlpX.
        x3 = mlpX(x3, hidden_units=transformer_unitsX, dropout_rate=0.1)
        #x3 = convF1(x3, projection_dimX,11, 0.1)

        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dimX] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)

    # Add mlpX.
    features = mlpX(representation, hidden_units=[1000,500], dropout_rate=0.5)

    return features

## Training

In [None]:

def trainer1(input_hdf5=None,
            output_name=None,                
            input_dimention=(6000, 3),
            shuffle=True, 
            label_type='gaussian',
            normalization_mode='std',
            augmentation=True,
            add_event_r=0.6,
            shift_event_r=0.99,
            add_noise_r=0.3, 
            drop_channel_r=0.5,
            add_gap_r=0.2,
            scale_amplitude_r=None,
            pre_emphasis=False,                
            mode='generator',
            batch_size=200,
            epochs=200, 
            monitor='val_loss',
            patience=12,
            multi_gpu=False,
            number_of_gpus=4,
            gpuid=None,
            gpu_limit=None,
            use_multiprocessing=True):
        
    """
    
    Generate a model and train it.  
    
    Parameters
    ----------
    input_hdf5: str, default=None
        Path to an hdf5 file containing only one class of data with NumPy arrays containing 3 component waveforms each 1 min long.


    output_name: str, default=None
        Output directory.
        
    input_dimention: tuple, default=(6000, 3)
        OLoss types for S picking respectively. 

    shuffle: bool, default=True
        To shuffle the list prior to the training.

    label_type: str, default='triangle'
        Labeling type. 'gaussian', 'triangle', or 'box'. 

    normalization_mode: str, default='std'
        Mode of normalization for data preprocessing, 'max': maximum amplitude among three components, 'std', standard deviation. 

    augmentation: bool, default=True
        If True, data will be augmented simultaneously during the training.

    add_event_r: float, default=0.6
        Rate of augmentation for adding a secondary event randomly into the empty part of a trace.

    shift_event_r: float, default=0.99
        Rate of augmentation for randomly shifting the event within a trace.
      
    add_noise_r: float, defaults=0.3 
        Rate of augmentation for adding Gaussian noise with different SNR into a trace.       
        
    drop_channel_r: float, defaults=0.4 
        Rate of augmentation for randomly dropping one of the channels.

    add_gap_r: float, defaults=0.2 
        Add an interval with zeros into the waveform representing filled gaps.       
        
    scale_amplitude_r: float, defaults=None
        Rate of augmentation for randomly scaling the trace. 
              
    pre_emphasis: bool, defaults=False
        If True, waveforms will be pre-emphasized. Defaults to False.  
        
          
    mode: str, defaults='generator'
        Mode of running. 'generator', or 'preload'. 
         
    batch_size: int, default=200
        Batch size.
          
    epochs: int, default=200
        The number of epochs.
          
    monitor: int, default='val_loss'
        The measure used for monitoring.
           
    patience: int, default=12
        The number of epochs without any improvement in the monitoring measure to automatically stop the training.          
        
    multi_gpu: bool, default=False
        If True, multiple GPUs will be used for the training. 
           
    number_of_gpus: int, default=4
        Number of GPUs uses for multi-GPU training.
           
    gpuid: int, default=None
        Id of GPU used for the prediction. If using CPU set to None. 
         
    gpu_limit: float, default=None
        Set the maximum percentage of memory usage for the GPU.
        
    use_multiprocessing: bool, default=True
        If True, multiple CPUs will be used for the preprocessing of data even when GPU is used for the prediction. 

        
    """     


    args = {
    "input_hdf5": input_hdf5,
    "output_name": output_name,
    "input_dimention": input_dimention,
    "shuffle": shuffle,
    "label_type": label_type,
    "normalization_mode": normalization_mode,
    "augmentation": augmentation,
    "add_event_r": add_event_r,
    "shift_event_r": shift_event_r,
    "add_noise_r": add_noise_r,
    "add_gap_r": add_gap_r,
    "drop_channel_r": drop_channel_r,
    "scale_amplitude_r": scale_amplitude_r,
    "pre_emphasis": pre_emphasis,
    "mode": mode,
    "batch_size": batch_size,
    "epochs": epochs,
    "monitor": monitor,
    "patience": patience,           
    "multi_gpu": multi_gpu,
    "number_of_gpus": number_of_gpus,           
    "gpuid": gpuid,
    "gpu_limit": gpu_limit,
    "use_multiprocessing": use_multiprocessing
    }
                       
    def train(args):
        """ 
        
        Performs the training.
    
        Parameters
        ----------
        args : dic
            A dictionary object containing all of the input parameters. 

        Returns
        -------
        history: dic
            Training history.  
            
        model: 
            Trained model.
            
        start_training: datetime
            Training start time. 
            
        end_training: datetime
            Training end time. 
            
        save_dir: str
            Path to the output directory. 
            
        save_models: str
            Path to the folder for saveing the models.  
            
        training size: int
            Number of training samples.
            
        validation size: int
            Number of validation samples.  
            
        """    

        
        save_dir, save_models=_make_dir(args['output_name'])
        training, validation=_split(args, save_dir)
        callbacks=_make_callback(args, save_models)
        #model=_build_model(args)
        #model.summary()  
    
        D1 = 5
        D2 = int(D1*2)
        D3 = int(D2*2)
        D4 = int(D3*2)
        D5 = int(D4*2)

        inp = Input(shape=input_shape,name="input")
        conv1 = UNET(inp,D1)
        out = Conv1D(D1,  3, strides =(1), padding='same',kernel_initializer='he_normal')(conv1)
        out = Conv1D(3,  3, strides =(1), padding='same',kernel_initializer='he_normal',name='picker_PP')(out)
        modeloriginal = Model(inp, out)

        modeloriginal.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['acc',f1,precision, recall])     
        modeloriginal.load_weights('../../../weights/UTrans_Foundation.h5')

        # Model CCT
        inputs = modeloriginal.layers[63].output  # layer that you want to connect your new FC layer to 

        features = create_vit_classifier(inputs,inp)
        #features = Reshape((6000,1))(features)

        e = Dense(1)(features)
        o = Activation('linear', name='output_layer')(e)
        model = Model(inputs=modeloriginal.input, outputs=o)
        
        Adm = tensorflow.optimizers.Adam(lr=1e-4)
        model.compile(optimizer=Adm,
                  loss='mse',
                  metrics=['mse'])     

        
        model.summary()
        
        
        
        if args['gpuid']:           
            os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(gpuid)
            tf.Session(config=tf.ConfigProto(log_device_placement=True))
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            config.gpu_options.per_process_gpu_memory_fraction = float(args['gpu_limit']) 
            K.tensorflow_backend.set_session(tf.Session(config=config))
            
        start_training = time.time()                  
            
        if args['mode'] == 'generator': 
            
            params_training = {'file_name': str(args['input_hdf5']), 
                              'dim': args['input_dimention'][0],
                              'batch_size': args['batch_size'],
                              'n_channels': args['input_dimention'][-1],
                              'shuffle': args['shuffle'],  
                              'norm_mode': args['normalization_mode'],
                              'label_type': args['label_type'],                          
                              'augmentation': args['augmentation'],
                              'add_event_r': args['add_event_r'], 
                              'add_gap_r': args['add_gap_r'],  
                              'shift_event_r': args['shift_event_r'],                            
                              'add_noise_r': args['add_noise_r'], 
                              'drop_channe_r': args['drop_channel_r'],
                              'scale_amplitude_r': args['scale_amplitude_r'],
                              'pre_emphasis': args['pre_emphasis']}    
                        
            params_validation = {'file_name': str(args['input_hdf5']),  
                                 'dim': args['input_dimention'][0],
                                 'batch_size': args['batch_size'],
                                 'n_channels': args['input_dimention'][-1],
                                 'shuffle': False,  
                                 'norm_mode': args['normalization_mode'],
                                 'augmentation': False}         

            training_generator = DataGenerator(training, **params_training)
            validation_generator = DataGenerator(validation, **params_validation) 

            print('Started training in generator mode ...') 
            print(training_generator)
            history = model.fit_generator(generator=training_generator,
                                          validation_data=validation_generator,
                                          use_multiprocessing=args['use_multiprocessing'],
                                          workers=multiprocessing.cpu_count(),    
                                          callbacks=callbacks, 
                                          epochs=args['epochs'],verbose=1)
        else:
            print('Please specify training_mode !', flush=True)
        end_training = time.time()  
        
        return history, model, start_training, end_training, save_dir, save_models, len(training), len(validation)
                  
    history, model, start_training, end_training, save_dir, save_models, training_size, validation_size=train(args)  





def _make_dir(output_name):
    
    """ 
    
    Make the output directories.

    Parameters
    ----------
    output_name: str
        Name of the output directory.
                   
    Returns
    -------   
    save_dir: str
        Full path to the output directory.
        
    save_models: str
        Full path to the model directory. 
        
    """   
    
    if output_name == None:
        print('Please specify output_name!') 
        return
    else:
        save_dir = os.path.join(os.getcwd(), str(output_name)+'_outputs')
        save_models = os.path.join(save_dir, 'models')      
        if os.path.isdir(save_dir):
            shutil.rmtree(save_dir)  
        os.makedirs(save_models)
    return save_dir, save_models



def _build_model(args): 
    
    """ 
    
    Build and compile the model.

    Parameters
    ----------
    args: dic
        A dictionary containing all of the input parameters. 
               
    Returns
    -------   
    model: 
        Compiled model.
        
    """       
    


    # Model EQCCT
    inputs = layers.Input(shape=input_shape,name='input')
    
    featuresS = create_cct_modelS(inputs)
    featuresS = Reshape((6000,1))(featuresS)

    logits  = Conv1D(1,  15, strides =(1), padding='same',activation='sigmoid', kernel_initializer='he_normal',name='picker_S')(featuresS)
    model = Model(inputs=[inputs], outputs=[logits])    
    

    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['acc',f1,precision, recall])    
    
  
    return model  
    

def _split(args, save_dir):
    
    """ 
    
    Split the list of input data into training, validation, and test set.

    Parameters
    ----------
    args: dic
        A dictionary containing all of the input parameters. 
        
    save_dir: str
       Path to the output directory. 
              
    Returns
    -------   
    training: str
        List of trace names for the training set. 
    validation : str
        List of trace names for the validation set. 
                
    """       
    
    # Loading the IDs.
    training = np.load('train_Events.npy')
    validation = np.load('valid_Events.npy')
    test = np.load('test_Events.npy')

    return training, validation 



def _make_callback(args, save_models):
    
    """ 
    
    Generate the callback.

    Parameters
    ----------
    args: dic
        A dictionary containing all of the input parameters. 
        
    save_models: str
       Path to the output directory for the models. 
              
    Returns
    -------   
    callbacks: obj
        List of callback objects. 
        
        
    """    
    
    m_name=str(args['output_name'])+'_{epoch:03d}.h5'   
    filepath=os.path.join(save_models, m_name)  
    early_stopping_monitor=EarlyStopping(monitor=args['monitor'], 
                                           patience=args['patience']) 
    checkpoint=ModelCheckpoint(filepath=filepath,
                                 monitor=args['monitor'], 
                                 mode='auto',
                                 verbose=1,
                                 save_best_only=True,
                                  save_weights_only=True)  
    lr_scheduler=LearningRateScheduler(_lr_schedule)

    lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                                   cooldown=0,
                                   patience=args['patience']-2,
                                   min_lr=0.5e-6)

    callbacks = [checkpoint, lr_reducer, lr_scheduler, early_stopping_monitor]
    return callbacks
 
    


def _pre_loading(args, training, validation):
    
    """ 
    
    Load data into memory.
create_cct_modelS
    Parameters
    ----------
    args: dic
        A dictionary containing all of the input parameters. 
        
    training: str
        List of trace names for the training set. 
        
    validation: str
        List of trace names for the validation set. 
              
    Returns
    -------   
    training_generator: obj
        Keras generator for the training set. 
        
    validation_generator: obj
        Keras generator for the validation set. 
        
        
    """   
    
    training_set={}
    fl = h5py.File(args['input_hdf5'], 'r')   
    
    print('Loading the training data into the memory ...')
    pbar = tqdm(total=len(training)) 
    for ID in training:
        pbar.update()
        if ID.split('_')[-1] == 'EV':
            dataset = fl.get(str(ID))
        elif ID.split('_')[-1] == 'NO':
            dataset = fl.get(str(ID))
        training_set.update( {str(ID) : dataset})  

    print('Loading the validation data into the memory ...', flush=True)            
    validation_set={}
    pbar = tqdm(total=len(validation)) 
    for ID in validation:
        pbar.update()
        if ID.split('_')[-1] == 'EV':
            dataset = fl.get(str(ID))
        elif ID.split('_')[-1] == 'NO':
            dataset = fl.get(str(ID))
        validation_set.update( {str(ID) : dataset})  
   
    params_training = {'dim':args['input_dimention'][0],
                       'batch_size': args['batch_size'],
                       'n_channels': args['input_dimention'][-1],
                       'shuffle': args['shuffle'],  
                       'norm_mode': args['normalization_mode'],
                       'label_type': args['label_type'],
                       'augmentation': args['augmentation'],
                       'add_event_r': args['add_event_r'], 
                       'add_gap_r': args['add_gap_r'],                         
                       'shift_event_r': args['shift_event_r'],  
                       'add_noise_r': args['add_noise_r'], 
                       'drop_channe_r': args['drop_channel_r'],
                       'scale_amplitude_r': args['scale_amplitude_r'],
                       'pre_emphasis': args['pre_emphasis']}  

    params_validation = {'dim': args['input_dimention'][0],
                         'batch_size': args['batch_size'],
                         'n_channels': args['input_dimention'][-1],
                         'shuffle': False,  
                         'norm_mode': args['normalization_mode'],
                         'augmentation': False}  
    
    training_generator = PreLoadGenerator(training, training_set, **params_training)  
    validation_generator = PreLoadGenerator(validation, validation_set, **params_validation) 
    
    return training_generator, validation_generator  




In [None]:
trainer1(input_hdf5= '/scratch/sadalyom/DataCollected',
        output_name='test_trainer_FoundationVit_Mag',                 
        shuffle=True, 
        label_type='triangle',
        normalization_mode='std',
        augmentation=True,
        add_event_r=None,
        shift_event_r=0.99,
        add_noise_r=0.5, 
        drop_channel_r=0.1,
        add_gap_r=None,
        scale_amplitude_r=None,
        pre_emphasis=False,               
        mode='generator',
        batch_size=40,
        epochs=50, 
        patience=10,
        gpuid=None,
        gpu_limit=None)