# Training main:

take from cyanite notebook the interface to observe tags distribution and create datasets


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.autograd import Variable

from frontend import Frontend_mine, Frontend_won
from backend import Backend, Backend2
from data_loader import get_DataLoader

main_dict = {"frontend_dict":
             {"list_out_channels":[128,128,128,256,256,256], 
              "list_kernel_sizes":[(3,3),(3,3),(3,3),(3,3),(3,3),(3,3)],
              "list_pool_sizes":  [(3,2),(2,2),(2,2),(2,1),(2,1),(2,1)], 
              "list_avgpool_flags":[False,False,False,False,False,True]},
             
             "backend_dict":
             {"n_class":50,
              "bert_config":None, 
              "recurrent_units":2}} # pass None to deactivate

In [2]:
class CRNNSA(nn.Module):
    
    """
    TODO: explore whether "rec_unit -> self_att -> rec_unit"
    would have worked better.
    """
    
    # BERT-based Convolutional Recurrent Neural Network
    # Code adopted from https://github.com/minzwon/sota-music-tagging-models/
    def __init__(self, main_dict=None, backend=None, frontend=None):
        super(CRNNSA, self).__init__()
        

        if main_dict is not None:
            self.frontend = Frontend_mine(main_dict["frontend_dict"])
            self.backend = Backend2(main_dict)
        else:
            self.frontend = frontend
            self.backend = backend


    def forward(self, spec):
        
        x = self.backend(self.frontend(spec))
        
        return x

In [3]:
crnnsa = CRNNSA(main_dict)
print(crnnsa)

CRNNSA(
  (frontend): Frontend_mine(
    (spec_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_block1): BlockChoi(
      (conv): Conv2dSame(1, 128, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ELU(alpha=1.0)
      (pool): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv_block2): BlockChoi(
      (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ELU(alpha=1.0)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv_block3): BlockChoi(
      (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (bn):

In [4]:
input_length = 5
input_length = int(input_length*16000/256)

In [5]:
spec = torch.FloatTensor(np.load("/import/c4dm-datasets/rmri_self_att/msd/1/1/113801.npy")[np.newaxis,
                                                                                           np.newaxis,
                                                                                           :,
                                                                                           :input_length])

In [6]:
crnnsa(spec).shape

torch.Size([50])

In [7]:
dataLoader = get_DataLoader(batch_size=64, input_length=5, mode="valid")

In [8]:
for x,y in dataLoader:
    print(x.shape)
    print(y.shape)
    print(crnnsa.frontend(x).shape)
    print(crnnsa(x).shape)
    break

torch.Size([64, 1, 96, 312])
torch.Size([64, 50])
torch.Size([64, 256, 39])
torch.Size([64, 50])


# compare won frontend output shape

In [9]:
frondendWon = Frontend_won()
frondendWon(x).shape

torch.Size([64, 256, 39])

In [10]:
print(sum(p.numel() for p in frondendWon.parameters() if p.requires_grad))

8863490


In [11]:
print(sum(p.numel() for p in crnnsa.frontend.parameters() if p.requires_grad))
print(sum(p.numel() for p in crnnsa.backend.parameters() if p.requires_grad))
print(sum(p.numel() for p in crnnsa.backend.seq2seq.parameters() if p.requires_grad))

1774082
1065522
789504


# TODO:

In [12]:
class Backend2(nn.Module):

    def __init__(self,main_dict, 
                 bert_config = None):
        super(Backend2, self).__init__()

        backend_dict = main_dict["backend_dict"]
        self.frontend_out_channels = main_dict["frontend_dict"]["list_out_channels"][-1]
        
        self.seq2seq = nn.GRU(self.frontend_out_channels, 
                              self.frontend_out_channels, 
                              backend_dict["recurrent_units"]) # input and output = (seq_len, batch, input_size)
        
        self.multihead_attn = nn.MultiheadAttention(self.frontend_out_channels,
                                                    8) # number of heads
        
        # Dense
        self.dropout = nn.Dropout(0.5)
        self.dense = nn.Linear(self.frontend_out_channels, backend_dict["n_class"])
        
    def forward(self, x):

        # frontend output shape = (batch, features, sequence)
        # input to self attention and recurrent unit (sequence, batch, features)
        x = x.permute(2,0,1)
        
        # see https://discuss.pytorch.org/t/dataparallel-issue-with-flatten-parameter/8282
        self.seq2seq.flatten_parameters() 
        outputs,hidden = self.seq2seq(x)  
        
        hidden = hidden[-1] #take just the hidden state of the last recurrent layer

        #x, attn_output_weights = self.multihead_attn(hidden, outputs, outputs) # (Q,K,V)
        
        x, _ = self.multihead_attn(hidden.unsqueeze(0), 
                                   outputs, 
                                   outputs) # (Q,K,V)
        
        # Dense
        x = self.dropout(x.squeeze())
        x = self.dense(x)
        x = nn.Sigmoid()(x)

        return x

In [13]:
backend = Backend2(main_dict)

In [14]:
frontend_out = crnnsa.frontend(x)
frontend_out.shape

torch.Size([64, 256, 39])

In [15]:
backend(frontend_out).shape

torch.Size([64, 50])