<a href="https://colab.research.google.com/github/chanyub/MusicVAE/blob/main/musicVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0.Setting

- google Colab 환경에서 진행하였습니다.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import pandas as pd

In [None]:
!pip install magenta==2.1.0

In [5]:
import magenta

# 1.Preprocessing

- mid -> note_seq -> tfrecord

In [6]:
data_root= '/content/drive/MyDrive/MusicVAE/groove'
csv_file = data_root+'/info.csv'

In [7]:
tfrec_root = '/content/drive/MyDrive/MusicVAE/tfrec/example.tfrecord'

In [8]:
from magenta.scripts.convert_dir_to_note_sequences import convert_directory

In [None]:
convert_directory(data_root,tfrec_root,recursive=True)
# .mid -> note_seq -> .tfrecord
# tfrec_root 위치에 .mid파일을 .tfrecord로 변환하여 학습이 가능한 형태로 저장해줌

## Config

- 학습을 위해 모델을 포함한 Config 정의
- train_examples_path 변수에 학습용으로 변환한 .tfrecord의 path를 넣어줌
- Config 설정 참고: https://github.com/magenta/magenta/blob/main/magenta/models/music_vae/configs.py

In [10]:
import collections

class Config(collections.namedtuple(
    'Config',
    ['model', 'hparams', 'note_sequence_augmenter', 'data_converter',
     'train_examples_path', 'eval_examples_path', 'tfds_name'])):

  def values(self):
    return self._asdict()

Config.__new__.__defaults__ = (None,) * len(Config._fields)


def update_config(config, update_dict):
  config_dict = config.values()
  config_dict.update(update_dict)
  return Config(**config_dict)


CONFIG_MAP = {}

In [11]:
from magenta.common import merge_hparams
from magenta.contrib import training as contrib_training
from magenta.models.music_vae import MusicVAE
from magenta.models.music_vae import lstm_models
from magenta.models.music_vae import data

HParams = contrib_training.HParams

# GrooVAE configs
CONFIG_MAP['groovae_4bar'] = Config(
    model=MusicVAE(lstm_models.BidirectionalLstmEncoder(),
                   lstm_models.GrooveLstmDecoder()), # MusicVAE 정의
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),
        HParams(
            batch_size=512,
            max_seq_len=16 * 4,  # 4 bars w/ 16 steps per bar
            z_size=256,
            enc_rnn_size=[512],
            dec_rnn_size=[256, 256],
            max_beta=0.2,
            free_bits=48,
            dropout_keep_prob=0.3,
        )),
    note_sequence_augmenter=None,
    data_converter=data.GrooveConverter(
        split_bars=4, steps_per_quarter=4, quarters_per_bar=4,
        max_tensors_per_notesequence=20,
        pitch_classes=data.ROLAND_DRUM_PITCH_CLASSES,
        inference_pitch_classes=data.REDUCED_DRUM_PITCH_CLASSES),
    # tfds_name='groove/4bar-midionly',
    train_examples_path='/content/drive/MyDrive/MusicVAE/tfrec/example.tfrecord',
)

### 인코더: BidirectionalLSTM

- 단방향으로만 정보를 전달하던 LSTM을 개선한 양방향 LSTM
- longer-term context 정보 전달에 좀 더 강점이 있다.

### 디코더: GrooveLSTM

- 베르누이분포를 기반으로 샘플링을 진행한다.

## Train Function

In [15]:
# Copyright 2022 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MusicVAE training script."""
import os

from magenta.models.music_vae import configs
from magenta.models.music_vae import data
import tensorflow.compat.v1 as tf
import tf_slim

# Should not be called from within the graph to avoid redundant summaries.
def _trial_summary(hparams, examples_path, output_dir):
  """Writes a tensorboard text summary of the trial."""

  examples_path_summary = tf.summary.text(
      'examples_path', tf.constant(examples_path, name='examples_path'),
      collections=[])

  hparams_dict = hparams.values()

  # Create a markdown table from hparams.
  header = '| Key | Value |\n| :--- | :--- |\n'
  keys = sorted(hparams_dict.keys())
  lines = ['| %s | %s |' % (key, str(hparams_dict[key])) for key in keys]
  hparams_table = header + '\n'.join(lines) + '\n'

  hparam_summary = tf.summary.text(
      'hparams', tf.constant(hparams_table, name='hparams'), collections=[])

  with tf.Session() as sess:
    writer = tf.summary.FileWriter(output_dir, graph=sess.graph)
    writer.add_summary(examples_path_summary.eval())
    writer.add_summary(hparam_summary.eval())
    writer.close()


def _get_input_tensors(dataset, config):
  """Get input tensors from dataset."""
  batch_size = config.hparams.batch_size
  iterator = tf.data.make_one_shot_iterator(dataset)
  (input_sequence, output_sequence, control_sequence,
   sequence_length) = iterator.get_next()
  input_sequence.set_shape(
      [batch_size, None, config.data_converter.input_depth])
  output_sequence.set_shape(
      [batch_size, None, config.data_converter.output_depth])
  if not config.data_converter.control_depth:
    control_sequence = None
  else:
    control_sequence.set_shape(
        [batch_size, None, config.data_converter.control_depth])
  sequence_length.set_shape([batch_size] + sequence_length.shape[1:].as_list())

  return {
      'input_sequence': input_sequence,
      'output_sequence': output_sequence,
      'control_sequence': control_sequence,
      'sequence_length': sequence_length
  }


def train(train_dir,
          config,
          dataset_fn,
          checkpoints_to_keep=5,
          keep_checkpoint_every_n_hours=1,
          num_steps=None,
          master='',
          num_sync_workers=0,
          num_ps_tasks=0,
          task=0):
  """Train loop."""
  tf.gfile.MakeDirs(train_dir)
  is_chief = (task == 0)
  # if is_chief:
  #   _trial_summary(
  #       config.hparams, config.train_examples_path or config.tfds_name,
  #       train_dir)

  with tf.Graph().as_default():
    with tf.device(tf.train.replica_device_setter(
        num_ps_tasks, merge_devices=True)):

      model = config.model
      model.build(config.hparams,
                  config.data_converter.output_depth,
                  is_training=True)

      optimizer = model.train(**_get_input_tensors(dataset_fn(), config))

      hooks = []
      if num_sync_workers:
        optimizer = tf.train.SyncReplicasOptimizer(
            optimizer,
            num_sync_workers)
        hooks.append(optimizer.make_session_run_hook(is_chief))

      grads, var_list = list(zip(*optimizer.compute_gradients(model.loss)))
      global_norm = tf.global_norm(grads)
      tf.summary.scalar('global_norm', global_norm)

      if config.hparams.clip_mode == 'value':
        g = config.hparams.grad_clip
        clipped_grads = [tf.clip_by_value(grad, -g, g) for grad in grads]
      elif config.hparams.clip_mode == 'global_norm':
        clipped_grads = tf.cond(
            global_norm < config.hparams.grad_norm_clip_to_zero,
            lambda: tf.clip_by_global_norm(  # pylint:disable=g-long-lambda
                grads, config.hparams.grad_clip, use_norm=global_norm)[0],
            lambda: [tf.zeros(tf.shape(g)) for g in grads])
      else:
        raise ValueError(
            'Unknown clip_mode: {}'.format(config.hparams.clip_mode))
      train_op = optimizer.apply_gradients(
          list(zip(clipped_grads, var_list)),
          global_step=model.global_step,
          name='train_step')

      logging_dict = {'global_step': model.global_step,
                      'loss': model.loss}

      hooks.append(tf.train.LoggingTensorHook(logging_dict, every_n_iter=100))
      if num_steps:
        hooks.append(tf.train.StopAtStepHook(last_step=num_steps))

      scaffold = tf.train.Scaffold(
          saver=tf.train.Saver(
              max_to_keep=checkpoints_to_keep,
              keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours))
      tf_slim.training.train(
          train_op=train_op,
          logdir=train_dir,
          scaffold=scaffold,
          hooks=hooks,
          save_checkpoint_secs=60,
          master=master,
          is_chief=is_chief)


def run(config_map,
        tf_file_reader=tf.data.TFRecordDataset,
        file_reader=tf.python_io.tf_record_iterator,
        is_training=True):
  """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.
    tf_file_reader: The tf.data.Dataset class to use for reading files.
    file_reader: The Python reader to use for reading files.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
  config = config_map['groovae_4bar']
  train_dir = '/content/drive/MyDrive/MusicVAE/train'
  num_steps = None # epoch

  def dataset_fn():
    return data.get_dataset(
        config,
        tf_file_reader=tf_file_reader,
        is_training=True,
        cache_dataset=True)

  if is_training == True:
    train(
        train_dir,
        config=config,
        dataset_fn=dataset_fn,
        num_steps=num_steps)
        # checkpoints_to_keep=FLAGS.checkpoints_to_keep,
        # keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
        # num_steps=FLAGS.num_steps,
        # master=FLAGS.master,
        # num_sync_workers=FLAGS.num_sync_workers,
        # num_ps_tasks=FLAGS.num_ps_tasks,
        # task=FLAGS.task)
  else:
    print("EVAL")
    # num_batches = FLAGS.eval_num_batches or data.count_examples(
    #     config.eval_examples_path,
    #     config.tfds_name,
    #     config.data_converter,
    #     file_reader) // config.hparams.batch_size
    # eval_dir = os.path.join(run_dir, 'eval' + FLAGS.eval_dir_suffix)
    # evaluate(
    #     train_dir,
    #     eval_dir,
    #     config=config,
    #     dataset_fn=dataset_fn,
    #     num_batches=num_batches,
    #     master=FLAGS.master)


# def main(unused_argv):
#   tf.logging.set_verbosity(FLAGS.log)
#   run(configs.CONFIG_MAP)


# def console_entry_point():
#   tf.disable_v2_behavior()
#   tf.app.run(main)


# if __name__ == '__main__':
#   console_entry_point()

# 2. Start Train

In [None]:
run(CONFIG_MAP)

# 3. Generate sequence

In [19]:
# GrooVAE configs
CONFIG_MAP['groovae_4bar'] = Config(
    model=MusicVAE(lstm_models.BidirectionalLstmEncoder(),
                   lstm_models.GrooveLstmDecoder()), # MusicVAE 정의
    hparams=merge_hparams(
        lstm_models.get_default_hparams(),
        HParams(
            batch_size=512,
            max_seq_len=16 * 4,  # 4 bars w/ 16 steps per bar
            z_size=256,
            enc_rnn_size=[512],
            dec_rnn_size=[256, 256],
            max_beta=0.2,
            free_bits=48,
            dropout_keep_prob=0.3,
        )),
    note_sequence_augmenter=None,
    data_converter=data.GrooveConverter(
        split_bars=4, steps_per_quarter=4, quarters_per_bar=4,
        max_tensors_per_notesequence=20,
        pitch_classes=data.ROLAND_DRUM_PITCH_CLASSES,
        inference_pitch_classes=data.REDUCED_DRUM_PITCH_CLASSES),
    # tfds_name='groove/4bar-midionly',
    train_examples_path='/content/drive/MyDrive/MusicVAE/tfrec/example.tfrecord',
)

In [20]:
from magenta.models.music_vae.trained_model import TrainedModel

In [None]:
model = TrainedModel(
    config=CONFIG_MAP['groovae_4bar'],
    batch_size=1,
    checkpoint_dir_or_path='/content/drive/MyDrive/MusicVAE/train')

In [22]:
generated_sequence = model.sample(n=1, length=16*4, temperature=0.5)

In [None]:
generated_sequence[0]

In [24]:
import note_seq

In [25]:
note_seq.sequence_proto_to_midi_file(generated_sequence[0], 'sample.mid')