# Training main:

take from cyanite notebook the interface to observe tags distribution and create datasets


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.autograd import Variable

from frontend import Frontend_mine, Frontend_won
from backend import Backend
from data_loader import get_DataLoader

main_dict = {"frontend_dict":
             {"list_out_channels":[128,128,128,256,256,256], 
              "list_kernel_sizes":[(5,5),(5,5),(3,3),(3,3),(3,3),(3,3)],
              "list_pool_sizes":  [(3,2),(2,2),(2,2),(2,1),(2,1),(2,1)], 
              "list_avgpool_flags":[False,False,False,False,False,True]},
             
             "backend_dict":
             {"n_class":50,
              "bert_config":None, 
              "recurrent_units":2}} # pass None to deactivate

In [2]:
class CRNNSA(nn.Module):
    
    """
    TODO: explore whether "rec_unit -> self_att -> rec_unit"
    would have worked better.
    """
    
    # BERT-based Convolutional Recurrent Neural Network
    # Code adopted from https://github.com/minzwon/sota-music-tagging-models/
    def __init__(self, main_dict=None, backend=None, frontend=None):
        super(CRNNSA, self).__init__()
        

        if main_dict is not None:
            self.frontend = Frontend_mine(main_dict["frontend_dict"])
            self.backend = Backend(main_dict)
        else:
            self.frontend = frontend
            self.backend = backend


    def forward(self, spec):
        
        x = self.backend(self.frontend(spec))
        
        return x

In [3]:
crnnsa = CRNNSA(main_dict)
print(crnnsa)

CRNNSA(
  (frontend): Frontend_mine(
    (spec_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_block1): BlockChoi(
      (conv): Conv2dSame(1, 128, kernel_size=(5, 5), stride=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ELU(alpha=1.0)
      (pool): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv_block2): BlockChoi(
      (conv): Conv2dSame(128, 128, kernel_size=(5, 5), stride=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ELU(alpha=1.0)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv_block3): BlockChoi(
      (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (bn):

In [4]:
spec = torch.FloatTensor(np.load("/import/c4dm-datasets/rmri_self_att/msd/1/1/113801.npy")[np.newaxis,np.newaxis,:,:])

In [5]:
crnnsa(spec).shape

torch.Size([1, 50])

In [6]:
dataLoader = get_DataLoader(batch_size=1, mode="valid")

In [7]:
for x,y in dataLoader:
    print(x.shape)
    print(y.shape)
    print(crnnsa.frontend(x).shape)
    print(crnnsa(x).shape)
    break

torch.Size([1, 1, 96, 937])
torch.Size([1, 50])
torch.Size([1, 256, 117])
torch.Size([1, 50])


In [8]:
print(sum(p.numel() for p in crnnsa.frontend.parameters() if p.requires_grad))
print(sum(p.numel() for p in crnnsa.backend.parameters() if p.requires_grad))
print(sum(p.numel() for p in crnnsa.backend.seq2seq.parameters() if p.requires_grad))

2038274
2447666
789504


# TODO:

Find out which dimension represent batch, sequence, embedding, in the TransformerEncoder input

In [11]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

In [10]:
out.shape

torch.Size([10, 32, 512])

In [None]:
class BertPooler(nn.Module):
    def __init__(self, 
                 config, 
                 activation=None):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        
        self.activation = nn.Tanh() if activation is None else activation

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output