In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
import os
# os.chdir('../../')

In [7]:
%cd /Users/arpitha/Documents/295B/musicautobot

/Users/arpitha/Documents/295B/musicautobot


In [8]:
from musicautobot.numpy_encode import *
from musicautobot.utils.file_processing import process_all, process_file
from musicautobot.config import *
from musicautobot.music_transformer import *
from musicautobot.multitask_transformer import *
from musicautobot.utils.stacked_dataloader import StackedDataBunch

In [9]:
from fastai.text import *

## MultitaskTransformer Training

Multitask Training is an extension of [MusicTransformer](../music_transformer/Train.ipynb).

Instead a basic language model that predicts the next word...

We train on multiple tasks
* [Next Word](../music_transformer/Train.ipynb)
* [Bert Mask](https://arxiv.org/abs/1810.04805)
* [Sequence to Sequence Translation](http://jalammar.github.io/illustrated-transformer/)

This gives a more generalized model and also let's you do some really cool [predictions](Generate.ipynb)

## End to end training pipeline 

1. Create and encode dataset
2. Initialize Transformer MOdel
3. Train
4. Predict

In [10]:
# Location of your midi files
midi_path = Path('data/midi/lmd_dataset')
midi_path.mkdir(parents=True, exist_ok=True)

# Location to save dataset
data_path = Path('data/numpy')
data_path.mkdir(parents=True, exist_ok=True)

data_save_name = 'musicitem_data_save.pkl'
s2s_data_save_name = 'multiitem_data_save.pkl'

## 1. Gather midi dataset

Make sure all your midi data is in `musicautobot/data/midi` directory

Here's a pretty good dataset with lots of midi data:  
https://www.reddit.com/r/datasets/comments/3akhxy/the_largest_midi_collection_on_the_internet/

Download the folder and unzip it to `data/midi`

## 2. Create dataset from MIDI files

In [13]:
midi_files = get_files(midi_path, '.mid', recurse=True); len(midi_files)

20

### 2a. Create NextWord/Mask Dataset

In [14]:
processors = [Midi2ItemProcessor()]
data = MusicDataBunch.from_files(midi_files, data_path, processors=processors, 
                                 encode_position=True, dl_tfms=mask_lm_tfm_pitchdur, 
                                 bptt=5, bs=2)
data.save(data_save_name)

  return np.array(a, dtype=dtype, **kwargs)


In [15]:
xb, yb = data.one_batch(); xb

{'msk': {'x': tensor([[  8, 141,   4, 138,   8],
          [  8, 141,   4, 138,   8]]),
  'pos': tensor([[16, 16, 20, 20, 20],
          [16, 16, 20, 20, 20]])},
 'lm': {'x': tensor([[138,   8, 141,  50, 138],
          [138,   8, 141,  50, 138]]),
  'pos': tensor([[16, 16, 20, 20, 20],
          [16, 16, 20, 20, 20]])}}

Key:
* 'msk' = masked input
* 'lm' = next word input
* 'pos' = timestepped postional encoding. This is in addition to relative positional encoding

Note: MultitaskTransformer trains on both the masked input ('msk') and next word input ('lm') at the same time.

The encoder is trained on the 'msk' data, while the decoder is trained on 'lm' data.



### 2b. Create sequence to sequence dataset

In [16]:
processors = [Midi2MultitrackProcessor()]
s2s_data = MusicDataBunch.from_files(midi_files, data_path, processors=processors, 
                                     preloader_cls=S2SPreloader, list_cls=S2SItemList,
                                     dl_tfms=melody_chord_tfm,
                                     bptt=5, bs=2)
s2s_data.save(s2s_data_save_name)



  def to_idx(self): return np.array((self.melody.to_idx(), self.chords.to_idx()))


Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks
Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks


Structure

In [17]:
xb, yb = s2s_data.one_batch(); xb

{'c2m': {'enc': tensor([[  5,   1,   8, 169, 107],
          [  5,   1,   8, 155,  65]]),
  'enc_pos': tensor([[ 0,  0,  0,  0, 32],
          [ 0,  0,  0,  0, 18]]),
  'dec': tensor([[  6,   1,   8, 153,  56],
          [  6,   1,   8, 147,  70]]),
  'dec_pos': tensor([[ 0,  0,  0,  0, 16],
          [ 0,  0,  0,  0, 10]])},
 'm2c': {'enc': tensor([[  6,   1,   8, 153,  56],
          [  6,   1,   8, 147,  70]]),
  'enc_pos': tensor([[ 0,  0,  0,  0, 16],
          [ 0,  0,  0,  0, 10]]),
  'dec': tensor([[  5,   1,   8, 169, 107],
          [  5,   1,   8, 155,  65]]),
  'dec_pos': tensor([[ 0,  0,  0,  0, 32],
          [ 0,  0,  0,  0, 18]])}}

Key:
* 'c2m' = chord2melody translation
 * enc = chord
 * dec = melody
* 'm2c' = next word input
 * enc = melody
 * dec = chord
* 'pos' = timestepped postional encoding. Gives the model a better reference when translating

Note: MultitaskTransformer trains both translations ('m2c' and 'c2m') at the same time. 

## 3. Initialize Model

In [None]:
# Load Data
batch_size = 2
bptt = 128

lm_data = load_data(data_path, data_save_name, 
                    bs=batch_size, bptt=bptt, encode_position=True,
                    dl_tfms=mask_lm_tfm_pitchdur)

s2s_data = load_data(data_path, s2s_data_save_name, 
                     bs=batch_size//2, bptt=bptt,
                     preloader_cls=S2SPreloader, dl_tfms=melody_chord_tfm)

# Combine both dataloaders so we can train multiple tasks at the same time
data = StackedDataBunch([lm_data, s2s_data])

In [None]:
# Create Model
config = multitask_config(); config

learn = multitask_model_learner(data, config.copy())
# learn.to_fp16(dynamic=True) # Enable for mixed precision

In [None]:
learn.model

# 4. Train

In [None]:
learn.fit_one_cycle(4)

In [None]:
learn.save('example')

## Predict

---
See [Generate.ipynb](Generate.ipynb) to use a pretrained model and generate better predictions

---

In [None]:
# midi_files = get_files(midi_path, '.mid', recurse=True)
midi_file = Path('data/midi/notebook_examples/single_bar_example.mid'); midi_file

In [None]:
next_word = nw_predict_from_midi(learn, midi_file, n_words=20, seed_len=8); next_word.show()

In [None]:
pred_melody = s2s_predict_from_midi(learn, midi_file, n_words=20, seed_len=4, pred_melody=True); pred_melody.show()

In [None]:
pred_notes = mask_predict_from_midi(learn, midi_file, predict_notes=True); pred_notes.show()