In [1]:
import tensorflow as tf
import numpy as np

from Models import MusicTransformer
tf.__version__

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


'2.0.0-rc0'

In [2]:
IntervalDim = 100

VelocityDim = 32
VelocityOffset = IntervalDim

NoteOnDim = NoteOffDim = 128
NoteOnOffset = IntervalDim + VelocityDim
NoteOffOffset = IntervalDim + VelocityDim + NoteOnDim

EventDim = IntervalDim + VelocityDim + NoteOnDim + NoteOffDim # 388

EmbeddingDim = 512
Heads = 16 # number of heads

Max_seq = 650 # max_length
HeadDim = EmbeddingDim / Heads # head Dim
ContextDim = HeadDim * Heads # 512

In [3]:
# ref : https://github.com/scpark20/music-transformer/blob/master/music-transformer.ipynb
class RelativeGlobalAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(RelativeGlobalAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.headDim = d_model // num_heads
        self.contextDim = int(self.headDim * self.num_heads)
        self.eventDim = 388

        assert d_model % self.num_heads == 0

        self.wq = tf.keras.layers.Dense(self.headDim)
        self.wk = tf.keras.layers.Dense(self.headDim)
        self.wv = tf.keras.layers.Dense(self.headDim)
    
    def call(self, v, k, q):
        # [Heads, Batch, Time, HeadDim]
        q = tf.stack([self.wq(q) for _ in range(self.num_heads)])
        k = tf.stack([self.wk(k) for _ in range(self.num_heads)])
        v = tf.stack([self.wv(v) for _ in range(self.num_heads)])
        print("inputs")
        print("[Heads, Batch, Time, HeadDim]", q.shape)

        self.batch_size = q.shape[1]
        self.max_len = q.shape[2]
        
        #skewing
        # Heads, Time, HeadDim
        E = self.add_weight('E', shape=[self.num_heads, self.max_len, self.headDim]) 
        # [Heads, Batch * Time, HeadDim]
        Q_ = tf.reshape(q, [self.num_heads, self.batch_size * self.max_len, self.headDim])
        # [Heads, Batch * Time, Time]
        S = tf.matmul(Q_, E, transpose_b=True)
        # [Heads, Batch, Time, Time]
        S = tf.reshape(S, [self.num_heads, self.batch_size, self.max_len, self.max_len])
        # [Heads, Batch, Time, Time+1]
        S = tf.pad(S, ((0, 0), (0, 0), (0, 0), (1, 0)))
        # [Heads, Batch, Time+1, Time]
        S = tf.reshape(S, [self.num_heads, self.batch_size, self.max_len + 1, self.max_len])   
        # [Heads, Batch, Time, Time]
        S = S[:, :, 1:]
        # [Heads, Batch, Time, Time]
        attention = (tf.matmul(q, k, transpose_b=True) + S) / np.sqrt(self.headDim)
        # mask tf 2.0 == tf.linalg.band_part
        mask = tf.linalg.band_part(tf.ones([self.max_len, self.max_len]), -1, 0)
        attention = attention * mask - tf.cast(1e10, attention.dtype) * (1-mask)
        score = tf.nn.softmax(attention, axis=3)
        print("Score : ", score.shape)

        # [Heads, Batch, Time, HeadDim]
        context = tf.matmul(score, v)
        print("[Heads, Batch, Time, HeadDim] : ", context.shape)
        # [Batch, Time, Heads, HeadDim]
        context = tf.transpose(context, [1, 2, 0, 3])
        print("[Batch, Time, Heads, HeadDim] : ", context.shape)        
        # [Batch, Time, ContextDim]
        context = tf.reshape(context, [self.batch_size, self.max_len, self.num_heads * self.headDim])
        print("[Batch, Time, ContextDim] : ", context.shape)
        # [Batch, Time, ContextDim]
        context = tf.keras.layers.Dense(EmbeddingDim, activation='relu')(context)
        print("[Batch, Time, ContextDim] : ", context.shape)     
        # [Batch, Time, EventDim]
        logits = tf.keras.layers.Dense(EventDim)(context)

        return logits

In [4]:
temp_mha = RelativeGlobalAttention(d_model=EmbeddingDim, num_heads=Heads)
y = tf.random.uniform((1, Max_seq, EmbeddingDim))  # (batch_size, encoder_sequence, d_model)
out = temp_mha(y, k=y, q=y)
out.shape

inputs
[Heads, Batch, Time, HeadDim] (16, 1, 650, 32)
Score :  (16, 1, 650, 650)
[Heads, Batch, Time, HeadDim] :  (16, 1, 650, 32)
[Batch, Time, Heads, HeadDim] :  (1, 650, 16, 32)
[Batch, Time, ContextDim] :  (1, 650, 512)
[Batch, Time, ContextDim] :  (1, 650, 512)


TensorShape([1, 650, 388])

In [5]:
sample_transformer = MusicTransformer(
    num_layers=2, d_model=512, num_heads=8, dff=2048, 
    input_vocab_size=8500, target_vocab_size=8000, 
    pe_input=10000, pe_target=6000)

temp_input = tf.random.uniform((8, 650)) # batch length
temp_target = tf.random.uniform((8, 650))

fn_out, _ = sample_transformer(temp_input, temp_target, training=False, 
                               enc_padding_mask=None, 
                               look_ahead_mask=None,
                               dec_padding_mask=None)

fn_out.shape  # (batch_size, tar_seq_len, target_vocab_size)

TensorShape([8, 650, 8000])