<a href="https://colab.research.google.com/github/jinjukang67/GrooVAE/blob/main/MusicVAE_full.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Groove MIDI Dataset을 이용하여 4마디에 해당하는 드럼 샘플 뽑기

- **GrooVAE이란?**
> GrooVAE는 표현력이 풍부한 드럼 성능을 생성하고 제어하는 ​​MusicVAE의 변이입니다.

### DATASET
- MIDI(Musical Instrument Digital Interface)란?
> 음을 특정 표기법에 따라 숫자나 문자로 상징적으로 표현한 파일을 의미햡니다.
 즉, 미디는 컴퓨터가 활용하기 용이한 형태로 데이터 저장되어 있습니다.
- [Groove MIDI Dataset](https://magenta.tensorflow.org/datasets/groove)
> 이 데이터셋은 150개의 MIDI파일과 22,000개 이상의 드럼 마디로 이루어져있습니다.
> TFDS로 ``tf.data.Dataset``로 간편하게 로드해서 사용가능합니다. <br>
> + [EXPANDED Groove MIDI](https://magenta.tensorflow.org/datasets/e-gmd)

+부가설명 <br>

MIDI 드럼 비트에 대해 생각해볼때, 라이브로 연주되거나 컴퓨터로 시퀀스되는지 여부에 관계없이 2가지 주요 구성 요소로 분해 가능합니다.
> 1. The Score(서양음악 기보법으로 그에 따라 드럼이 연주됨)
> 2. The Groove(드럼이 어떻게 연주되는지, 강약과 타이밍)

간략하게 말하자면, 드럼 비트는 악보와 그루브의 조합이라고 볼 수 있습니다.

# 0. 환경 설정

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -q pyfluidsynth
!git clone https://github.com/tensorflow/magenta.git

fatal: destination path 'magenta' already exists and is not an empty directory.


In [3]:
cd /content/magenta

/content/magenta


In [4]:
# Installing the dependencies
!pip install -e .

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/magenta
Installing collected packages: magenta
  Attempting uninstall: magenta
    Found existing installation: magenta 2.1.3
    Can't uninstall 'magenta'. No files were found to uninstall.
  Running setup.py develop for magenta
Successfully installed magenta-2.1.3


In [5]:
from google.colab import files
import os
import warnings
import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
import numpy as np
import os
import tensorflow as tf

warnings.filterwarnings("ignore", category=DeprecationWarning)

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit


# 1. 데이터 전처리

4마디 샘플을 추출하기 위해 Groove MIDI Dataset 중 '4bar-midionly'를 전처리해보았습니다. 

- MIDI데이터는 기본적으로  .midi의 확장자명을 가지고 있기 때문에 학습에 이용하기 위하여 벡터로 변환하는 전처리 과정이 필요합니다.
- 학습이 가능하도록 tf.record형식으로 저장하는 작업을 진행하였습니다.

**groove/4bar-midionly** 
- Config 설명 : 오디오가 없는 Groove 데이터 세트, 4마디 청크로 분할.
- Feature 구조:
```
FeaturesDict({
    'bpm': tf.int32,
    'drummer': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    'id': tf.string,
    'midi': tf.string,
    'style': FeaturesDict({
        'primary': ClassLabel(shape=(), dtype=tf.int64, num_classes=18),
        'secondary': tf.string,
    }),
    'time_signature': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    'type': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
})
```

TFRecord란?

> The ``TFRecord`` file format is a simple record-oriented binary format that many TensorFlow applications use for training data

- 간단히 말해, 바이너리 형식으로 저장하기 위한 용도의 파일 형식입니다.
- 성능과 개발의 편의성을 이유로 TFRecord 파일 포맷을 이용하는 것이 좋다고 합니다.

In [6]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

print(tfds.__version__)

# tf.data.Dataset로 4bar-midionly 데이터 로드
dataset, info = tfds.load(
    name="groove/4bar-midionly",
    split=tfds.Split.TRAIN,
    with_info=True,
    try_gcs=True)

print(info) # Feature 구조 확인

4.0.1
tfds.core.DatasetInfo(
    name='groove',
    version=2.0.1,
    description='The Groove MIDI Dataset (GMD) is composed of 13.6 hours of aligned MIDI and
(synthesized) audio of human-performed, tempo-aligned expressive drumming
captured on a Roland TD-11 V-Drum electronic drum kit.',
    homepage='https://g.co/magenta/groove-dataset',
    features=FeaturesDict({
        'bpm': tf.int32,
        'drummer': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
        'id': tf.string,
        'midi': tf.string,
        'style': FeaturesDict({
            'primary': ClassLabel(shape=(), dtype=tf.int64, num_classes=18),
            'secondary': tf.string,
        }),
        'time_signature': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
        'type': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
    }),
    total_num_examples=21415,
    splits={
        'test': 2033,
        'train': 17261,
        'validation': 2121,
    },
    supervised_keys=None,
    citation="""

In [7]:
dataset = dataset.shuffle(1024).batch(32).prefetch(
    tf.data.experimental.AUTOTUNE)
for features in dataset.take(1):
  # Access the features you are interested in
  midi, genre = features["midi"], features["style"]["primary"]

In [8]:
print("MIDI 형태 살펴보기: \n",midi)

MIDI 형태 살펴보기: 
 tf.Tensor(
[b'MThd\x00\x00\x00\x06\x00\x01\x00\x02\x01\xe0MTrk\x00\x00\x00\x19\x00\xffQ\x03\x0b*;\x00\xffX\x04\x04\x02\x18\x08\x00\xffY\x02\x00\x00\x01\xff/\x00MTrk\x00\x00\x04w\x00\xff\x03\nMidi Drums\x00\xc9\x00\x00\xb9\x04*\x02\x04Zy\x04Z\x00\x99*.=\xb9\x04\x12\x05\x99*\x00\x08,A\x12\x16>\x02\xb9\x04Z.\x99,\x00\x12\x16\x00I(\x7f\x01\xb9\x04Z\x00\x99*h\t$98(\x00\x02*\x00\x08$\x00\x1e\xb9\x04Z\x00\x99\x16FC\x16\x00M\xb9\x04Z\x00\x99*\x18\x0f$54*\x00\x0e$\x00!(\x7f\x05\xb9\x04Z\x00\x99\x16\x7f>(\x00\x04\x16\x00=$)C$\x00\x1f$H\x0e\xb9\x04Z\x00\x99\x16\x7f4$\x00\x0e\x16\x00?&\x7fB&\x00\x1d\xb9\x04Z\x00\x99\x16\x7f\x08$7;\x16\x00\x07$\x00C\xb9\x04Z\x00\x99*L\x0b$(\x13\xb9\x04\x00%\x99*\x00\t,C\x01$\x00\x03\xb9\x04Z\x05\x99$L\x08\x1a\x7f\x16\xb9\x04\x00\x1b\x99,\x00\t$\x00\t\x1a\x00\t\xb9\x04\x00\x81(\x04\x02\x00\x99\x1a\x7f\x00&\x7f@\xb9\x04\x07\x03\x99\x1a\x00\x00&\x00W\xb9\x04\x00W\x04\x10\x01\x99(\x7f\x05$O\x05,/\x01\x1a\x7f\x14\xb9\x04Z#\x99(\x00\x05$\x00\x06\x1a\x00\x

위에선 ``tf.data.Dataset``을 통해 midi의 형태를 살펴보았습니다.

아래에선 데이터셋을 다운받아 midi 데이터를 tfrecord 형식으로 ``sequences.tfrecord``로 저장했습니다.

In [9]:
import pathlib


dataset_url = "https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip"
data_dir = tf.keras.utils.get_file(origin=dataset_url, 
                                   fname='/content/drive/MyDrive/Pozalabs_VAE/data/midionly_dataset', 
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [10]:
# MIDI parser가 잘못된 MIDI 파일을 만나면 Warning이 나올 수 있지만 무시해도 됩니다.
# MIDI 파일 중 파싱이 안되는 파일은 스킵됩니다.

!convert_dir_to_note_sequences \
  --input_dir=/content/drive/MyDrive/Pozalabs_VAE/data/midionly_dataset \
  --output_file=/content/drive/MyDrive/Pozalabs_VAE/data/sequences.tfrecord \
  --recursive

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Instructions for updating:
non-resource variables are not supported in the long term
INFO:tensorflow:Converting files in '/content/drive/MyDrive/Pozalabs_VAE/data/midionly_dataset/'.
I0630 02:23:16.769789 140680571557760 convert_dir_to_note_sequences.py:82] Converting files in '/content/drive/MyDrive/Pozalabs_VAE/data/midionly_dataset/'.
Traceback (most recent call last):
  File "/usr/local/bin/convert_dir_to_note_sequences", line 33, in <module>
    sys.exit(load_entry_point('magenta', 'console_scripts', 'convert_dir_to_note_seque

# 2. 모델 학습

- ValueError 나는 이유
> ValueError: if required flags are missing or invalid.

[music_vae_train.py의 일부 코드](https://github.com/magenta/magenta/blob/main/magenta/models/music_vae/music_vae_train.py)
``` 
 if FLAGS.tfds_name:
    if FLAGS.examples_path:
      raise ValueError(
          'At most one of --examples_path and --tfds_name can be set.')
```

In [11]:
os.chdir("/content/magenta/magenta/models/music_vae/")

## 2-1. 직접 전처리 수행 후 학습

In [12]:
!python music_vae_train.py \
--config=groovae_4bar \
--run_dir=/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_4bar_tfr \
--mode=train \
--num_steps=20000 \
--examples_path=/content/drive/MyDrive/Pozalabs_VAE/data/sequences.tfrecord 

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Instructions for updating:
non-resource variables are not supported in the long term
2022-06-30 02:23:23.547287: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, GrooveLstmDecoder, and hparams:
{'max_seq_len': 64, 'z_size': 256, 'free_bits': 48, 'max_beta': 0.2, 'beta_rate': 0.0, 'batch_size': 512, 'grad_clip': 1.0, 'clip_m

> - model checkpoint가 0 이후로 생성되지 않고, loss값도 계산되질 않는 걸 봐선 학습이 제대로 진행되지 않는 것 같다고 판단해 직접 전처리한 데이터는 생성 시 사용하지 않았습니다.

## 2-2. magenta에 이미 전처리 코드가 구현되있으므로 바로 학습 진행
- [data.py](https://github.com/magenta/magenta/blob/main/magenta/models/music_vae/data.py)
- [preprocess_tfrecord.py](https://github.com/magenta/magenta/blob/main/magenta/models/music_vae/preprocess_tfrecord.py)
 

### groove/full-midionly
> groove 전체 데이터셋

In [13]:
!python music_vae_train.py \
--config=groovae_4bar \
--run_dir=/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full \
--mode=train \
--num_steps=20000 \
--tfds_name=groove/full-midionly \ 

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Instructions for updating:
non-resource variables are not supported in the long term
2022-06-30 02:23:49.053195: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, GrooveLstmDecoder, and hparams:
{'max_seq_len': 64, 'z_size': 256, 'free_bits': 48, 'max_beta': 0.2, 'beta_rate': 0.0, 'batch_size': 512, 'grad_clip': 1.0, 'clip_m

> loss가 가장 작은 loss = 42.277954 (27.000 sec)일 때의 checkpoint는 ``model.ckpt-19251``에 저장되있기 때문에 아래 생성시 이 파일을 사용할 것입니다.

# 3. 생성

## 3-1. 생성 방법 1

### groovae_full

In [26]:
def play(note_sequence):
  mm.play_sequence(note_sequence, synth=mm.fluidsynth)

def download(note_sequence, filename):
  mm.sequence_proto_to_midi_file(note_sequence, filename)
  files.download(filename)

print("Music VAE 초기화 중...")

config = 'groovae_4bar'
model_path = '/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/train/model.ckpt-19251'
num_music = 5

music_vae = TrainedModel(
      configs.CONFIG_MAP[config], 
      batch_size=num_music, 
      checkpoint_dir_or_path=model_path)

print('🌟생성완료🌟(1)')

Music VAE 초기화 중...
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, GrooveLstmDecoder, and hparams:
{'max_seq_len': 64, 'z_size': 256, 'free_bits': 48, 'max_beta': 0.2, 'beta_rate': 0.0, 'batch_size': 5, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [256, 256], 'enc_rnn_size': [512], 'dropout_keep_prob': 0.3, 'sampling_schedule': 'constant', 'sampling_rate': 0.0, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}


INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, GrooveLstmDecoder, and hparams:
{'max_seq_len': 64, 'z_size': 256, 'free_bits': 48, 'max_beta': 0.2, 'beta_rate': 0.0, 'batch_size': 5, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_learning_rate': 1e-05, 'conditional': True, 'dec_rnn_size': [256, 256], 'enc_rnn_size': [512], 'dropout_keep_prob': 0.3, 'sampling_schedule': 'constant', 'sampling_rate': 0.0, 'use_cudnn': False, 'residual_encoder': False, 'residual_decoder': False, 'control_preprocessing_rnn_size': [256]}


INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [512]



INFO:tensorflow:
Encoder Cells (bidirectional):
  units: [512]











INFO:tensorflow:
Decoder Cells:
  units: [256, 256]



INFO:tensorflow:
Decoder Cells:
  units: [256, 256]





  name=name),
  return layer.apply(inputs)
  self._names["W"], [input_size + self._num_units, self._num_units * 4])
  initializer=tf.constant_initializer(0.0))


INFO:tensorflow:Restoring parameters from /content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/train/model.ckpt-19251


  kernel_initializer=tf.random_normal_initializer(stddev=0.001))
  kernel_initializer=tf.random_normal_initializer(stddev=0.001))
INFO:tensorflow:Restoring parameters from /content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/train/model.ckpt-19251


🌟생성완료🌟(1)


## 3-2. 생성 방법 2


### groovae_full


In [28]:
!python music_vae_generate.py \
--config=groovae_4bar \
--checkpoint_file=/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/train/model.ckpt-19251 \
--mode=sample \
--num_outputs=5 \
--output_dir=/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated

print('🌟생성완료🌟(2)')

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Instructions for updating:
non-resource variables are not supported in the long term
INFO:tensorflow:Loading model...
I0630 04:20:19.106403 140229391525760 music_vae_generate.py:149] Loading model...
INFO:tensorflow:Building MusicVAE model with BidirectionalLstmEncoder, GrooveLstmDecoder, and hparams:
{'max_seq_len': 64, 'z_size': 256, 'free_bits': 48, 'max_beta': 0.2, 'beta_rate': 0.0, 'batch_size': 5, 'grad_clip': 1.0, 'clip_mode': 'global_norm', 'grad_norm_clip_to_zero': 10000, 'learning_rate': 0.001, 'decay_rate': 0.9999, 'min_

# 4. 샘플 재생

## 4-1. 재생 방법 1
- magenta 라이브러리는 NoteSquences를 중심으로 전개되기 때문에 note_seq 라이브러리를 사용해 연주 재생해보았습니다.


In [27]:
# Hack to allow python to pick up the newly-installed fluidsynth lib.
# This is only needed for the hosted Colab environment.
import note_seq
import ctypes.util
orig_ctypes_util_find_library = ctypes.util.find_library
def proxy_find_library(lib):
  if lib == 'fluidsynth':
    return 'libfluidsynth.so.1'
  else:
    return orig_ctypes_util_find_library(lib)
ctypes.util.find_library = proxy_find_library


temperature = 0.5 #@param {type:"slider", min:0.1, max:1.5, step:0.1}
drums_samples = music_vae.sample(n=5, length=64, temperature=temperature)

for ns in drums_samples:
    note_seq.plot_sequence(ns)
    note_seq.play_sequence(ns, synth=note_seq.fluidsynth)

ValueError: ignored

## 4-2. 재생 방법 2
- [pretty_midi](https://craffel.github.io/pretty-midi/)
>- ``class pretty_midi.PrettyMIDI(midi_file=None, resolution=220, initial_tempo=120.0)``
>- 쉽게 조작할 수 있는 형식의 MIDI 데이터용 컨테이너입니다.

In [29]:
import pretty_midi
from glob import glob
from IPython import display


def display_audio(pm: pretty_midi.PrettyMIDI, seconds=50):
  waveform = pm.fluidsynth(fs=_SAMPLING_RATE)
  # midi파일의 50초까지 sequence를 불러옵니다.
  waveform_short = waveform[:seconds*_SAMPLING_RATE]

  return display.Audio(waveform_short, rate=_SAMPLING_RATE)

In [30]:
i = 0
example_path = glob("/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated/*")

for path in example_path:
  print(path)

# MIDI 파일을 PrettyMIDI 객체로 불러옵니다.
pm = pretty_midi.PrettyMIDI(example_path[i])

# Print an empirical estimate of its global tempo
print(pm.estimate_tempo())

seed = 777
tf.random.set_seed(seed)
np.random.seed(seed)

# Sampling rate for audio playback  
# 보통 audio sample rate는 초당 44,100 샘플이라고 합니다.
_SAMPLING_RATE = 44100

display_audio(pm)

/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated/groovae_4bar_sample_2022-06-30_042019-000-of-005.mid
/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated/groovae_4bar_sample_2022-06-30_042019-001-of-005.mid
/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated/groovae_4bar_sample_2022-06-30_042019-002-of-005.mid
/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated/groovae_4bar_sample_2022-06-30_042019-003-of-005.mid
/content/drive/MyDrive/Pozalabs_VAE/checkpoints/groovae_full/generated/groovae_4bar_sample_2022-06-30_042019-004-of-005.mid
195.97030752916223
