In [1]:
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
import os
import glob
import re

test_data_path = r"D:\Documents\datasets\AIST4010\muse\wavs\000xQL6tZNLJzIrtIgxqSl.npy"
checkpoint_path = r"Wavegram_Cnn14_mAP=0.389.pth"

In [2]:
def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvPreWavBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvPreWavBlock, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=3, stride=1,
                              padding=1, bias=False)
                              
        self.conv2 = nn.Conv1d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=3, stride=1, dilation=2, 
                              padding=2, bias=False)
                              
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        x = F.max_pool1d(x, kernel_size=pool_size)
        
        return x

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x   

In [3]:
class Wavegram_Cnn14(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 
        fmax, classes_num):
        
        super(Wavegram_Cnn14, self).__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        self.pre_conv0 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, stride=5, padding=5, bias=False)
        self.pre_bn0 = nn.BatchNorm1d(64)
        self.pre_block1 = ConvPreWavBlock(64, 64)
        self.pre_block2 = ConvPreWavBlock(64, 128)
        self.pre_block3 = ConvPreWavBlock(128, 128)
        self.pre_block4 = ConvBlock(in_channels=4, out_channels=64)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
        
        self.init_weight()

    def init_weight(self):
        init_layer(self.pre_conv0)
        init_bn(self.pre_bn0)
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset)
 
    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""

        # Wavegram
        a1 = F.relu_(self.pre_bn0(self.pre_conv0(input[:, None, :])))
        a1 = self.pre_block1(a1, pool_size=4)
        a1 = self.pre_block2(a1, pool_size=4)
        a1 = self.pre_block3(a1, pool_size=4)
        a1 = a1.reshape((a1.shape[0], -1, 32, a1.shape[-1])).transpose(2, 3)
        a1 = self.pre_block4(a1, pool_size=(2, 1))

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            a1 = do_mixup(a1, mixup_lambda)
        
        x = a1
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)
        
        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))
        
        output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

        return output_dict

In [4]:
class WaveNet(nn.Module):
    def __init__(self, out_dim, fe_dim=2048, 
                 sr=22050, wsize=520, hsize=320, mel_bins=128, fmin=50, fmax=8000, 
                 fcs=[], dropout=0.2, act=nn.ReLU, init=nn.init.kaiming_normal_):
        super().__init__()
        self.wavecnn = Wavegram_Cnn14(sample_rate=sr, window_size=wsize, 
                                      hop_size=hsize, mel_bins=mel_bins, 
                                      fmin=fmin, fmax=fmax, classes_num=527)
        self.wavecnn.fc_audioset.require_grad = False
        checkpoint = torch.load(r"Wavegram_Cnn14_mAP=0.389.pth", 
                                map_location="cuda")
        self.wavecnn.load_state_dict(checkpoint['model'])
        fcs = [fe_dim] + fcs + [out_dim]
        fc_layers = []
        for idx in range(1, len(fcs)):
            if idx != 1:
                fc_layers.append(nn.Dropout(dropout))
            fc_layers.append(nn.Linear(fcs[idx-1], fcs[idx]))
            if act and idx != (len(fcs) - 1):
                fc_layers.append(act())
        if init:
            for layer in fc_layers:
                if type(layer) == nn.Linear:
                    init(layer.weight)
        self.classifier = nn.Sequential(*fc_layers)

    def forward(self, x, mixup_lambda=None):
        spectrum_features = self.wavecnn(x, mixup_lambda)['embedding']
        return self.classifier(spectrum_features)

In [5]:
sr = 22050
wsize, hsize, mel_bins = 520, 320, 128
fmin, fmax = 50, 8000
fcs, dropout, act = [2048], 0.5, nn.ReLU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = WaveNet(1, 2048, sr=sr, wsize=wsize, hsize=hsize, mel_bins=mel_bins, 
                fmin=fmin, fmax=fmax, fcs=fcs, dropout=dropout, act=act).half().to(device)

In [6]:
from torch.utils.data import Dataset, DataLoader

data_dir = r"D:\Documents\datasets\AIST4010\muse"
wav_dir = os.path.join(data_dir, "wavs")
songs_fp = os.path.join(data_dir, "extracted_data.csv")

def get_labels(ids, fp=songs_fp):
    songs_data = pd.read_csv(fp)
    songs_data.set_index("spotify_id", inplace=True)
    labels = songs_data.loc[ids, ["valence_tags", "arousal_tags", "dominance_tags"]].values
    return labels

class LazyWavDataset(Dataset):
    def __init__(self, dir_path, labels, re_pattern, 
                 transform=None, extension='.npy', 
                 pad_len=661500, sample_size=8500):
        super().__init__()
        self.files = glob.glob(os.path.join(dir_path, '*'+extension))
        self.files.sort(key=lambda fp: re.match(re_pattern, fp).group(1))
        self.ids = [re.match(re_pattern, fp).group(1) for fp in self.files]
        self.labels = torch.from_numpy(labels)
        if sample_size:
            self.files = self.files[:8500]
            self.labels = self.labels[:8500]
        self.transform = transform
        self.pad_len = pad_len
    
    def __getitem__(self, item):
        file = self.files[item]
        song_id = self.ids[item]
        wav = np.load(file)
        pad_len = self.pad_len
        if pad_len:
            if len(wav) >= pad_len:
                wav = wav[:pad_len]
            else:
                wav = np.pad(wav, (0, pad_len - len(wav)), 
                             mode='constant', constant_values=(0, 0))
        file = torch.from_numpy(wav)
        label = self.labels[item]
        if self.transform:
            file = self.transform(file)
        return song_id, file, label
    
    def __len__(self):
        return len(self.files)

In [7]:
npy_fps = glob.glob(os.path.join(wav_dir, '*.npy'))
re_pattern = r".*\\([^\\\.]*)\.npy"
npy_fps.sort(key=lambda fp: re.match(re_pattern, fp).group(1))
song_ids = [re.match(re_pattern, fp).group(1) for fp in npy_fps]
labels = get_labels(song_ids)[:, 0] / 9

train_size = 8500
ds = LazyWavDataset(wav_dir, labels, re_pattern, sample_size=train_size)
loader = DataLoader(ds, batch_size=32)

In [8]:
LR, MOMENTUM, DECAY = 1e-3, 0.9, 1e-3
criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=DECAY)
optimizer = torch.optim.SGD(model.parameters(), lr=LR, 
                            momentum=MOMENTUM, weight_decay=DECAY)

EPOCHS = 10
for epoch in range(EPOCHS):
    running_loss = 0
    for ids, wavs, labels in loader:
        optimizer.zero_grad()
        wavs, labels = wavs.half().to(device), labels.reshape(-1, 1).half().to(device)
#         input()
        outputs = model(wavs, None)
#         input()
#         print(outputs.shape, wavs.shape, labels.shape)
#         input()
#         break
        loss = criterion(outputs, labels)
        running_loss += loss.item() * len(ids)
        loss.backward()
        optimizer.step()
    epoch_loss = running_loss / len(loader.dataset)
    print(f"loss: {epoch_loss:.5f}")

loss: 0.17147
loss: 0.06327
loss: 0.04927
loss: 0.04261
loss: 0.04078
loss: 0.03924
loss: 0.03784
loss: 0.03665
loss: 0.03703
loss: 0.03545


In [11]:
criterion(outputs, labels)

tensor(33.6250, device='cuda:0', dtype=torch.float16,
       grad_fn=<MseLossBackward0>)

In [14]:
import sys
sys.getsizeof(wavs.storage())

21168056

In [9]:
import time