In [1]:
import json
from pathlib import Path
from functools import partial

In [2]:
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [4]:
import sys

sys.path.append('../../')

In [5]:
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.separator import Separator
from spleeter.audio.convertor import to_stereo
from spleeter.model import model_fn
from spleeter.model.provider import ModelProvider
from spleeter.dataset import get_training_dataset, get_validation_dataset

In [6]:
config_path = "config/unet_config.json"

In [7]:
with open(config_path) as f:
    params = json.load(f)

#### Создаем `estimator`

In [8]:
def _create_estimator(params):
    """ Creates estimator.

    :param params: TF params to build estimator from.
    :returns: Built estimator.
    """
    session_config = tf.compat.v1.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.85
    session_config.gpu_options.allow_growth = True
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=params['model_dir'],
        params=params,
        config=tf.estimator.RunConfig(
            save_checkpoints_steps=params['save_checkpoints_steps'],
            tf_random_seed=params['random_seed'],
            save_summary_steps=params['save_summary_steps'],
            session_config=session_config,
            log_step_count_steps=100,
            keep_checkpoint_max=5))
    return estimator

In [9]:
estimator = _create_estimator(params)
estimator

INFO:tensorflow:Using config: {'_model_dir': 'voice_model', '_tf_random_seed': 3, '_save_summary_steps': 50, '_save_checkpoints_steps': 300, '_save_checkpoints_secs': None, '_session_config': gpu_options {
  per_process_gpu_memory_fraction: 0.85
  allow_growth: true
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001E18B994198>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


<tensorflow_estimator.python.estimator.estimator.Estimator at 0x1e1fdb76e48>

#### Загрузчик тренировочных данных

In [10]:
def _create_train_spec(params, audio_adapter, audio_path):
    """ Creates train spec.

    :param params: TF params to build spec from.
    :returns: Built train spec.
    """
    input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
    train_spec = tf.estimator.TrainSpec(
        input_fn=input_fn,
        max_steps=params['train_max_steps'])
    return train_spec

In [11]:
audio_adapter = get_default_audio_adapter()

In [12]:
audio_path = str(Path('..') / '..' / 'data' / "transformed" / "diarisation" / 'nr')

In [13]:
assert len(list(Path(audio_path).iterdir())) >= 2

In [14]:
train_spec = _create_train_spec(params, audio_adapter, audio_path)

#### Загрузчик валидационных данных

In [15]:
def _create_evaluation_spec(params, audio_adapter, audio_path):
    """ Setup eval spec evaluating ever n seconds

    :param params: TF params to build spec from.
    :returns: Built evaluation spec.
    """
    input_fn = partial(
        get_validation_dataset,
        params,
        audio_adapter,
        audio_path)
    evaluation_spec = tf.estimator.EvalSpec(
        input_fn=input_fn,
        steps=None,
        throttle_secs=params['throttle_secs'])
    return evaluation_spec

In [16]:
evaluation_spec = _create_evaluation_spec(params, audio_adapter, audio_path)

#### Обучение
(вся магия — внутри)

In [17]:
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)

INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 300 or save_checkpoints_secs None.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Apply unet for first_speaker_spectrogram
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Apply unet for second_speaker_spectrogram
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done 

KeyboardInterrupt: 

#### Сохраним полученные веса

In [18]:
ModelProvider.writeProbe(params['model_dir'])

Остановил обучение, так как совсем не обучалось (изменения совсем незначительные) - ничего не обучалось.

![first_loss](https://i.imgur.com/zcPhyOs.png)

![second_loss](https://i.imgur.com/IFtounW.png)