In [33]:
import json
from pathlib import Path
from functools import partial

In [34]:
import tensorflow as tf

In [35]:
tf.logging.set_verbosity(tf.logging.DEBUG)

In [36]:
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.separator import Separator
from spleeter.audio.convertor import to_stereo
from spleeter.model import model_fn
from spleeter.model.provider import ModelProvider
from spleeter.dataset import get_training_dataset, get_validation_dataset

In [37]:
config_path = "config/voice_config.json"

In [38]:
with open(config_path) as f:
    params = json.load(f)

#### Создаем `estimator`

In [39]:
def _create_estimator(params):
    """ Creates estimator.

    :param params: TF params to build estimator from.
    :returns: Built estimator.
    """
    session_config = tf.compat.v1.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=params['model_dir'],
        params=params,
        config=tf.estimator.RunConfig(
            save_checkpoints_steps=params['save_checkpoints_steps'],
            tf_random_seed=params['random_seed'],
            save_summary_steps=params['save_summary_steps'],
            session_config=session_config,
            log_step_count_steps=10,
            keep_checkpoint_max=2))
    return estimator

In [40]:
estimator = _create_estimator(params)
estimator

INFO:tensorflow:Using config: {'_model_dir': 'voice_model', '_tf_random_seed': 3, '_save_summary_steps': 5, '_save_checkpoints_steps': 30, '_save_checkpoints_secs': None, '_session_config': gpu_options {
  per_process_gpu_memory_fraction: 0.45
}
, '_keep_checkpoint_max': 2, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 10, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000020974778348>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


<tensorflow_estimator.python.estimator.estimator.Estimator at 0x209747785c8>

#### Загрузчик тренировочных данных

In [41]:
def _create_train_spec(params, audio_adapter, audio_path):
    """ Creates train spec.

    :param params: TF params to build spec from.
    :returns: Built train spec.
    """
    input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
    train_spec = tf.estimator.TrainSpec(
        input_fn=input_fn,
        max_steps=params['train_max_steps'])
    return train_spec

In [42]:
audio_adapter = get_default_audio_adapter()

In [43]:
audio_path = str(Path('..') / '..' / "data" / "nr")

In [44]:
assert len(list(Path(audio_path).iterdir())) >= 2

In [45]:
train_spec = _create_train_spec(params, audio_adapter, audio_path)

#### Загрузчик валидационных данных

In [46]:
def _create_evaluation_spec(params, audio_adapter, audio_path):
    """ Setup eval spec evaluating ever n seconds

    :param params: TF params to build spec from.
    :returns: Built evaluation spec.
    """
    input_fn = partial(
        get_validation_dataset,
        params,
        audio_adapter,
        audio_path)
    evaluation_spec = tf.estimator.EvalSpec(
        input_fn=input_fn,
        steps=None,
        throttle_secs=params['throttle_secs'])
    return evaluation_spec

In [47]:
evaluation_spec = _create_evaluation_spec(params, audio_adapter, audio_path)

#### Обучение
(вся магия — внутри)

In [48]:
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)

INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 30 or save_checkpoints_secs None.


TypeError: Value passed to parameter 'x' has DataType uint8 not in list of allowed values: bfloat16, float16, float32, float64, int32, int64

#### Сохраним полученные веса

In [None]:
ModelProvider.writeProbe(params['model_dir'])