# Training main:

take from cyanite notebook the interface to observe tags distribution and create datasets


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.autograd import Variable

from frontend import Frontend_mine, Frontend_won
from backend import Backend
from data_loader import get_DataLoader

main_dict = {"frontend_dict":
             {"list_out_channels":[128,128,128,256,256,256], 
              "list_kernel_sizes":[(5,5),(5,5),(3,3),(3,3),(3,3),(3,3)],
              "list_pool_sizes":  [(3,2),(2,2),(2,2),(2,1),(2,1),(2,1)], 
              "list_avgpool_flags":[False,False,False,False,False,True]},
             
             "backend_dict":
             {"n_class":50,
              "bert_config":None, 
              "recurrent_units":2}} # pass None to deactivate

In [2]:
class CRNNSA(nn.Module):
    
    """
    TODO: explore whether "rec_unit -> self_att -> rec_unit"
    would have worked better.
    """
    
    # BERT-based Convolutional Recurrent Neural Network
    # Code adopted from https://github.com/minzwon/sota-music-tagging-models/
    def __init__(self, main_dict=None, backend=None, frontend=None):
        super(CRNNSA, self).__init__()
        

        if main_dict is not None:
            self.frontend = Frontend_mine(main_dict["frontend_dict"])
            self.backend = Backend(main_dict)
        else:
            self.frontend = frontend
            self.backend = backend


    def forward(self, spec):
        
        x = self.backend(self.frontend(spec))
        
        return x

In [3]:
crnnsa = CRNNSA(main_dict)
print(crnnsa)

CRNNSA(
  (frontend): Frontend_mine(
    (spec_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_block1): BlockChoi(
      (conv): Conv2dSame(1, 128, kernel_size=(5, 5), stride=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
      (pool): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv_block2): BlockChoi(
      (conv): Conv2dSame(128, 128, kernel_size=(5, 5), stride=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (conv_block3): BlockChoi(
      (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (bn): BatchNorm2d(128

In [4]:
spec = torch.FloatTensor(np.load("/import/c4dm-datasets/rmri_self_att/msd/1/1/113801.npy")[np.newaxis,np.newaxis,:,:])

In [5]:
crnnsa(spec).shape

torch.Size([1, 50])

In [6]:
dataLoader = get_DataLoader(batch_size=1, mode="valid")

In [7]:
for x,y in dataLoader:
    print(x.shape)
    print(y.shape)
    print(crnnsa.frontend(x).shape)
    print(crnnsa(x).shape)
    break

torch.Size([1, 1, 96, 937])
torch.Size([1, 50])
torch.Size([1, 256, 117])
torch.Size([1, 50])


In [8]:
print(sum(p.numel() for p in crnnsa.frontend.parameters() if p.requires_grad))
print(sum(p.numel() for p in crnnsa.backend.parameters() if p.requires_grad))
print(sum(p.numel() for p in crnnsa.backend.seq2seq.parameters() if p.requires_grad))

2038274
2447666
789504


In [9]:
torch.manual_seed(0) # for reproducibility
torch.rand((1, 256))

tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964, 0.4556,
         0.6323, 0.3489, 0.4017, 0.0223, 0.1689, 0.2939, 0.5185, 0.6977, 0.8000,
         0.1610, 0.2823, 0.6816, 0.9152, 0.3971, 0.8742, 0.4194, 0.5529, 0.9527,
         0.0362, 0.1852, 0.3734, 0.3051, 0.9320, 0.1759, 0.2698, 0.1507, 0.0317,
         0.2081, 0.9298, 0.7231, 0.7423, 0.5263, 0.2437, 0.5846, 0.0332, 0.1387,
         0.2422, 0.8155, 0.7932, 0.2783, 0.4820, 0.8198, 0.9971, 0.6984, 0.5675,
         0.8352, 0.2056, 0.5932, 0.1123, 0.1535, 0.2417, 0.7262, 0.7011, 0.2038,
         0.6511, 0.7745, 0.4369, 0.5191, 0.6159, 0.8102, 0.9801, 0.1147, 0.3168,
         0.6965, 0.9143, 0.9351, 0.9412, 0.5995, 0.0652, 0.5460, 0.1872, 0.0340,
         0.9442, 0.8802, 0.0012, 0.5936, 0.4158, 0.4177, 0.2711, 0.6923, 0.2038,
         0.6833, 0.7529, 0.8579, 0.6870, 0.0051, 0.1757, 0.7497, 0.6047, 0.1100,
         0.2121, 0.9704, 0.8369, 0.2820, 0.3742, 0.0237, 0.4910, 0.1235, 0.1143,
         0.4725, 0.5751, 0.2