# Attention Architecture

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import music21

In [3]:
# from fastai.text import *
from enum import Enum
import torch
from fastai.text.models.awd_lstm import *
from fastai.text.models.transformer import *

In [4]:
import numpy as np
import torch.nn as nn

In [5]:
np.set_printoptions(edgeitems=10, threshold=40, linewidth=200)

In [6]:
import sys
sys.path.insert(0, '../../')
from src.fastai_data import *
from src.encode_data import *
from src.serve import *

In [7]:
from src.lmnp_transformer import *

In [8]:
path = Path('../../data/midi/v14/piano_duet/')

## Single Stream Encoding

In [9]:
config = v14s_config(vocab); config

{'ctx_len': 150,
 'n_layers': 16,
 'n_heads': 8,
 'd_model': 256,
 'd_head': 32,
 'd_inner': 2048,
 'resid_p': 0.1,
 'attn_p': 0.1,
 'ff_p': 0.1,
 'embed_p': 0.1,
 'output_p': 0.1,
 'bias': False,
 'scale': True,
 'act': <Activation.GeLU: 3>,
 'double_drop': True,
 'tie_weights': True,
 'out_bias': True,
 'init': <function fastai.text.models.transformer.init_transformer(m)>,
 'mem_len': 512,
 'mask': True,
 'pad_idx': 1,
 'bos_idx': 0,
 'sep_idx': 8,
 'transpose_range': (0, 12),
 'note_range': (9, 138),
 'bs': 16,
 'bptt': 256,
 'vocab_size': 274}

In [10]:
config['mem_len'] = 0

## Fastai Learner

In [11]:
# def lm_tfm()
# def load_music_data_2(path, cache_name, vocab, transpose_range=(0,1), **kwargs):
#     transpose_tfm = partial(rand_transpose, note_range=vocab.note_range, rand_range=transpose_range)
#     single_tfm = partial(to_single_stream, vocab=vocab)
#     data = MusicDataBunch.load(path=path, cache_name=cache_name, **kwargs, 
#                               train_tfms=[single_tfm, transpose_tfm], valid_tfms=[single_tfm])
#     return data

In [12]:
data = load_music_data(path, cache_name='tmp/sample', vocab=vocab, y_offset=0, **config)

DLTFMS: None


In [13]:
xb,yb = next(iter(data.train_dl))

In [14]:
xb,yb = data.one_batch(cpu=False)

In [15]:
xb

tensor([[ 84, 143,  60,  ..., 140,  72, 140],
        [ 90, 143,  66,  ..., 140,  68, 140],
        [ 86, 143,  62,  ..., 140,  70, 140],
        ...,
        [ 86, 143,  62,  ..., 140,  72, 140],
        [ 86, 143,  62,  ..., 140,  73, 140],
        [ 84, 143,  60,  ..., 140,  72, 140]], device='cuda:0')

In [16]:
yb

tensor([[ 84, 143,  60,  ..., 140,  72, 140],
        [ 90, 143,  66,  ..., 140,  68, 140],
        [ 86, 143,  62,  ..., 140,  70, 140],
        ...,
        [ 86, 143,  62,  ..., 140,  72, 140],
        [ 86, 143,  62,  ..., 140,  73, 140],
        [ 84, 143,  60,  ..., 140,  72, 140]], device='cuda:0')

## LMNP

In [17]:
m_len = 0
x_len = 16 # bptt
seq_len = m_len+x_len
torch.triu(torch.ones(x_len, seq_len), diagonal=m_len).byte()[None,None].cpu().numpy()

array([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]]], dtype=uint8)

In [18]:
torch.triu(torch.ones(x_len, seq_len), diagonal=m_len+1).byte()[None,None].cpu().numpy()

array([[[[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]], dtype=uint8)

In [20]:
class MusicAttention(MultiHeadRelativeAttention):
    def __init__(self, n_heads:int, d_model:int, d_head:int, resid_p:float=0., attn_p:float=0., bias:bool=True,
                 scale:bool=True, residual:bool=True):
        super().__init__(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
        self.residual = residual
        
    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        if self.residual: return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))
        return self.ln(self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))
#         return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))
    
class MusicTransformerXL(nn.Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=False, scale:bool=True,
                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MusicAttention,
                 learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0, **kwargs):
        super().__init__()
        self.encoder = nn.Embedding(vocab_sz, d_model)
        self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
        self.drop_emb = nn.Dropout(embed_p)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask
        self.init = False
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, 
                      attn_cls=attn_cls) for k in range(n_layers)])
        
        self.layers[0].mhra.residual = False
    
    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()
    
    def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]
    
    def forward(self, x):
        #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
        if self.mem_len > 0 and not self.init: 
            self.reset()
            self.init = True
        bs,x_len = x.size()
        inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
#         mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=m_len+1).byte()[None,None] if self.mask else None
        mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=m_len).byte()[None,None] if self.mask else None
        if m_len == 0: mask[0,0] = 0
#         mask = None
        #[None,:,:None] for einsum implementation of attention
        hids = []
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i] if self.mem_len > 0 else None
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        if self.mem_len > 0 : self._update_mems(hids)
        return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]


In [21]:
# learn.model[0].layers[0].mhra.residual

In [22]:
from fastai.text.learner import _model_meta

In [23]:
_model_meta[MusicTransformerXL] = _model_meta[TransformerXL]
_model_meta[MusicTransformerXL]['config_lm'] = config

In [24]:
config

{'ctx_len': 150,
 'n_layers': 16,
 'n_heads': 8,
 'd_model': 256,
 'd_head': 32,
 'd_inner': 2048,
 'resid_p': 0.1,
 'attn_p': 0.1,
 'ff_p': 0.1,
 'embed_p': 0.1,
 'output_p': 0.1,
 'bias': False,
 'scale': True,
 'act': <Activation.GeLU: 3>,
 'double_drop': True,
 'tie_weights': True,
 'out_bias': True,
 'init': <function fastai.text.models.transformer.init_transformer(m)>,
 'mem_len': 0,
 'mask': True,
 'pad_idx': 1,
 'bos_idx': 0,
 'sep_idx': 8,
 'transpose_range': (0, 12),
 'note_range': (9, 138),
 'bs': 16,
 'bptt': 256,
 'vocab_size': 274}

In [25]:
def get_music_model(arch:Callable, vocab_sz:int, config:dict=None, drop_mult:float=1.):
    "Create a language model from `arch` and its `config`, maybe `pretrained`."
    meta = _model_meta[arch]
    config = ifnone(config, meta['config_lm'].copy())
    for k in config.keys(): 
        if k.endswith('_p'): config[k] *= drop_mult
#     tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])
    tie_weights,output_p,out_bias = map(config.get, ['tie_weights', 'output_p', 'out_bias'])
    init = config.pop('init') if 'init' in config else None
    encoder = arch(vocab_sz, **config)
    enc = encoder.encoder if tie_weights else None
    decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias)
    model = SequentialRNN(encoder, decoder)
    return model if init is None else model.apply(init)


def music_model_learner(arch:Callable, data:DataBunch, config:dict=None, drop_mult:float=1., pretrained:bool=False,
                        pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner':
    "Create a `Learner` with a language model from `data` and `arch`."
    model = get_music_model(arch, config['vocab_size'], config=config, drop_mult=drop_mult)
    
    meta = _model_meta[arch]
    learn = MusicLearner(data, model, split_func=meta['split_lm'], 
                         bos_idx=config['bos_idx'], sep_idx=config['sep_idx'],
                        **learn_kwargs)
    
    if pretrained:
        if 'url' not in meta: 
            warn("There are no pretrained weights for that architecture yet!")
            return learn
        model_path = untar_data(meta['url'], data=False)
        fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
        learn.load_pretrained(*fnames)
        learn.freeze()
    if pretrained_fnames is not None:
        fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]
        learn.load_pretrained(*fnames)
        learn.freeze()
    return learn

## Load

In [26]:
learn = music_model_learner(MusicTransformerXL, data, config, drop_mult=0)

Sep_idx: 8


In [27]:
xb,yb = data.one_batch(cpu=False)

In [28]:
xb.shape, yb.shape

(torch.Size([16, 256]), torch.Size([16, 256]))

In [29]:
learn.loss_func

FlattenedLoss of CrossEntropyLoss()

In [30]:
learn.validate()

[5.6648936, tensor(0.)]

In [31]:
# learn.lr_find()
# learn.recorder.plot()

In [32]:
# BERT - y_offset = 0, residual first, mask = no self 
learn.fit_one_cycle(3, 1e-4)

epoch,train_loss,valid_loss,accuracy,time
0,1.952899,1.515398,0.706836,01:15
1,0.258453,0.219606,0.99375,01:15
2,0.198986,0.179661,0.995068,01:15


In [None]:
# y_offset = 1, residual only, mask = self
learn.fit_one_cycle(3, 1e-4)

In [30]:
# y_offset = 1, disable residual for first layer only, mask = no self
learn.fit_one_cycle(3, 1e-4)

epoch,train_loss,valid_loss,accuracy,time
0,3.072333,3.064499,0.268164,01:58
1,1.899065,1.958865,0.48667,01:58
2,1.832989,1.840047,0.512012,01:58


In [30]:
# disable residual for first layer only
learn.fit_one_cycle(3, 1e-4)

epoch,train_loss,valid_loss,accuracy,time
0,2.668824,2.490198,0.426465,01:57
1,0.657023,0.509421,0.936182,01:57
2,0.399431,0.330187,0.955176,01:57


In [30]:
# bert - lm mask
learn.fit_one_cycle(3, 1e-4)

epoch,train_loss,valid_loss,accuracy,time
0,3.304801,3.487209,0.062207,01:00
1,3.249355,3.447737,0.062207,01:00
2,3.252849,3.417712,0.062207,01:00


In [29]:
# no mask
learn.fit_one_cycle(3, 1e-4)

epoch,train_loss,valid_loss,accuracy,time
0,2.02257,1.841922,0.615723,01:53
1,1.089564,0.795155,0.92417,01:53
2,0.343327,0.300019,0.975391,01:53


In [29]:
# normal mask
learn.fit_one_cycle(1, 1e-4)

epoch,train_loss,valid_loss,accuracy,time
0,2.361891,2.432884,0.429639,01:58
