In [1]:
import soundata
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from IPython.display import Audio

import librosa
import torchlibrosa
import tqdm
import models

In [2]:
class MixAugmentDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, signal_categories, noise_categories, folds=[1,2,3], augment_type=None, sample_rate=22050, 
                    window_size=2048, mel_bins=128, hop_size=512, fmin=50, fmax=11025):
        self.augment_type = augment_type
        self.sample_rate = sample_rate
        
        x = []
        y = []
        noise = []
        clips = esc50.load_clips()
        
        for clip_id in tqdm.tqdm(dataset.clip_ids):
            clip = clips[clip_id]
            if clip.fold in folds:
                waveform = clip.audio[0]
                original_sample_rate = clip.audio[1]
                
                waveform = librosa.resample(
                    waveform, 
                    original_sample_rate, 
                    sample_rate,
                    res_type='kaiser_fast'
                )
                
                self.sample_rate = clip.audio[1]
                
                if clip.category in noise_categories:
                    noise.append(waveform)
                elif clip.category in signal_categories:
                    x.append(waveform)
                    # set label to the index of the signal category
                    label = np.where(signal_categories == clip.category)[0][0]
                    y.append(label)
                    
        self.x = np.array(x)
        self.y = np.array(y)
        self.noise = np.array(noise)
        
        # feature extractors   
        def logmel_extractor(z):
            return librosa.feature.melspectrogram(
                y          = z, 
                sr         = self.sample_rate,
                n_mels     = mel_bins,
                n_fft      = window_size, 
                hop_length = hop_size, 
                win_length = None, 
                window     = 'hann', 
                center     = True, 
                pad_mode   = 'reflect', 
                power      = 2.0,
                fmin       = fmin,
                fmax       = fmax,
                #ref        = 1.0,
                #amin       = 1e-10,
                #top_db     = None
            )
        
        self.logmel_extractor = logmel_extractor
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        
        if self.augment_type == 'waveform':
            p = np.random.uniform(0.1, 0.3)
            idx = np.random.randint(0, len(self.noise))
            aug = self.noise[idx]
            x_aug = (x * (1-p)) + (aug * p)
            
            # extract logmel after augmentation
            x_aug = self.logmel_extractor(x_aug)
            x_aug = np.transpose(x_aug, (1, 0))
            x_aug = np.expand_dims(x_aug, 0)
            
        elif self.augment_type == 'logmel':
            p = np.random.uniform(0.1, 0.3)
            idx = np.random.randint(0, len(self.noise))
            aug = self.noise[idx]
            
            # extract logmel before augmentation
            x = self.logmel_extractor(x)
            x = np.transpose(x, (1, 0))
            x = np.expand_dims(x, 0)
            
            aug = self.logmel_extractor(aug)
            aug = np.transpose(aug, (1, 0))
            aug = np.expand_dims(aug, 0)
            
            x_aug = (x * (1-p)) + (aug * p)
        else:
            # extract logmel without augmentation
            x_aug = self.logmel_extractor(x)
            x_aug = np.transpose(x_aug, (1, 0))
            x_aug = np.expand_dims(x_aug, 0)
        return x_aug, y

In [3]:
def random_signal_and_noise_categories(dataset, nb_signal_categories):
    categories = []
    clips = dataset.load_clips()
    clip_ids = dataset.clip_ids
    for clip_id in clip_ids:
        categories.append(clips[clip_id].category)
    categories = list(set(categories))

    signal_categories = np.random.choice(categories, nb_signal_categories, replace=False)
    noise_categories = list(set(categories).difference(set(signal_categories)))
    
    return signal_categories, noise_categories

In [4]:
def train(model, optimizer, loss_function, train_loader):
    model.train()
    
    running_loss = 0
    count = 0
    for (x, y) in tqdm.tqdm(train_loader):
        x = x.cuda()
        y = y.type(torch.LongTensor).cuda()
        
        optimizer.zero_grad()
        
        y_pred = model(x)['clipwise_output']
        loss = loss_function(y_pred, y)
        loss.backward()
        
        optimizer.step()
    
        running_loss += loss.item()
        count += 1
    return running_loss / count

In [5]:
def evaluate(model, loss_function, loader):
    model.eval()
    
    count = 0
    running_acc = 0
    running_loss = 0
    
    ys = []
    ys_pred_probs = []
    for (x, y) in loader:
        x = x.cuda()
        y = y.type(torch.LongTensor).cuda()
        
        y_pred = model(x)['clipwise_output']
        loss = loss_function(y_pred, y)
        running_loss += loss.item()
        
        y_pred_prob = y_pred.detach().cpu().numpy()
        y_pred = np.argmax(y_pred_prob, axis=1)
        y      = y.detach().cpu().numpy().astype(np.int)

        running_acc += np.mean(y==y_pred)
        
        count+=1
        
        ys.append(y)
        ys_pred_probs.append(y_pred_prob)
    
    return running_loss / count, running_acc / count, np.concatenate(ys), np.concatenate(ys_pred_probs)

In [None]:
# setup writer
writer = SummaryWriter(log_dir='log_dir/no_augmentation')

# setup model
sample_rate = 22050
window_size = 2048
hop_size = 512
mel_bins = 128
fmin = 50
fmax = sample_rate // 2
classes_num = 10
feature_type = 'logmel'

model = models.Cnn6(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, feature_type)
model = model.cuda()

best_model = models.Cnn6(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, feature_type)
best_model = best_model.cuda()

# setup training
learning_rate = 1e-4
patience = 100
epochs = 1000

optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
loss_function = torch.nn.CrossEntropyLoss()

augment_type = None

# setup datasets
esc50 = soundata.initialize('esc50')
train_folds = [1,2,3]
valid_folds = [4]
test_folds = [5]

signal_categories, noise_categories = random_signal_and_noise_categories(esc50, classes_num)

train_dataset = MixAugmentDataset(esc50, signal_categories, noise_categories, folds=train_folds, 
                                 augment_type=augment_type, sample_rate=sample_rate,
                                 window_size=window_size, mel_bins=mel_bins, hop_size=hop_size,
                                 fmin=fmin, fmax=fmax)
valid_dataset = MixAugmentDataset(esc50, signal_categories, noise_categories, folds=valid_folds, 
                                 augment_type=augment_type, sample_rate=sample_rate,
                                 window_size=window_size, mel_bins=mel_bins, hop_size=hop_size,
                                 fmin=fmin, fmax=fmax)
# test_dataset = MixAugmentDataset(esc50, signal_categories, noise_categories, folds=test_folds, 
#                                 augment=augment, sample_rate=sample_rate,
#                                 window_size=window_size, mel_bins=mel_bins, hop_size=hop_size,
#                                 fmin=fmin, fmax=fmax)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=8)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=8)
#test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8)


best_valid_loss = np.inf
best_epoch = 0
epoch = 0
not_converged = True
while not_converged:
    print("Epoch: {}".format(epoch))
    train_loss = train(model, optimizer, loss_function, train_loader)
    valid_loss, valid_acc, _, _ = evaluate(model, loss_function, valid_loader)
    print("valid loss: {}, acc: {}".format(valid_loss, valid_acc))
    writer.add_scalar('loss/train', train_loss, epoch)
    writer.add_scalar('loss/valid', valid_loss, epoch)
    writer.add_scalar('acc/valid', valid_acc, epoch)

    epoch += 1

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        best_epoch = epoch
        best_model.load_state_dict(model.state_dict())

    # convergence criterion
    if epoch - best_epoch >= patience or epoch > epochs:
        not_converged = False

100%|██████████| 2000/2000 [00:35<00:00, 56.84it/s]
100%|██████████| 2000/2000 [00:11<00:00, 171.76it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

Epoch: 0


100%|██████████| 15/15 [00:01<00:00, 14.83it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2728600025177004, acc: 0.15
Epoch: 1


100%|██████████| 15/15 [00:00<00:00, 18.93it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2112403869628907, acc: 0.175
Epoch: 2


100%|██████████| 15/15 [00:00<00:00, 18.64it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1711857318878174, acc: 0.2
Epoch: 3


100%|██████████| 15/15 [00:00<00:00, 19.06it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1911591053009034, acc: 0.15
Epoch: 4


100%|██████████| 15/15 [00:00<00:00, 17.86it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.192969465255737, acc: 0.175
Epoch: 5


100%|██████████| 15/15 [00:00<00:00, 18.41it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2193193435668945, acc: 0.1875
Epoch: 6


100%|██████████| 15/15 [00:00<00:00, 19.17it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2349711894989013, acc: 0.15
Epoch: 7


100%|██████████| 15/15 [00:00<00:00, 18.32it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.189618635177612, acc: 0.1625
Epoch: 8


100%|██████████| 15/15 [00:00<00:00, 18.79it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2182164669036863, acc: 0.1625
Epoch: 9


100%|██████████| 15/15 [00:00<00:00, 18.69it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.203043270111084, acc: 0.175
Epoch: 10


100%|██████████| 15/15 [00:00<00:00, 18.56it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2104072093963625, acc: 0.2
Epoch: 11


100%|██████████| 15/15 [00:00<00:00, 18.89it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.216559886932373, acc: 0.1875
Epoch: 12


100%|██████████| 15/15 [00:00<00:00, 18.51it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1887277126312257, acc: 0.125
Epoch: 13


100%|██████████| 15/15 [00:00<00:00, 18.80it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2198331356048584, acc: 0.2
Epoch: 14


100%|██████████| 15/15 [00:00<00:00, 18.93it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.141807746887207, acc: 0.25
Epoch: 15


100%|██████████| 15/15 [00:00<00:00, 18.41it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.192142963409424, acc: 0.175
Epoch: 16


100%|██████████| 15/15 [00:00<00:00, 18.51it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1690470218658446, acc: 0.225
Epoch: 17


100%|██████████| 15/15 [00:00<00:00, 18.62it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.229466199874878, acc: 0.1625
Epoch: 18


100%|██████████| 15/15 [00:00<00:00, 18.52it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1450403213500975, acc: 0.2
Epoch: 19


100%|██████████| 15/15 [00:00<00:00, 18.88it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.2095767498016357, acc: 0.175
Epoch: 20


100%|██████████| 15/15 [00:00<00:00, 18.89it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.145741605758667, acc: 0.225
Epoch: 21


100%|██████████| 15/15 [00:00<00:00, 19.11it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.167265844345093, acc: 0.225
Epoch: 22


100%|██████████| 15/15 [00:00<00:00, 18.09it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1488459587097166, acc: 0.2
Epoch: 23


100%|██████████| 15/15 [00:00<00:00, 18.76it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1081613063812257, acc: 0.2625
Epoch: 24


100%|██████████| 15/15 [00:00<00:00, 18.81it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.124941825866699, acc: 0.3
Epoch: 25


100%|██████████| 15/15 [00:00<00:00, 18.54it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1008806943893434, acc: 0.2625
Epoch: 26


100%|██████████| 15/15 [00:00<00:00, 19.21it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1931944370269774, acc: 0.225
Epoch: 27


100%|██████████| 15/15 [00:00<00:00, 18.63it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.11301646232605, acc: 0.2375
Epoch: 28


100%|██████████| 15/15 [00:00<00:00, 18.92it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.056243419647217, acc: 0.3125
Epoch: 29


100%|██████████| 15/15 [00:00<00:00, 18.98it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.0967803478240965, acc: 0.25
Epoch: 30


100%|██████████| 15/15 [00:00<00:00, 18.95it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.131458377838135, acc: 0.2625
Epoch: 31


100%|██████████| 15/15 [00:00<00:00, 18.54it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.070169973373413, acc: 0.275
Epoch: 32


100%|██████████| 15/15 [00:00<00:00, 18.69it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1517130851745607, acc: 0.25
Epoch: 33


100%|██████████| 15/15 [00:00<00:00, 18.57it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.0704947471618653, acc: 0.325
Epoch: 34


100%|██████████| 15/15 [00:00<00:00, 18.44it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1654253005981445, acc: 0.2625
Epoch: 35


100%|██████████| 15/15 [00:00<00:00, 18.46it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1614270210266113, acc: 0.225
Epoch: 36


100%|██████████| 15/15 [00:00<00:00, 18.61it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.183646631240845, acc: 0.1625
Epoch: 37


100%|██████████| 15/15 [00:00<00:00, 18.77it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.1265883922576903, acc: 0.275
Epoch: 38


100%|██████████| 15/15 [00:00<00:00, 18.46it/s]
  0%|          | 0/15 [00:00<?, ?it/s]

valid loss: 2.0673206090927123, acc: 0.325
Epoch: 39
