# Attention Architecture

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import music21

In [3]:
from fastai.text import *
from transformer_xl.default_txl import get_default_model
from enum import Enum
import torch
from fastai.text.models.awd_lstm import *
from fastai.text.models.transformer import *

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [4]:
from fastai_data import *

Failed to load FluidSynth. Must install if you want to convert to wav files.


In [5]:
import numpy as np
import torch.nn as nn

In [6]:
from ht_encode import *
import ht_encode

In [7]:
# bs=4
# bptt=64

bs=8
bptt=256

In [8]:
import fastai_data
fastai_data.Y_OFFSET=1

In [9]:
path = Path('data/midi/v7/midi_encode/np/hook_dur/')
ht_encode.config.continuous=False
data = LMNPDataBunch.load(path, bs=bs, bptt=bptt, cache_name='tmp/all')

In [10]:
# loss_func = CrossEntropy with parts

In [11]:
class TransformerEmbed(nn.Module):
    def __init__(self, vocab_map, vocab_sizes, d_model, emb_dim=100, embed_p:float=0.1, **kwargs):
        super().__init__()
        # note, octave, duration, instrument
#         assert(sum(emb_dims) == d_model)
        self.vocab_map = vocab_map
        embeddings = []
        
        out_d = 100
        for in_d in vocab_sizes:
            embeddings.append(nn.Embedding(in_d, out_d, padding_idx=PAD_IDX))
        self.embeddings = nn.ModuleList(embeddings[:-2])
        self.positions = nn.ModuleList(embeddings[-2:])
        self.drop_emb = nn.Dropout(embed_p)
        
    def forward(self, x):
        # batch x bptt x (n,o,d,i)
        embs = []
        pos = self.positions[0](x[:,:,iB]) + self.positions[1](x[:,:,iM])
        for i in range(x.shape[-1]-2):
            emb_idx = self.vocab_map[i]
            ex = self.embeddings[emb_idx](x[:,:,i])
            embs.append(ex+pos)
        emb = torch.cat(embs, dim=-1)
        return self.drop_emb(emb)

In [12]:
class TXLLinearDecoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_out:int, n_hid:int, tie_encoder:nn.Module=None, bias:bool=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input):
        return self.decoder(input)

In [13]:
class TransformerDec(nn.Module):
    def __init__(self, txl_emb, vocab_map, output_p=0.0, out_bias=True, **kwargs):
        super().__init__()
        self.output_dp = RNNDropout(output_p)
        
        start_range = 0
        self.ranges = []
        decoders = []
        for i in range(len(vocab_map)-2):
            emb = txl_emb.embeddings[vocab_map[i]]
            decoder = TXLLinearDecoder(emb.num_embeddings, emb.embedding_dim, tie_encoder=emb, bias=out_bias)
            decoders.append(decoder)
            end_range = start_range+emb.embedding_dim
            self.ranges.append([start_range, end_range])
            start_range = end_range
            
        self.decoders = nn.ModuleList(decoders)
        
    def forward(self, input):
        raw_outputs, outputs = input
        output = self.output_dp(outputs[-1])
        res = []
        for dec,(start,end) in zip(self.decoders, self.ranges):
            res.append(dec(output[:,:,start:end]))
        return res, raw_outputs, outputs

In [14]:

class LMNPTransformerXL(nn.Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, encoder, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=False, scale:bool=True,
                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,
                 learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0, **kwargs):
        super().__init__()
        self.encoder = encoder
        self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask
        if self.mem_len > 0: self.reset()
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, 
                      attn_cls=attn_cls) for k in range(n_layers)])
    
    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]
#         self.hidden = [next(self.parameters()).data.new(0).cuda() for i in range(self.n_layers+1)]

    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()
    
    def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]
    
    def forward(self, x):
        bs,x_len,_ = x.size()
        inp = self.encoder(x)
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).byte()[None,None] if self.mask else None
        #[None,:,:None] for einsum implementation of attention
        hids = []
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i] if self.mem_len > 0 else None
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        if self.mem_len > 0 : self._update_mems(hids)
        return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]

In [15]:
config = tfmerXL_lm_config

In [16]:
PAD_IDX=ht_encode.config.pad_idx+ht_encode.config.enc_offset
# must remember to add 1 to tensors

In [17]:
train_ids_file = path/'tmp/all/train_ids.npy'
all_ids = np.load(train_ids_file)
id_cat = np.concatenate(all_ids); id_cat.shape
max_vocab = id_cat.max(axis=0)
max_vocab = (max_vocab+1).tolist(); max_vocab

[15, 13, 132, 15, 15, 15, 15, 132, 7, 12]

In [18]:
idx2embidx = {
    0:0,1:1,2:2,
    3:0,4:0,5:0,6:0,7:2,
    8:3,
    9:4
}

In [19]:
idx_cat = [0,1,2,0,0,0,0,2,3,4]

In [20]:
max_vocab_sizes = [max_vocab[i] for i in [0,1,2,8,9]]

In [21]:
max_vocab_sizes

[15, 13, 132, 7, 12]

In [22]:
emb_dims = [100,100,100,100,100]

In [23]:
idx2dims = {k:(max_vocab_sizes[v],emb_dims[v]) for k,v in idx2embidx.items()}; idx2dims

{0: (15, 100),
 1: (13, 100),
 2: (132, 100),
 3: (15, 100),
 4: (15, 100),
 5: (15, 100),
 6: (15, 100),
 7: (132, 100),
 8: (7, 100),
 9: (12, 100)}

In [24]:
config['vocab_sizes'] = max_vocab_sizes
config['vocab_map'] = idx2embidx

In [25]:
config['emb_dims'] = emb_dims
# total_embs = sum([i[1] for i in idx2dims.values()])
# config['d_model'] = total_embs
config['d_model'] = 800

In [26]:
config['ctx_len'] = 0
config['mem_len'] = 256
config['d_inner'] = 1024

In [27]:
# config['mem_len'] = 0
# config['d_inner'] = 512

In [28]:
config

{'ctx_len': 0,
 'n_layers': 12,
 'n_heads': 10,
 'd_model': 800,
 'd_head': 41,
 'd_inner': 1024,
 'resid_p': 0.1,
 'attn_p': 0.1,
 'ff_p': 0.1,
 'embed_p': 0.1,
 'output_p': 0.1,
 'bias': False,
 'scale': True,
 'act': <Activation.ReLU: 1>,
 'double_drop': True,
 'tie_weights': True,
 'out_bias': True,
 'init': <function fastai.text.models.transformer.init_transformer(m)>,
 'mem_len': 256,
 'mask': True,
 'vocab_sizes': [15, 13, 132, 7, 12],
 'vocab_map': {0: 0, 1: 1, 2: 2, 3: 0, 4: 0, 5: 0, 6: 0, 7: 2, 8: 3, 9: 4},
 'emb_dims': [100, 100, 100, 100, 100]}

In [29]:
def get_language_model(config:dict=None, drop_mult:float=1.):
    "Create a language model from `arch` and its `config`, maybe `pretrained`."
    for k in config.keys(): 
        if k.endswith('_p'): config[k] *= drop_mult
    init = config.pop('init') if 'init' in config else None
    
    embed = TransformerEmbed(**config)
    txl = LMNPTransformerXL(embed, **config)
    decoder = TransformerDec(embed, **config)
    model = SequentialRNN(txl, decoder).cuda()
    
    return model if init is None else model.apply(init)


def language_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1., pretrained:bool=True,
                           **learn_kwargs) -> 'LanguageLearner':
    "Create a `Learner` with a language model from `data` and `arch`."
    model = get_language_model(config=config, drop_mult=drop_mult)
    learn = LanguageLearner(data, model, split_func=tfmer_lm_split, **learn_kwargs)
    return learn

In [30]:

class LMNPLoss(nn.Module):
    "Same as `func`, but flattens input and target."
    def __init__(self):
        super().__init__()
        self.fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX) 
        # not using func otherwise _loss_func_name2activ uses this attribute to get cross entropy loss

    def __repr__(self): return f"numpyenc loss of {self.fn}"

    def forward(self, inputs:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
        losses = []
        for idx,input in enumerate(inputs):
            if idx in BIDX_ALL: continue
            t = target[:,:,idx]
            input = input.view(-1,input.shape[-1])
            losses.append(self.fn(input, t.view(-1)))
        return sum(losses)

In [31]:
import pdb

In [43]:
def lmnp_accuracy(inputs:Tensor, target:Tensor)->Rank0Tensor:
    "Compute accuracy with `targs` when `input` is bs * n_classes."
    target = target[:,:,:-2]
    inputs = [i.argmax(dim=-1).unsqueeze(dim=-1) for i in inputs]
    input_cat = torch.cat(inputs, dim=-1).cpu().numpy()
    target = target.cpu().numpy()
    acc = (input_cat==target).astype(float)
    acc[target==PAD_IDX] = np.nan
    return torch.tensor(np.nanmean(acc, dtype=float))

In [44]:
learn = language_model_learner(data, config, clip=0.25, loss_func=LMNPLoss(), metrics=[lmnp_accuracy])

In [45]:
learn.validate()

[135.32268, tensor(0.3754)]

In [None]:
x,y = data.one_batch(cpu=False)
# emb = learn.model[0]
# a = torch.ones_like(x)
# emb.encoder(x)

# out = x[0].cpu().numpy()
# song = dec_arr(out)
# song.to_stream().show()

In [93]:
learn.model.cuda()
learn.model[0].reset()
out = learn.model(x)

In [94]:
# # clip = 0.5
# learn.lr_find(num_it=300)
# learn.recorder.plot()

In [46]:
learn.fit_one_cycle(10, 4e-4)

epoch,train_loss,valid_loss,lmnp_accuracy,time
1,13.227628,12.115812,0.748762,03:57
2,10.386578,10.137149,0.780244,03:57
3,9.491807,9.47389,0.779867,03:57
4,9.210784,9.242332,0.789214,03:57
5,9.157971,8.92519,0.79323,03:56
6,8.676299,8.686626,0.797475,03:56
7,8.382814,8.56877,0.801655,03:56


KeyboardInterrupt: 

In [47]:
learn.save('first_run')

In [48]:
from fastai import basic_train

In [49]:
def predict_func(parts): return [F.softmax(p, dim=-1) for p in parts]

In [50]:
loss_func_name = camel2snake(learn.loss_func.__class__.__name__)
basic_train.loss_func_name2activ[loss_func_name] = predict_func
basic_train._loss_func2activ(learn.loss_func)

<function __main__.predict_func(parts)>

In [51]:
def predict(self, xb, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',
            decoder=decode_spec_tokens):
    "Return the `n_words` that come after `text`."
    ds = self.data.single_dl.dataset
    self.model.reset()
    if xb.shape[0] > 1: xb = xb[0][None]
    yb = torch.ones_like(xb)
    new_idx = []
    for _ in range(n_words): #progress_bar(range(n_words), leave=False):
        timestep = []
        outputs = learn.pred_batch(batch=(xb,yb))
        for item in outputs: #progress_bar(range(n_words), leave=False):
            res = item[0][-1]
            if min_p is not None: 
                if (res >= min_p).float().sum() == 0:
                    warn(f"There is no item with probability >= {min_p}, try a lower value.")
                else: res[res < min_p] = 0.
            if temperature != 1.: res.pow_(1 / temperature)
            idx = torch.multinomial(res, 1).item()
            timestep.append(idx)
        new_idx.append(timestep)
        xb = xb.new_tensor(timestep).view(1,1,-1)
    return new_idx


In [52]:
xb,yb = learn.data.one_batch(cpu=False)

In [57]:
xb.shape

torch.Size([8, 256, 10])

In [53]:
out = predict(learn, xb, n_words=60)

IndexError: index 8 is out of bounds for dimension 2 with size 8

In [54]:
np.array(out)[:,iND]

NameError: name 'out' is not defined

In [55]:
xb.cpu().numpy().shape

(8, 256, 10)

In [56]:
song = dec_arr(np.array(out)); song

NameError: name 'out' is not defined

In [47]:
stream = song.to_stream()

In [48]:
stream.show('text')

{0.0} <music21.instrument.Piano Piano>
{0.0} <music21.stream.Part 0x7fe9db7b2748>
    {0.0} <music21.stream.Part 0x7fe9db7b26a0>
        {1.0} <music21.note.Note D>
        {1.5} <music21.note.Note D>
        {2.25} <music21.note.Note C>
        {2.5} <music21.note.Note C>
        {3.25} <music21.note.Note G>
        {3.5} <music21.note.Note C>
        {3.75} <music21.note.Note C>
        {4.25} <music21.note.Note C>
        {4.5} <music21.note.Note G>
        {4.75} <music21.note.Note C>
        {5.0} <music21.note.Note F>
        {5.25} <music21.note.Note C>
        {5.75} <music21.note.Note G>
        {6.25} <music21.note.Note C>
        {6.5} <music21.note.Note B>
        {6.75} <music21.note.Note C>
        {7.0} <music21.note.Note A>
        {7.25} <music21.note.Note C>
        {7.5} <music21.note.Note D>
        {7.75} <music21.note.Note B>
        {8.25} <music21.note.Note B>
        {8.5} <music21.note.Note E>
        {9.0} <music21.note.Note B>
        {9.5} <music21.note.Not

In [49]:
stream.show()

SubConverterFileIOException: png file of xml not found. Or file >999 pages?

In [50]:
stream.show('midi')