This is an implementation of a CRNN for music classification [[ref]](https://arxiv.org/pdf/1609.04243.pdf), which consists of 1D CNNs to capture temporal features from a spectrogram, and then RNNs for summarisation of these features.

In [1]:
import torch.nn as nn

In [6]:
img_height = 96
img_width = 1366

class k1c2(nn.Module):
    def __init__(self, num_classes:int, img_shape:list):
        """
        Convnet for music classification: 4 conv layers then 2 FC layers
        """
        
        super(k1c2, self).__init__()
                
        self.img_width = img_shape[0]
        self.img_height = img_shape[1]
        
        net = nn.Sequential()
        
        def add_CNN_layer(num, size_in, size_out, kernel_size=3, padding=1, bn=True, pool_size=2):
            net.add_module( f"conv{num}", nn.Conv2d(size_in, size_out, kernel_size=3, stride=1, padding=padding) )
            if bn: 
                net.add_module( f"bn{num}",nn.BatchNorm2d(size_out) )
            if pool_size:
                net.add_module( f"pool{num}", nn.MaxPool2d(kernel_size=pool_size) )
        
        def add_FC_layer(num, size_in, size_out, dropout=0.5):
            net.add_module( f"fc{num}", nn.Linear(size_in, size_out))
            if dropout:
                net.add_module( f"dropout{num}", nn.Dropout(dropout))
        
        # for 2.5e5 parameter network
        add_CNN_layer(0, 1, 33)
        add_CNN_layer(1, 33, 33)
        add_CNN_layer(2, 33, 66)
        add_CNN_layer(3, 66, 66)
        add_CNN_layer(4, 100, 133)
        add_FC_layer(5, 66, 66)
        add_FC_layer(6, 66, 66)
        add_FC_layer(7, 66, num_classes)
        
        self.net = net
        
    def forward(self, input):
        return self.net(input)
    