<a href="https://colab.research.google.com/github/hsyvy/ion-switching/blob/master/ion_switching_wavenet2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab environment development



In [0]:
# mount google drive
from google.colab import drive
drive.mount('/content/gdrive')

#imports

In [1]:
# imports
import os
import gc
import time

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('ggplot')

import warnings
warnings.simplefilter("ignore")
from typing import (List, NoReturn, Union, Tuple, Optional, 
                    Text, Generic, Callable, Dict)


from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from sklearn.model_selection import KFold, GroupKFold, GroupShuffleSplit, train_test_split

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
from tensorflow.keras import models, losses
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import Sequence, to_categorical, get_custom_objects
from tensorflow.keras.losses import binary_crossentropy, categorical_crossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback, ReduceLROnPlateau, LearningRateScheduler
from tensorflow.keras.layers import  Conv1D, Activation, Input, Dense, Add, Multiply



from logging import getLogger, Formatter, StreamHandler, FileHandler, INFO
from contextlib import contextmanager

from tqdm import tqdm_notebook as tqdm

  import pandas.util.testing as tm


# Data

In [2]:
%%time
# clean data without feature engineering
train = pd.read_csv("gdrive/My Drive/data/ion-switching/train_clean.csv")
# test = pd.read_csv("gdrive/My Drive/data/ion-switching/test_clean.csv")

CPU times: user 885 ms, sys: 156 ms, total: 1.04 s
Wall time: 1.1 s


In [3]:
train.head()

Unnamed: 0,time,signal,open_channels
0,0.0001,-2.76,0
1,0.0002,-2.8557,0
2,0.0003,-2.4074,0
3,0.0004,-3.1404,0
4,0.0005,-3.1525,0


# Helper functions

## log manager

In [0]:
def init_logger():
    handler = StreamHandler()
    handler.setLevel(INFO)
    handler.setFormatter(Formatter(LOGFORMAT))
    fh_handler = FileHandler('{}.log'.format(MODELNAME))
    fh_handler.setFormatter(Formatter(LOGFORMAT))
    logger.setLevel(INFO)
    logger.addHandler(handler)
    logger.addHandler(fh_handler)

In [0]:
@contextmanager
def timer(name : Text):
    t0 = time.time()
    yield
    logger.info(f'[{name}] done in {time.time() - t0:.0f} s')

COMPETITION = 'ION-Switching'
logger = getLogger(COMPETITION)
LOGFORMAT = '%(asctime)s %(levelname)s %(message)s'
MODELNAME = 'WaveNet'

## Seed everything

In [0]:
def seed_everything(seed : int) -> NoReturn :
        
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    tf.random.set_seed(seed)

seed_everything(1)

## Data helper

In [0]:
def read_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    
    train = pd.read_csv('gdrive/My Drive/data/ion-switching/train_clean.csv', dtype={'time': np.float32, 'signal': np.float32, 'open_channels':np.int32})
    test  = pd.read_csv('gdrive/My Drive/data/ion-switching/test_clean.csv', dtype={'time': np.float32, 'signal': np.float32})
    
    return train, test

In [0]:
def batching(df : pd.DataFrame,
             batch_size : int) -> pd.DataFrame :
    
    df['group'] = df.groupby(df.index//batch_size, sort=False)['signal'].agg(['ngroup']).values
    df['group'] = df['group'].astype(np.uint16)
        
    return df

In [0]:
def lag_with_pct_change(df : pd.DataFrame,
                        shift_sizes : Optional[List]=[1, 2],
                        add_pct_change : Optional[bool]=False,
                        add_pct_change_lag : Optional[bool]=False) -> pd.DataFrame:
    
    for shift_size in shift_sizes:    
        df['signal_shift_pos_'+str(shift_size)] = df.groupby('group')['signal'].shift(shift_size).fillna(method='bfill')
        df['signal_shift_neg_'+str(shift_size)] = df.groupby('group')['signal'].shift(-1*shift_size).fillna(method='ffill')

    if add_pct_change:
        df['pct_change'] = df['signal'].pct_change()
        if add_pct_change_lag:
            for shift_size in shift_sizes:    
                df['pct_change_shift_pos_'+str(shift_size)] = df.groupby('group')['pct_change'].shift(shift_size).fillna(method='bfill')
                df['pct_change_shift_neg_'+str(shift_size)] = df.groupby('group')['pct_change'].shift(-1*shift_size).fillna(method='ffill')
    return df

In [0]:
def run_feat_enginnering(df : pd.DataFrame,
                         create_all_data_feats : bool,
                         batch_size : int) -> pd.DataFrame:
    
    df = batching(df, batch_size=batch_size)
    if create_all_data_feats:
        df = lag_with_pct_change(df, [1, 2, 3],  add_pct_change=False, add_pct_change_lag=False)
    df['signal_2'] = df['signal'] ** 2
    return df

In [0]:
def feature_selection(df : pd.DataFrame) -> Tuple[pd.DataFrame, List]:
    use_cols = [col for col in df.columns if col not in ['index','group', 'open_channels', 'time']]
    df = df.replace([np.inf, -np.inf], np.nan)
    for col in use_cols:
        col_mean = df[col].mean()
        df[col] = df[col].fillna(col_mean)

    gc.collect()
    return df, use_cols

In [0]:
def augment(X: np.array, y:np.array) -> Tuple[np.array, np.array]:
    
    X = np.vstack((X, np.flip(X, axis=1)))
    y = np.vstack((y, np.flip(y, axis=1)))
    
    return X, y

In [0]:
def normalize(train, test):
    
    train_input_mean = train.signal.mean()
    train_input_sigma = train.signal.std()
    train['signal'] = (train.signal-train_input_mean)/train_input_sigma
    test['signal'] = (test.signal-train_input_mean)/train_input_sigma

    return train, test

## Lr scheduler

In [0]:
def lr_schedule(epoch):
    if epoch < 10:
        lr = LR
    elif epoch < 15:
        lr = LR / 3
    elif epoch < 20:
        lr = LR / 6
    elif epoch < 75:
        lr = LR / 9
    elif epoch < 85:
        lr = LR / 12
    elif epoch < 100:
        lr = LR / 15
    else:
        lr = LR / 50
    return lr


## Mish activation function

In [0]:
class Mish(tf.keras.layers.Layer):

    def __init__(self, **kwargs):
        super(Mish, self).__init__(**kwargs)
        self.supports_masking = True

    def call(self, inputs):
        return inputs * K.tanh(K.softplus(inputs))

    def get_config(self):
        base_config = super(Mish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
def mish(x):
	return tf.keras.layers.Lambda(lambda x: x*K.tanh(K.softplus(x)))(x)
 

get_custom_objects().update({'mish': Activation(mish)})

## Focal loss and F1 matric

In [0]:
def categorical_focal_loss(gamma=2.0, alpha=0.25):
    """
    Implementation of Focal Loss from the paper in multiclass classification
    Formula:
        loss = -alpha*((1-p)^gamma)*log(p)
    Parameters:
        alpha -- the same as wighting factor in balanced cross entropy
        gamma -- focusing parameter for modulating factor (1-p)
    Default value:
        gamma -- 2.0 as mentioned in the paper
        alpha -- 0.25 as mentioned in the paper
    """
    def focal_loss(y_true, y_pred):
        epsilon = K.epsilon()
        
        # Clip the prediction value
        y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
        # Calculate cross entropy
        cross_entropy = -y_true*K.log(y_pred)
        # Calculate weight that consists of  modulating factor and weighting factor
        weight = alpha * y_true * K.pow((1-y_pred), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)
        return loss
    
    return focal_loss

In [0]:
class MacroF1(Callback):
    def __init__(self, model, inputs, targets):
        self.model = model
        self.inputs = inputs
        self.targets = np.argmax(targets, axis=2).reshape(-1)

    def on_epoch_end(self, epoch, logs):
        pred = np.argmax(self.model.predict(self.inputs), axis=2).reshape(-1)
        score = f1_score(self.targets, pred, average="macro")
        print(f' F1Macro: {score:.5f}')    


# parameters

In [0]:
EPOCHS=9
NNBATCHSIZE=20
BATCHSIZE = 10000
SEED = 23
SELECT = True
LR = 0.001
fe_config = [
    (True, 10000),
]

# wavenet

In [0]:
def WaveNetResidualConv1D(num_filters, kernel_size, stacked_layer):
    def build_residual_block(l_input):
        resid_input = l_input
        for dilation_rate in [2**i for i in range(stacked_layer)]:
            l_sigmoid_conv1d = Conv1D(num_filters, kernel_size, 
                                      dilation_rate=dilation_rate,
                                      padding='same', 
                                      activation='sigmoid')(l_input)
            l_tanh_conv1d = Conv1D(num_filters, kernel_size, 
                                   dilation_rate=dilation_rate,
                                   padding='same', 
                                   activation='mish')(l_input)
            l_input = Multiply()([l_sigmoid_conv1d, l_tanh_conv1d])
            l_input = Conv1D(num_filters, 1, padding='same')(l_input)
            resid_input = Add()([resid_input ,l_input])
        return resid_input
    return build_residual_block

In [0]:
def Classifier(shape_):
    num_filters_ = 16
    kernel_size_ = 3    
    stacked_layers_ = [12, 8, 4, 1]
    inp = Input(shape=(shape_))
    x = Conv1D(num_filters_, 1, padding='same')(inp)
    x = WaveNetResidualConv1D(num_filters_, kernel_size_, stacked_layers_[0])(x)
    x = Conv1D(num_filters_*2, 1, padding='same')(x)
    x = WaveNetResidualConv1D(num_filters_*2, kernel_size_, stacked_layers_[1])(x)
    x = Conv1D(num_filters_*4, 1, padding='same')(x)
    x = WaveNetResidualConv1D(num_filters_*4, kernel_size_, stacked_layers_[2])(x)
    x = Conv1D(num_filters_*8, 1, padding='same')(x)
    x = WaveNetResidualConv1D(num_filters_*8, kernel_size_, stacked_layers_[3])(x)

    out = Dense(11, activation='softmax')(x)
    model = models.Model(inputs=[inp], outputs=[out])
    return model

# opt = Adam(lr=LR)
# opt = tfa.optimizers.SWA(opt)
# model = Classifier()
# model.compile(loss=losses.CategoricalCrossentropy(), optimizer=opt, metrics=["accuracy"])


## train

In [0]:
def model_train(train : pd.DataFrame,
                batch_col : Text,
                feats : List,
                nn_epochs : int,
                nn_batch_size : int) -> NoReturn:
    
    seed_everything(SEED)

    # tensorflow session 
    K.clear_session()
    config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
    sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=config)
    tf.compat.v1.keras.backend.set_session(sess)

    

    # train and val data split
    group = train['group']
    train_id = list()
    val_id = list()
    for i in range(0, 5000000, 500000):
        train_id.extend(list(range(i, i+450000)))
        val_id.extend(list(range(i+450000, i+500000)))
    train_id = np.asarray(train_id)
    val_id = np.asarray(val_id)
    train_val_ids = []
    train_val_ids.append(np.unique(group[train_id]))
    train_val_ids.append(np.unique(group[val_id]))
 

    tr = pd.concat([pd.get_dummies(train.open_channels), train[['group']]], axis=1)

    tr.columns = ['target_'+str(i) for i in range(11)] + ['group']
    target_cols = ['target_'+str(i) for i in range(11)]
    train_tr = np.array(list(tr.groupby('group').apply(lambda x: x[target_cols].values))).astype(np.float32)
    train = np.array(list(train.groupby('group').apply(lambda x: x[feats].values)))


    train_x, train_y = train[train_val_ids[0]], train_tr[train_val_ids[0]]
    valid_x, valid_y = train[train_val_ids[1]], train_tr[train_val_ids[1]]
        
    train_x, train_y = augment(train_x, train_y)

    gc.collect()

    shape_ = (None, train_x.shape[2])
    model = Classifier(shape_)
    opt = Adam(lr=LR)
    opt = tfa.optimizers.SWA(opt)
    model.compile(loss=categorical_focal_loss(), optimizer=opt, metrics=["accuracy"])

    cb_lr_schedule = LearningRateScheduler(lr_schedule)
    cb_prg = tfa.callbacks.TQDMProgressBar(leave_epoch_progress=False,
                                           leave_overall_progress=False, 
                                           show_epoch_progress=True,
                                           show_overall_progress=True)       
    model.fit(train_x,train_y,
              epochs=nn_epochs,
            #   callbacks=[cb_lr_schedule],
              batch_size=nn_batch_size, verbose=1,
              validation_data=(valid_x,valid_y))
    
    # predictions and performance matrics
    y_preds = model.predict(valid_x)
    y_true_ = np.argmax(valid_y, axis=2).reshape(-1)
    y_preds_ = np.argmax(y_preds, axis=2).reshape(-1)
    accuracy_score_ = accuracy_score(y_true_, y_preds_)
    precision_score_ = precision_score(y_true_, y_preds_, average = 'macro')
    recall_score_ = recall_score(y_true_, y_preds_, average = 'macro')
    f1_score_ = f1_score(y_true_, y_preds_, average = 'macro')
    cm = confusion_matrix(y_true_, y_preds_, labels=np.unique(y_true_))

    
    logger.info(f'Training completed')
    logger.info(f'accuracy score: {accuracy_score_:1.5f}')
    logger.info(f'precision score: {precision_score_:1.5f}')
    logger.info(f'recall score: {recall_score_:1.5f}')
    logger.info(f'macro f1 score: {f1_score_:1.5f}')
    logger.info(f'confusion matrix\n{cm}')

    return 


In [0]:
def run_all():
    not_feats_cols = ['time']
    target_col = ['open_channels']
    init_logger()
    with timer(f'Reading Data'):
        logger.info('Reading Data Started ...')
        train, test = read_data()
        train, test = normalize(train, test)    
        logger.info('Reading and Normalizing Data Completed ...')
    with timer(f'Creating Features'):
        logger.info('Feature Enginnering Started ...')
        for config in fe_config:
            train = run_feat_enginnering(train, create_all_data_feats=config[0], batch_size=config[1])
        train, feats = feature_selection(train)
        logger.info('Feature Enginnering Completed ...')

    with timer(f'Running Wavenet model'):
        model_train(train, batch_col='group', feats=feats,  nn_epochs=EPOCHS, nn_batch_size=NNBATCHSIZE)
        logger.info(f'Training completed ...')

In [23]:
run_all()

2020-06-03 17:29:09,332 INFO Reading Data Started ...
2020-06-03 17:29:10,775 INFO Reading and Normalizing Data Completed ...
2020-06-03 17:29:10,776 INFO [Reading Data] done in 1 s
2020-06-03 17:29:10,778 INFO Feature Enginnering Started ...
2020-06-03 17:29:13,690 INFO Feature Enginnering Completed ...
2020-06-03 17:29:13,691 INFO [Creating Features] done in 3 s


Epoch 1/9
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Epoch 2/9
Epoch 3/9
Epoch 4/9
Epoch 5/9
Epoch 6/9
Epoch 7/9
Epoch 8/9
Epoch 9/9


2020-06-03 17:31:58,629 INFO Training completed
2020-06-03 17:31:58,632 INFO accuracy score: 0.96763
2020-06-03 17:31:58,633 INFO precision score: 0.93484
2020-06-03 17:31:58,635 INFO recall score: 0.93676
2020-06-03 17:31:58,636 INFO macro f1 score: 0.93578
2020-06-03 17:31:58,638 INFO confusion matrix
[[123785    226      0      0      0      0      0      0      0      0
       0]
 [   315  98565    323      0      0      0      0      0      0      0
       0]
 [     0    480  54906    561      0      0      1      0      0      0
       0]
 [     0      3    864  65348    821      0      0      0      0      0
       0]
 [     0      0      7    488  38939    799      0      0      0      0
       0]
 [     0      0      1     28    622  25172    862      0      0      0
       0]
 [     0      0      0      0      2    755  16581   1574      0      0
       0]
 [     0      0      1      0      0      1   1231  23622   1590      0
       0]
 [     0      0      0      0      0   

# logs



2020-06-01 12:33:42,685 INFO Training fold 1 completed. macro f1 score : 0.93829

2020-06-01 13:19:43,528 INFO Training fold 2 completed. macro f1 score : 0.93830

2020-06-01 13:44:03,020 INFO Training fold 3 completed. macro f1 score : 0.93727

2020-06-01 14:08:18,136 INFO Training fold 4 completed. macro f1 score : 0.93784

2020-06-01 14:32:34,309 INFO Training fold 5 completed. macro f1 score : 0.93721

2020-06-01 14:32:38,380 INFO Training completed. oof macro f1 score : 0.93779



45/45 [==============================] - 16s 360ms/step - loss: 121.2756 - accuracy: 0.6884 - val_loss: 54.4157 - val_accuracy: 0.8288
Epoch 2/9
45/45 [==============================] - 13s 293ms/step - loss: 44.4338 - accuracy: 0.8666 - val_loss: 23.6422 - val_accuracy: 0.9151
Epoch 3/9
45/45 [==============================] - 13s 293ms/step - loss: 18.9142 - accuracy: 0.9330 - val_loss: 11.2660 - val_accuracy: 0.9481
Epoch 4/9
45/45 [==============================] - 13s 292ms/step - loss: 10.1861 - accuracy: 0.9550 - val_loss: 6.9352 - val_accuracy: 0.9636
Epoch 5/9
45/45 [==============================] - 13s 292ms/step - loss: 7.9574 - accuracy: 0.9612 - val_loss: 6.8401 - val_accuracy: 0.9637
Epoch 6/9
45/45 [==============================] - 13s 293ms/step - loss: 7.2419 - accuracy: 0.9629 - val_loss: 6.4256 - val_accuracy: 0.9643
Epoch 7/9
45/45 [==============================] - 13s 292ms/step - loss: 7.3751 - accuracy: 0.9621 - val_loss: 6.0553 - val_accuracy: 0.9661
Epoch 8/9
45/45 [==============================] - 13s 292ms/step - loss: 6.7684 - accuracy: 0.9643 - val_loss: 5.8179 - val_accuracy: 0.9672
Epoch 9/9
45/45 [==============================] - 13s 294ms/step - loss: 6.5601 - accuracy: 0.9646 - val_loss: 5.6661 - val_accuracy: 0.9676


2020-06-03 17:31:58,629 INFO Training completed

2020-06-03 17:31:58,632 INFO accuracy score: 0.96763

2020-06-03 17:31:58,633 INFO precision score: 0.93484

2020-06-03 17:31:58,635 INFO recall score: 0.93676

2020-06-03 17:31:58,636 INFO macro f1 score: 0.93578

2020-06-03 17:31:58,638 INFO confusion matrix

[[123785    226      0      0      0      0      0      0      0      0  0]

 [   315  98565    323      0      0      0      0      0      0      0  0]

 [     0    480  54906    561      0      0      1      0      0      0  0]

 [     0      3    864  65348    821      0      0      0      0      0  0]

 [     0      0      7    488  38939    799      0      0      0      0  0]

 [     0      0      1     28    622  25172    862      0      0      0  0]

 [     0      0      0      0      2    755  16581   1574      0      0  0]

 [     0      0      1      0      0      1   1231  23622   1590      0  0]

 [     0      0      0      0      0      0      0   1636  21700   1238  0]

 [     0      0      0      2      0      0      0      0    973  12126 426]

 [     0      0      0      0      0      0      0      0      0    357 3069]]

2020-06-03 17:31:58,643 INFO Training completed ...
2020-06-03 17:31:58,644 INFO [Running Wavenet model] done in 165 s

In [24]:
train, feats = prepare_data(fe_config)

2020-06-03 16:03:41,794 INFO Reading Data Started ...
2020-06-03 16:03:43,981 INFO Reading and Normalizing Data Completed ...
2020-06-03 16:03:43,982 INFO [Reading Data] done in 2 s
2020-06-03 16:03:43,983 INFO Feature Enginnering Started ...
2020-06-03 16:03:46,940 INFO Feature Enginnering Completed ...
2020-06-03 16:03:46,941 INFO [Creating Features] done in 3 s


In [0]:
def prepare_data(fe_config : List) -> Tuple[pd.DataFrame, List]:
    not_feats_cols = ['time']
    target_col = ['open_channels']
    init_logger()
    with timer(f'Reading Data'):
        logger.info('Reading Data Started ...')
        train, test = read_data()
        train, test = normalize(train, test)    
        logger.info('Reading and Normalizing Data Completed ...')
    with timer(f'Creating Features'):
        logger.info('Feature Enginnering Started ...')
        for config in fe_config:
            train = run_feat_enginnering(train, create_all_data_feats=config[0], batch_size=config[1])
        train, feats = feature_selection(train)
        logger.info('Feature Enginnering Completed ...')
    return train, feats



In [0]:
# train and val data split
group = train['group']
train_id = list()
val_id = list()
for i in range(0, 5000000, 500000):
    train_id.extend(list(range(i, i+450000)))
    val_id.extend(list(range(i+450000, i+500000)))
train_id = np.asarray(train_id)
val_id = np.asarray(val_id)
train_val_ids = []
train_val_ids.append(np.unique(group[train_id]))
train_val_ids.append(np.unique(group[val_id]))


In [0]:
def run_cv_model_by_batch(train : pd.DataFrame,
                          test : pd.DataFrame,
                          splits : int,
                          batch_col : Text,
                          feats : List,
                          nn_epochs : int,
                          nn_batch_size : int) -> NoReturn:
    
    seed_everything(SEED)
    K.clear_session()
    config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1,inter_op_parallelism_threads=1)
    sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=config)
    tf.compat.v1.keras.backend.set_session(sess)
    oof_ = np.zeros((len(train), 11))
    preds_ = np.zeros((len(test), 11))
    target = ['open_channels']
    group = train['group']
    kf = GroupKFold(n_splits=5)
    splits = [x for x in kf.split(train, train[target], group)]

    new_splits = []
    for sp in splits:
        new_split = []
        new_split.append(np.unique(group[sp[0]]))
        new_split.append(np.unique(group[sp[1]]))
        new_split.append(sp[1])    
        new_splits.append(new_split)
        
    tr = pd.concat([pd.get_dummies(train.open_channels), train[['group']]], axis=1)

    tr.columns = ['target_'+str(i) for i in range(11)] + ['group']
    target_cols = ['target_'+str(i) for i in range(11)]
    train_tr = np.array(list(tr.groupby('group').apply(lambda x: x[target_cols].values))).astype(np.float32)
    train = np.array(list(train.groupby('group').apply(lambda x: x[feats].values)))
    test = np.array(list(test.groupby('group').apply(lambda x: x[feats].values)))

    for n_fold, (tr_idx, val_idx, val_orig_idx) in enumerate(new_splits[0:], start=0):
        train_x, train_y = train[tr_idx], train_tr[tr_idx]
        valid_x, valid_y = train[val_idx], train_tr[val_idx]
        
        if n_fold < 2:
            train_x, train_y = augment(train_x, train_y)

        gc.collect()
        shape_ = (None, train_x.shape[2])
        # print("before model classifier")
        model = Classifier(shape_)
        # print("after model classifier")
        cb_lr_schedule = LearningRateScheduler(lr_schedule)
        cb_prg = tfa.callbacks.TQDMProgressBar(leave_epoch_progress=False,leave_overall_progress=False, show_epoch_progress=False,show_overall_progress=True)
        print("before model fit")
       
        model.fit(train_x,train_y,
                  epochs=nn_epochs,
                  callbacks=[cb_prg, cb_lr_schedule],
                  batch_size=nn_batch_size,verbose=0,
                  validation_data=(valid_x,valid_y))
        print("after model fit")
        preds_f = model.predict(valid_x)
        f1_score_ = f1_score(np.argmax(valid_y, axis=2).reshape(-1),  np.argmax(preds_f, axis=2).reshape(-1), average = 'macro')
        logger.info(f'Training fold {n_fold + 1} completed. macro f1 score : {f1_score_ :1.5f}')
        preds_f = preds_f.reshape(-1, preds_f.shape[-1])
        oof_[val_orig_idx,:] += preds_f
        te_preds = model.predict(test)
        te_preds = te_preds.reshape(-1, te_preds.shape[-1])           
        preds_ += te_preds / SPLITS
    f1_score_ =f1_score(np.argmax(train_tr, axis=2).reshape(-1),  np.argmax(oof_, axis=1), average = 'macro')
    logger.info(f'Training completed. oof macro f1 score : {f1_score_:1.5f}')
    # sample_submission['open_channels'] = np.argmax(preds_, axis=1).astype(int)
    # sample_submission.to_csv('submission.csv', index=False, float_format='%.4f')
    # display(sample_submission.head())
    np.save('oof.npy', oof_)
    np.save('preds.npy', preds_)

    return 
