In [1]:
#default_exp multitask_transformer.learner

In [2]:
#export
from fastai.basics import *
from musicautobot.vocab import *
from musicautobot.numpy_encode import SAMPLE_FREQ
from musicautobot.utils.top_k_top_p import top_k_top_p
from musicautobot.utils.midifile import is_empty_midi
from musicautobot.music_transformer.transform import *
from musicautobot.music_transformer.learner import filter_invalid_indexes
from musicautobot.multitask_transformer.model import MTTTransformer
from musicautobot.multitask_transformer.dataloader import *
from musicautobot.multitask_transformer.transform import *

In [3]:
from typing import Dict

In [11]:
from fastai.losses import CrossEntropyLossFlat

In [4]:
#export
class MultitaskLearner(Learner):
    def predict_nw(self, item:MusicItem, n_words:int=128,
                     temperatures:float=(1.0,1.0), min_bars=4,
                     top_k=30, top_p=0.6):
        "Return the `n_words` that come after `text`."
        self.model.reset()
        new_idx = []
        vocab = self.data.vocab
        x, pos = item.to_tensor(), item.get_pos_tensor()
        last_pos = pos[-1] if len(pos) else 0
        y = torch.tensor([0])

        start_pos = last_pos

        sep_count = 0
        bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
        vocab = self.data.vocab

        repeat_count = 0

        for i in progress_bar(range(n_words), leave=True):
            batch = { 'lm': { 'x': x[None], 'pos': pos[None] } }, y
            logits = self.pred_batch(batch=batch)['lm'][-1][-1]

            prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx

            # Temperature
            # Use first temperatures value if last prediction was duration
            temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
            repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
            temperature += repeat_penalty
            if temperature != 1.: logits = logits / temperature
                

            # Filter
            # bar = 16 beats
            filter_value = -float('Inf')
            if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value

            logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
            logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
            
            # Sample
            probs = F.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, 1).item()

            # Update repeat count
            num_choices = len(probs.nonzero().view(-1))
            if num_choices <= 2: repeat_count += 1
            else: repeat_count = repeat_count // 2

            if prev_idx==vocab.sep_idx: 
                duration = idx - vocab.dur_range[0]
                last_pos = last_pos + duration

                bars_pred = (last_pos - start_pos) // 16
                abs_bar = last_pos // 16
                # if (bars % 8 == 0) and (bars_pred > min_bars): break
                if (i / n_words > 0.80) and (abs_bar % 4 == 0): break


            if idx==vocab.bos_idx: 
                print('Predicted BOS token. Returning prediction...')
                break

            new_idx.append(idx)
            x = x.new_tensor([idx])
            pos = pos.new_tensor([last_pos])

        pred = MusicItem(np.array(new_idx), vocab)
        full = item.append(pred)
        return pred, full

    def predict_mask(self, masked_item:MusicItem,
                    temperatures:float=(1.0,1.0),
                    top_k=20, top_p=0.8):
        x = masked_item.to_tensor()
        pos = masked_item.get_pos_tensor()
        y = torch.tensor([0])
        vocab = self.data.vocab
        self.model.reset()
        mask_idxs = (x == vocab.mask_idx).nonzero().view(-1)

        repeat_count = 0

        for midx in progress_bar(mask_idxs, leave=True):
            prev_idx = x[midx-1]

            # Using original positions, otherwise model gets too off track
            # pos = torch.tensor(-position_enc(xb[0].cpu().numpy()), device=xb.device)[None]
    
            # Next Word
            logits = self.pred_batch(batch=({ 'msk': { 'x': x[None], 'pos': pos[None] } }, y) )['msk'][0][midx]

            # Temperature
            # Use first temperatures value if last prediction was duration
            temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
            repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
            temperature += repeat_penalty
            if temperature != 1.: logits = logits / temperature

            # Filter
            filter_value = -float('Inf')
            special_idxs = [vocab.bos_idx, vocab.sep_idx, vocab.stoi[EOS]]
            logits[special_idxs] = filter_value # Don't allow any special tokens (as we are only removing notes and durations)
            logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
            logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)

            # Sampling
            probs = F.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, 1).item()

            # Update repeat count
            num_choices = len(probs.nonzero().view(-1))
            if num_choices <= 2: repeat_count += 1
            else: repeat_count = repeat_count // 2

            x[midx] = idx

        return MusicItem(x.cpu().numpy(), vocab)

    def predict_s2s(self, input_item:MusicItem, target_item:MusicItem, n_words:int=256,
                        temperatures:float=(1.0,1.0), top_k=30, top_p=0.8,
                        use_memory=True):
        vocab = self.data.vocab
        
        # Input doesn't change. We can reuse the encoder output on each prediction
        with torch.no_grad():
            inp, inp_pos = input_item.to_tensor(), input_item.get_pos_tensor()
            x_enc = self.model.encoder(inp[None], inp_pos[None])
        
        # target
        targ = target_item.data.tolist()
        targ_pos = target_item.position.tolist()
        last_pos = targ_pos[-1]
        self.model.reset()

        repeat_count = 0

        max_pos = input_item.position[-1] + SAMPLE_FREQ * 4 # Only predict until both tracks/parts have the same length
        x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
        
        for i in progress_bar(range(n_words), leave=True):
            # Predict
            with torch.no_grad():
                dec = self.model.decoder(x[None], pos[None], x_enc)
                logits = self.model.head(dec)[-1, -1]

            # Temperature
            # Use first temperatures value if last prediction was duration
            prev_idx = targ[-1] if len(targ) else vocab.pad_idx
            temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
            repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
            temperature += repeat_penalty
            if temperature != 1.: logits = logits / temperature
                
            # Filter
            filter_value = -float('Inf')
            logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
            logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)

            # Sample
            probs = F.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, 1).item()

            # Update repeat count
            num_choices = len(probs.nonzero().view(-1))
            if num_choices <= 2: repeat_count += 1
            else: repeat_count = repeat_count // 2

            if idx == vocab.bos_idx | idx == vocab.stoi[EOS]: 
                print('Predicting BOS/EOS')
                break

            if prev_idx == vocab.sep_idx: 
                duration = idx - vocab.dur_range[0]
                last_pos = last_pos + duration
                if last_pos > max_pos:
                    print('Predicted past counter-part length. Returning early')
                    break

            targ_pos.append(last_pos)
            targ.append(idx)
            
            if use_memory:
                # Relying on memory for kv. Only need last prediction index
                x, pos = inp.new_tensor([targ[-1]]), inp_pos.new_tensor([targ_pos[-1]])
            else:
                # Reset memory after each prediction, since we feeding the whole sequence every time
                self.model.reset()
                x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)

        return MusicItem(np.array(targ), vocab)
    

In [5]:
#export
# High level prediction functions from midi file
def nw_predict_from_midi(learn, midi=None, n_words=400, 
                      temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
    vocab = learn.dls.vocab
    seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
    if seed_len is not None: seed = seed.trim_to_beat(seed_len)
        
    pred, full = learn.predict_nw(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
    return full

def s2s_predict_from_midi(learn, midi=None, n_words=200, 
                      temperatures=(1.0,1.0), top_k=24, top_p=0.7, seed_len=None, pred_melody=True, **kwargs):
    multitrack_item = MultitrackItem.from_file(midi, learn.dls.vocab)
    melody, chords = multitrack_item.melody, multitrack_item.chords
    inp, targ = (chords, melody) if pred_melody else (melody, chords)
    
    # if seed_len is passed, cutoff sequence so we can predict the rest
    if seed_len is not None: targ = targ.trim_to_beat(seed_len)
    targ = targ.remove_eos()
        
    pred = learn.predict_s2s(inp, targ, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
    
    part_order = (pred, inp) if pred_melody else (inp, pred)
    return MultitrackItem(*part_order)

def mask_predict_from_midi(learn, midi=None, predict_notes=True,
                           temperatures=(1.0,1.0), top_k=30, top_p=0.7, section=None, **kwargs):
    item = MusicItem.from_file(midi, learn.dls.vocab)
    masked_item = item.mask_pitch(section) if predict_notes else item.mask_duration(section)
    pred = learn.predict_mask(masked_item, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
    return pred

# LOSS AND METRICS

class MultiLoss():
    def __init__(self, ignore_index=None):
        "Loss mult - Mask, NextWord, Seq2Seq"
        self.loss = CrossEntropyFlat(ignore_index=ignore_index)
        
    def __call__(self, inputs:Dict[str,Tensor], targets:Dict[str,Tensor]):
        losses = [self.loss(inputs[key], target) for key,target in targets.items()]
        return sum(losses)
    
def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx):
    if input is None or targ is None: return None
    n = targ.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targ = targ.view(n,-1)
    mask = targ != pad_idx
    return (input[mask]==targ[mask]).float().mean()

def acc_index(inputs, targets, key, pad_idx):
    return acc_ignore_pad(inputs.get(key), targets.get(key), pad_idx)
    
def mask_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'msk', pad_idx)
def lm_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'lm', pad_idx)
def c2m_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'c2m', pad_idx)
def m2c_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'm2c', pad_idx)



# Cell
class AvgMultiMetric(AvgMetric):
    "Average the values of `func` taking into account potential different batch sizes"
    def accumulate(self, learn):
        val = learn.to_detach(self.func(learn.pred, *learn.yb))
        if val is None: return
        bs = find_bs(learn.yb)
        self.total += val*bs
        self.count += bs

    
    
# class AvgMultiMetric(AvgMetric):
#     "Updated fastai.AverageMetric to support multi task metrics."
#     def on_batch_end(self, last_output, last_target, **kwargs):
#         "Update metric computation with `last_output` and `last_target`."
#         if not is_listy(last_target): last_target=[last_target]
#         val = self.func(last_output, *last_target)
#         if val is None: return
#         self.count += first_el(last_target).size(0)
#         if self.world:
#             val = val.clone()
#             dist.all_reduce(val, op=dist.ReduceOp.SUM)
#             val /= self.world
#         self.val += first_el(last_target).size(0) * val.detach().cpu()

#     def on_epoch_end(self, last_metrics, **kwargs):
#         "Set the final result in `last_metrics`."
#         if self.count == 0: return add_metrics(last_metrics, 0)
#         return add_metrics(last_metrics, self.val/self.count)


In [6]:
#export
# MODEL LOADING
class MTTrainer(Callback):
    "`Callback` that regroups lr adjustment to seq_len, AR and TAR."
    def __init__(self, learn:Learner, dataloaders=None, starting_mask_window=1):
        super().__init__(learn)
        self.count = 1
        self.mw_start = starting_mask_window
        self.dataloaders = dataloaders

    def on_epoch_begin(self, **kwargs):
        "Reset the hidden state of the model."
        model = get_model(self.learn.model)
        model.reset()
        model.encoder.mask_steps = max(self.count+self.mw_start, 100)
        
    def on_epoch_end(self, last_metrics, **kwargs):
        "Finish the computation and sends the result to the Recorder."
        if self.dataloaders is not None: 
            self.learn.dls = self.dataloaders[self.count % len(self.dataloaders)]
        self.count += 1

In [7]:
base_path = Path('../../data/v20')

# Location of your midi files
midi_path = base_path/'midi_sources/hooktheory'

# Location to save dataset
s2s_path = base_path/'s2s_encode/hooktheory'
lm_path = base_path/'piano_duet/hooktheory'

In [8]:
vocab = MusicVocab.create()
s2s_files = get_files(s2s_path, '.npy', recurse=True)[:400]
lm_files = get_files(lm_path, '.npy', recurse=True)[:400]

len(s2s_files), len(lm_files)

(0, 0)

In [9]:
from musicautobot.music_transformer.all import *
from fastai.text.all import *

In [10]:
bs,seq_len = 16,512
s2s_tfms = [S2SFileTfm(vocab), partial(mtt2tensor, seq_len=seq_len)]
splits = RandomSplitter(seed=42)(range(len(s2s_files)))
s2s_dsets = Datasets(s2s_files, [s2s_tfms], splits=splits)

s2s_dls = s2s_dsets.dataloaders(
    bs=bs, seq_len=seq_len,
    after_batch=[M2CTransform()]
)


lm_tfms = [MusicItemTfm(vocab), rand_transpose, mi2tensor]
splits = RandomSplitter(seed=42)(range(len(lm_files)))
lm_dsets = Datasets(lm_files, [lm_tfms], splits=splits, dl_type=LMDataLoader)

lm_dls = lm_dsets.dataloaders(bs=bs, seq_len=seq_len, cache=bs*4,
                        after_batch=[MaskLMTransform(vocab)],
                       ).cuda()

TypeError: 'NoneType' object is not iterable

In [38]:
#export
def multitask_model_learner(data:DataLoaders, config:dict=None, drop_mult:float=1., 
                            pretrained_path:Path=None, **learn_kwargs) -> 'LanguageLearner':
    "Create a `Learner` with a language model from `data` and `arch`."
    vocab = data.vocab
    vocab_size = len(vocab)

    if pretrained_path: 
        state = torch.load(pretrained_path, map_location='cpu')
        if config is None: config = state['config']

    model = MTTTransformer.from_config(vocab_size, config=config, drop_mult=drop_mult, pad_idx=vocab.pad_idx)
    metrics = [AvgMultiMetric(partial(m, pad_idx=vocab.pad_idx)) for m in [mask_acc, lm_acc, c2m_acc, m2c_acc]]
    loss_func = MultiLoss(ignore_index=data.vocab.pad_idx)
    learn = MultitaskLearner(data, model, loss_func=loss_func, metrics=metrics, **learn_kwargs)
    
    if pretrained_path: learn.load(pretrained_path)
        
    return learn

In [39]:

# class FlattenedLoss():
#     "Same as `func`, but flattens input and target."
#     def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):
#         self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d
#         functools.update_wrapper(self, self.func)

#     def __repr__(self): return f"FlattenedLoss of {self.func}"
#     @property
#     def reduction(self): return self.func.reduction
#     @reduction.setter
#     def reduction(self, v): self.func.reduction = v

#     @property
#     def weight(self): return self.func.weight
#     @weight.setter
#     def weight(self, v): self.func.weight = v

#     def __call__(self, input:Tensor, target:Tensor, **kwargs):
#         input = input.transpose(self.axis,-1).contiguous()
#         target = target.transpose(self.axis,-1).contiguous()
#         if self.floatify: target = target.float()
#         input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
#         return self.func.__call__(input, target.view(-1), **kwargs)

# def CrossEntropyFlat(*args, axis:int=-1, **kwargs):
#     "Same as `nn.CrossEntropyLoss`, but flattens input and target."
#     return FlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)

In [40]:
import musicautobot.utils.fastai_transformer as ft

# Create Model
config = ft.tfmer_lm_config.copy()

# Create Model
config = ft.tfmer_lm_config.copy()
config['mem_len'] = 512
config['enc_layers'] = 4
config['dec_layers'] = 4
config['bias'] = True
del config['n_layers']

learn = multitask_model_learner(lm_dls, config.copy())
# learn.to_fp16(dynamic=True) # Enable for mixed precision

In [41]:
from typing import Callable

In [262]:
# DataLoading
class StackedDLs():
    def __init__(self, dls, num_it=100):
        self.dls = dls
        self.train = StackedDL([dls.train for dls in self.dls], num_it)
        self.valid = StackedDL([dls.valid for dls in self.dls], num_it)
        self.path = dls[0].path
        self.device = dls[0].device
        self.vocab = dls[0].vocab
        self.empty_val = False
        self.n_inp = 1
        
    def __getitem__(self, i):
        return [self.train, self.valid][i]
    def __getattr__(self, attr):
        def redirected(*args, **kwargs):
            vals = []
            for dl in self.dls:
                if hasattr(dl, attr):
                    val = getattr(dl, attr)(*args, **kwargs)
                    vals.append(val)
            if vals: return vals[0]
        if any([hasattr(dl, attr) for dl in self.dls]): return redirected
        raise AttributeError
    
class StackedDL():
    def __init__(self, dls, num_it=100):
        self.dls = dls
        self.num_it = num_it
        self.dl_idx = -1
        self.n_inp = 1
        
    def __len__(self)->int: return sum([len(dl) for dl in self.dls])
    def __getattr__(self, attr):
        def redirected(*args, **kwargs):
            vals = []
            for dl in self.dls:
                if hasattr(dl, attr):
                    val = getattr(dl, attr)(*args, **kwargs)
                    vals.append(val)
            if vals: return vals[0]
        if any([hasattr(dl, attr) for dl in self.dls]): return redirected
        raise AttributeError
        
    def __iter__(self):
        "Process and returns items from `DataLoader`."
        iters = [iter(dl) for dl in self.dls]
        self.dl_idx = -1
        while len(iters):
            self.dl_idx = (self.dl_idx+1) % len(iters)
            for b in range(self.num_it):
                try:
                    yield next(iters[self.dl_idx])
                except StopIteration as e:
                    iters.remove(iters[self.dl_idx])
                    break
#         raise StopIteration

#     def new(self, **kwargs):
#         "Create a new copy of `self` with `kwargs` replacing current values."
#         new_dls = [dl.new(**kwargs) for dl in self.dls]
#         return StackedDataloader(new_dls, self.num_it)


In [263]:
??DataLoader

In [264]:
sdls = StackedDLs([s2s_dls, lm_dls])

In [265]:
learn.dls = sdls

In [266]:
b = lm_dls.train.one_batch()
i = lm_dls.train.n_inp
xb,yb = b[:i],b[i:]

In [267]:
b = sdls.one_batch()
i = sdls.n_inp
xb,yb = b[:i],b[i:]

In [268]:
learn.dl

In [269]:
learn.dls = sdls

In [270]:
learn.fit(1)

epoch,train_loss,valid_loss,mask_acc,lm_acc,c2m_acc,m2c_acc,time
0,6.339056,6.43009,0.186382,0.230989,0.24602,0.118107,00:26


In [271]:
#hide
from nbdev.export import notebook2script
notebook2script(recursive=True)

Converted config.ipynb.
Converted Train-before_cleanup.ipynb.
Converted Train.ipynb.
Converted dataloader.ipynb.
Converted learner.ipynb.
Converted model.ipynb.
Converted transform.ipynb.
Converted Train.ipynb.
Converted dataloader.ipynb.
Converted learner.ipynb.
Converted model.ipynb.
Converted Train-Scratch.ipynb.
Converted dataloader-reference.ipynb.
Converted dataloader-v1.ipynb.
Converted transform.ipynb.
Converted numpy_encode.ipynb.
Converted attention_mask.ipynb.
Converted env_setup.ipynb.
Converted fastai_transformer.ipynb.
Converted file_processing.ipynb.
Converted lamb.ipynb.
Converted midifile.ipynb.
Converted stacked_dataloader.ipynb.
Converted top_k_top_p.ipynb.
Converted vocab.ipynb.
