In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from training.frontend import Frontend_mine, Frontend_won
from training.backend import Backend

In [2]:
stack_dict = {"list_out_channels":[128,128,128,256,256,256], 
              "list_kernel_sizes":[(3,3),(3,3),(3,3),(3,3),(3,3),(3,3)],
              "list_pool_sizes":  [(3,2),(2,2),(2,2),(2,2),(2,2),(2,2)], 
              "list_avgpool_flags":[False,False,False,False,False,True]}

frontend_mine = Frontend_mine(stack_dict)
print(frontend_mine)
print(sum(p.numel() for p in frontend_mine.parameters() if p.requires_grad))

Frontend_mine(
  (freq_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_block1): BlockChoi(
    (conv): Conv2dSame(1, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ELU(alpha=1.0)
    (pool): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv_block2): BlockChoi(
    (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ELU(alpha=1.0)
    (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv_block3): BlockChoi(
    (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, tr

In [3]:
frontend_won = Frontend_won()
print(frontend_won)
print(sum(p.numel() for p in frontend_won.parameters() if p.requires_grad))

Frontend_won(
  (spec_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Res_2d(
    (conv_1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer2): Res_2d(
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(128, eps=1e-05, momentum=0.1,

In [4]:
backend = Backend(recurrent=True)
print(backend)
print(sum(p.numel() for p in backend.parameters() if p.requires_grad))

Backend(
  (seq2seq): GRU(256, 256, num_layers=2, batch_first=True)
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=256, out_features=256, bias=True)
            (key): Linear(in_features=256, out_features=256, bias=True)
            (value): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.5, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=256, out_features=256, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.4, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=256, out_features=1024, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=1024, out_features=256, bias=True)
          (LayerNorm): BertLayerNorm()
        

In [5]:
spec = torch.FloatTensor(np.load("/import/c4dm-datasets/rmri_self_att/msd/1/1/113801.npy")[np.newaxis,np.newaxis,:,:])

In [6]:
print(frontend_won(spec).shape)
print(backend(frontend_won(spec)).shape)

torch.Size([1, 256, 469])
torch.Size([1, 50])


In [7]:
print(frontend_mine(spec).shape)
print(backend(frontend_mine(spec)).shape)

torch.Size([1, 256, 58])
torch.Size([1, 50])


# https://github.com/minzwon/sota-music-tagging-models/tree/master/training