In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from architectures.front.conv import BlockChoi

In [3]:
m = BlockChoi(3,16,(3,3),(2,2),True)
m(torch.rand((32,3,2,1000))).shape

torch.Size([32, 16, 1, 500])

In [4]:
sum(p.numel() for p in m.parameters() if p.requires_grad)

480

In [5]:
print(m)

BlockChoi(
  (conv): Conv2dSame(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation): ELU(alpha=1.0)
  (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [6]:
from architectures.stacker import ConvStack

In [7]:
stack_dict = {"list_out_channels":[128,128,128,256,256,256], 
              "list_kernel_sizes":[(3,3),(3,3),(3,3),(3,3),(3,3),(3,3)],
              "list_pool_sizes":  [(3,2),(2,2),(2,2),(2,2),(2,2),(2,2)], 
              "list_avgpool_flags":[False,False,False,False,False,True]}

conv_stack = ConvStack(stack_dict)
print(conv_stack)

ConvStack(
  (conv_block_0): BlockChoi(
    (conv): Conv2dSame(1, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ELU(alpha=1.0)
    (pool): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv_block_1): BlockChoi(
    (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ELU(alpha=1.0)
    (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv_block_2): BlockChoi(
    (conv): Conv2dSame(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ELU(alpha=1.0)
    (pool): MaxPool2d(kernel_size=(2

In [8]:
sum(p.numel() for p in conv_stack.parameters() if p.requires_grad)

1774080

In [9]:
from architectures.self_att_won import *

In [10]:
class Frontend_won(nn.Module):
    '''
    Won et al. 2019
    Toward interpretable music tagging with self-attention.
    Feature extraction with CNN 
    '''
    def __init__(self,
                 n_channels=128):
        super(Frontend_won, self).__init__()
        
        self.spec_bn = nn.BatchNorm2d(1)

        # CNN
        self.layer1 = Res_2d(1, n_channels, stride=2)
        self.layer2 = Res_2d(n_channels, n_channels, stride=2)
        self.layer3 = Res_2d(n_channels, n_channels*2, stride=2)
        self.layer4 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
        self.layer5 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
        self.layer6 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))
        self.layer7 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1))

    def forward(self, x):
        
        # IT'S HERE THAT WE CAN PERMUTE THE INPUT AXES
        # TO PERFORM FREQ_BATCH_NORM
        x = self.spec_bn(x)

        # CNN
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = x.squeeze(2)
        
        return x

In [11]:
class Backend(nn.Module):
    '''
    Won et al. 2019
    Toward interpretable music tagging with self-attention.
    Temporal summary with Transformer encoder.
    '''
    def __init__(self,recurrent=False,n_class=50):
        super(Backend, self).__init__()

        if recurrent:
            self.gru = nn.GRU(256, 256, 1, batch_first=True) # input and output = (batch, seq, feature)
            
        # Transformer encoder
        bert_config = BertConfig(vocab_size=256,
                                 hidden_size=256,
                                 num_hidden_layers=2,
                                 num_attention_heads=8,
                                 intermediate_size=1024,
                                 hidden_act="gelu",
                                 hidden_dropout_prob=0.4,
                                 max_position_embeddings=700,
                                 attention_probs_dropout_prob=0.5)
        self.encoder = BertEncoder(bert_config)
        self.pooler = BertPooler(bert_config)
        self.vec_cls = self.get_cls(256)

        # Dense
        self.dropout = nn.Dropout(0.5)
        self.dense = nn.Linear(256, n_class)

    def get_cls(self, channel):
        np.random.seed(0)
        single_cls = torch.Tensor(np.random.random((1, channel)))
        vec_cls = torch.cat([single_cls for _ in range(64)], dim=0)
        vec_cls = vec_cls.unsqueeze(1)
        return vec_cls

    def append_cls(self, x):
        batch, _, _ = x.size()
        part_vec_cls = self.vec_cls[:batch].clone()
        part_vec_cls = part_vec_cls.to(x.device)
        return torch.cat([part_vec_cls, x], dim=1)

    def forward(self, x):
        
        x = x.permute(0, 2, 1)
        
        x,_ = self.gru(x)
        
        # Get [CLS] token
        x = self.append_cls(x)

        # Transformer encoder
        _,x = self.encoder(x)

        x = self.pooler(x)

        # Dense
        x = self.dropout(x)
        x = self.dense(x)
        x = nn.Sigmoid()(x)

        return x

In [12]:
frontend = Frontend_won()
backend = Backend(recurrent=True)

In [13]:
print(frontend)
print(backend)

Frontend_won(
  (spec_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Res_2d(
    (conv_1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer2): Res_2d(
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(128, eps=1e-05, momentum=0.1,

    def CRNNSA(nn.Module)
        """
        Choi frontend + gated recurrent unit + Won backend
        """
        def __init__(self,n_class=50):
            super(CRNNSA, self).__init__()

In [17]:
print(sum(p.numel() for p in frontend.parameters() if p.requires_grad))
print(sum(p.numel() for p in backend.parameters() if p.requires_grad))

8863490
2052914


In [18]:
spec = torch.FloatTensor(np.load("/import/c4dm-datasets/rmri_self_att/msd/1/1/113801.npy")[np.newaxis,np.newaxis,:,:])

In [20]:
frontend(spec).shape

torch.Size([1, 256, 469])

In [21]:
conv_stack(spec).shape

torch.Size([1, 256, 58])

In [23]:
backend(conv_stack(torch.rand((32,1,128,1292)))).shape

torch.Size([32, 50])

In [26]:
backend(frontend(torch.rand((32,1,128,1292)))).shape

torch.Size([32, 50])

# https://github.com/minzwon/sota-music-tagging-models/tree/master/training

# Remember frequency batchnorm as first 
