In [None]:
!pip install -q loguru
!pip install -q mlcrate
!pip install -q omegaconf
!pip install -q segmentation_models
!pip install -q iterative-stratification

In [None]:
import os, glob
import sys
import random
import math
import numpy as np
import pandas as pd
import tqdm
from collections import OrderedDict

from omegaconf import OmegaConf
from loguru import logger
import mlcrate as mlc

from sklearn.model_selection import StratifiedKFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

---

In [None]:
config = {
    'expr_name': 'expr007',
    'save_root': './',
    
    'data_root': '../input/imet-2020-fgvc7/',
    
    'device': {
        'name': 'gpu',
        'id': '0'
    },
    
    'num_epochs': 28,
    'batch_size': 32,
    'run_fold1_only': True,
    
    'scheduler': {
        'name': 'CosineScheduler',
        'params': {
            'init_lr': 1e-3,
            'min_lr': 1e-5,
            'total_epochs': 20,
        },
    },
    
    #
    
    'transform': {
        'val_size': (320, 320),
        'train_size': (320, 320),
        'aug_name': 'rand_augment',
        'aug_params': {
            'num_augments': 2,
            'magnitude': 5,
        },
    },
    
    'model': {
        'base_func': "cs.Classifiers.get('resnet34')[0]",
        'base_weights': 'imagenet',
        'dropout_rate': 0.1,
    },
    
    'stages': {
        'epochs': [0, 6, 14, 20],
        'transforms': [
            {'train_size': (224, 224), 'aug_params': {'num_augments':2, 'magnitude': 2}},
            {'train_size': (256, 256), 'aug_params': {'num_augments':2, 'magnitude': 3}},
            {'train_size': (288, 288), 'aug_params': {'num_augments':2, 'magnitude': 4}},
            {'train_size': (320, 320), 'aug_params': {'num_augments':2, 'magnitude': 5}},
        ],
        'models': [
            {'dropout_rate': 0.0},
            {'dropout_rate': 0.0},
            {'dropout_rate': 0.1},
            {'dropout_rate': 0.2},
        ]
    },
    
    #
    
    'loss': {
        'name': 'tf.keras.losses.BinaryCrossentropy',
        'params': {'reduction': 'none'},
        'weight_decay': 1e-5,
    },
    
    'optimizer': {
        'name': 'tf.keras.optimizers.Adam',
        'params': {}
    },
    
    'seed': 8888,
    'num_classes': 3474,
    'num_folds': 5,
}

config = OmegaConf.create(config)
OmegaConf.update(config, 'save_path', os.path.join(config.save_root, config.expr_name), merge=True)

In [None]:
if config.device.name == 'gpu':
    os.environ['CUDA_VISIBLE_DEVICES'] = config.device.id
elif config.device.name == 'cpu':
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import classification_models.tfkeras as cs
import efficientnet.tfkeras as eff


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_everything(config.seed)


if config.device.name == 'tpu':
    from kaggle_datasets import KaggleDatasets
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

OmegaConf.update(config, 'batch_size', config.batch_size * strategy.num_replicas_in_sync, merge=True)
OmegaConf.update(config, 'scheduler.params.init_lr', config.scheduler.params.init_lr * strategy.num_replicas_in_sync, merge=True)
OmegaConf.update(config, 'scheduler.params.min_lr', config.scheduler.params.min_lr * strategy.num_replicas_in_sync, merge=True)

In [None]:
'''
Train/Val List - CUSTOM
'''

df = pd.read_csv(os.path.join(config.data_root, 'train.csv'))
fpath_list = np.array(os.path.join(config.data_root, 'train-320/') + df.id + '.jpg')
label_list = np.load(os.path.join(config.data_root, 'labels.npy'))
fold_assignments = np.load(os.path.join(config.data_root, f'fold-assignment_K-{config.num_folds}_seed-{config.seed}.npy'))


def get_train_val_list(fold_idx):
    train_fpaths = fpath_list[fold_assignments!=fold_idx]
    val_fpaths = fpath_list[fold_assignments==fold_idx]

    train_labels = label_list[fold_assignments!=fold_idx]
    val_labels = label_list[fold_assignments==fold_idx]
    
    return (train_fpaths, train_labels), (val_fpaths, val_labels)

In [None]:
'''
Transform
'''
def load_image(fine_size, aug_func):
    @tf.function
    def load_image_(fpath, label):
        img = tf.image.decode_jpeg(tf.io.read_file(fpath))
        img = tf.cast(tf.image.resize(img, fine_size), tf.uint8)
        img = tf.ensure_shape(img, fine_size + (3,))
        
        if aug_func is not None:
            img = aug_func(img)
            
        img = tf.cast(img, tf.float32) / 255.
        return img, label
    return load_image_

def augment():
    @tf.function
    def augment_(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_contrast(img, 0.6, 1.2)
        img = tf.image.random_saturation(img, 0.6, 1.2)
        img = tf.image.random_brightness(img, 0.2)
        return img
    return augment_

def rand_augment(num_augments, magnitude):
    ra = RandAugment()
    @tf.function
    def rand_augment_(img):
        ra.apply(img, num_augments=num_augments, magnitude=magnitude)
        return img
    return rand_augment_

'''
Rand Augment
'''
class RandAugment:
    def __init__(self, max_level=10, cutout_const=80, translate_const=100, replace_value=0):
        self.max_level = max_level
        self.cutout_const = cutout_const
        self.translate_const = translate_const
        self.replace = replace_value
        self.operations = [
            self.identity,
            self.autocontrast,
            self.equalize,
            self.invert,
            self.rotate,
            self.posterize,
            self.solarize,
            self.solarize_add,
            self.color,
            self.contrast,
            self.brightness,
            self.sharpness,
            self.shear_x,
            self.shear_y,
            self.translate_x,
            self.translate_y,
            self.cutout,
        ]
        
    @tf.function
    def apply(self, uint8_image, num_augments=2, magnitude=5, constant=True):
        for _ in range(num_augments):
            op_to_select = tf.random.uniform((), 0, len(self.operations), dtype=tf.int32)
            if constant:
                level = tf.cast(magnitude, tf.float32)
            else:
                level = tf.random.uniform((), 0., tf.cast(magnitude, tf.float32), dtype=tf.float32)
            for i, operation in enumerate(self.operations):
                uint8_image = tf.cond(
                    tf.equal(i, op_to_select),
                    lambda: operation(uint8_image, level),
                    lambda: uint8_image
                )
        return uint8_image
    
    @tf.function
    def _rotate_level_to_arg(self, rel_level):
        level = rel_level * 30.
        level = level if tf.random.normal(()) > 0. else -level
        return level
    
    @tf.function
    def _enhance_level_to_arg(self, rel_level):
        return rel_level * 1.8 + 0.1
    
    @tf.function
    def _shear_level_to_arg(self, rel_level):
        level = rel_level * 0.3
        level = level if tf.random.normal(()) > 0. else -level
        return level
    
    @tf.function
    def _translate_level_to_arg(self, rel_level, translate_const):
        level = rel_level * tf.cast(translate_const, tf.float32)
        level = level if tf.random.normal(()) > 0. else -level
        return level
    
    # -------------------------------------------------------------------------------------
    
    @tf.function
    def identity(self, image, level):
        return image
    
    @tf.function
    def autocontrast(self, image, level):
        @tf.function
        def scale_channel(image):
            """Scale the 2D image using the autocontrast rule."""
            # A possibly cheaper version can be done using cumsum/unique_with_counts
            # over the histogram values, rather than iterating over the entire image.
            # to compute mins and maxes.
            lo = tf.cast(tf.reduce_min(image), tf.float32)
            hi = tf.cast(tf.reduce_max(image), tf.float32)
            # Scale the image, making the lowest value 0 and the highest value 255.
            def scale_values(im):
                scale = 255.0 / (hi - lo)
                offset = -lo * scale
                im = tf.cast(im, tf.float32) * scale + offset
                im = tf.clip_by_value(im, 0.0, 255.0)
                return tf.cast(im, tf.uint8)
            result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
            return result
        s1 = scale_channel(image[:, :, 0])
        s2 = scale_channel(image[:, :, 1])
        s3 = scale_channel(image[:, :, 2])
        image = tf.stack([s1, s2, s3], axis=-1)
        return image
    
    @tf.function
    def equalize(self, image, level):
        return tfa.image.equalize(image)
    
    @tf.function
    def invert(self, image, level):
        return tf.bitwise.invert(image)
    
    @tf.function
    def rotate(self, image, level):
        degree = self._rotate_level_to_arg(level/self.max_level)
        radian = degree * math.pi / 180.0
        return tfa.image.rotate(image, radian)
    
    @tf.function
    def posterize(self, image, level):
        bits = tf.cast((level/self.max_level) * 4, tf.uint8)
        shift = 8 - bits
        return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
    
    @tf.function
    def solarize(self, image, level):
        threshold = tf.cast((level/self.max_level) * 256, tf.uint8)
        return tf.where(image < threshold, image, 255 - image)
    
    @tf.function
    def solarize_add(self, image, level):
        addition = tf.cast((level/self.max_level) * 110, tf.int64)
        added_image = tf.cast(image, tf.int64) + addition
        added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
        return tf.where(image < 128, added_image, image)
    
    @tf.function
    def color(self, image, level):
        factor = self._enhance_level_to_arg(level/self.max_level)
        degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
        return tf.cast(tfa.image.blend(tf.cast(degenerate, tf.float32), tf.cast(image, tf.float32), factor), tf.uint8)
    
    @tf.function
    def contrast(self, image, level):
        factor = self._enhance_level_to_arg(level/self.max_level)
        """Equivalent of PIL Contrast."""
        degenerate = tf.image.rgb_to_grayscale(image)
        # Cast before calling tf.histogram.
        degenerate = tf.cast(degenerate, tf.int32)
        # Compute the grayscale histogram, then compute the mean pixel value,
        # and create a constant image size of that value.  Use that as the
        # blending degenerate target of the original image.
        hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
        mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
        degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
        degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
        degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
        return tf.cast(tfa.image.blend(tf.cast(degenerate, tf.float32), tf.cast(image, tf.float32), factor), tf.uint8)
    
    @tf.function
    def brightness(self, image, level):
        factor = self._enhance_level_to_arg(level/self.max_level)
        degenerate = tf.zeros_like(image)
        return tf.cast(tfa.image.blend(tf.cast(degenerate, tf.float32), tf.cast(image, tf.float32), factor), tf.uint8)
    
    @tf.function
    def sharpness(self, image, level):
        factor = self._enhance_level_to_arg(level/self.max_level)
        """Implements Sharpness function from PIL using TF ops."""
        orig_image = image
        image = tf.cast(image, tf.float32)
        # Make image 4D for conv operation.
        image = tf.expand_dims(image, 0)
        # SMOOTH PIL Kernel.
        kernel = tf.constant(
          [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
          shape=[3, 3, 1, 1]
        ) / 13.
        # Tile across channel dimension.
        kernel = tf.tile(kernel, [1, 1, 3, 1])
        strides = [1, 1, 1, 1]
        with tf.device('/cpu:0'):
            # Some augmentation that uses depth-wise conv will cause crashing when
            # training on GPU. See (b/156242594) for details.
            degenerate = tf.nn.depthwise_conv2d(
                image, kernel, strides, padding='VALID', dilations=[1, 1]
            )
        degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
        degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
        # For the borders of the resulting image, fill in the values of the
        # original image.
        mask = tf.ones_like(degenerate)
        padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
        padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
        result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
        # Blend the final result.
        return tf.cast(tfa.image.blend(tf.cast(result, tf.float32), tf.cast(orig_image, tf.float32), factor), tf.uint8)
    
    @tf.function
    def shear_x(self, image, level):
        level = self._shear_level_to_arg(level/self.max_level)
        return tfa.image.shear_x(image, level, self.replace)
    
    @tf.function
    def shear_y(self, image, level):
        level = self._shear_level_to_arg(level/self.max_level)
        return tfa.image.shear_y(image, level, self.replace)
    
    @tf.function
    def cutout(self, image, level):
        mask_size = int((level/self.max_level) * self.cutout_const)
        return tfa.image.random_cutout(tf.expand_dims(image, 0), mask_size, self.replace)[0]
    
    @tf.function
    def translate_x(self, image, level):
        pixels = self._translate_level_to_arg(level/self.max_level, self.translate_const)
        return tfa.image.translate_xy(image, [-pixels, 0], self.replace)
    
    @tf.function
    def translate_y(self, image, level):
        pixels = self._translate_level_to_arg(level/self.max_level, self.translate_const)
        return tfa.image.translate_xy(image, [0, -pixels], self.replace)

In [None]:
'''
DataLoader
'''
def get_ds(train_data, val_data, batch_size, train_size, val_size, aug_func):
    train_ds = tf.data.Dataset.from_tensor_slices(
        train_data
    ).shuffle(
        len(train_data[0])
    ).map(
        load_image(train_size, aug_func), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False
    ).batch(
        batch_size, drop_remainder=True
    ).repeat(-1).prefetch(1)

    val_ds = tf.data.Dataset.from_tensor_slices(
        val_data
    ).map(
        load_image(val_size, None), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True
    ).batch(
        batch_size, drop_remainder=False
    ).prefetch(1)
    
    return train_ds, val_ds

In [None]:
'''
Model - CUSTOM
'''
def build_model(input_shape, num_classes, dropout_rate, weight_decay, base_func, base_weights):
    import os, tempfile
    def add_regularization(model, weight_reg):
        custom_objects={}
        for layer in model.layers:
            for attr in ['kernel_regularizer']:
                if hasattr(layer, attr):
                    setattr(layer, attr, tf.keras.regularizers.l2(weight_reg))
        model_json = model.to_json()
        tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
        model.save_weights(tmp_weights_path)
        model = tf.keras.models.model_from_json(model_json, custom_objects=custom_objects)
        model.load_weights(tmp_weights_path, by_name=True)
        return model    
    
    base_model = base_func(input_shape=input_shape, include_top=False, weights=base_weights)
    
    ip = tf.keras.layers.Input(input_shape)
    h = base_model(ip)
    h = tf.keras.layers.GlobalAveragePooling2D()(h)
    h = tf.keras.layers.Dropout(dropout_rate)(h)
    h = tf.keras.layers.Dense(num_classes, activation=tf.nn.sigmoid)(h)
    model = tf.keras.models.Model(ip, h)
    
    if weight_decay > 0.:
        model = add_regularization(model, weight_decay)
    return model


def model_regularizer_loss(model):
    loss = 0
    for l in model.layers:
        if hasattr(l,'layers') and l.layers:
            loss += model_regularizer_loss(l)
        if hasattr(l,'kernel_regularizer') and l.kernel_regularizer:
            loss += l.kernel_regularizer(l.kernel)
        if hasattr(l,'bias_regularizer') and l.bias_regularizer:
            loss += l.bias_regularizer(l.bias)
    return loss

In [None]:
'''
Metrics - CUSTOM
'''
class Metrics: # Sample Averaging F2-Score
    def __init__(self, threshold=0.2):
        self.scores = [tf.keras.metrics.Mean()]
        self.threshold = threshold
        
        self.loss = tf.keras.metrics.Mean()
        self.header = ['loss'] + [f'score_{i+1}' for i in range(len(self.scores))]
        self.df = pd.DataFrame(columns=self.header)
    
    def update_state(self, y_trues, y_preds, loss):
        trues = y_trues > self.threshold
        preds = y_preds > self.threshold
        
        tps = tf.reduce_sum(tf.cast(trues & (trues & preds), tf.float32), axis=1)
        sum_trues = tf.reduce_sum(tf.cast(trues, tf.float32), axis=1)
        sum_preds = tf.reduce_sum(tf.cast(preds, tf.float32), axis=1)
        
        recalls = tf.where(sum_trues > 0., tps / sum_trues, 0.)
        precisions = tf.where(sum_preds > 0., tps / sum_preds, 0.)
        
        denom = 4. * precisions + recalls
        denom = tf.where(denom > 0., denom, 1.)
        f2_scores = ((1. + 4.) * precisions * recalls) / denom
        
        self.scores[0].update_state(f2_scores)
        self.loss.update_state(loss)
    
    def get_loss(self):
        return self.loss.result().numpy()
    
    def get_scores(self):
        score = self.scores[0].result().numpy()
        return [score]
    
    def reset_state(self):
        self.scores[0].reset_states()
        self.loss.reset_states()
    
    def on_epoch_end(self, e):
        self.df.loc[e] = [self.get_loss()] + self.get_scores()
    
    def get_latest(self):
        return self.df.index.tolist()[-1], self.df.iloc[-1, :].tolist()

In [None]:
'''
Scheduler
'''

class CosineScheduler:
    def __init__(self, init_lr, min_lr, total_epochs):
        self.curr_lr = init_lr
        self.init_lr = init_lr
        self.min_lr = min_lr
        self.total_epochs = total_epochs
    
    def get_next_lr(self, next_epoch, curr_loss, curr_score):
        self.curr_lr = (self.init_lr - self.min_lr) * 0.5 * (np.cos(np.pi * next_epoch / self.total_epochs) + 1.) + self.min_lr
        return self.curr_lr

    
class ReduceOnPlateauScheduler:
    def __init__(self, init_lr, min_lr, decay, max_patience):
        self.curr_lr = init_lr
        self.min_lr = min_lr
        self.decay = decay
        self.max_patience = max_patience
        self.patience = 0
        self.best_loss = 10000.
        self.best_score = 0.
        
    def get_next_lr(self, next_epoch, curr_loss, curr_score):
        self.patience += 1
        
        if curr_loss < self.best_loss:
            self.best_loss = curr_loss
            self.patience = 0
            
        if curr_score > self.best_score:
            self.best_score = curr_score
            self.patience = 0
        
        if self.patience > self.max_patience:
            self.patience = 0
            self.curr_lr = max(self.curr_lr * self.decay, self.min_lr)
        
        return self.curr_lr

In [None]:
'''
Train/Val Proc on epoch
'''

def get_proc_on_batch():
    @tf.function
    def train_on_batch(inputs, model, criterion, optimizer):
        x, y = inputs
        with tf.GradientTape() as tape:
            p = model(x, training=True)
            loss = criterion(y, p)
            total_loss = tf.reduce_mean(loss) + model_regularizer_loss(model)
        grads = tape.gradient(loss, sources=model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return y, p, loss

    @tf.function
    def val_on_batch(inputs, model, criterion):
        x, y = inputs
        p = model(x, training=False)
        loss = criterion(y, p)
        return y, p, loss
    
    return train_on_batch, val_on_batch


def proc_on_epoch(training, proc_on_batch, dataset, metrics, criterion, model, epoch, total_epoch, iters_per_epoch=None, optimizer=None):
    
    metrics.reset_state()
    if training:
        assert iters_per_epoch is not None
        assert optimizer is not None
        phase = 'train'
    else:
        phase = 'valid'
    
    with tqdm.tqdm(dataset, total=iters_per_epoch, ncols=0, desc=f'{phase} {(1+epoch)}/{total_epoch}') as tq:
        for iter_i, inputs in enumerate(tq):
            with tf.device('/gpu:0'):
                if training:
                    if iter_i > iters_per_epoch:
                        break
                    y, p, loss = proc_on_batch(inputs, model, criterion, optimizer)
                else:
                    y, p, loss = proc_on_batch(inputs, model, criterion)
            
            metrics.update_state(y, p, loss)
            tq.set_postfix(OrderedDict(
                loss = metrics.get_loss(),
                scores = metrics.get_scores()
            ))
    
    metrics.on_epoch_end(epoch+1)
    if training:
        return metrics, model, optimizer
    else:
        return metrics

---

In [None]:
def run_fold(fold_idx, config):
    
    scheduler = eval(config.scheduler.name)(**config.scheduler.params)
    optimizer = eval(config.optimizer.name)(**config.optimizer.params)
    criterion = eval(config.loss.name)(**config.loss.params)

    (train_fpaths, train_labels), (val_fpaths, val_labels) = get_train_val_list(fold_idx)
    train_iters_per_epoch = len(train_fpaths) // config.batch_size

    train_metrics = Metrics()
    val_metrics = Metrics()

    best_score = 0.
    best_loss = 10000.

    for e in range(config.num_epochs):

        if e in config.stages.epochs:
            '''
            stage setting: dataset, inference_method, train/val_model
            '''
            stage_idx = config.stages.epochs.index(e)
            config.transform = OmegaConf.merge(config.transform, config.stages.transforms[stage_idx])
            config.model = OmegaConf.merge(config.model, config.stages.models[stage_idx])
            logger.info(f'New Stage {stage_idx}')

            train_ds, val_ds = get_ds(
                (train_fpaths, train_labels), (val_fpaths, val_labels), config.batch_size, 
                tuple(config.transform.train_size), tuple(config.transform.val_size),
                eval(config.transform.aug_name)(**config.transform.aug_params),
            )
            train_on_batch, val_on_batch = get_proc_on_batch()

            with strategy.scope():
                train_model = build_model(
                    tuple(config.transform.train_size) + (3,), config.num_classes, weight_decay=config.loss.weight_decay,
                    dropout_rate=config.model.dropout_rate, 
                    base_func=eval(config.model.base_func),
                    base_weights=config.model.base_weights
                )
                if stage_idx > 0:
                    train_model.load_weights(os.path.join(config.save_path, f'latest-{fold_idx}.h5'))

                val_model = build_model(
                    tuple(config.transform.val_size) + (3,), config.num_classes, weight_decay=config.loss.weight_decay,
                    dropout_rate=config.model.dropout_rate,
                    base_func=eval(config.model.base_func),
                    base_weights=config.model.base_weights
                )

        if e == 0:
            csv_logger = mlc.LinewiseCSVWriter(
                os.path.join(config.save_path, f'log-{fold_idx}.csv'),
                header=['epoch'] + [f'train_{h}' for h in train_metrics.header] + [f'val_{h}' for h in val_metrics.header]
            )
            timer = mlc.time.Timer()
            logger.info(f'fold-{fold_idx} start')


        '''
        train
        '''
        timer.add('train')
        train_metrics, train_model, optimizer = proc_on_epoch(
            True, train_on_batch, train_ds, train_metrics, criterion, train_model, e, config.num_epochs, 
            iters_per_epoch=train_iters_per_epoch, optimizer=optimizer
        )
        train_elapsed = timer.fsince('train')


        '''
        val
        '''
        val_model.set_weights(train_model.get_weights())
        timer.add('val')
        val_metrics = proc_on_epoch(
            False, val_on_batch, val_ds, val_metrics, criterion, val_model, e, config.num_epochs, 
        )
        val_elapsed = timer.fsince('val')


        '''
        log
        '''
        logger.info(f'epoch: {e+1} train: {train_elapsed} val: {val_elapsed}')
        logger.info(f'train: {train_metrics.get_latest()[1]}')
        logger.info(f'val: {val_metrics.get_latest()[1]}')
        csv_logger.write([e+1] + train_metrics.get_latest()[1] + val_metrics.get_latest()[1])


        '''
        save
        '''
        val_model.save(os.path.join(config.save_path, f'latest-{fold_idx}.h5'))

        val_loss = val_metrics.get_loss()
        if val_loss < best_loss:
            logger.info(f'loss got improved: {best_loss:.4f} to {val_loss:.4f}')
            best_loss = val_loss
            val_model.save(os.path.join(config.save_path, f'best_loss-{fold_idx}.h5'))

        val_score = val_metrics.get_scores()[0]
        if val_score > best_score:
            logger.info(f'score got improved: {best_score:.4f} to {val_score:.4f}')
            best_score = val_score
            val_model.save(os.path.join(config.save_path, f'best_score-{fold_idx}.h5'))
            
            
        '''
        lr update
        '''
        curr_lr = optimizer.learning_rate.numpy()
        next_lr = scheduler.get_next_lr(e+1, val_loss, val_score)
        optimizer.learning_rate = next_lr
        if curr_lr != next_lr:
            logger.info(f'    lr {curr_lr} -> {next_lr}')
    
    
    logger.info(f'fold-{fold_idx} end')

In [None]:
def main(config):
    os.makedirs(config.save_path, exist_ok=True)
    logger.add(os.path.join(config.save_path, 'log.txt'), mode='w')
    OmegaConf.save(config, os.path.join(config.save_path, f'{config.expr_name}.yaml'))

    for fold_idx in range(config.num_folds):
        run_fold(fold_idx, config)

        if config.run_fold1_only:
            break

In [None]:
if __name__ == '__main__':
    main(config)