In [1]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2"
sys.path.append('/root/code')

from definitions import LOG_DIR, WEIGHT_DIR
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
import time
import datetime
import numpy
from utils.contrastiveLoss import ContrastiveLoss
from models.SCNN18_Flatten import SCNN18_Flatten
from models.BYOL import BYOL
from models.SCNN18_random_aug import SCNN18_random_aug
import utils.dataset as dataset
import logging
from logging import handlers

In [2]:
LOG = logging.getLogger('root')

def initLog(debug=False):
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(message)s',
        datefmt='%Y-%m-%d %H:%M',
        handlers=[logging.StreamHandler(), handlers.RotatingFileHandler('BYOL.log', "w", 1024 * 1024 * 100, 3, "utf-8")]
    )
    LOG.setLevel(logging.DEBUG if debug else logging.INFO)
    tf.get_logger().setLevel('ERROR')

In [3]:
initLog()

In [4]:
def get_optimizer(optimizer, lr):
    optimizer = optimizer.lower()
    if optimizer == 'adadelta':
        return tf.optimizers.Adadelta() if lr == 0 else tf.optimizers.Adadelta(learning_rate=lr)
    elif optimizer == 'adagrad':
        return tf.optimizers.Adagrad() if lr == 0 else tf.optimizers.Adagrad(learning_rate=lr)
    elif optimizer == 'adam':
        return tf.optimizers.Adam() if lr == 0 else tf.optimizers.Adam(learning_rate=lr)
    elif optimizer == 'adamax':
        return tf.optimizers.Adamax() if lr == 0 else tf.optimizers.Adamax(learning_rate=lr)
    elif optimizer == 'sgd':
        return tf.optimizers.SGD() if lr == 0 else tf.optimizers.SGD(learning_rate=lr)
    elif optimizer == 'rmsprop':
        return tf.optimizers.RMSprop() if lr == 0 else tf.optimizers.RMSprop(learning_rate=lr)

In [5]:
# Define model architecture & define hyper-parameter
lr = 1e-5
max_epochs = 50
sample_size=[32000, 1]
batch_size = 64
input_shape = tuple(sample_size)
output_shape = 2

In [6]:
origin_path = ['./SCNN18_0.1second/SCNN-BYOL-Jamendo-FMA-train_30s.h5']
random_augment_path = ['./SCNN18_0.1second/SCNN-BYOL-Jamendo-FMA-train_30s-random_augment.h5']

In [7]:
def generate_batched_data(ds_path):
    for index, ds in enumerate(ds_path):
        if index == 0:
            train_ds = dataset.get_dataset_without_label(ds)
            slices_ds = tf.data.Dataset.from_tensor_slices((train_ds))
        else:
            train_ds = dataset.get_dataset_without_label(ds)
            add_ds = tf.data.Dataset.from_tensor_slices((train_ds))
            slices_ds.concatenate(add_ds)
    dataset_length = [i for i, _ in enumerate(slices_ds)][-1] + 1

    return dataset_length, slices_ds.shuffle(dataset_length, seed=8, reshuffle_each_iteration=True).batch(batch_size)

In [8]:
ds_length, origin_ds = generate_batched_data(origin_path)
ds_length, random_augment_ds = generate_batched_data(random_augment_path)

In [9]:
strategy = tf.distribute.MirroredStrategy(devices=[f'/gpu:{i}' for i in range(3)])
with strategy.scope():
    optimizer = get_optimizer('adam', lr)

    # 创建 BYOL 模型
    byol_model = BYOL(input_shape=input_shape)
    byol_model.compile(optimizer=optimizer)

    # 分布式训练循环
    for epoch in range(max_epochs):
        step = 0
        for o_batch, ra_batch in zip(origin_ds, random_augment_ds):
            step += 1
            start_time = time.time()

            inputs_1 = tf.convert_to_tensor(o_batch, dtype=tf.float32)
            inputs_2 = tf.convert_to_tensor(ra_batch, dtype=tf.float32)

            def distributed_train_step(inputs_1, inputs_2):
                with tf.GradientTape() as tape:
                    loss = byol_model(inputs_1, inputs_2)
                grads = tape.gradient(loss, byol_model.trainable_variables)
                byol_model.optimizer.apply_gradients(zip(grads, byol_model.trainable_variables))
                byol_model.update_target_network()
                return loss

            per_replica_losses = strategy.run(distributed_train_step, args=(inputs_1, inputs_2))
            mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)

            end_time = time.time()

            if step % 10 == 0:
                print(f"Epoch {epoch + 1}, Step {step}/{ds_length // batch_size}, Loss: {mean_loss.numpy()}, taken_time: {end_time - start_time:.3f}")

        LOG.info(f"Epoch {epoch + 1}, Loss: {mean_loss.numpy()}")

    # 保存权重
    byol_model.save_weights(os.path.join(WEIGHT_DIR, f"{str(datetime.date.today())}_SCNN_BYOL_Jamendo_FMA_pretrain_30epochs_lr=1e-5.h5"))
    print("Training complete. Weights saved to 'BYOL_SCNN_pretrain.h5'.")

Epoch 1, Step 10/44, Loss: 2.3254237174987793, taken_time: 1.939
Epoch 1, Step 20/44, Loss: 1.5065809488296509, taken_time: 1.911
Epoch 1, Step 30/44, Loss: 1.3159046173095703, taken_time: 1.951
Epoch 1, Step 40/44, Loss: 0.9632480144500732, taken_time: 1.971


2024-07-28 08:58 INFO Epoch 1, Loss: 0.7915165424346924


Epoch 2, Step 10/44, Loss: 0.5248597860336304, taken_time: 1.997
Epoch 2, Step 20/44, Loss: 0.38934096693992615, taken_time: 1.988
Epoch 2, Step 30/44, Loss: 0.25765368342399597, taken_time: 1.977
Epoch 2, Step 40/44, Loss: 0.18287543952465057, taken_time: 2.028


2024-07-28 08:59 INFO Epoch 2, Loss: 0.1660401076078415


Epoch 3, Step 10/44, Loss: 0.13221341371536255, taken_time: 1.997
Epoch 3, Step 20/44, Loss: 0.11530232429504395, taken_time: 2.029
Epoch 3, Step 30/44, Loss: 0.09465350955724716, taken_time: 2.020
Epoch 3, Step 40/44, Loss: 0.08176670223474503, taken_time: 2.012


2024-07-28 09:01 INFO Epoch 3, Loss: 0.07427093386650085


Epoch 4, Step 10/44, Loss: 0.06582700461149216, taken_time: 2.118
Epoch 4, Step 20/44, Loss: 0.052933961153030396, taken_time: 2.000
Epoch 4, Step 30/44, Loss: 0.043033625930547714, taken_time: 2.022
Epoch 4, Step 40/44, Loss: 0.03772173449397087, taken_time: 1.999


2024-07-28 09:02 INFO Epoch 4, Loss: 0.03431796282529831


Epoch 5, Step 10/44, Loss: 0.02774573303759098, taken_time: 2.011
Epoch 5, Step 20/44, Loss: 0.02371051162481308, taken_time: 2.026
Epoch 5, Step 30/44, Loss: 0.02520514465868473, taken_time: 2.020
Epoch 5, Step 40/44, Loss: 0.016442513093352318, taken_time: 1.988


2024-07-28 09:04 INFO Epoch 5, Loss: 0.02073695883154869


Epoch 6, Step 10/44, Loss: 0.013771362602710724, taken_time: 2.020
Epoch 6, Step 20/44, Loss: 0.010887444019317627, taken_time: 2.032
Epoch 6, Step 30/44, Loss: 0.00982043705880642, taken_time: 2.061
Epoch 6, Step 40/44, Loss: 0.00872274860739708, taken_time: 2.051


2024-07-28 09:06 INFO Epoch 6, Loss: 0.00668737618252635


Epoch 7, Step 10/44, Loss: 0.008398612961173058, taken_time: 2.012
Epoch 7, Step 20/44, Loss: 0.006927995011210442, taken_time: 2.141
Epoch 7, Step 30/44, Loss: 0.005621111020445824, taken_time: 1.996
Epoch 7, Step 40/44, Loss: 0.00573376752436161, taken_time: 2.030


2024-07-28 09:07 INFO Epoch 7, Loss: 0.009922795929014683


Epoch 8, Step 10/44, Loss: 0.004646088927984238, taken_time: 2.016
Epoch 8, Step 20/44, Loss: 0.005237055942416191, taken_time: 2.036
Epoch 8, Step 30/44, Loss: 0.0044832900166511536, taken_time: 2.013
Epoch 8, Step 40/44, Loss: 0.004887795075774193, taken_time: 2.047


2024-07-28 09:09 INFO Epoch 8, Loss: 0.0041834148578345776


Epoch 9, Step 10/44, Loss: 0.003984477370977402, taken_time: 2.017
Epoch 9, Step 20/44, Loss: 0.004437068477272987, taken_time: 2.046
Epoch 9, Step 30/44, Loss: 0.004511326551437378, taken_time: 2.013
Epoch 9, Step 40/44, Loss: 0.010744448751211166, taken_time: 2.027


2024-07-28 09:10 INFO Epoch 9, Loss: 0.004004312679171562


Epoch 10, Step 10/44, Loss: 0.005248710513114929, taken_time: 1.988
Epoch 10, Step 20/44, Loss: 0.0033644847571849823, taken_time: 1.996
Epoch 10, Step 30/44, Loss: 0.007612314075231552, taken_time: 2.008
Epoch 10, Step 40/44, Loss: 0.005281249061226845, taken_time: 2.033


2024-07-28 09:12 INFO Epoch 10, Loss: 0.016676535829901695


Epoch 11, Step 10/44, Loss: 0.004226997494697571, taken_time: 2.004
Epoch 11, Step 20/44, Loss: 0.004490358754992485, taken_time: 1.981
Epoch 11, Step 30/44, Loss: 0.01619715616106987, taken_time: 2.029
Epoch 11, Step 40/44, Loss: 0.007200848311185837, taken_time: 2.028


2024-07-28 09:13 INFO Epoch 11, Loss: 0.0035244524478912354


Epoch 12, Step 10/44, Loss: 0.003546411171555519, taken_time: 2.021
Epoch 12, Step 20/44, Loss: 0.004097850993275642, taken_time: 2.044
Epoch 12, Step 30/44, Loss: 0.006105830892920494, taken_time: 2.001
Epoch 12, Step 40/44, Loss: 0.004001434892416, taken_time: 1.997


2024-07-28 09:15 INFO Epoch 12, Loss: 0.0056589641608297825


Epoch 13, Step 10/44, Loss: 0.0033616162836551666, taken_time: 2.041
Epoch 13, Step 20/44, Loss: 0.003063010051846504, taken_time: 1.996
Epoch 13, Step 30/44, Loss: 0.0032680686563253403, taken_time: 2.019
Epoch 13, Step 40/44, Loss: 0.012718193233013153, taken_time: 2.030


2024-07-28 09:16 INFO Epoch 13, Loss: 0.0019677679520100355


Epoch 14, Step 10/44, Loss: 0.0024947505444288254, taken_time: 2.033
Epoch 14, Step 20/44, Loss: 0.0031710397452116013, taken_time: 1.991
Epoch 14, Step 30/44, Loss: 0.003969894722104073, taken_time: 2.009
Epoch 14, Step 40/44, Loss: 0.00527336448431015, taken_time: 2.043


2024-07-28 09:18 INFO Epoch 14, Loss: 0.013064119964838028


Epoch 15, Step 10/44, Loss: 0.0020658429712057114, taken_time: 1.993
Epoch 15, Step 20/44, Loss: 0.004872506484389305, taken_time: 2.015
Epoch 15, Step 30/44, Loss: 0.0024974700063467026, taken_time: 2.027
Epoch 15, Step 40/44, Loss: 0.004214102402329445, taken_time: 2.013


2024-07-28 09:19 INFO Epoch 15, Loss: 0.0041819978505373


Epoch 16, Step 10/44, Loss: 0.005100153386592865, taken_time: 1.993
Epoch 16, Step 20/44, Loss: 0.0034996457397937775, taken_time: 2.018
Epoch 16, Step 30/44, Loss: 0.003884941339492798, taken_time: 2.012
Epoch 16, Step 40/44, Loss: 0.0022864602506160736, taken_time: 2.039


2024-07-28 09:21 INFO Epoch 16, Loss: 0.0029837919864803553


Epoch 17, Step 10/44, Loss: 0.006997862830758095, taken_time: 2.018
Epoch 17, Step 20/44, Loss: 0.0025948844850063324, taken_time: 2.001
Epoch 17, Step 30/44, Loss: 0.004802355542778969, taken_time: 1.955
Epoch 17, Step 40/44, Loss: 0.003034450113773346, taken_time: 2.001


2024-07-28 09:22 INFO Epoch 17, Loss: 0.0015456014079973102


Epoch 18, Step 10/44, Loss: 0.003286018967628479, taken_time: 2.029
Epoch 18, Step 20/44, Loss: 0.0033307913690805435, taken_time: 2.002
Epoch 18, Step 30/44, Loss: 0.004223860800266266, taken_time: 2.001
Epoch 18, Step 40/44, Loss: 0.0023806504905223846, taken_time: 2.086


2024-07-28 09:24 INFO Epoch 18, Loss: 0.002742479322478175


Epoch 19, Step 10/44, Loss: 0.0026124492287635803, taken_time: 2.020
Epoch 19, Step 20/44, Loss: 0.0029780808836221695, taken_time: 1.996
Epoch 19, Step 30/44, Loss: 0.004676520824432373, taken_time: 2.036
Epoch 19, Step 40/44, Loss: 0.0024651233106851578, taken_time: 2.003


2024-07-28 09:25 INFO Epoch 19, Loss: 0.0024517145939171314


Epoch 20, Step 10/44, Loss: 0.0033669471740722656, taken_time: 2.008
Epoch 20, Step 20/44, Loss: 0.001755300909280777, taken_time: 2.025
Epoch 20, Step 30/44, Loss: 0.0021529849618673325, taken_time: 1.988
Epoch 20, Step 40/44, Loss: 0.0019757691770792007, taken_time: 1.995


2024-07-28 09:27 INFO Epoch 20, Loss: 0.005986303091049194


Epoch 21, Step 10/44, Loss: 0.004196407273411751, taken_time: 2.046
Epoch 21, Step 20/44, Loss: 0.0029821693897247314, taken_time: 2.024
Epoch 21, Step 30/44, Loss: 0.0017596036195755005, taken_time: 1.977
Epoch 21, Step 40/44, Loss: 0.0018246155232191086, taken_time: 2.006


2024-07-28 09:28 INFO Epoch 21, Loss: 0.002084016799926758


Epoch 22, Step 10/44, Loss: 0.004443913698196411, taken_time: 1.988
Epoch 22, Step 20/44, Loss: 0.002640245482325554, taken_time: 2.024
Epoch 22, Step 30/44, Loss: 0.0014508049935102463, taken_time: 2.029
Epoch 22, Step 40/44, Loss: 0.002331947907805443, taken_time: 2.013


2024-07-28 09:30 INFO Epoch 22, Loss: 0.0012070635566487908


Epoch 23, Step 10/44, Loss: 0.001993713900446892, taken_time: 2.026
Epoch 23, Step 20/44, Loss: 0.0016725119203329086, taken_time: 2.033
Epoch 23, Step 30/44, Loss: 0.001106930896639824, taken_time: 2.001
Epoch 23, Step 40/44, Loss: 0.006323406472802162, taken_time: 1.985


2024-07-28 09:31 INFO Epoch 23, Loss: 0.0013262364082038403


Epoch 24, Step 10/44, Loss: 0.002574719488620758, taken_time: 2.008
Epoch 24, Step 20/44, Loss: 0.004569802433252335, taken_time: 2.008
Epoch 24, Step 30/44, Loss: 0.0018283762037754059, taken_time: 1.991
Epoch 24, Step 40/44, Loss: 0.002413123846054077, taken_time: 2.152


2024-07-28 09:33 INFO Epoch 24, Loss: 0.0011979705886915326


Epoch 25, Step 10/44, Loss: 0.0025834720581769943, taken_time: 1.992
Epoch 25, Step 20/44, Loss: 0.0019489489495754242, taken_time: 1.983
Epoch 25, Step 30/44, Loss: 0.002011200413107872, taken_time: 2.039
Epoch 25, Step 40/44, Loss: 0.001188386231660843, taken_time: 2.028


2024-07-28 09:34 INFO Epoch 25, Loss: 0.016505790874361992


Epoch 26, Step 10/44, Loss: 0.0014971811324357986, taken_time: 2.025
Epoch 26, Step 20/44, Loss: 0.0022588148713111877, taken_time: 2.033
Epoch 26, Step 30/44, Loss: 0.0012771710753440857, taken_time: 2.016
Epoch 26, Step 40/44, Loss: 0.0021602604538202286, taken_time: 1.977


2024-07-28 09:36 INFO Epoch 26, Loss: 0.0010934637393802404


Epoch 27, Step 10/44, Loss: 0.0011897198855876923, taken_time: 2.014
Epoch 27, Step 20/44, Loss: 0.0013011433184146881, taken_time: 2.016
Epoch 27, Step 30/44, Loss: 0.0015771836042404175, taken_time: 1.995
Epoch 27, Step 40/44, Loss: 0.0012476686388254166, taken_time: 2.016


2024-07-28 09:37 INFO Epoch 27, Loss: 0.001397384679876268


Epoch 28, Step 10/44, Loss: 0.004447953775525093, taken_time: 1.998
Epoch 28, Step 20/44, Loss: 0.0007919147610664368, taken_time: 2.010
Epoch 28, Step 30/44, Loss: 0.0023900140076875687, taken_time: 2.005
Epoch 28, Step 40/44, Loss: 0.0015964116901159286, taken_time: 2.036


2024-07-28 09:39 INFO Epoch 28, Loss: 0.002911842428147793


Epoch 29, Step 10/44, Loss: 0.0008081626147031784, taken_time: 2.017
Epoch 29, Step 20/44, Loss: 0.001990118995308876, taken_time: 2.008
Epoch 29, Step 30/44, Loss: 0.0009456612169742584, taken_time: 2.032
Epoch 29, Step 40/44, Loss: 0.001205654814839363, taken_time: 1.987


2024-07-28 09:40 INFO Epoch 29, Loss: 0.0011838740902021527


Epoch 30, Step 10/44, Loss: 0.0027642417699098587, taken_time: 1.988
Epoch 30, Step 20/44, Loss: 0.0015776418149471283, taken_time: 2.023
Epoch 30, Step 30/44, Loss: 0.0010992325842380524, taken_time: 2.007
Epoch 30, Step 40/44, Loss: 0.002368580549955368, taken_time: 1.997


2024-07-28 09:42 INFO Epoch 30, Loss: 0.0021518999710679054


Training complete. Weights saved to 'BYOL_SCNN_pretrain.h5'.
