# Attention Architecture

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import music21

In [3]:
# from fastai.text import *
from enum import Enum
import torch
from fastai.text.models.awd_lstm import *
from fastai.text.models.transformer import *

In [4]:
import numpy as np
import torch.nn as nn

In [5]:
np.set_printoptions(edgeitems=10, threshold=40, linewidth=200)

In [6]:
import sys
sys.path.insert(0, '../../')
from src.fastai_data import *
from src.encode_data import *
from src.serve import *

In [7]:
from src.lmnp_transformer import *

In [8]:
path = Path('../../data/midi/v15/piano_duet/')

## Single Stream Encoding

In [9]:
config = v15s_config(vocab); config

{'ctx_len': 150,
 'n_layers': 16,
 'n_heads': 8,
 'd_model': 256,
 'd_head': 32,
 'd_inner': 2048,
 'resid_p': 0.1,
 'attn_p': 0.1,
 'ff_p': 0.1,
 'embed_p': 0.1,
 'output_p': 0.1,
 'bias': False,
 'scale': True,
 'act': <Activation.GeLU: 3>,
 'double_drop': True,
 'tie_weights': True,
 'out_bias': True,
 'init': <function fastai.text.models.transformer.init_transformer(m)>,
 'mem_len': 512,
 'mask': True,
 'pad_idx': 1,
 'bos_idx': 0,
 'sep_idx': 8,
 'transpose_range': (0, 12),
 'note_range': (9, 138),
 'bs': 16,
 'bptt': 256,
 'vocab_size': 274}

## Fastai Learner

In [10]:
dl_tfms = [mask_tfm, next_sentence_tfm]

In [11]:
data = load_music_data(path, cache_name='tmp/sample', vocab=vocab, y_offset=0, dl_tfms=dl_tfms, **config)

DLTFMS: [<function mask_tfm at 0x7f9916b08d90>, <function next_sentence_tfm at 0x7f9916b08d08>]


In [12]:
xb,yb = data.one_batch(cpu=False)

In [13]:
xb

tensor([[155,  68, 155,  ...,   8,   4,  52],
        [155,   4, 155,  ...,   4, 147,  52],
        [  4,  68, 155,  ...,   4, 147,  55],
        ...,
        [155,  72, 155,  ...,   4, 147,  57],
        [155,  68, 155,  ...,   8, 147,  53],
        [155,  68,  34,  ...,   4, 147,  52]], device='cuda:0')

In [14]:
yb

[tensor([[  1,   1,   1,  ...,   1, 147,   1],
         [  1,  71,   1,  ...,   8,   1,  52],
         [155,   1,   1,  ...,   8,   1,   1],
         ...,
         [  1,   1,   1,  ...,   8,   1,   1],
         [  1,   1,   1,  ...,   1,   1,  53],
         [  1,  68, 155,  ...,   8,   1,   1]], device='cuda:0'),
 tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')]

## LMNP

In [15]:
# m_len = 0
# x_len = 16 # bptt
# seq_len = m_len+x_len
# torch.triu(torch.ones(x_len, seq_len), diagonal=m_len).byte()[None,None].cpu().numpy()
# torch.triu(torch.ones(x_len, seq_len), diagonal=m_len+1).byte()[None,None].cpu().numpy()

In [16]:
import torch.nn as nn

In [17]:
class TransformerEmbedding(nn.Module):
    "Embedding + positional encoding + dropout"
    def __init__(self, vocab_sz:int, emb_sz:int, inp_p:float=0.):
        super().__init__()
        self.emb_sz = emb_sz
        self.embed = embedding(vocab_sz, emb_sz)
        self.pos_enc = PositionalEncoding(emb_sz)
        self.drop = nn.Dropout(inp_p)
    
    def forward(self, inp): 
        pos = torch.arange(0, inp.size(1), device=inp.device).float()
        return self.drop(self.embed(inp) * math.sqrt(self.emb_sz) + self.pos_enc(pos))

In [90]:
class MusicTransformer(nn.Module):
    "Transformer model: https://arxiv.org/abs/1706.03762."
    def __init__(self, embed:nn.Module, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True,
                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadAttention,
                 mask:bool=True, **kwargs):
        super().__init__()
        self.mask = mask
        self.encoder = embed
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, 
                      attn_cls=attn_cls) for k in range(n_layers)])
    
    def reset(self): pass
    
    def forward(self, x):
        bs, x_len, _ = x.size()
        mask = window_mask(x_len, x.device) if self.mask else None
#         mask = torch.triu(torch.(x_len, x_len), diagonal=1).byte()[None,None] if self.mask else None
        #[None,:,:None] for einsum implementation of attention
        for layer in self.layers: x = layer(x, mask=mask)
        return ([x],[x]) #For the LinearDecoder


In [91]:
# learn.model[0].layers[0].mhra.residual

In [92]:
from fastai.text.learner import _model_meta

In [93]:
config['mem_len'] = 0
config['mask'] = False

In [94]:
config

{'ctx_len': 150,
 'n_layers': 16,
 'n_heads': 8,
 'd_model': 256,
 'd_head': 32,
 'd_inner': 2048,
 'resid_p': 0.0,
 'attn_p': 0.0,
 'ff_p': 0.0,
 'embed_p': 0.0,
 'output_p': 0.0,
 'bias': False,
 'scale': True,
 'act': <Activation.GeLU: 3>,
 'double_drop': True,
 'tie_weights': True,
 'out_bias': True,
 'mem_len': 0,
 'mask': False,
 'pad_idx': 1,
 'bos_idx': 0,
 'sep_idx': 8,
 'transpose_range': (0, 12),
 'note_range': (9, 138),
 'bs': 16,
 'bptt': 256,
 'vocab_size': 274}

In [95]:
_model_meta[MusicTransformer] = _model_meta[Transformer]
_model_meta[MusicTransformer]['config_lm'] = config

In [96]:
class BertHead(nn.Module):
    def __init__(self, embed, encoder, mask_decoder, ns_decoder, s2s_decoder):
        super().__init__()
        self.embed = embed
        self.encoder = encoder
        self.mask_decoder = mask_decoder
        self.ns_decoder = ns_decoder
        self.s2s_decoder = s2s_decoder
        
    def forward(self, x, y=None):
        x_emb = self.embed(x)
        x_enc = self.encoder(x_emb)
        
        if y is None: # mask, and next sentence task
            return self.mask_decoder(x_enc), self.ns_decoder(x_enc)
        
        y_emb = self.embed(y)
        return self.mask_decoder(x_enc), self.s2s_decoder(x_enc, y_emb)
    
    def __getitem__(self, idx):
        return [self.encoder, self.mask_decoder, self.ns_decoder, self.s2s_decoder][idx]
        
    "A sequential module that passes the reset call to its children."
    def reset(self): pass

In [97]:
class S2SDecoderBlock(nn.Module):
    "Decoder block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
                 bias:bool=True, scale:bool=True, double_drop:bool=True, **kwargs):
        super().__init__()
        self.mha1 = MultiHeadAttention(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
        self.mha2 = MultiHeadAttention(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
        self.ff   = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop)
    
    def forward(self, x:Tensor, enc:Tensor, mask_in:Tensor=None, mask_out:Tensor=None): 
        y = self.mha1(x, x, x, mask_out)
        return self.ff(self.mha2(y, enc, enc, mask=mask_in))

In [98]:
def window_mask(x_len, device, m_len=0, size=(1,1)):
    mem_mask = np.zeros((x_len,m_len))
    tri_mask = np.triu(np.ones((x_len//win_size+1,x_len//win_size+1)),k=k)
    window_mask = tri_mask.repeat(win_size,axis=0).repeat(win_size,axis=1)[:x_len,:x_len]
    np_mask = np.concatenate((mem_mask, window_mask), axis=1)
    mask = torch.tensor(np_mask, device=device).byte()[None,None]
    return mask
    
def rand_window_mask(x_len,m_len,device,max_size=3,p=0.2,is_eval=False):
    if is_eval or m_len == 0 or np.random.rand() >= p: 
        win_size,k = (1,1)
    else: win_size,k = (np.random.randint(0,max_size)+1,0)
    return window_mask(x_len, device, m_len, size=(win_size,k))

In [99]:
def get_output_mask(inp, pad_idx:int=1):
    return torch.triu(inp.new_ones(inp.size(1),inp.size(1)), diagonal=1)[None,None].byte()

In [100]:
class S2SDecoder(nn.Module):
    def __init__(self, embed, n_hid, vocab_sz, n_layers, **kwargs):
        super().__init__()
        self.decoder = nn.ModuleList([S2SDecoderBlock(**kwargs) for _ in range(n_layers)])
        self.head = MusicLinearDecoder(n_hid, vocab_sz, tie_encoder=embed.embed, **kwargs)
        
#         self.pad_idx = pad_idx
        
    def forward(self, inp, out):
        x_len = inp.shape[-1]
#         mask_out = rand_window_mask(x_len, 0, inp.device, is_eval=not self.training)
        mask_out = window_mask(x_len, inp.device)
    
        out = self.embed(out)
        for dec_block in self.decoder: out = dec_block(out, inp, mask_in, mask_out)
        return self.head(out)


In [101]:

class MusicLinearDecoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True, **kwargs):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
        raw_outputs, outputs = input
        output = self.output_dp(outputs[-1])
        decoded = self.decoder(output)
        return decoded, raw_outputs, outputs


In [102]:
def get_music_model(vocab_sz:int, config:dict=None, drop_mult:float=1.):
    "Create a language model from `arch` and its `config`, maybe `pretrained`."
    for k in config.keys(): 
        if k.endswith('_p'): config[k] *= drop_mult
#     tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])
    tie_weights,output_p,out_bias = map(config.get, ['tie_weights', 'output_p', 'out_bias'])
    init = config.pop('init') if 'init' in config else None
    n_hid = config['d_model']
    embed = TransformerEmbedding(vocab_sz, n_hid, inp_p=config['embed_p'])
    encoder = MusicTransformer(embed=embed.embed, **config)
    mask_decoder = MusicLinearDecoder(n_hid, vocab_sz, output_p, tie_encoder=embed.embed, bias=out_bias)
    ns_decoder = MusicLinearDecoder(n_hid, 4, output_p, tie_encoder=None, bias=out_bias)
    s2s_decoder = S2SDecoder(embed, n_hid, vocab_sz, **config)
    model = BertHead(embed, encoder, mask_decoder, ns_decoder, s2s_decoder)
    return model if init is None else model.apply(init)


def music_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1., pretrained:bool=False,
                        pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner':
    "Create a `Learner` with a language model from `data` and `arch`."
    model = get_music_model(config['vocab_size'], config=config, drop_mult=drop_mult)
    
    meta = _model_meta[Transformer]
    learn = MusicLearner(data, model, split_func=meta['split_lm'], 
                         bos_idx=config['bos_idx'], sep_idx=config['sep_idx'],
                        **learn_kwargs)
    
    if pretrained:
        if 'url' not in meta: 
            warn("There are no pretrained weights for that architecture yet!")
            return learn
        model_path = untar_data(meta['url'], data=False)
        fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
        learn.load_pretrained(*fnames)
        learn.freeze()
    if pretrained_fnames is not None:
        fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]
        learn.load_pretrained(*fnames)
        learn.freeze()
    return learn

## Load

In [103]:
learn = music_model_learner(data, config, drop_mult=0)

Sep_idx: 8


In [104]:
xb,yb = data.one_batch(cpu=False)

In [105]:
# xb.shape, yb.shape

In [106]:
class BertLoss():
    def __init__(self, mask_loss, sent_loss):
        self.mask_loss = mask_loss
        self.sent_loss = sent_loss
        
    def __call__(self, input:Tensor, target:Tensor, target_sen:Tensor, **kwargs)->Rank0Tensor:
        m = self.mask_loss.__call__(input[0], target, **kwargs)
        s = self.sent_loss.__call__(input[1], target_sen, **kwargs)
        return m + s

In [107]:
class BertTrainer(LearnerCallback):
    "`Callback` that regroups lr adjustment to seq_len, AR and TAR."
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_loss_begin(self, last_output:Tuple[Tensor,Tensor,Tensor], **kwargs):
        "Save the extra outputs for later and only returns the true output."
        return {'last_output': (last_output[0][0], last_output[1][0]) }


In [108]:
type(learn.callbacks)

list

In [109]:
# learn.callbacks = [BertTrainer(learn, alpha=2, beta=1)]
learn.callbacks = [BertTrainer(learn)]

In [110]:
learn.loss_func = BertLoss(CrossEntropyFlat(ignore_index=vocab.pad_idx), CrossEntropyFlat())

In [111]:
learn.model(xb)

((tensor([[[-0.0333, -0.2630, -0.0281,  ..., -0.1831, -0.0589,  0.1242],
           [-0.1453, -0.2368, -0.0192,  ..., -0.1638, -0.1066,  0.1291],
           [-0.1413, -0.3135, -0.0474,  ..., -0.0977, -0.0727,  0.1028],
           ...,
           [-0.1811, -0.1423,  0.0825,  ..., -0.1925, -0.0855,  0.0012],
           [-0.2093, -0.1797,  0.0943,  ..., -0.1873, -0.0806,  0.0069],
           [-0.2064, -0.1918,  0.0996,  ..., -0.1923, -0.0564,  0.0511]],
  
          [[-0.0359, -0.2567, -0.0225,  ..., -0.2064, -0.0432,  0.1188],
           [-0.1053, -0.2352, -0.0572,  ..., -0.1237, -0.1362,  0.1502],
           [-0.1404, -0.3143, -0.0448,  ..., -0.0981, -0.0744,  0.1048],
           ...,
           [-0.2232, -0.1973,  0.0741,  ..., -0.2047, -0.0783,  0.0179],
           [-0.2075, -0.1832,  0.0969,  ..., -0.1880, -0.0828,  0.0105],
           [-0.2437, -0.2103,  0.1251,  ..., -0.1699, -0.0869,  0.0407]],
  
          [[-0.0312, -0.2696, -0.0255,  ..., -0.1817, -0.0611,  0.1243],
           

In [112]:
def mask_acc(input:Tensor, t1:Tensor, t2:Tensor)->Rank0Tensor:
    n = t1.shape[0]
    input = input[0].argmax(dim=-1).view(n,-1)
    t1 = t1.view(n,-1)
    mask = t1 != vocab.pad_idx
    return (input[mask]==t1[mask]).float().mean()

def ns_acc(input:Tensor, t1:Tensor, t2:Tensor)->Rank0Tensor:
    return accuracy(input[1], t2[0])

In [113]:
learn.metrics = [mask_acc, ns_acc]

In [114]:
learn.validate()

[8.0582905, tensor(0.0147), tensor(0.0989)]

In [115]:
# learn.lr_find()
# learn.recorder.plot()

In [116]:
learn.fit_one_cycle(3, 1e-4)

epoch,train_loss,valid_loss,mask_acc,ns_acc,time
0,4.212763,4.351722,0.2464,0.366042,01:55
1,4.13758,4.07846,0.278793,0.398475,01:55
2,3.940586,3.869962,0.314668,0.481689,01:54


In [118]:
learn.model

BertHead(
  (embed): TransformerEmbedding(
    (embed): Embedding(274, 256)
    (pos_enc): PositionalEncoding()
    (drop): Dropout(p=0.0)
  )
  (encoder): MusicTransformer(
    (encoder): Embedding(274, 256)
    (layers): ModuleList(
      (0): DecoderLayer(
        (mhra): MultiHeadAttention(
          (attention): Linear(in_features=256, out_features=768, bias=False)
          (out): Linear(in_features=256, out_features=256, bias=False)
          (drop_att): Dropout(p=0.0)
          (drop_res): Dropout(p=0.0)
          (ln): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
        )
        (ff): SequentialEx(
          (layers): ModuleList(
            (0): Linear(in_features=256, out_features=2048, bias=True)
            (1): GeLU()
            (2): Dropout(p=0.0)
            (3): Linear(in_features=2048, out_features=256, bias=True)
            (4): Dropout(p=0.0)
            (5): MergeLayer()
            (6): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_af