In [None]:
import os
import os.path as osp
import sys
from collections import OrderedDict
sys.path.append(osp.abspath('..'))

import numpy as np
import torch
import torch.cuda as cuda
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import models
from tqdm import tqdm_notebook as tqdm

from datasets.gtzan import GTZAN_MELSPEC as GTZAN

%load_ext autoreload
%autoreload 2

In [2]:
# Random seeds
torch.manual_seed(1234)
cuda.manual_seed_all(1234)
np.random.seed(1234)

SEGMENTS = 10
BATCH_SIZE = 8
EPOCHS = 100
LR = 1e-3
NUM_CLASSES = 10
NUM_KFOLD = 10
NUM_REDUCE_LR_PATIENCE = 3
NUM_EARLY_STOPPING_PATIENCE = 10

DEVICE = torch.device('cuda:0')

In [3]:
class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return x
    

class NET(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = models.vgg16_bn().features
        self.classifier = nn.Sequential(
            Flatten(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes))

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [4]:
def run_epoch(net, dataloader, criterion, phase):
    if phase == 'train':
        net.train()
    else:
        net.eval()

    running_samples = 0
    running_segments = 0
    running_loss = 0
    running_corrects = 0
    running_seg_corrects = 0
    with tqdm(dataloader, total=len(dataloader)) as progress:
        for inputs, labels in progress:
            num_samples, num_segments, num_freqs, num_frames = inputs.shape
            inputs = inputs.type(torch.FloatTensor).to(DEVICE)
            labels = labels.to(DEVICE)
            labels_ = labels

            inputs = inputs.view(num_samples * num_segments, 1, num_freqs, num_frames)
            inputs = inputs.expand(-1, 3, -1, -1)
            labels = labels.expand(num_segments, num_samples).transpose_(0, 1)
            labels = labels.contiguous().view(labels.numel())

            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_samples += num_samples
            running_segments += num_samples * num_segments

            running_loss += loss.item() * num_samples * num_segments

            preds = outputs.contiguous().view(num_samples, num_segments, NUM_CLASSES)
            preds = F.softmax(preds, dim=2)
            preds = preds.sum(dim=1)
            _, preds = torch.max(preds, 1)
            running_corrects += torch.sum(preds == labels_.data).item()

            _, seg_preds = torch.max(outputs, 1)
            running_seg_corrects += torch.sum(seg_preds == labels.data).item()

            loss = running_loss / running_segments
            acc = running_corrects / running_samples
            seg_acc = running_seg_corrects / running_segments

            progress.set_postfix(OrderedDict(
                phase=phase,
                loss='{:.4f}'.format(loss),
                acc='{:.2%}'.format(acc),
                seg_acc='{:.2%}'.format(seg_acc)))

    return loss, acc, seg_acc

In [5]:
criterion = nn.CrossEntropyLoss()
cv_results = []
dataset = GTZAN(phase='all', min_segments=SEGMENTS)
train_set = GTZAN(phase='all', min_segments=SEGMENTS, randomized=True)
test_set = GTZAN(phase='all', min_segments=SEGMENTS)

skf = StratifiedKFold(NUM_KFOLD, shuffle=True, random_state=1234)
for kfold, (train_index, test_index) in enumerate(skf.split(dataset.X, dataset.Y)):
    train_set.X, train_set.Y = dataset.X[train_index], dataset.Y[train_index]
    test_set.X, test_set.Y = dataset.X[test_index], dataset.Y[test_index]
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=9, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=9, pin_memory=True)

    dataloaders = {
        'train': train_loader,
        'test': test_loader,
    }
    
    net = NET().to(DEVICE)
    optimizer = optim.Adam(net.parameters(), lr=LR)
    
    reduce_lr = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=NUM_REDUCE_LR_PATIENCE)
    
    early_stopping_patience = NUM_EARLY_STOPPING_PATIENCE
    best_loss = 1e9
    
    best_net = NET().to(DEVICE)
    best_test_acc = 0
    best_info = {}
    tb_writer = SummaryWriter('runs/vgg16_fold_%d' % kfold)

    with tqdm(range(EPOCHS), total=EPOCHS) as epoch_progress:
        for epoch in epoch_progress:
            early_stopping = False
            train_loss = -1
            train_acc = -1
            train_seg_acc = -1
            test_loss = -1
            test_acc = -1
            test_seg_acc = -1
            for phase in ('train', 'test'):
                loss, acc, seg_acc = run_epoch(net, dataloaders[phase], criterion, phase)
                
                if phase == 'test':
                    test_loss = loss
                    test_acc = acc
                    test_seg_acc = seg_acc
                    if test_acc > best_test_acc:
                        best_net.load_state_dict(net.state_dict())
                        best_test_acc = test_acc

                        train_loss_, train_acc_, train_seg_acc_ = run_epoch(
                            best_net, dataloaders['train'], criterion, 'test')

                        best_info = {
                            'train_loss': train_loss_,
                            'train_acc': train_acc_,
                            'train_seg_acc': train_seg_acc_,
                            'test_loss': test_loss,
                            'test_acc': test_acc,
                            'test_seg_acc': test_seg_acc,
                        }

                        with open('checkpoints/best_vgg16_baseline_params_cv_{}.pth'.format(kfold), 'wb') as f:
                            torch.save(net.state_dict(), f)
                else:
                    train_loss = loss
                    train_acc = acc
                    train_seg_acc = seg_acc
                    reduce_lr.step(loss)
                    if loss > best_loss:
                        early_stopping_patience -= 1
                    else:
                        early_stopping_patience = NUM_EARLY_STOPPING_PATIENCE
                    best_loss = min(best_loss, loss)
                    if early_stopping_patience == 0:
                        early_stopping = True
                        break
                        
            tb_writer.add_scalar('train_loss', train_loss, epoch)
            tb_writer.add_scalar('train_acc', train_acc, epoch)
            tb_writer.add_scalar('train_seg_acc', train_seg_acc, epoch)
            tb_writer.add_scalar('test_loss', test_loss, epoch)
            tb_writer.add_scalar('test_acc', test_acc, epoch)
            tb_writer.add_scalar('test_seg_acc', test_seg_acc, epoch)
                        
            epoch_progress.set_postfix(OrderedDict(
                train_loss='{:.4f}'.format(train_loss),
                train_acc='{:.2%}'.format(train_acc),
                train_seg_acc='{:.2%}'.format(train_seg_acc),
                test_loss='{:.4f}'.format(test_loss),
                test_acc='{:.2%}'.format(test_acc),
                test_seg_acc='{:.2%}'.format(test_seg_acc)))
            
            if early_stopping is True:
                epoch_progress.close()
                break
                    
    # Collect cross-validate summaries
    cv_results.append(best_info)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=225), HTML(value='')))

HBox(children=(IntProgress(value=0, max=25), HTML(value='')))

HBox(children=(IntProgress(value=0, max=225), HTML(value='')))

HBox(children=(IntProgress(value=0, max=225), HTML(value='')))

Process Process-34:
Process Process-31:
Process Process-36:
Process Process-35:
Process Process-32:
Process Process-33:
Process Process-28:
Process Process-30:
Process Process-29:
Traceback (most recent call last):
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/youchen/miniconda3/lib/python3.6/multiproce

  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/youchen/miniconda3/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
KeyboardInterrupt
  File "/home/youchen/miniconda3/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/home/youchen/MusicNew/datasets/gtzan.py", line 167, in <listcomp>
    x = np.array([features.get_melspectrogram(xr) for xr in x])
  File "/home/youchen/miniconda3/lib/python3.6/selectors.p




KeyboardInterrupt: 

In [None]:
for kfold, result in enumerate(cv_results):
    print('Fold {}, train loss: {:.4f}, train acc: {:.2%}, train seg acc: {:.2%}, '
          'test loss: {:.4f}, test acc: {:.2%}, test seg acc: {:.2%}'.format(
              kfold, result['train_loss'], result['train_acc'], result['train_seg_acc'],
              result['test_loss'], result['test_acc'], result['test_seg_acc']))
    
print('{}-fold cross-validation'.format(len(cv_results)))
print('train loss: {:.4f}'.format(sum(x['train_loss'] for x in cv_results) / len(cv_results)))
print('train acc: {:.2%}'.format(sum(x['train_acc'] for x in cv_results) / len(cv_results)))
print('train seg acc: {:.2%}'.format(sum(x['train_seg_acc'] for x in cv_results) / len(cv_results)))
print('test loss: {:.4f}'.format(sum(x['test_loss'] for x in cv_results) / len(cv_results)))
print('test acc: {:.2%}'.format(sum(x['test_acc'] for x in cv_results) / len(cv_results)))
print('test seg acc: {:.2%}'.format(sum(x['test_seg_acc'] for x in cv_results) / len(cv_results)))