In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
os.chdir('../../')

In [3]:
from src.numpy_encode import *
from src.utils.file_processing import process_all, process_file
from src.config import *
from src.music_transformer import *
from src.multitask_transformer import *
from src.utils.stacked_dataloader import StackedDataBunch

In [4]:
from fastai.text import *

## MultitaskTransformer Training

Multitask Training is an extension of [MusicTransformer](../music_transformer/Train.ipynb).

Instead a basic language model that predicts the next word...

We train on multiple tasks
* [Next Word](../music_transformer/Train.ipynb)
* [Bert Mask](https://arxiv.org/abs/1810.04805)
* [Sequence to Sequence Translation](http://jalammar.github.io/illustrated-transformer/)

This gives a more generalized model and also let's you do some really cool [predictions](Generate.ipynb)

## End to end training pipeline 

1. Create and encode dataset
2. Initialize Transformer MOdel
3. Train
4. Predict

In [5]:
# Location of your midi files
midi_path = Path('data/midi/examples')
midi_path.mkdir(parents=True, exist_ok=True)

# Location to save dataset
data_path = Path('data/numpy')
data_path.mkdir(parents=True, exist_ok=True)

data_save_name = 'musicitem_data_save.pkl'
s2s_data_save_name = 'multiitem_data_save.pkl'

## 1. Gather midi dataset

Make sure all your midi data is in `musicautobot/data/midi` directory

Here's a pretty good dataset with lots of midi data:  
https://www.reddit.com/r/datasets/comments/3akhxy/the_largest_midi_collection_on_the_internet/

Download the folder and unzip it to `data/midi`

## 2. Create dataset from MIDI files

In [6]:
midi_files = get_files(midi_path, '.mid', recurse=True); len(midi_files)

18

### 2a. Create NextWord/Mask Dataset

In [30]:
processors = [Midi2ItemProcessor()]
data = MusicDataBunch.from_files(midi_files, data_path, processors=processors, 
                                 encode_position=True, dl_tfms=mask_lm_tfm, 
                                 bptt=5, bs=2)
data.save(data_save_name)

In [33]:
xb, yb = data.one_batch(); xb

{'msk': {'x': tensor([[  4, 145,  61, 145,   4],
          [ 64, 145,  61, 145,   4]]), 'pos': tensor([[8, 8, 8, 8, 8],
          [8, 8, 8, 8, 8]])}, 'lm': {'x': tensor([[139,  64, 145,  61, 145],
          [139,  64, 145,  61, 145]]), 'pos': tensor([[8, 8, 8, 8, 8],
          [8, 8, 8, 8, 8]])}}

Key:
* 'msk' = masked input
* 'lm' = next word input
* 'pos' = timestepped postional encoding. This is in addition to relative positional encoding

Note: MultitaskTransformer trains on both the masked input ('msk') and next word input ('lm') at the same time.

The encoder is trained on the 'msk' data, while the decoder is trained on 'lm' data.



### 2b. Create sequence to sequence dataset

In [34]:
processors = [Midi2MultitrackProcessor()]
s2s_data = MusicDataBunch.from_files(midi_files, data_path, processors=processors, 
                                     preloader_cls=S2SPreloader, list_cls=S2SItemList,
                                     dl_tfms=melody_chord_tfm,
                                     bptt=5, bs=2)
s2s_data.save(s2s_data_save_name)

Structure

In [35]:
xb, yb = s2s_data.one_batch(); xb

{'c2m': {'enc': tensor([[  5,   1,  61, 145,  59],
          [  5,   1,  57, 153,  53]]), 'enc_pos': tensor([[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]), 'dec': tensor([[  6,   1,  85, 143,   8],
          [  6,   1,  77, 139,   8]]), 'dec_pos': tensor([[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]])}, 'm2c': {'enc': tensor([[  6,   1,  85, 143,   8],
          [  6,   1,  77, 139,   8]]), 'enc_pos': tensor([[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]), 'dec': tensor([[  5,   1,  61, 145,  59],
          [  5,   1,  57, 153,  53]]), 'dec_pos': tensor([[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]])}}

Key:
* 'c2m' = chord2melody translation
 * enc = chord
 * dec = melody
* 'm2c' = next word input
 * enc = melody
 * dec = chord
* 'pos' = timestepped postional encoding. Gives the model a better reference when translating

Note: MultitaskTransformer trains both translations ('m2c' and 'c2m') at the same time. 

## 3. Initialize Model

In [36]:
# Load Data
batch_size = 2
bptt = 128

lm_data = load_data(data_path, data_save_name, 
                    bs=batch_size, bptt=bptt, encode_position=True,
                    dl_tfms=mask_lm_tfm)

s2s_data = load_data(data_path, s2s_data_save_name, 
                     bs=batch_size//2, bptt=bptt,
                     preloader_cls=S2SPreloader, dl_tfms=melody_chord_tfm)

# Combine both dataloaders so we can train multiple tasks at the same time
data = StackedDataBunch([lm_data, s2s_data])

In [17]:
# Create Model
config = multitask_config(); config

learn = multitask_model_learner(data, config.copy())

In [None]:
learn.model

# 4. Train

In [None]:
learn.fit_one_cycle(4)

In [None]:
learn.save('example')

## Predict

---
See [Generate.ipynb](Generate.ipynb) to use a pretrained model and generate better predictions

---

In [None]:
# midi_files = get_files(midi_path, '.mid', recurse=True)
midi_file = Path('data/single_bar_example.mid'); midi_file

In [None]:
next_word = nw_predict_from_midi(learn, 'midi_file', n_words=20, seed_len=8); next_word.show()

In [None]:
pred_melody = s2s_predict_from_midi(learn, midi_file, n_words=20, seed_len=4, pred_melody=True); pred_melody.show()

In [None]:
pred_notes = mask_predict_from_midi(learn, midi_file, n_words=20, predict_notes=True); pred_notes.show()