<a href="https://colab.research.google.com/github/brianellis1997/Generative_Music_Notation_and_Sound/blob/main/User_Input.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Parent Paper Code Run
### Compose and Embellish: Well-Structured Piano Performance Generation via A Two-Stage Approach
#### 17 Sep 2022  ·  Shih-Lun Wu, Yi-Hsuan Yang

### Team 9: Brian Ellis,

In [None]:
# Clone repositories
!git clone https://github.com/slSeanWU/Compose_and_Embellish
!git clone https://github.com/brianellis1997/Generative_Music_Notation_and_Sound.git

# Install libraries
!pip install -r requirements.txt

# Install pre-trained transformers (15 min runtime)
!pip install git+https://github.com/cifkao/fast-transformers.git@39e726864d1a279c9719d33a95868a4ea2fb5ac5
!git clone https://huggingface.co/slseanwu/compose-and-embellish-pop1k7
!pip install miditoolkit

Cloning into 'Compose_and_Embellish'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 74 (delta 28), reused 38 (delta 7), pack-reused 0[K
Receiving objects: 100% (74/74), 52.25 KiB | 3.73 MiB/s, done.
Resolving deltas: 100% (28/28), done.
Cloning into 'Generative_Music_Notation_and_Sound'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (36/36), done.[K
Receiving objects: 100% (36/36), 23.38 KiB | 7.79 MiB/s, done.
remote: Total 36 (delta 14), reused 0 (delta 0), pack-reused 0[K
Resolving deltas: 100% (14/14), done.
[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
[0mCollecting git+https://github.com/cifkao/fast-transformers.git@39e726864d1a279c9719d33a95868a4ea2fb5ac5
  Cloning https://github.com/cifkao/fast-transformers.git (to revi

# Create Model

## Embedding

In [None]:
import math
import torch
from torch import nn
import torch.nn.functional as F

def generate_causal_mask(seq_len, device):
    mask = (torch.triu(torch.ones(seq_len, seq_len, device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask.requires_grad = False
    return mask

def generate_bidirectional_pad_mask(max_seqlen, batch_lens):
    mask = torch.zeros(len(batch_lens), max_seqlen, dtype=bool)
    for i, l in enumerate(batch_lens):
        mask[i, l:] = True
    return mask

def weight_init_normal(weight, normal_std):
  nn.init.normal_(weight, 0.0, normal_std)

def weight_init_orthogonal(weight, gain):
  nn.init.orthogonal_(weight, gain)

def bias_init(bias):
  nn.init.constant_(bias, 0.0)

def weights_init(m):
    classname = m.__class__.__name__
    # print ('[{}] initializing ...'.format(classname))

    if classname.find('Linear') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            weight_init_normal(m.weight, 0.01)
        if hasattr(m, 'bias') and m.bias is not None:
            bias_init(m.bias)
    elif classname.find('Embedding') != -1:
        if hasattr(m, 'weight'):
            weight_init_normal(m.weight, 0.01)
    elif classname.find('LayerNorm') != -1:
        if hasattr(m, 'weight'):
            nn.init.normal_(m.weight, 1.0, 0.01)
        if hasattr(m, 'bias') and m.bias is not None:
            bias_init(m.bias)
    elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
        if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
            weight_init_normal(m.cluster_weight, 0.01)
        if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
            bias_init(m.cluster_bias)
        if hasattr(m, 'out_projs'):
            for i in range(len(m.out_projs)):
                if m.out_projs[i] is not None:
                    weight_init_normal(m.out_projs[i], 0.02)
    elif classname.find('TXLDecoder') != -1:
        if hasattr(m, 'r_emb'):
            weight_init_normal(m.r_emb, 0.01)
        if hasattr(m, 'r_w_bias'):
            weight_init_normal(m.r_w_bias, 0.01)
        if hasattr(m, 'r_r_bias'):
            weight_init_normal(m.r_r_bias, 0.01)
        if hasattr(m, 'r_bias'):
            bias_init(m.r_bias)
    elif classname.find('LSTM') != -1:
        for param in m.parameters():
            if len(param.shape) >= 2:  # weights
                weight_init_orthogonal(param, 0.01)
            else:                      # biases
                bias_init(param)
    # else:
    #   print ('*** [ {:64} ] not initialized !!'.format(classname))


class SinusoidalPE(nn.Module):
    def __init__(self, d_embed, max_pos=20480):
        super(SinusoidalPE, self).__init__()
        self.d_embed = d_embed
        self.max_pos = max_pos

        pe = torch.zeros(max_pos, d_embed)
        position = torch.arange(0, max_pos, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, seq_len, bsz=None):
        pos_encoding = self.pe[:seq_len, :]

        if bsz is not None:
          pos_encoding = pos_encoding.expand(seq_len, bsz, -1)

        return pos_encoding

class WordEmbedding(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, emb_scale=0.5, pad_idx=None):
        super(WordEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed
        self.d_proj = d_proj
        self.emb_scale = d_proj ** emb_scale

        if pad_idx is None:
            pad_idx = n_token - 1

        self.emb_lookup = nn.Embedding(n_token, d_embed, padding_idx=pad_idx)
        if d_proj != d_embed:
            self.emb_proj = nn.Linear(d_embed, d_proj, bias=False)
        else:
            self.emb_proj = None

    def forward(self, inp_tokens):
        inp_emb = self.emb_lookup(inp_tokens)

        if self.emb_proj is not None:
            inp_emb = self.emb_proj(inp_emb)

        return inp_emb.mul_(self.emb_scale)

class OctaveAwarePitchEmbedding(nn.Module):
    def __init__(self, n_octave, d_embed, d_proj, idx2event,
                 emb_scale=0.5, n_chroma=12, min_pitch=12
        ):
        super(OctaveAwarePitchEmbedding, self).__init__()

        self.n_octave = n_octave
        self.n_chroma = n_chroma
        self.min_pitch = min_pitch

        self.d_embed = d_embed
        self.d_proj = d_proj
        self.emb_scale = d_proj ** emb_scale

        self.octave_emb_lookup = nn.Embedding(
            n_octave + 1, d_embed // 2, padding_idx=n_octave
        )
        self.chroma_emb_lookup = nn.Embedding(
            n_chroma + 1, d_embed // 2, padding_idx=n_chroma
        )

        if d_proj != d_embed:
            self.emb_proj = nn.Linear(d_embed, d_proj, bias=False)
        else:
            self.emb_proj = None

        self.octave_translate_dict, self.chroma_translate_dict =\
            self.make_idx_translate_dicts(idx2event)

    def make_idx_translate_dicts(self, idx2event):
        idx2event[ len(idx2event) ] = 'PAD_None'

        octave_dict = dict()
        chroma_dict = dict()
        for idx, ev in idx2event.items():
            if not 'Note_Pitch' in ev:
                octave_dict[idx] = self.n_octave
                chroma_dict[idx] = self.n_chroma
            else:
                pitch = int(ev.split('_')[-1])
                pitch -= self.min_pitch
                octave_dict[idx] = pitch // self.n_chroma
                chroma_dict[idx] = pitch % self.n_chroma

        return octave_dict, chroma_dict

    def forward(self, inp_tokens):
        # st = time.time()
        orig_device = inp_tokens.device

        octave_tokens = inp_tokens.clone().cpu()
        chroma_tokens = inp_tokens.clone().cpu()

        octave_tokens.apply_(self.octave_translate_dict.get)
        chroma_tokens.apply_(self.chroma_translate_dict.get)
        octave_tokens = octave_tokens.to(orig_device)
        chroma_tokens = chroma_tokens.to(orig_device)
        # print ('[mapping] {:.3f}'.format(time.time() - st))

        # st = time.time()
        octave_emb = self.octave_emb_lookup(octave_tokens)
        chroma_emb = self.chroma_emb_lookup(chroma_tokens)
        inp_emb = torch.cat([octave_emb, chroma_emb], dim=-1)

        # print ('[bedding] {:.3f}'.format(time.time() - st))

        if self.emb_proj is not None:
            inp_emb = self.emb_proj(inp_emb)

        return inp_emb.mul_(self.emb_scale)

def get_min_max_pitch_idx(idx2event):
    min_idx, max_idx = len(idx2event), 0

    for k, v in idx2event.items():
        if 'Note_Pitch' in v:
            min_idx = min(min_idx, k)
            max_idx = max(max_idx, k)

    return min_idx, max_idx

## Decoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# from .transformer_helpers import WordEmbedding


class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            return pos_emb[:,None,:]



class PositionwiseFF(nn.Module):
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
        super(PositionwiseFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.CoreNet = nn.Sequential(
            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )

        self.layer_norm = nn.LayerNorm(d_model)

        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = self.CoreNet(self.layer_norm(inp))

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = self.CoreNet(inp)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output


class MultiHeadCrossAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 pre_lnorm=False, **kwargs):
        super(MultiHeadCrossAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)
        self.scale = 1 / (d_head ** 0.5)
        self.pre_lnorm = pre_lnorm

    def forward(self, h, c, attn_mask=None, h_pos_embed=None, c_pos_embed=None):
        if self.pre_lnorm:
            ##### layer normalization
            c = self.layer_norm(c)

        if h_pos_embed is not None:
            h_ = h + self.drop(h_pos_embed)
        else:
            h_ = h
        if c_pos_embed is not None:
            c_ = c + self.drop(c_pos_embed)
        else:
            c_ = c

        head_q = self.q_net(h_)
        head_k, head_v = torch.chunk(self.kv_net(c_), 2, -1)

        head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
        head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
        head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)

        # print ('[cross inputs]', head_q.mean(), head_k.mean(), head_v.mean())

        # [qlen x klen x bsz x n_head]
        attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
        attn_score.mul_(self.scale)
        # print ('[attn score]', attn_score.mean(), attn_score.std())
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
            elif attn_mask.dim() == 3:
                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
        # print (attn_score[0, :128, :, 0])
        # print ('[masked attn score]', attn_score.mean(), attn_score.std())

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)
        attn_prob = attn_prob / (torch.sum(attn_prob, dim=1)[:, None, :, :] + 1e-8)
        # print (attn_prob[0, :128, :, 0])
        # print (torch.isnan(attn_prob).sum())
        # idx = torch.nonzero(torch.isnan(attn_prob))
        # for i in idx:
        #     print (i)
        # print ('[cross attn prob]', attn_prob.mean(), attn_prob.std())
        # exit()

        # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
        # print ('[cross attn vec]', attn_vec.mean(), attn_vec.std())

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = h + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(h + attn_out)
        # print ('[cross attn out]', attn_out.mean(), attn_out.std())

        return output


class MultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 pre_lnorm=False):
        super(MultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

    def forward(self, h, attn_mask=None, mems=None):
        ##### multihead attention
        # [hlen x bsz x n_head x d_head]

        if mems is not None:
            c = torch.cat([mems, h], 0)
        else:
            c = h

        if self.pre_lnorm:
            ##### layer normalization
            c = self.layer_norm(c)

        head_q = self.q_net(h)
        head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)

        head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
        head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
        head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)

        # [qlen x klen x bsz x n_head]
        attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
        attn_score.mul_(self.scale)
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
            elif attn_mask.dim() == 3:
                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)
        # attn_prob = attn_prob / (torch.sum(attn_prob, dim=1)[:, None, :, :] + 1e-8)

        # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = h + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(h + attn_out)

        return output


class RelMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, **kwargs):
        super(RelMultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

    def _parallelogram_mask(self, h, w, left=False):
        mask = torch.ones((h, w)).byte()
        m = min(h, w)
        mask[:m,:m] = torch.triu(mask[:m,:m])
        mask[-m:,-m:] = torch.tril(mask[-m:,-m:])

        if left:
            return mask
        else:
            return mask.flip(0)

    def _shift(self, x, qlen, klen, mask, left=False):
        if qlen > 1:
            zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
                                    device=x.device, dtype=x.dtype)
        else:
            zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)

        if left:
            mask = mask.flip(1)
            x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
        else:
            x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)

        x = x_padded.masked_select(mask[:,:,None,None]) \
                    .view(qlen, klen, x.size(2), x.size(3))

        return x

    def _rel_shift(self, x, zero_triu=False):
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

    def forward(self, w, r, attn_mask=None, mems=None):
        raise NotImplementedError


class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None, return_avg_attn=False):
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head

        rr_head_q = w_head_q + r_r_bias
        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
        # print ('[masked self attn score]', attn_score.mean(), attn_score.std())


        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        if return_avg_attn:
            avg_attn_prob = attn_prob.mean(dim=-1)
        attn_prob = self.dropatt(attn_prob)
        attn_prob = attn_prob / (torch.sum(attn_prob, dim=1)[:, None, :, :] + 1e-8)
        # print ('[self attn prob]', attn_prob.mean(), attn_prob.std())

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        if not return_avg_attn:
            return output
        else:
            return output, avg_attn_prob


class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
        # r_emb: [klen, n_head, d_head], used for term B
        # r_w_bias: [n_head, d_head], used for term C
        # r_bias: [klen, n_head], used for term D

        qlen, bsz = w.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)

        if klen > r_emb.size(0):
            r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
            r_emb = torch.cat([r_emb_pad, r_emb], 0)
            r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
            r_bias = torch.cat([r_bias_pad, r_bias], 0)
        else:
            r_emb = r_emb[-klen:]
            r_bias = r_bias[-klen:]

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[None]                                   # qlen x bsz x n_head x d_head

        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
        B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb))                  # qlen x klen x bsz x n_head
        D_ = r_bias[None, :, None]                                              # 1    x klen x 1   x n_head
        BD = self._rel_shift(B_ + D_)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
            elif attn_mask.dim() == 3:
                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)
        attn_prob = attn_prob / (torch.sum(attn_prob, dim=1)[:, None, :, :] + 1e-8)

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output

class DecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
        super(DecoderLayer, self).__init__()

        self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
                                     pre_lnorm=kwargs.get('pre_lnorm'))

    def forward(self, dec_inp, dec_attn_mask=None, mems=None):

        output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
                               mems=mems)
        output = self.pos_ff(output)

        return output

class RelLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 **kwargs):
        super(RelLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
                                         **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
                                     pre_lnorm=kwargs.get('pre_lnorm'))

    def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):

        output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
                               attn_mask=dec_attn_mask,
                               mems=mems)
        output = self.pos_ff(output)

        return output

class RelPartialLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 **kwargs):
        super(RelPartialLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
                            d_head, dropout, **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
                                     pre_lnorm=kwargs.get('pre_lnorm'))

        if 'use_cross_attn' in kwargs and kwargs.get('use_cross_attn') is True:
            self.cross_attn = MultiHeadCrossAttn(n_head, d_model, d_head, dropout, **kwargs)
        else:
            self.cross_attn = None

    def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None,
                cross_latent=None, dec_cross_pos_emb=None, latent_cross_pos_emb=None,
                cross_attn_mask=None, return_avg_attn=False):

        if not return_avg_attn:
            output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
                                attn_mask=dec_attn_mask,
                                mems=mems, return_avg_attn=False)
        else:
            output, avg_attn_prob = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
                                                  attn_mask=dec_attn_mask,
                                                  mems=mems, return_avg_attn=True)

        if self.cross_attn is not None and cross_latent is not None:
            if dec_cross_pos_emb is None:
                dec_cross_pos_emb = torch.zeros_like(dec_inp)
            if latent_cross_pos_emb is None:
                latent_cross_pos_emb = torch.zeros_like(cross_latent)

            output = self.cross_attn.forward(
                                output, cross_latent,
                                attn_mask=cross_attn_mask,
                                h_pos_embed=dec_cross_pos_emb,
                                c_pos_embed=latent_cross_pos_emb
                            )

        output = self.pos_ff(output)

        if not return_avg_attn:
            return output
        else:
            return output, avg_attn_prob

class SegmentEmbeddingProj(nn.Module):
    def __init__(self, d_in, d_out, n_layer=None, tie_seg_emb_projs=True, scale=1.):
        super(SegmentEmbeddingProj, self).__init__()
        self.d_in = d_in
        self.d_out = d_out

        self.emb_proj = nn.ModuleList()
        self.tie_seg_emb_projs = tie_seg_emb_projs

        if tie_seg_emb_projs:
            self.emb_proj.append( nn.Linear(d_in, d_out, bias=False) )
        else:
            for l in range(n_layer):
                self.emb_proj.append( nn.Linear(d_in, d_out, bias=False) )

        self.scale = scale
        print ('[seg emb scale]', scale)

    def forward(self, inp, layer=None):
        if layer is None or self.tie_seg_emb_projs:
            emb_out = self.emb_proj[0](inp)
        else:
            emb_out = self.emb_proj[layer](inp)

        return emb_out.mul_(self.scale)


class OptimusTXLDecoder(nn.Module):
    def __init__(self, n_layer, n_head, d_model, d_head, d_inner, d_segment_emb,
                 dropout, dropatt, pre_lnorm=False, use_segment_emb=True,
                 tgt_len=None, ext_len=None, mem_len=None,
                 same_length=False, attn_type=0, clamp_len=-1,
                 tie_seg_emb_projs=True, in_attn_cond=True,
                 use_cross_attn=False, cross_len=192, seg_proj_scale=1.
        ):
        super(OptimusTXLDecoder, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head
        self.drop = nn.Dropout(dropout)
        self.n_layer = n_layer
        self.d_segment_emb = d_segment_emb

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len
        self.pre_lnorm = pre_lnorm
        self.use_segment_emb = use_segment_emb

        self.tie_seg_emb_projs = tie_seg_emb_projs
        self.in_attn_cond = in_attn_cond

        if self.use_segment_emb:
            self.seg_proj_scale = seg_proj_scale
            self.seg_emb_projs = SegmentEmbeddingProj(
                                    d_segment_emb, d_model, n_layer, tie_seg_emb_projs,
                                    scale=self.seg_proj_scale
                                )
        else:
            self.seg_emb_projs = None

        self.use_cross_attn = use_cross_attn
        if self.use_cross_attn:
            self.cross_len = cross_len
            self.cross_pos_emb = WordEmbedding(
                                    cross_len, d_model, d_model, emb_scale=0.2
                                )

        self.attn_type = attn_type

        self.layers = nn.ModuleList()
        if attn_type == 0: # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        dropatt=dropatt, pre_lnorm=pre_lnorm,
                        use_cross_attn=use_cross_attn)
                )
        elif attn_type == 1: # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )
        elif attn_type in [2, 3]: # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        dropatt=dropatt, pre_lnorm=pre_lnorm)
                )

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()

    def backward_compatible(self):
        self.sample_softmax = -1

    def _create_params(self):
        if self.attn_type == 0: # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
        elif self.attn_type == 1: # learnable
            self.r_emb = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head, self.d_head))
            self.r_w_bias = nn.Parameter(torch.Tensor(
                    self.n_layer, self.n_head, self.d_head))
            self.r_bias = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head))
        elif self.attn_type == 2: # absolute standard
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 3: # absolute deeper SA
            self.r_emb = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head, self.d_head))

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

    def init_mems(self, batchsize=None):
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
            for i in range(self.n_layer+1):
                if batchsize is None:
                    empty = torch.empty(0, dtype=param.dtype, device=param.device)
                else:
                    empty = torch.empty(0, batchsize, self.d_model, dtype=param.dtype, device=param.device)
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen, dec_seg_len=None):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []

            if dec_seg_len is None:
                end_idx = mlen + max(0, qlen - 0 - self.ext_len)
                beg_idx = max(0, end_idx - self.mem_len)
                for i in range(len(hids)):
                    cat = torch.cat([mems[i], hids[i]], dim=0)
                    new_mems.append(cat[beg_idx:end_idx].detach())

            else:      # different len for each sample in batch, `ext_len != 0` is not supported
                assert dec_seg_len.size(0) == hids[0].size(1)
                batchsize = hids[0].size(1)

                for i in range(len(hids)):
                    new_layer_mem = []
                    for samp_idx in range(batchsize):
                        samp_len = dec_seg_len[samp_idx]
                        old_samp_mem = mems[i][:, samp_idx, :]
                        new_samp_mem = hids[i][:samp_len, samp_idx, :]
                        cat = torch.cat([old_samp_mem, new_samp_mem], dim=0)
                        end_idx, beg_idx = cat.size(0), max(0, cat.size(0) - self.mem_len)
                        new_layer_mem.append(cat[beg_idx:end_idx].detach())

                    max_new_mlen = max([cat.size(0) for cat in new_layer_mem])
                    for samp_idx in range(batchsize):
                        samp_new_mlen = new_layer_mem[ samp_idx ].size(0)
                        if samp_new_mlen < max_new_mlen:
                            new_layer_mem[samp_idx] = torch.cat([
                                torch.zeros(max_new_mlen - samp_new_mlen, mems[i].size(-1), dtype=mems[i].dtype, device=mems[i].device).detach(),
                                new_layer_mem[samp_idx]
                            ], dim=0)
                    new_mems.append(torch.stack(new_layer_mem, dim=1).detach())

        return new_mems

    def _forward(self, dec_input, segment_emb, mems=None, dec_seg_len=None,
                 cross_latent=None, cross_attn_mask=None,
                 dec_cross_pos_emb=None, latent_cross_pos_emb=None, return_avg_attn=False):
        qlen, bsz, _ = dec_input.size()
        # print ('[debug] reached inner _forward()')

        if isinstance(mems, tuple) and len(mems) == 1:
            mems = mems[0]
            assert len(mems) == self.n_layer + 1

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = dec_input.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
                    + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
        else:
            dec_attn_mask = torch.triu(
                dec_input.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]

        hids = []
        if return_avg_attn:
            all_layer_avg_attn_probs = []
        if self.use_segment_emb:
            layer_seg_emb = self.seg_emb_projs(segment_emb, layer=0)
        else:
            layer_seg_emb = torch.zeros_like(dec_input, device=dec_input.device)

        if self.use_cross_attn and cross_latent is not None:
            layer_cross_latent = self.drop(
                                    self.seg_emb_projs(cross_latent, layer=0)
                                 )
        else:
            layer_cross_latent = None
        # print ('[cross pos embs]', dec_cross_pos_emb.mean(), latent_cross_pos_emb.mean())

        if self.attn_type == 0: # default
            pos_seq = torch.arange(klen-1, -1, -1.0, device=dec_input.device,
                                   dtype=dec_input.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(dec_input)
            # print ('[layer 0] inp: {:.3f} (+/- {:.3f}) | segemb: {:.3f} (+/- {:.3f})'.format(
            #     core_out.mean().item(), core_out.std().item(),
            #     layer_seg_emb[ layer_seg_emb != 0. ].mean().item(), layer_seg_emb[ layer_seg_emb != 0. ].std().item()
            # ))
            core_out += self.drop(layer_seg_emb)
            pos_emb = self.drop(pos_emb)
            hids.append(core_out)

            for i, layer in enumerate(self.layers):
                # print ('[cross latent]', layer_cross_latent.mean())
                mems_i = None if mems is None else mems[i]
                if not return_avg_attn:
                    core_out = layer(
                                    core_out, pos_emb, self.r_w_bias,
                                    self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i,
                                    cross_latent=layer_cross_latent,
                                    dec_cross_pos_emb=dec_cross_pos_emb,
                                    latent_cross_pos_emb=latent_cross_pos_emb,
                                    cross_attn_mask=cross_attn_mask,
                                    return_avg_attn=False
                                )
                else:
                    core_out, layer_avg_attn_prob = layer(
                                    core_out, pos_emb, self.r_w_bias,
                                    self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i,
                                    cross_latent=layer_cross_latent,
                                    dec_cross_pos_emb=dec_cross_pos_emb,
                                    latent_cross_pos_emb=latent_cross_pos_emb,
                                    cross_attn_mask=cross_attn_mask,
                                    return_avg_attn=True
                                )
                    all_layer_avg_attn_probs.append(layer_avg_attn_prob)
                    # print ('[avg attn probs]', all_layer_avg_attn_probs[-1].size())

                if i != len(self.layers) - 1 and self.in_attn_cond and self.use_segment_emb:
                    layer_seg_emb = self.seg_emb_projs(segment_emb, layer=i+1)
                    core_out += self.drop(layer_seg_emb)
                    if self.use_cross_attn:
                        layer_cross_latent = self.drop(
                                        self.seg_emb_projs(cross_latent, layer=i+1)
                                    )

                hids.append(core_out)
                # print ('[layer {}] inp: {:.3f} (+/- {:.3f})'.format(
                #     i+1, core_out.mean().item(), core_out.std().item()
                # ))
                # print ('[layer {}] inp: {:.3f} (+/- {:.3f}) | segemb: {:.3f} (+/- {:.3f})'.format(
                #     i + 1,
                #     core_out.mean().item(), core_out.std().item(),
                #     layer_seg_emb[ layer_seg_emb != 0. ].mean().item(), layer_seg_emb[ layer_seg_emb != 0. ].std().item()
                # ))

        elif self.attn_type == 1: # learnable
            core_out = self.drop(dec_input)
            core_out += self.drop(layer_seg_emb)
            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                if self.clamp_len > 0:
                    r_emb = self.r_emb[i][-self.clamp_len :]
                    r_bias = self.r_bias[i][-self.clamp_len :]
                else:
                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]
                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out, r_emb, self.r_w_bias[i],
                        r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                if i != len(self.layers) - 1 and self.in_attn_cond and self.use_segment_emb:
                    layer_seg_emb = self.seg_emb_projs(segment_emb, layer=i+1)
                    core_out += self.drop(layer_seg_emb)
                hids.append(core_out)

        elif self.attn_type == 2: # absolute
            pos_seq = torch.arange(klen - 1, -1, -1.0, device=dec_input.device,
                                   dtype=dec_input.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(dec_input + pos_emb[-qlen:])
            core_out += self.drop(layer_seg_emb)
            hids.append(core_out)

            for i, layer in enumerate(self.layers):
                layer_seg_emb = self.seg_emb_projs(segment_emb, layer=i)
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                if i != len(self.layers) - 1 and self.in_attn_cond and self.use_segment_emb:
                    print ('shouldn\'t be here !!!')
                    layer_seg_emb = self.seg_emb_projs(segment_emb, layer=i+1)
                    core_out += self.drop(layer_seg_emb)
                hids.append(core_out)

        elif self.attn_type == 3:
            core_out = self.drop(dec_input)
            core_out += self.drop(layer_seg_emb)
            hids.append(core_out)

            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.view(mlen, 1, -1)
                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)

                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                if i != len(self.layers) - 1 and self.in_attn_cond and self.use_segment_emb:
                    layer_seg_emb = self.seg_emb_projs(segment_emb, layer=i+1)
                    core_out += self.drop(layer_seg_emb)
                hids.append(core_out)

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen, dec_seg_len=dec_seg_len)

        if not return_avg_attn:
            return core_out, new_mems
        else:
            return core_out, new_mems, all_layer_avg_attn_probs

    def forward(self, dec_input, segment_emb, *mems, dec_seg_len=None, cross_latent=None,
                cross_attn_mask=None, dec_cross_pos_seq=None, latent_cross_pos_seq=None, return_avg_attn=False):
        if not mems: mems = self.init_mems(batchsize=dec_input.size(1) if dec_seg_len is not None else None)

        if self.use_cross_attn is True and dec_cross_pos_seq is not None and latent_cross_pos_seq is not None:
            dec_cross_pos_emb = self.cross_pos_emb(dec_cross_pos_seq)
            latent_cross_pos_emb = self.cross_pos_emb(latent_cross_pos_seq)
            # print ('[cross pos embs]', dec_cross_pos_emb.size(), latent_cross_pos_emb.size())
        else:
            dec_cross_pos_emb = latent_cross_pos_emb = None

        if not return_avg_attn:
            dec_out, new_mems = self._forward(
                                    dec_input, segment_emb, mems=mems,
                                    dec_seg_len=dec_seg_len,
                                    cross_latent=cross_latent,
                                    cross_attn_mask=cross_attn_mask,
                                    dec_cross_pos_emb=dec_cross_pos_emb,
                                    latent_cross_pos_emb=latent_cross_pos_emb,
                                    return_avg_attn=False
                                )
        else:
            dec_out, new_mems, avg_attn_probs = self._forward(
                                                    dec_input, segment_emb, mems=mems,
                                                    dec_seg_len=dec_seg_len,
                                                    cross_latent=cross_latent,
                                                    cross_attn_mask=cross_attn_mask,
                                                    dec_cross_pos_emb=dec_cross_pos_emb,
                                                    latent_cross_pos_emb=latent_cross_pos_emb,
                                                    return_avg_attn=True
                                                )

        if new_mems is None and not return_avg_attn:
            return [dec_out]
        elif new_mems is not None and not return_avg_attn:
            return [dec_out] + new_mems
        else:
            return [dec_out] + new_mems, avg_attn_probs

if __name__ == '__main__':
    device = 'cpu'

    tgt_len, mem_len, ext_len = 128, 600, 0

    model = OptimusTXLDecoder(n_layer=12, n_head=8, d_segment_emb=64,
                    d_model=512, d_head=64, d_inner=2048,
                    dropout=0.1, dropatt=0.1, pre_lnorm=True, tie_seg_emb_projs=False,
                    tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len).to(device)

    print(sum(p.numel() for p in model.parameters()))

    mems = tuple()
    for idx in range(10):
        # inp = torch.randint(0, 100, (tgt_len, 1))
        inp = torch.randn(128, 4, 512)
        segment_emb = torch.randn(128, 4, 64)
        print('batch {}'.format(idx))
        out = model(inp, segment_emb, *mems)
        mems = out[1:]
        print ('[dec out]', out[0].size())
        print ('[mem layer 0]', mems[0].size())

## Transformer

In [None]:
# Define transformer architecure
import sys

import torch
from torch import nn
import torch.nn.functional as F

# from .optimus_txl_decoder import OptimusTXLDecoder

# from .transformer_helpers import (
#   WordEmbedding,
#   weights_init
# )

class PlainTransformer(nn.Module):
  def __init__(self, d_word_embed, vocab_size,
               dec_n_layer, dec_n_head, dec_d_model, dec_d_ff, dec_mem_len, dec_tgt_len,
               dec_dropout=0.1, dec_activation='relu',
               pad_index=None, pre_lnorm=False,
  ):
    super(PlainTransformer, self).__init__()

    self.d_word_embed = d_word_embed
    self.vocab_size = vocab_size

    self.dec_n_layer = dec_n_layer
    self.dec_n_head = dec_n_head
    self.dec_d_model = dec_d_model
    self.dec_d_ff = dec_d_ff
    self.dec_dropout = dec_dropout
    self.dec_activation = dec_activation
    self.dec_mem_len = dec_mem_len
    self.dec_tgt_len = dec_tgt_len

    self.word_emb = WordEmbedding(vocab_size, d_word_embed, dec_d_model)
    self.emb_dropout = nn.Dropout(dec_dropout)
    if pad_index is None:
      self.pad_index = self.vocab_size - 1
    else:
      self.pad_index = pad_index

    self.decoder = OptimusTXLDecoder(
                    dec_n_layer, dec_n_head, dec_d_model, dec_d_model // dec_n_head, dec_d_ff,
                    None, dec_dropout, dec_dropout,
                    tgt_len=dec_tgt_len, mem_len=dec_mem_len, ext_len=0,
                    pre_lnorm=pre_lnorm, use_segment_emb=False
                  )
    self.dec_out_proj = nn.Linear(dec_d_model, vocab_size)

    self.apply(weights_init)

  def generate(self, dec_input, dec_mems):
    dec_word_emb = self.word_emb(dec_input)
    dec_input = self.emb_dropout(dec_word_emb)
    dec_out = self.decoder(dec_input, None, *dec_mems)
    dec_logits = self.dec_out_proj(dec_out[0])[-1, 0, :]
    new_dec_mems = dec_out[1:]

    return dec_logits, new_dec_mems

  def forward(self, dec_input, dec_mems, dec_seg_len=None, return_avg_attn=False):
    dec_word_emb = self.word_emb(dec_input)
    dec_input = self.emb_dropout(dec_word_emb)
    # print ('[debug] in model forward()')

    if not return_avg_attn:
      dec_out = self.decoder(dec_input, None, *dec_mems, dec_seg_len=dec_seg_len)
      dec_logits = self.dec_out_proj(dec_out[0])
      new_dec_mems = dec_out[1:]
      return dec_logits, new_dec_mems

    else:
      dec_out, avg_attn_probs = self.decoder(
                                  dec_input, None, *dec_mems,
                                  dec_seg_len=dec_seg_len,
                                  return_avg_attn=True
                                )
      dec_logits = self.dec_out_proj(dec_out[0])
      new_dec_mems = dec_out[1:]

      return dec_logits, new_dec_mems, avg_attn_probs

  def compute_loss(self, dec_logits, dec_tgt, reduction='mean'):
    ce_loss = F.cross_entropy(
                    dec_logits.view(-1, dec_logits.size(-1)),
                    dec_tgt.contiguous().view(-1),
                    ignore_index=self.pad_index,
                    reduction=reduction
                  )

    return {
      'ce_loss': ce_loss,
      'total_loss': ce_loss
    }

# Generate Melody

In [None]:
import sys
import os
import random
import pickle
# sys.path.append('./model/')
sys.path.append('./')

import yaml
import torch
import numpy as np

# from plain_transformer import PlainTransformer
from convert2midi import skyline_event_to_midi, TempoEvent
from utils import pickle_load
from inference_utils import generate_plain_xl

config_path = sys.argv[1]
out_dir = sys.argv[2]
n_pieces = int(sys.argv[3]) if len(sys.argv) > 3 else 20

config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
ckpt_dir = config['output']['ckpt_dir']

top_p = 0.97
max_dec_len = 2400
print ('[nucleus parameters] t = {}, p = {}'.format(temp, top_p))

torch.cuda.device(config['device'])

# for generation w/ melody prompts
use_prompt = False
prompt_bars = 8


def read_vocab(vocab_file):
  event2idx, idx2event = pickle_load(vocab_file)
  orig_vocab_size = len(event2idx)
  pad_token = orig_vocab_size
  event2idx['PAD_None'] = pad_token
  vocab_size = pad_token + 1

  return event2idx, idx2event, vocab_size


def dump_midi(words, idx2event, output_midi_path=None,
              rfreq_cls=None, polyph_cls=None, output_event_path=None,
              return_tempo=False, enforce_tempo_val=None):
  events = [idx2event[w] for w in words]

  if output_event_path is not None:
    f = open(output_event_path, 'w')
    if rfreq_cls is not None:
      f.write('[rhymfreq] ')
      f.write(str(rfreq_cls))
      f.write('\n')
    if polyph_cls is not None:
      f.write('[polyph  ] ')
      f.write(str(polyph_cls))
      f.write('\n')
      f.write('======================================================================\n')
    print (*events, sep='\n', file=f)

  if return_tempo:
    return skyline_event_to_midi(events, output_midi_path=output_midi_path, return_tempo=True)[1]
  elif enforce_tempo_val is not None:
    skyline_event_to_midi(events, output_midi_path=output_midi_path, enforce_tempo=True, enforce_tempo_val=enforce_tempo_val)
  else:
    skyline_event_to_midi(events, output_midi_path=output_midi_path)


def get_leadsheet_prompt(data_dir, piece, prompt_n_bars):
  bar_pos, evs = pickle_load(
    os.path.join(data_dir, piece + '.pkl')
  )

  prompt_evs = [
    '{}_{}'.format(x['name'], x['value']) for x in evs[ : bar_pos[prompt_n_bars] + 1 ]
  ]
  assert len( np.where( np.array(prompt_evs) == 'Bar_None' )[0] ) == prompt_n_bars + 1
  target_bars = len(bar_pos)

  return prompt_evs, target_bars

def user_input():
    max_bars = int(input('Please input max bars: '))
    temp = float(input('Please input temperature (randomness): '))
    return (max_bars, temp)

In [None]:
# Generation

if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  max_bars, temp = user_input()
  event2idx, idx2event, vocab_size = \
    read_vocab(config['data']['vocab_path'])

  if use_prompt:
    prompt_pieces = pickle_load(config['data']['val_split'])
    prompt_pieces = [x for x in prompt_pieces if os.path.exists(
      os.path.join(config['data']['data_dir'], x + '.pkl')
    )]
    if len(prompt_pieces) > n_pieces:
      prompt_pieces = random.sample(prompt_pieces, n_pieces)

    pickle.dump(
        prompt_pieces,
        open(os.path.join(out_dir, 'sampled_pieces.pkl'), 'wb')
      )
    prompts = []
    for p in prompt_pieces:
      prompts.append(
        get_leadsheet_prompt(
          config['data']['data_dir'], p,
          prompt_bars
        )
      )


  mconf = config['model']
  model = PlainTransformer(
            mconf['d_word_embed'],
            vocab_size,
            mconf['decoder']['n_layer'],
            mconf['decoder']['n_head'],
            mconf['decoder']['d_model'],
            mconf['decoder']['d_ff'],
            mconf['decoder']['tgt_len'],
            mconf['decoder']['tgt_len'],
            dec_dropout=mconf['decoder']['dropout'],
            pre_lnorm=mconf['pre_lnorm']
          ).cuda()
  print ('[info] # params:', sum(p.numel() for p in model.parameters() if p.requires_grad))

  pretrained_dict = torch.load(config['inference_param_path'], map_location='cpu')
  model.load_state_dict( pretrained_dict )
  model.eval()

  generated_pieces = 0
  total_pieces = n_pieces
  gen_times = []

  while generated_pieces < n_pieces:
    piece_id = generated_pieces + 1

    out_name = 'samp_{:02d}'.format(piece_id)
    if os.path.exists(os.path.join(out_dir, out_name + '.mid')):
      print ('[info] {} exists, skipping ...'.format(out_name))
      continue

    if not use_prompt:
      tempo_range = range(65, 165, 3)
      tempo = random.choice(
        tempo_range
      )
      orig_tempos = [
        TempoEvent(tempo, 0, 0)
      ]
      print ('[global tempo]', orig_tempos[0].tempo)
    else:
      target_bars = prompts[p][1]
      orig_tempos = [
        TempoEvent(int(prompts[p][0][0].split('_')[-1]), 0, 0)
      ]

    print (' -- generating leadsheet #{} of {}'.format(
      generated_pieces + 1, total_pieces
    ))


    if not use_prompt:
      gen_words, t_sec = generate_plain_xl(
                            model,
                            event2idx, idx2event,
                            max_events=max_dec_len, max_bars=max_bars,
                            primer=['Tempo_{}'.format(orig_tempos[0].tempo), 'Bar_None'],
                            temp=temp, top_p=top_p
                          )
    else:
      gen_words, t_sec = generate_plain_xl(
                            model,
                            event2idx, idx2event,
                            max_events=max_dec_len, max_bars=target_bars,
                            primer=prompts[p][0],
                            temp=temp, top_p=top_p,
                            prompt_bars=prompt_bars
                          )

    if gen_words is None: # model failed repeatedly
      continue
    if len(gen_words) >= max_dec_len:
      continue
    if len( np.where( np.array(gen_words) == event2idx[ 'Bar_None' ] )[0] ) >= max_bars:
      continue

    dump_midi(
      gen_words, idx2event,
      os.path.join(out_dir, out_name + '.mid'),
      output_event_path=os.path.join(out_dir, out_name + '.txt'),
      enforce_tempo_val=orig_tempos
    )

    gen_times.append(t_sec)
    generated_pieces += 1

  print ('[info] finished generating {} pieces, avg. time: {:.2f} +/- {:.2f} secs.'.format(
    generated_pieces, np.mean(gen_times), np.std(gen_times)
  ))

In [None]:
# Compose

# Generating a leadsheet
!python3 stage01_compose/inference.py \
  stage01_compose/config/pop1k7_finetune.yaml \
  generation/stage01 \
  1   # Generate one leadsheet

[nucleus parameters] t = 1.2, p = 0.97
[info] # params: 41331059
[global tempo] 143
 -- generating leadsheet #1 of 1
[info] generated 1 bars, #events = 16
[info] generated 2 bars, #events = 37
[info] generated 3 bars, #events = 59
[info] generated 4 bars, #events = 66
[info] generated 5 bars, #events = 72
[info] generated 6 bars, #events = 91
[info] generated 7 bars, #events = 108
[info] generated 8 bars, #events = 127
[info] generated 9 bars, #events = 145
[info] generated 10 bars, #events = 166
[info] generated 11 bars, #events = 184
[info] generated 12 bars, #events = 195
[info] generated 13 bars, #events = 205
[info] generated 14 bars, #events = 224
[info] generated 15 bars, #events = 243
[info] generated 16 bars, #events = 262
[info] generated 17 bars, #events = 279
[info] generated 18 bars, #events = 303
[info] generated 19 bars, #events = 322
[info] generated 20 bars, #events = 345
[info] generated 21 bars, #events = 371
[info] generated 22 bars, #events = 391
[info] generated 2

In [None]:
# Embellish

# We will embellish our generated leadsheet
!python3 stage02_embellish/inference.py \
  stage02_embellish/config/pop1k7_default.yaml \
  generation/stage01 \
  generation/stage02

[preparing data] now at #0
[info] model init completed
[info] model loaded
[# pieces] 1
The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2426.)
  Q, _ = torch.qr(block)
[info] generated 1 bars, #events = 86
[info] generated 2 bars, #events = 220
[info] generated 3 bars, #events = 366
[info] generated 4 bars, #events = 440
[info] generated 5 bars, #events = 522
[info] generated 6 bars, #events = 665
[info] generated 7 bars, #events = 788
[info] generated 8 bars, #events = 910
[info] generated 9 bars, #events = 1016
[info] generated 10 bars, #events = 1138
[info] generated 11 bars, #events = 1259
[info] generated 12 bars, #events = 1346
[info] generated 13 bars, #events = 1413
[info] generated 14 bars, #events = 1523
[info] generated 15 bars, #events = 1630
[info] generated 16 ba

In [None]:
# installs and imports to convert MIDI into audio
!pip install pretty_midi
!wget https://www.dropbox.com/s/4x27l49kxcwamp5/GeneralUser_GS_1.471.zip
!unzip GeneralUser_GS_1.471.zip
!apt install -y fluidsynth
from pretty_midi import PrettyMIDI
from IPython.display import Audio
from scipy.io.wavfile import write
import librosa

--2024-01-21 21:39:07--  https://www.dropbox.com/s/4x27l49kxcwamp5/GeneralUser_G5_1.471.zip
Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.18, 2620:100:6016:18::a27d:112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/4x27l49kxcwamp5/GeneralUser_G5_1.471.zip [following]
--2024-01-21 21:39:07--  https://www.dropbox.com/s/raw/4x27l49kxcwamp5/GeneralUser_G5_1.471.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc404e5d06cdd4b619aa0029e6ee.dl.dropboxusercontent.com/cd/0/inline/CLx7uzsqAm39h1G4nKBBF8UcgFfAiYYP3N-9eT9FgE_o93ytGFcUWU5h5zZvC1YjBho6mPW9pV2VeIb1y_-4hVwWrlAt6pcRcCyrky1SfwbW8uERWxQfG_4QOMf2dgDuEts/file# [following]
--2024-01-21 21:39:08--  https://uc404e5d06cdd4b619aa0029e6ee.dl.dropboxusercontent.com/cd/0/inline/CLx7uzsqAm39h1G4nKBBF8UcgFfAiYYP3N-9eT9FgE_o93ytGFcUWU5h5zZvC1YjBho6mPW9pV2VeIb1y_-4hVw

In [None]:
##########
# LISTEN #
##########

# render the first stage
!fluidsynth -ni GeneralUser\ GS\ 1.471/GeneralUser\ GS\ v1.471.sf2 generation/stage01/samp_01.mid -F first_stage.wav -r 44100

# render the second stage
!fluidsynth -ni GeneralUser\ GS\ 1.471/GeneralUser\ GS\ v1.471.sf2 generation/stage02/samp_01_2stage_samp01.mid -F second_stage.wav -r 44100

# # uncomment if you want to hear the melody
# # generated in the first stage
# # hear the first stage
# x,sr=librosa.load('first_stage.wav')
# Audio(x,rate=sr)

# hear the second stage
x,sr=librosa.load('second_stage.wav')
Audio(x,rate=sr)