In [None]:
# %%
from __future__ import print_function, division

import os

import tensorflow.compat.v1 as tf1
tf1.disable_v2_behavior()

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[2], True)

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
tf.config.experimental.set_visible_devices(physical_devices[3], 'GPU')

os.environ['KERAS_BACKEND'] = 'tensorflow'

In [None]:



# %%
import sys
import time
import math
import json
import pickle
import random
import shutil
import logging
import warnings
import contextlib
import multiprocessing
import datetime

import numpy as np
import pandas as pd
import h5py

from tqdm import tqdm
from glob import glob
from os.path import join
from datetime import datetime, timedelta
from scipy import signal
from matplotlib.lines import Line2D

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import tensorflow
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers, callbacks
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler

from tensorflow.keras.layers import (
    Input, Dense, Dropout, Flatten, Reshape, Activation,
    Conv1D, MaxPooling1D, UpSampling1D, BatchNormalization,
    Add, concatenate
)

warnings.filterwarnings("ignore")
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

import faulthandler
faulthandler.enable()

# External utilities (as in your original code)
from EQCCT_P_utils import DataGenerator, _lr_schedule, data_reader

def recall(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    return true_positives / (predicted_positives + K.epsilon())

def f1(y_true, y_pred):
    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    return 2 * ((p * r) / (p + r + K.epsilon()))

def wbceEdit(y_true, y_pred):
    ms = K.mean(K.square(y_true - y_pred))
    ssim = 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
    return (ssim + ms)


## Foundation Model

In [None]:
stochastic_depth_rate = 0.1
positional_emb = False
conv_layers = 1
num_classes = 1

input_shape = (375, 256)
projection_dim = 80
num_heads = 4
transformer_units = [projection_dim, projection_dim]
transformer_layers = 4


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


class StochasticDepth(layers.Layer):
    def __init__(self, drop_prop, **kwargs):
        super(StochasticDepth, self).__init__(**kwargs)
        self.drop_prob = drop_prop

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x


class CCTTokenizer1(layers.Layer):
    def __init__(
        self,
        kernel_size=3,
        stride=1,
        padding=1,
        pooling_kernel_size=3,
        pooling_stride=(1,1,1,1),
        num_conv_layers=conv_layers,
        num_output_channels=[int(projection_dim)] * 8,
        positional_emb=positional_emb,
        **kwargs,
    ):
        super(CCTTokenizer1, self).__init__(**kwargs)

        self.conv_model = tf.keras.Sequential()
        for i in range(num_conv_layers):
            self.conv_model.add(
                layers.Conv1D(
                    num_output_channels[i],
                    kernel_size,
                    stride,
                    padding="same",
                    use_bias=False,
                    activation="relu",
                    kernel_initializer="he_normal",
                )
            )

        self.positional_emb = positional_emb

    def call(self, images):
        outputs = self.conv_model(images)
        reshaped = tf.reshape(outputs, (-1, tf.shape(outputs)[1], tf.shape(outputs)[2]))
        return outputs

    def positional_embedding(self, image_size):
        if self.positional_emb:
            dummy_inputs = tf.ones((1, image_size, 1))
            dummy_outputs = self.call(dummy_inputs)
            sequence_length = int(dummy_outputs.shape[1])
            projection_dim = int(dummy_outputs.shape[-1])

            embed_layer = layers.Embedding(input_dim=sequence_length, output_dim=projection_dim)
            return embed_layer, sequence_length
        else:
            return None


def create_cct_model1(inputs):
    cct_tokenizer = CCTTokenizer1()
    encoded_patches = cct_tokenizer(inputs)

    if positional_emb:
        pos_embed, seq_length = cct_tokenizer.positional_embedding(image_size)
        positions = tf.range(start=0, limit=seq_length, delta=1)
        position_embeddings = pos_embed(positions)
        encoded_patches += position_embeddings

    dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]

    for i in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)

        attention_output = StochasticDepth(dpr[i])(attention_output)
        x2 = layers.Add()([attention_output, encoded_patches])

        x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)

        x3 = StochasticDepth(dpr[i])(x3)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
    return representation


def UNET(inputs, D1):
    D2 = int(D1 * 2)
    D3 = int(D2 * 2)
    D4 = int(D3 * 2)
    D5 = int(D4 * 2)

    conv1 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling1D(pool_size=(2))(conv1)

    conv2 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling1D(pool_size=(2))(conv2)

    conv3 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling1D(pool_size=(2))(conv3)

    conv4 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    pool4 = MaxPooling1D(pool_size=(2))(conv4)

    conv44 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv44 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv44)
    pool44 = MaxPooling1D(pool_size=(5))(conv44)

    conv5 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool44)
    conv5 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)

    drop5 = create_cct_model1(conv5)

    up66 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(5))(drop5))
    merge66 = concatenate([pool4, up66], axis=-1)
    conv66 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge66)
    conv66 = Conv1D(D5, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv66)

    up6 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv66))
    merge6 = concatenate([conv4, up6], axis=-1)
    conv6 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv1D(D4, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    up7 = Conv1D(D3, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv6))
    merge7 = concatenate([conv3, up7], axis=-1)
    conv7 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv1D(D3, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    up8 = Conv1D(D2, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv7))
    merge8 = concatenate([conv2, up8], axis=-1)
    conv8 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    conv8 = Conv1D(D2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    up9 = Conv1D(D1, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling1D(size=(2))(conv8))
    merge9 = concatenate([conv1, up9], axis=-1)
    conv9 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    conv9 = Conv1D(D1, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)

    return conv9


## EQCCT

In [None]:
def recall(y_true, y_pred):

    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision(y_true, y_pred):

    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision
    
def f1(y_true, y_pred):
           
    precisionx = precision(y_true, y_pred)
    recallx = recall(y_true, y_pred)
    return 2*((precisionx*recallx)/(precisionx+recallx+K.epsilon()))


import tensorflow as tf
def wbceEdit( y_true, y_pred) :
    
    ms = K.mean(K.square(y_true-y_pred)) 
    
    ssim = 1-tf.reduce_mean(tf.image.ssim(y_true,y_pred,1.0))
    
    return (ssim + ms)


stochastic_depth_rateX = 0.1
positional_embX = False
input_shapeX = (6000, 3)
image_sizeX = 6000  # We'll resize input images to this size
patch_sizeX = 40  # Size of the patches to be extract from the input images
num_patchesX = (image_sizeX // patch_sizeX)
projection_dimX = 40

num_headsX = 4
transformer_unitsX = [
    projection_dimX,
    projection_dimX,
]  # Size of the transformer layers
transformer_layersX = 4

    
class PatchEncoderX(layers.Layer):
    def __init__(self, num_patchesX, projection_dimX, **kwargs):
        super(PatchEncoderX, self).__init__()
        self.num_patchesX = num_patchesX
        self.projection = layers.Dense(units=projection_dimX)
        self.position_embedding = layers.Embedding(
            input_dim=num_patchesX, output_dim=projection_dimX
        )

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_patchesX' : self.num_patchesX, 
            'projection_dimX' : projection_dimX, 
            
        })
        
        return config
    
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patchesX, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        
        
        return encoded
    
# Referred from: github.com:rwightman/pytorch-image-models.
class StochasticDepthX(layers.Layer):
    def __init__(self, drop_prop, **kwargs):
        super(StochasticDepthX, self).__init__(**kwargs)
        self.drop_prob = drop_prop

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

def convF1(inpt, D1, fil_ord, Dr):

    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    #filters = inpt._keras_shape[channel_axis]
    filters = int(inpt.shape[-1])
    
    pre = Conv1D(filters,  fil_ord, strides =(1), padding='same',kernel_initializer='he_normal')(inpt)
    pre = BatchNormalization()(pre)    
    pre = Activation(tf.nn.gelu)(pre)
    
    
    inf  = Conv1D(filters,  fil_ord, strides =(1), padding='same',kernel_initializer='he_normal')(pre)
    inf = BatchNormalization()(inf)    
    inf = Activation(tf.nn.gelu)(inf)
    inf = Add()([inf,inpt])
    
    inf1  = Conv1D(D1,  fil_ord, strides =(1), padding='same',kernel_initializer='he_normal')(inf)
    inf1 = BatchNormalization()(inf1)  
    inf1 = Activation(tf.nn.gelu)(inf1)    
    encode = Dropout(Dr)(inf1)

    return encode

def mlpX(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        #x = layers.Dense(units, activation='relu')(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


class PatchesX(layers.Layer):
    def __init__(self, patch_sizeX, **kwargs):
        super(PatchesX, self).__init__()
        self.patch_sizeX = patch_sizeX

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'patch_sizeX' : self.patch_sizeX, 
            
        })
        
        return config
        
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_sizeX, 1, 1],
            strides=[1, self.patch_sizeX, 1, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
class PatchEncoderX(layers.Layer):
    def __init__(self, num_patchesX, projection_dimX, **kwargs):
        super(PatchEncoderX, self).__init__()
        self.num_patchesX = num_patchesX
        self.projection = layers.Dense(units=projection_dimX)
        self.position_embedding = layers.Embedding(
            input_dim=num_patchesX, output_dim=projection_dimX
        )

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_patchesX' : self.num_patchesX, 
            'projection_dimX' : projection_dimX, 
            
        })
        
        return config
    
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patchesX, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
    
        
        return encoded
    
    


def create_cct_modelP(inputs,inp):
    
    ## Decoder of the Foundation Encoder
    x = convF1(inputs,   80, 13, 0.1)
    x = convF1(x,   80, 13, 0.1)
    x = convF1(x,   80, 13, 0.1)
    x = Flatten()(x)
    x =  Reshape((6000,1))(x)
    ## Ready to Concatenate
    x = concatenate([x,inp])
    
    
    inputs1 = convF1(x,   10, 11, 0.1)
    inputs1 = convF1(inputs1, 20, 11, 0.1)
    inputs1 = convF1(inputs1, 40, 11, 0.1)
    
    inputreshaped = layers.Reshape((6000,1,40))(inputs1)
    # Create patches.
    patchesX = PatchesX(patch_sizeX)(inputreshaped)
    # Encode patches.
    encoded_patches = PatchEncoderX(num_patchesX, projection_dimX)(patchesX)


    # Calculate Stochastic Depth probabilities.
    dpr = [x for x in np.linspace(0, stochastic_depth_rateX, transformer_layersX)]

    # Create multiple layers of the Transformer block.
    for i in range(transformer_layersX):
        #encoded_patches = convF1(encoded_patches, 80,3, 0.1)
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)

        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_headsX, key_dim=projection_dimX, dropout=0.1
        )(x1, x1)
        #attention_output = convF1(attention_output, 80,3, 0.1)
        

        # Skip connection 1.
        attention_output = StochasticDepthX(dpr[i])(attention_output)
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)

        # MLP.
        x3 = mlpX(x3, hidden_units=transformer_unitsX, dropout_rate=0.1)
        #print(x3,x2)
        
        # Skip connection 2.
        x3 = StochasticDepthX(dpr[i])(x3)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    
    return representation




## Testing

In [None]:



import keras
from keras.models import load_model
from keras import backend as K
from keras.layers import Input
import tensorflow as tf
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import h5py
import time
import csv
import os
import shutil
import multiprocessing
from EQCCT_P_utils import DataGenerator, _lr_schedule, data_reader
import datetime
from tqdm import tqdm
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
from tensorflow.python.client import device_lib



def picker(args, yh3, yh3_std, spt=None):
    """ 
    
    Performs detection and picking.

    Parameters
    ----------
    args : dic
        A dictionary containing all of the input parameters.  
        
        
    yh3 : 1D array
        P arrival probabilities.  
        
    yh3_std : 1D array
        P arrival standard deviations.  
        
        
    spt : {int, None}, default=None    
        P arrival time in sample.
        
        
    """               
        


    P_PICKall=[]
    Ppickall=[]
    Pproball = []
    perrorall=[]



    sP_arr = _detect_peaks(yh3, mph=args['P_threshold'], mpd=1)

    P_PICKS = []
    pick_errors = []
    if len(sP_arr) > 0:
        P_uncertainty = None  

        for pick in range(len(sP_arr)):        
            sauto = sP_arr[pick]

            if  args['estimate_uncertainty'] and sauto:
                P_uncertainty = np.round(yh3_std[int(sauto)], 3)

            if sauto: 
                P_prob = np.round(yh3[int(sauto)], 3) 
                P_PICKS.append([sauto,P_prob, P_uncertainty]) 

    so=[]
    si=[]
    P_PICKS = np.array(P_PICKS)
    P_PICKall.append(P_PICKS)
    for ij in P_PICKS:
        so.append(ij[1])
        si.append(ij[0])
    try:
        so = np.array(so)
        inds = np.argmax(so)
        swave = si[inds]
        perrorall.append(int(spt- swave))  
        Ppickall.append(int(swave))
        Pproball.append(int(np.max(so)))
    except:
        perrorall.append(None)
        Ppickall.append(None)
        Pproball.append(None)


    #Ppickall = np.array(Ppickall)
    #perrorall = np.array(perrorall)  
    #Pproball = np.array(Pproball)
    
    return Ppickall, perrorall, Pproball

def _detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False):

    """
    
    Detect peaks in data based on their amplitude and other features.

    Parameters
    ----------
    x : 1D array_like
        data.
        
    mph : {None, number}, default=None
        detect peaks that are greater than minimum peak height.
        
    mpd : int, default=1
        detect peaks that are at least separated by minimum peak distance (in number of data).
        
    threshold : int, default=0
        detect peaks (valleys) that are greater (smaller) than `threshold in relation to their immediate neighbors.
        
    edge : str, default=rising
        for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None).
        
    kpsh : bool, default=False
        keep peaks with same height even if they are closer than `mpd`.
        
    valley : bool, default=False
        if True (1), detect valleys (local minima) instead of peaks.

    Returns
    ---------
    ind : 1D array_like
        indeces of the peaks in `x`.

    Modified from 
   ----------------
    .. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
    

    """

    x = np.atleast_1d(x).astype('float64')
    if x.size < 3:
        return np.array([], dtype=int)
    if valley:
        x = -x
    # find indices of all peaks
    dx = x[1:] - x[:-1]
    # handle NaN's
    indnan = np.where(np.isnan(x))[0]
    if indnan.size:
        x[indnan] = np.inf
        dx[np.where(np.isnan(dx))[0]] = np.inf
    ine, ire, ife = np.array([[], [], []], dtype=int)
    if not edge:
        ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
    else:
        if edge.lower() in ['rising', 'both']:
            ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
        if edge.lower() in ['falling', 'both']:
            ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
    ind = np.unique(np.hstack((ine, ire, ife)))
    # handle NaN's
    if ind.size and indnan.size:
        # NaN's and values close to NaN's cannot be peaks
        ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
    # first and last values of x cannot be peaks
    if ind.size and ind[0] == 0:
        ind = ind[1:]
    if ind.size and ind[-1] == x.size-1:
        ind = ind[:-1]
    # remove peaks < minimum peak height
    if ind.size and mph is not None:
        ind = ind[x[ind] >= mph]
    # remove peaks - neighbors < threshold
    if ind.size and threshold > 0:
        dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
        ind = np.delete(ind, np.where(dx < threshold)[0])
    # detect small peaks closer than minimum peak distance
    if ind.size and mpd > 1:
        ind = ind[np.argsort(x[ind])][::-1]  # sort ind by peak height
        idel = np.zeros(ind.size, dtype=bool)
        for i in range(ind.size):
            if not idel[i]:
                # keep peaks with the same height if kpsh is True
                idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
                    & (x[ind[i]] > x[ind] if kpsh else True)
                idel[i] = 0  # Keep current peak
        # remove the small peaks and sort back the indices by their occurrence
        ind = np.sort(ind[~idel])

    return ind


def generate_arrays_from_file(file_list, step):
    
    """ 
    
    Make a generator to generate list of trace names.
    
    Parameters
    ----------
    file_list : str
        A list of trace names.  
        
    step : int
        Batch size.  
        
    Returns
    --------  
    chunck : str
        A batch of trace names. 
        
    """     
    
    n_loops = int(np.ceil(len(file_list) / step))
    b = 0
    while True:
        for i in range(n_loops):
            e = i*step + step 
            if e > len(file_list):
                e = len(file_list)
            chunck = file_list[b:e]
            b=e
            yield chunck   
class DataGeneratorTest(keras.utils.Sequence):
    
    """ 
    
    Keras generator with preprocessing. For testing. 
    
    Parameters
    ----------
    list_IDsx: str
        List of trace names.
            
    file_name: str
        Path to the input hdf5 file.
            
    dim: tuple
        Dimension of input traces. 
           
    batch_size: int, default=32
        Batch size.
            
    n_channels: int, default=3
        Number of channels.
            
    norm_mode: str, default=max
        The mode of normalization, 'max' or 'std'.
            
    Returns
    --------        
    Batches of two dictionaries: {'input': X}: pre-processed waveform as input {'detector': y1, 'picker_P': y2, 'picker_S': y3}: outputs including three separate numpy arrays as labels for detection, P, and S respectively.
    
    """   
    
    def __init__(self, 
                 list_IDs, 
                 file_name, 
                 dim, 
                 batch_size=32, 
                 n_channels=3, 
                 norm_mode = 'max'):
       
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.file_name = file_name        
        self.n_channels = n_channels
        self.on_epoch_end()
        self.norm_mode = norm_mode

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]           
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        X = self.__data_generation(list_IDs_temp)
        return ({'input': X})

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
    
    def normalize(self, data, mode = 'max'):  
        'Normalize waveforms in each batch'
        
        data -= np.mean(data, axis=0, keepdims=True)
        if mode == 'max':
            max_data = np.max(data, axis=0, keepdims=True)
            assert(max_data.shape[-1] == data.shape[-1])
            max_data[max_data == 0] = 1
            data /= max_data              

        elif mode == 'std':               
            std_data = np.std(data, axis=0, keepdims=True)
            assert(std_data.shape[-1] == data.shape[-1])
            std_data[std_data == 0] = 1
            data /= std_data
        return data    


    def __data_generation(self, list_IDs_temp):
        'readint the waveforms' 
        
        X = np.zeros((self.batch_size, self.dim, self.n_channels))
        fl = h5py.File(self.file_name, 'r')
        

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            dataset = fl.get(str(ID))
            data = np.array(dataset['data'])              

          
            if self.norm_mode:                    
                data = self.normalize(data, self.norm_mode)  
                            
            X[i, :, :] = data                                       

        fl.close() 
                           
        return X
    
    
def tester1(input_hdf5=None,
           input_testset=None,
           input_model=None,
           output_name=None,
           detection_threshold=0.20,                
           P_threshold=0.1,
           S_threshold=0.1, 
           number_of_plots=100,
           estimate_uncertainty=True, 
           number_of_sampling=5,
           loss_weights=[0.05, 0.40, 0.55],
           loss_types=['mse'],
           input_dimention=(6000, 3),
           normalization_mode='std',
           mode='generator',
           batch_size=500,
           gpuid=None,
           gpu_limit=None):

    """
    
    Applies a trained model to a windowed waveform to perform both detection and picking at the same time.  


    Parameters
    ----------
    input_hdf5: str, default=None
        Path to an hdf5 file containing only one class of "data" with NumPy arrays containing 3 component waveforms each 1 min long.

    input_testset: npy, default=None
        Path to a NumPy file (automaticaly generated by the trainer) containing a list of trace names.        

    input_model: str, default=None
        Path to a trained model.
        
    output_dir: str, default=None
        Output directory that will be generated. 
        
    output_probabilities: bool, default=False
        If True, it will output probabilities and estimated uncertainties for each trace into an HDF file. 
       
          
    P_threshold: float, default=0.1
        A value which the P probabilities above it will be considered as P arrival.

               
    number_of_plots: float, default=10
        The number of plots for detected events outputed for each station data.
        
    estimate_uncertainty: bool, default=False
        If True uncertainties in the output probabilities will be estimated.  
        
    number_of_sampling: int, default=5
        Number of sampling for the uncertainty estimation. 
               
             
    input_dimention: tuple, default=(6000, 3)
        Loss types for P picking.          

    normalization_mode: str, default='std' 
        Mode of normalization for data preprocessing, 'max', maximum amplitude among three components, 'std', standard deviation.

    mode: str, default='generator'
        Mode of running. 'pre_load_generator' or 'generator'.
                      
    batch_size: int, default=500 
        Batch size. This wont affect the speed much but can affect the performance.

    gpuid: int, default=None
        Id of GPU used for the prediction. If using CPU set to None.
         
    gpu_limit: int, default=None
        Set the maximum percentage of memory usage for the GPU.
        

    """ 
              
         
    args = {
    "input_hdf5": input_hdf5,
    "input_testset": input_testset,
    "input_model": input_model,
    "output_name": output_name,
    "detection_threshold": detection_threshold,
    "P_threshold": P_threshold,
    "number_of_plots": number_of_plots,
    "estimate_uncertainty": estimate_uncertainty,
    "number_of_sampling": number_of_sampling,
    "input_dimention": input_dimention,
    "normalization_mode": normalization_mode,
    "mode": mode,
    "batch_size": batch_size,
    "gpuid": gpuid,
    "gpu_limit": gpu_limit
    }  

    
    if args['gpuid']:           
        os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(args['gpuid'])
        tf.Session(config=tf.ConfigProto(log_device_placement=True))
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = float(args['gpu_limit']) 
        K.tensorflow_backend.set_session(tf.Session(config=config))
    
    save_dir = os.path.join(os.getcwd(), str(args['output_name'])+'_outputs')
    save_figs = os.path.join(save_dir, 'figures')
 
    if os.path.isdir(save_dir):
        shutil.rmtree(save_dir)  
    os.makedirs(save_figs) 
 
    test = np.load(args['input_testset'])
    #test = test[0:10000]

    print('Loading the model ...', flush=True) 


    
 # Model CCT
    D1 = 5
    D2 = int(D1*2)
    D3 = int(D2*2)
    D4 = int(D3*2)
    D5 = int(D4*2)

    inp = Input(shape=input_shapeX,name="input")
    conv1 = UNET(inp,D1)
    out = Conv1D(D1,  3, strides =(1), padding='same',kernel_initializer='he_normal')(conv1)
    out = Conv1D(3,  3, strides =(1), padding='same',kernel_initializer='he_normal',name='picker_PP')(out)
    modeloriginal = Model(inp, out)

    modeloriginal.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['acc',f1,precision, recall])     
    modeloriginal.load_weights('../../../weights/UTrans_Foundation.h5')

    # Model CCT
    inputs = modeloriginal.layers[63].output  # layer that you want to connect your new FC layer to 


    featuresP = create_cct_modelP(inputs,inp)
    featuresP = Reshape((6000,1))(featuresP)

    logitp  = Conv1D(1,  3, strides =(1), padding='same',activation='sigmoid', kernel_initializer='he_normal',name='picker_P')(featuresP)

    model = Model(inputs=[modeloriginal.input], outputs=logitp)

    model.compile(optimizer='adam',
          loss='binary_crossentropy',
          metrics=['acc',f1,precision, recall]) 

        

    model.load_weights(args['input_model'])


    model.summary()  
        
    print('Loading is complete!', flush=True)  
    print('Testing ...', flush=True)    
    print('Writting results into: " ' + str(args['output_name'])+'_outputs'+' "', flush=True)
    
    start_training = time.time()          

    csvTst = open(os.path.join(save_dir,'X_test_results.csv'), 'w')          
    test_writer = csv.writer(csvTst, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    test_writer.writerow([
                          
                          
                          'p_arrival_sample', 
                          

                          
                          'P_pick',
                          'P_probability',
                          'P_error'
                          ])  
    csvTst.flush()        
        
    plt_n = 0
    list_generator = generate_arrays_from_file(test, args['batch_size']) 
    pred_PP_mean_all=[]
    pred_PP_std_all=[]
    sptall = []
    
    pbar_test = tqdm(total= int(np.ceil(len(test)/args['batch_size'])))            
    for _ in range(int(np.ceil(len(test) / args['batch_size']))):
        pbar_test.update()
        new_list = next(list_generator)

        if args['mode'].lower() == 'pre_load_generator':                
            params_test = {'dim': args['input_dimention'][0],
                           'batch_size': len(new_list),
                           'n_channels': args['input_dimention'][-1],
                           'norm_mode': args['normalization_mode']}  
            test_set={}


        
        else:       
            params_test = {'file_name': str(args['input_hdf5']), 
                           'dim': args['input_dimention'][0],
                           'batch_size': len(new_list),
                           'n_channels': args['input_dimention'][-1],
                           'norm_mode': args['normalization_mode']}     
    
            test_generator = DataGeneratorTest(new_list, **params_test)
            
            if args['estimate_uncertainty']:
                pred_PP = []
                for mc in range(args['number_of_sampling']):
                    predP = model.predict_generator(generator=test_generator)
                    pred_PP.append(predP)
                
    
                
                pred_PP = np.array(pred_PP).reshape(args['number_of_sampling'], len(new_list), params_test['dim'])
                pred_PP_mean = pred_PP.mean(axis=0)
                pred_PP_std = pred_PP.std(axis=0) 

            else:          
                pred_PP_mean = model.predict_generator(generator=test_generator)
                pred_PP_mean = pred_PP_mean.reshape(pred_PP_mean.shape[0], pred_PP_mean.shape[1]) 
                
                pred_PP_std = np.zeros((pred_PP_mean.shape))   
                
   
            test_set={}
            fl = h5py.File(args['input_hdf5'], 'r')
            for ID in new_list:
                dataset = fl.get(str(ID))
                test_set.update( {str(ID) : dataset})                 
            
            for ts in range(pred_PP_mean.shape[0]): 
                evi =  new_list[ts] 
                dataset = test_set[evi]  
                
                    
                try:
                    spt = int(dataset.attrs['p_arrival_sample']);
                except Exception:     
                    spt = None
                
                
                Ppick, perror, Pprob =  picker(args, pred_PP_mean[ts], pred_PP_std[ts], spt) 
                
                _output_writter_test(args, dataset, evi, test_writer, csvTst, Ppick, perror, Pprob)
                        
                pred_PP_mean_all.append(pred_PP_mean[ts])
                pred_PP_std_all.append(pred_PP_std[ts])
                sptall.append(spt)

    
    np.save('pred_PP_mean_all_FoundationEQCCT.npy',pred_PP_mean_all)
    np.save('pred_PP_std_all_FoundationEQCCT.npy',pred_PP_std_all)
    np.save('pall_FoundationEQCCT.npy',sptall)
    
    end_training = time.time()  
    delta = end_training - start_training
    hour = int(delta / 3600)
    delta -= hour * 3600
    minute = int(delta / 60)
    delta -= minute * 60
    seconds = delta     
                    
    with open(os.path.join(save_dir,'X_report.txt'), 'a') as the_file: 
        the_file.write('================== Overal Info =============================='+'\n')               
                
        the_file.write('input_hdf5: '+str(args['input_hdf5'])+'\n')            
        the_file.write('input_testset: '+str(args['input_testset'])+'\n')
        the_file.write('input_model: '+str(args['input_model'])+'\n')
        the_file.write('output_name: '+str(args['output_name']+'_outputs')+'\n')  
        the_file.write('================== Testing Parameters ======================='+'\n')  
        the_file.write('mode: '+str(args['mode'])+'\n')  
        the_file.write('finished the test in:  {} hours and {} minutes and {} seconds \n'.format(hour, minute, round(seconds, 3))) 
        the_file.write('batch_size: '+str(args['batch_size'])+'\n')
        the_file.write('total number of tests '+str(len(test))+'\n')
        the_file.write('gpuid: '+str(args['gpuid'])+'\n')
        the_file.write('gpu_limit: '+str(args['gpu_limit'])+'\n')             
        the_file.write('================== Other Parameters ========================='+'\n')            
        the_file.write('normalization_mode: '+str(args['normalization_mode'])+'\n')
        the_file.write('estimate uncertainty: '+str(args['estimate_uncertainty'])+'\n')
        the_file.write('number of Monte Carlo sampling: '+str(args['number_of_sampling'])+'\n')                       
        the_file.write('P_threshold: '+str(args['P_threshold'])+'\n')
        
    
    
def _output_writter_test(args, 
                        dataset, 
                        evi, 
                        output_writer, 
                        csvfile, 
                        Ppick,
                        perror,
                        Pprob
                        ):
    
    """ 
    
    Writes the detection & picking results into a CSV file.

    Parameters
    ----------
    args: dic
        A dictionary containing all of the input parameters.    
 
    dataset: hdf5 obj
        Dataset object of the trace.

    evi: str
        Trace name.    
              
    output_writer: obj
        For writing out the detection/picking results in the CSV file.
        
    csvfile: obj
        For writing out the detection/picking results in the CSV file.   
             
        
    Returns
    --------  
    X_test_results.csv  
    
        
    """        
    
    
    #print(dataset.attrs['data_category'] )
    
    if dataset.attrs['data_category'] != 'noise':                                     

        p_arrival_sample = dataset.attrs['p_arrival_sample'] 

                   
    elif dataset.attrs['data_category'] == 'noise':               
        #network_code = dataset.attrs['network_code']
        source_id = None
        source_distance_km = None 
        snr_db = None
        #trace_name = dataset.attrs['trace_name'] 
        #trace_category = dataset.attrs['trace_category']            
        trace_start_time = None
        source_magnitude = None
        p_arrival_sample = None
        p_status = None
        p_weight = None
        s_arrival_sample = None
        s_status = None
        s_weight = None
        #receiver_type = dataset.attrs['receiver_type'] 
      
    #print(Ppick[0])

    output_writer.writerow([ 

                            p_arrival_sample, 
                             

            
                            Ppick[0], 
                            Pprob[0],
                            perror[0],
                            
                            ]) 
    
    csvfile.flush()


In [None]:
#`input_model: Path to the trained model directory
#`input_hdf5`: Path to the full waveform HDF5 file.

tester1(input_hdf5= '/scratch/sadalyom/DataCollected',
       input_testset='test_Events.npy',
       input_model = 'test_trainer_P_FoundationEQCCT_outputs/models/',
       output_name='test_tester_FoundationEQCCT',
       detection_threshold=0.1,                
       P_threshold=0.1,
       number_of_plots=3,
       estimate_uncertainty=False, 
       number_of_sampling=5,
       input_dimention=(6000, 3),
       normalization_mode='std',
       mode='generator',
       batch_size=1024,
       gpuid=None,
       gpu_limit=None)