In [26]:
import numpy as np
import pathlib
import os

In [27]:
input_path = os.path.join(os.path.join(pathlib.Path(globals()['_dh'][0]).parent, "data"), "js-fakes-16thSeparated.npz")
jsf = np.load(input_path, allow_pickle=True, encoding='latin1')

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

import matplotlib.pyplot as plt

In [68]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads_1, num_heads_2, bs, emb_notes, emb_f):
        super(TransformerBlock, self).__init__() # Seq, batch, features
        self.ln1 = nn.LayerNorm([bs * num_heads_2 * emb_notes, num_heads_1 * emb_f]) #layer norm: [L, bs, Emb_f, Emb_notes] -> [L, bs, Emb_f, Emb_notes]
        self.mha_f = nn.MultiheadAttention(num_heads_1 * emb_f, num_heads_1, dropout=0.25) # multi-head attention per features: [L, BS * Emb_notes, Emb_f] -> [L, BS * Emb_notes, Emb_f]
        self.ln2 = nn.LayerNorm([bs * num_heads_1 * emb_f, num_heads_2 * emb_notes]) #layer norm: [BS, L, Emb_f, Emb_notes] -> [BS, L, Emb_f, Emb_notes]
        self.mha_l = nn.MultiheadAttention(num_heads_2 * emb_notes, num_heads_2, dropout=0.25) # multi-head attention per lunghezza: [BS * Emb_f, L, Emb_notes] -> [BS * Emb_f, L, Emb_notes]
        self.ln3 = nn.LayerNorm([bs, num_heads_1 * emb_f, num_heads_2 * emb_notes]) #layer norm
        
        self.mlp = nn.Sequential(
            nn.Linear(num_heads_2 * emb_notes, num_heads_2 * emb_notes),  # Linear transformation
            nn.LayerNorm([bs, num_heads_1 * emb_f, num_heads_2 * emb_notes]),  # Layer normalization
            nn.ELU(),  # Activation function (ELU)
            nn.Linear(num_heads_2 * emb_notes, num_heads_2 * emb_notes)  # Linear transformation
        )

    
    def forward(self, x): # add various reshape
        #[L, bs, Emb_f, Emb_notes]
        print("Step 1, x shape = {}".format(x.shape))
        x_1 = x.transpose(2, 3).reshape((x.shape[0], x.shape[1] * x.shape[3], x.shape[2])) 
        #[L, BS * Emb_notes, Emb_f]
        print("Step 2, x_1 shape = {}".format(x_1.shape))
        norm_x_1 = self.ln1(x_1) 
        print("Step 3, x_1 shape = {}".format(x_1.shape))
        attn_output_1 = self.mha_f(norm_x_1, norm_x_1, norm_x_1)[0] # [0] selects the attention output, to be decided if padding is needed
        print("Step 4, attn_output_1 shape = {}".format(attn_output_1.shape))
        x_2 = x_1 + attn_output_1 # residual connection
        print("Step 4, x_2 shape = {}".format(x_2.shape))
        x_2 = x_2.reshape(x.shape[0], x.shape[1], x.shape[3], x.shape[2]).transpose(2,3).reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
        #[L, BS * Emb_f, Emb_notes]
        print("Step 5, x_2 shape = {}".format(x_2.shape))
        norm_x_2 = self.ln2(x_2)
        print("Step 6, x_norm_x_2 shape = {}".format(norm_x_2.shape))
        attn_output_2 = self.mha_l(norm_x_2, norm_x_2, norm_x_2)[0] # [0] selects the attention output, to be decided if padding is needed
        print("Step 7, attn_output_2 shape = {}".format(attn_output_2.shape))
        x = x + attn_output_2.reshape(x.shape)
        print("Step 8, x shape = {}".format(x.shape))
        x = self.ln3(x)
        print("Step 9, x shape = {}".format(x.shape))
        x = self.mlp(x)
        print("Step 9, x shape = {}".format(x.shape))
        return x

In [69]:
bs = 4
emb_notes = 32 
emb_f = 4
L = 150
num_heads_1 = 4
num_heads_2 = 4
trans_block = TransformerBlock(num_heads_1 = num_heads_1, num_heads_2 = num_heads_2, bs = bs, emb_notes = emb_notes, emb_f = emb_f)
#[L, bs, Emb_f, Emb_notes]

rand_input = torch.rand((L, bs, num_heads_1 * emb_f, num_heads_2 * emb_notes))
