In [3]:
#default_exp multitask_transformer.dataloader

In [4]:
#export
from fastai.basics import *

from musicautobot.multitask_transformer.transform import *
from musicautobot.music_transformer.dataloader import *

In [7]:
from musicautobot.vocab import *

base_path = Path('../../data/v20')

# Location of your midi files
midi_path = base_path/'midi_sources/hooktheory'

# Location to save dataset
s2s_path = base_path/'s2s_encode/hooktheory'
lm_path = base_path/'piano_duet/hooktheory'

vocab = MusicVocab.create()
s2s_files = get_files(s2s_path, '.npy', recurse=True)[:400]
lm_files = get_files(lm_path, '.npy', recurse=True)[:400]

len(s2s_files), len(lm_files)

(400, 400)

### 2a. Create NextWord/Mask Dataset

In [10]:
#export

# DATALOADING AND TRANSFORMATIONS
# These transforms happen on batch

def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3):
    # mask range (min, max)
    # replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx]
    # p = replacement probability
    x,y = b
    x,y = x.clone(),y.clone()
    rand = torch.rand(x.shape, device=x.device)
    rand[x < mask_range[0]] = 1.0
    rand[x >= mask_range[1]] = 1.0
    
    # p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged
    y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics
    x[rand <= (p*.8)] = mask_idx # 80% = mask
    wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
    x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device)
    return x, y

def mask_lm_tfm_default(b, vocab, mask_p=0.3):
    return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)

def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9):
    mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range
    return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)

def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p):
    x,y = b
    x_lm,x_pos = x[...,0], x[...,1]
    y_lm,y_pos = y[...,0], y[...,1]
    
    # Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training
    x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p)
    msk_pos = y_pos
    
    x_dict = { 
        'msk': { 'x': x_msk, 'pos': msk_pos },
        'lm': { 'x': x_lm, 'pos': msk_pos }
    }
    y_dict = { 'msk': y_msk, 'lm': y_lm }
    return x_dict, y_dict


class MaskLMTransform(ItemTransform):
    def __init__(self, vocab, mask_p=0.5):
        self.vocab = vocab
        self.mask_p = mask_p
        
    def encodes(self, b):
        vocab = self.vocab
        return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=self.mask_p)

In [14]:
from fastai.text.all import LMDataLoader

In [15]:
# tfms = [MusicItemTfm(vocab), rand_transpose, lambda x: x.data]
tfms = [MusicItemTfm(vocab), rand_transpose, mi2tensor]
splits = RandomSplitter(seed=42)(range(len(lm_files)))
dsets = Datasets(lm_files, [tfms], splits=splits)

In [16]:
bs = 16
seq_len = 512

batch_tfms = [MaskLMTransform(vocab)]
dls = dsets.dataloaders(dl_type=LMDataLoader, 
                        bs=bs, seq_len=seq_len, cache=bs*4,
                        after_batch=batch_tfms,
                       ).cuda()

In [134]:
xb,yb = dls.one_batch()
xb['msk']['x'].shape

torch.Size([16, 512])

In [28]:
#export

# Sequence 2 Sequence Translate

class S2SFileTfm(Transform):
    def __init__(self, vocab):
        self.vocab = vocab
    
    def encodes(self, item):
        m,c = np.load(item, allow_pickle=True)
        return MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab)

class Midi2MultitrackTfm(Transform):
    "Converts midi files to multitrack items"
    def __init__(self, vocab):
        self.vocab = vocab
        
    def encodes(self, midi_file):
        try:
            item = MultitrackItem.from_file(midi_file, vocab=self.vocab)
        except Exception as e:
            print(e)
            return None
        return item

def mtt2tensor(mtt, seq_len):
    item = mtt.pad_to(seq_len+1)
    ((m_x, m_pos), (c_x, c_pos)) = item.to_idx()
    return m_x, m_pos, c_x, c_pos

def melody_chord_tfm(b):
    m,m_pos,c,c_pos = b
    
    # offset x and y for next word prediction
    y_m = m[:,1:]
    x_m, m_pos = m[:,:-1], m_pos[:,:-1]
    
    y_c = c[:,1:]
    x_c, c_pos = c[:,:-1], c_pos[:,:-1]
    
    x_dict = { 
        'c2m': {
            'enc': x_c,
            'enc_pos': c_pos,
            'dec': x_m,
            'dec_pos': m_pos
        },
        'm2c': {
            'enc': x_m,
            'enc_pos': m_pos,
            'dec': x_c,
            'dec_pos': c_pos
        }
    }
    y_dict = {
        'c2m': y_m, 'm2c': y_c
    }
    return x_dict, y_dict

class M2CTransform(ItemTransform):
    def encodes(self, b):
        return melody_chord_tfm(b[0])

In [29]:
bs,seq_len = 2,8
tfms = [S2SFileTfm(vocab), partial(mtt2tensor, seq_len=seq_len)]
# tfms = [S2SFileTfm(vocab), rand_transpose, mtt2tensor]
splits = RandomSplitter(seed=42)(range(len(s2s_files)))
dsets = Datasets(s2s_files, [tfms], splits=splits)

In [30]:
batch_tfms = [M2CTransform()]
dls = dsets.dataloaders(bs=bs, seq_len=seq_len,
                        after_batch=batch_tfms)


In [31]:
xb, yb = dls.one_batch(); xb

{'c2m': {'enc': tensor([[  5,   1,   8, 169,  59, 145,  56, 145],
          [  5,   1,  69, 145,  66, 145,  62, 145]], device='cuda:0'),
  'enc_pos': tensor([[ 0,  0,  0,  0, 32, 32, 32, 32],
          [ 0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0'),
  'dec': tensor([[  6,   1,  76, 139,   8, 139,  76, 139],
          [  6,   1,  74, 139,   8, 139,  78, 139]], device='cuda:0'),
  'dec_pos': tensor([[0, 0, 0, 0, 0, 0, 2, 2],
          [0, 0, 0, 0, 0, 0, 2, 2]], device='cuda:0')},
 'm2c': {'enc': tensor([[  6,   1,  76, 139,   8, 139,  76, 139],
          [  6,   1,  74, 139,   8, 139,  78, 139]], device='cuda:0'),
  'enc_pos': tensor([[0, 0, 0, 0, 0, 0, 2, 2],
          [0, 0, 0, 0, 0, 0, 2, 2]], device='cuda:0'),
  'dec': tensor([[  5,   1,   8, 169,  59, 145,  56, 145],
          [  5,   1,  69, 145,  66, 145,  62, 145]], device='cuda:0'),
  'dec_pos': tensor([[ 0,  0,  0,  0, 32, 32, 32, 32],
          [ 0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')}}

In [33]:
#hide
from nbdev.export import notebook2script
notebook2script(recursive=True)

Converted config.ipynb.
Converted Train-before_cleanup.ipynb.
Converted Train.ipynb.
Converted dataloader.ipynb.
Converted learner.ipynb.
Converted model.ipynb.
Converted transform.ipynb.
Converted Train.ipynb.
Converted dataloader.ipynb.
Converted learner.ipynb.
Converted model.ipynb.
Converted Train-Scratch.ipynb.
Converted dataloader-reference.ipynb.
Converted dataloader-v1.ipynb.
Converted transform.ipynb.
Converted numpy_encode.ipynb.
Converted attention_mask.ipynb.
Converted env_setup.ipynb.
Converted fastai_transformer.ipynb.
Converted file_processing.ipynb.
Converted lamb.ipynb.
Converted midifile.ipynb.
Converted stacked_dataloader.ipynb.
Converted top_k_top_p.ipynb.
Converted vocab.ipynb.
