In [1]:
import os
import sys
sys.path.append('../..')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from datasets.gtzan import GTZAN_MELSPEC as GTZAN

In [2]:
dataset = GTZAN(phase='all', min_segments=10)
dataloader = DataLoader(dataset, batch_size=3)

In [3]:
class BasicConv2d(nn.Sequential):
    def __init__(self, num_in_channels, num_out_channels, **kwargs):
        super().__init__()
        
        self.add_module('bn', nn.BatchNorm2d(num_in_channels, eps=1e-3))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_in_channels, num_out_channels, **kwargs))

    
class InceptionModule(nn.Module):
    def __init__(self, num_in_channels, num_out_channels):
        super().__init__()

        self.branch1x1 = BasicConv2d(num_in_channels, num_out_channels,
                                     kernel_size=1)

        self.branch5x5_1 = BasicConv2d(num_in_channels, num_out_channels,
                                       kernel_size=1)
        self.branch5x5_2 = BasicConv2d(num_out_channels, num_out_channels, 
                                       kernel_size=(1, 5), padding=(0, 2))

        self.branch3x3_1 = BasicConv2d(num_in_channels, num_out_channels,
                                       kernel_size=1)
        self.branch3x3_2 = BasicConv2d(num_out_channels, num_out_channels,
                                       kernel_size=(1, 3), padding=(0, 1))

        self.branch_pool = BasicConv2d(num_in_channels, num_out_channels,
                                       kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch_pool = F.max_pool2d(x, kernel_size=(1, 3),
                                   stride=1, padding=(0, 1))
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
        return torch.cat(outputs, dim=1)


class DenseModule(nn.Module):
    def __init__(self, num_dense_modules, num_in_channels, num_out_channels):
        super().__init__()
        
        for i in range(num_dense_modules):
            layer = InceptionModule(num_in_channels, num_out_channels)
            self.add_module('dense_layer_%d' % i, layer)
            num_in_channels += 4 * num_out_channels

    
    def forward(self, x):
        for name, m in self.named_children():
            outputs = m(x)
            x = torch.cat([x, outputs], dim=1)
        return x


class TransitionModule(nn.Sequential):
    def __init__(self, num_in_channels, num_out_channels):
        super().__init__()
        
        self.add_module('bn', nn.BatchNorm2d(num_in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_in_channels, num_out_channels,
                                          kernel_size=1))
        self.add_module('avg_pool', nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)))
            

class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return x


class DenseInception(nn.Module):
    def __init__(self, num_in_channels, num_out_channels, num_dense_modules, num_classes=10):
        super().__init__()
        
        self.preprocessing = nn.Sequential(
            nn.Conv2d(num_in_channels, num_out_channels,
                      kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(num_out_channels, eps=0.001),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(1, 4)))
        num_in_channels = num_out_channels

        self.features = nn.Sequential(
            DenseModule(num_dense_modules, num_in_channels,
                        num_out_channels),
            TransitionModule(num_in_channels + num_dense_modules * 4 * num_out_channels,
                             num_out_channels))
        num_in_channels = num_out_channels
        
        self.classifier = nn.Sequential(
            nn.BatchNorm2d(num_in_channels),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((128, 1)),
            Flatten(),
            nn.Linear(num_in_channels * 128, num_classes))
        
    def forward(self, x):
        x = self.preprocessing(x)
        x = self.features(x)
        x = self.classifier(x)
        return x


In [4]:
net = DenseInception(1, 32, 3)

In [5]:
inputs = torch.randn(3, 10, )