## install dependencies

magenta 와 music_vae libraries 들을 import 하고 사용할 때 필요한 dependencies.

In [None]:
!apt-get update -qq && apt-get install -qq libfluidsynth2 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -q pyfluidsynth

## install my **forked** magenta package

In [None]:
!pip install -qU -e git+https://github.com/lukysummer/magenta.git#egg=magenta

**MAKE SURE TO RESTART AFTER RUNNING THE ABOVE CELL !!!**

## Train with forked package and new config `groovae_4bar_MusicVAE`

In [None]:
# https://github.com/magenta/magenta/tree/main/magenta/models/music_vae#training-your-own-musicvae
!music_vae_train \
--config=groovae_4bar_MusicVAE \
--run_dir=groove_music_vae/ \
--mode=train \
--tfds_name=groove/4bar-midionly \
--hparams=learning_rate=0.0005

## Sample with trained checkpoints



### Sample with checkpoints at different steps to compare

To check if there is any improvement as training progress.

In [1]:
import os

train_ckpt_dir = "groove_music_vae/train"  # directory where all checkpoints during training are saved.
sample_save_dir = "generated_samples"      # directory to save generated samples

# Create the directory to save generated samples.
if not os.path.exists(sample_save_dir):
  os.mkdir(sample_save_dir)

In [5]:
!mv test/* groove_music_vae/train/

In [21]:
import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
from typing import List

def generate(ckpt_path: str, # checkpoint path to use for sampling
             config_name: str, # name of the config to use for sampling
             n_samples: int, # Number of samples to generate
             n_16th_notes: int, # Number of 16th notes to generate (= Number of bars to generate * 4)
             temperature: float, # degree of randomness between 0 and 1 (1: no randomness)
             save_sample: bool = False, # whether or not to save generated sample(s)
             sample_save_dir: str = "", # path to generated sample(s) if save_sample=True
             ):  
  # Create a TrainedModel instance.
  loaded_model = TrainedModel(configs.CONFIG_MAP[config_name], batch_size=4, checkpoint_dir_or_path=ckpt_path)

  # Generate `n_samples` note sequences with the current checkpoint.
  drum_samples = loaded_model.sample(n=n_samples, 
                                     length=n_16th_notes, 
                                     temperature=temperature)
  
  # For each of the `n_samples` generated note sequence,
  for i, ns in enumerate(drum_samples):
    # Play generated note sequence.
    mm.play_sequence(ns, synth=mm.fluidsynth)

    # Convert the generated note sequence into a midi file and save to `sample_save_dir`.
    midi_out_path = os.path.join(sample_save_dir, f"step_{test_step}_sample_{i}.mid")
    mm.sequence_proto_to_midi_file(ns, midi_out_path)

In [None]:
# Define a list of training step numbers to generate samples from
steps_to_test=[1345, 18257, 27518, 36656, 47527, 57773] 

for test_step in steps_to_test:
  print("----------------------------------")
  print(f"Now testing checkpoint at step {test_step}")

  # Define checkpoint path.
  ckpt_path = os.path.join(train_ckpt_dir, f"model.ckpt-{test_step}")

  # Generate samples.
  generate(ckpt_path=ckpt_path, config_name='groovae_4bar_MusicVAE',
           n_samples=4, n_16th_notes=16*4, temperature=0.5,
           save_sample=True, sample_save_dir=sample_save_dir)

In [12]:
# Zip and download generated files
!zip -qq -r generated_samples.zip generated_samples
from google.colab import files
files.download("generated_samples.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Try sampling with other models for comparison

Using pre-trained checkpoints available in [Music VAE Github](https://github.com/magenta/magenta/tree/master/magenta/models/music_vae)

In [None]:
BASE_DIR = "gs://download.magenta.tensorflow.org/models/music_vae/colab2"

# For example, use the pre-trained checkpoint trained with 2-bar drums with 9 classes.
generate(ckpt_path=BASE_DIR + '/checkpoints/drums_2bar_small.lokl.ckpt', 
         config_name='cat-drums_2bar_small',
         n_samples=4, n_16th_notes=16*4, temperature=0.5)