In [7]:
import numpy as np
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck
import torch.nn.functional as F
import pandas as pd
import torch.utils.data as data
import argparse
import sys

In [8]:
class AdaptedResNet(ResNet):
    def __init__(self):
        super(AdaptedResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=2) # Based on ResNet18
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3,bias=False)

class GenreClassifier(nn.Module):
    def __init__(self):
        super(GenreClassifier, self).__init__()

        self.fc = nn.Linear(2998, 3 * 32 * 32)
        self.net = AdaptedResNet()
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 3, 32, 32) 
        x = self.net(x)
        return x

In [9]:
def train(args, model, device, train_loader, optimizer, epoch, loss_fn):
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tlr: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), optimizer.param_groups[0]['lr']))
            if args.dry_run:
                break

In [17]:
def test(model, device, test_loader, loss_fn):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        acc))
    return acc

In [18]:
class BeatDataset(data.Dataset):
    def __init__(self, train):
        if train:
            self.dataset = np.load('train_data.npy')
            self.labels = np.load('train_labels.npy')
        else:
            self.dataset = np.load('test_data.npy')
            self.labels = np.load('test_labels.npy')
            
        self.labels = np.array([(0 if l=='Jazz' else 1) for l in self.labels]).astype('int64')
#         self.transform_fn = transform=transforms.Compose([
#                                 transforms.ToTensor(),
#                                 #transforms.Normalize((0.5,), (0.5,))
#                                 ])
    def __getitem__(self, index):
        data = torch.from_numpy(self.dataset[index])
        label = torch.tensor(self.labels[index])
        return (data, label)
    def __len__(self):
        return len(self.labels)

In [21]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("-f", "--fff", help="dummy argument", default="1")
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-model', type=str, default='first')
args = parser.parse_args(sys.argv[1:])
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': args.batch_size, 'shuffle': True}
test_kwargs = {'batch_size': args.test_batch_size, 'shuffle': False}
if use_cuda:
    cuda_kwargs = {'num_workers': 0,
                   'pin_memory': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

trainset = BeatDataset(train=True)
testset = BeatDataset(train=False)

train_loader = torch.utils.data.DataLoader(trainset,**train_kwargs)
test_loader = torch.utils.data.DataLoader(testset, **test_kwargs)

loss_fn = nn.CrossEntropyLoss()

model = GenreClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, betas=(0.5, 0.999), weight_decay=5e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
best_acc = 0
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch, loss_fn)
    acc = test(model, device, test_loader, loss_fn)
    scheduler.step()

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), '{}.pt'.format(args.save_model))
        print('Saved model to {}.pt'.format(args.save_model))
print('Best Accuracy: {}'.format(best_acc))


Test set: Average loss: 0.0048, Accuracy: 85/179 (47.49%)

Saved model to first.pt

Test set: Average loss: 0.0066, Accuracy: 85/179 (47.49%)


Test set: Average loss: 0.0082, Accuracy: 85/179 (47.49%)


Test set: Average loss: 0.0060, Accuracy: 88/179 (49.16%)

Saved model to first.pt

Test set: Average loss: 0.0036, Accuracy: 118/179 (65.92%)

Saved model to first.pt

Test set: Average loss: 0.0032, Accuracy: 126/179 (70.39%)

Saved model to first.pt

Test set: Average loss: 0.0031, Accuracy: 130/179 (72.63%)

Saved model to first.pt

Test set: Average loss: 0.0031, Accuracy: 132/179 (73.74%)

Saved model to first.pt

Test set: Average loss: 0.0031, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0032, Accuracy: 130/179 (72.63%)


Test set: Average loss: 0.0032, Accuracy: 130/179 (72.63%)


Test set: Average loss: 0.0032, Accuracy: 130/179 (72.63%)


Test set: Average loss: 0.0032, Accuracy: 130/179 (72.63%)


Test set: Average loss: 0.0033, Accuracy: 130/179 (72.63%)


Test 


Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0040, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0038, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0038, Accuracy: 133/179 (74.30%)

Saved model to first.pt

Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0040, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0040, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0040, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0039, Accuracy: 130/179 (72.63%)


Test set: Average loss: 0.0039, Accuracy: 132/179 (73.74%)




Test set: Average loss: 0.0042, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0041, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0042, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0042, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0042, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0041, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0041, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0042, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0042, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0041, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0041, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0043, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0043, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0042, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0042, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0043, Accuracy: 131/179 (73.18%)


Test set: Average loss:


Test set: Average loss: 0.0044, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0043, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0043, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0042, Accuracy: 134/179 (74.86%)


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0044, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0043, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0043, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0044, Accuracy: 131/179 (73.18%)


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0043, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0043, Accuracy: 133/179 (74.30%)


Test set: Average loss:


Test set: Average loss: 0.0044, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0045, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0045, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0045, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0045, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0045, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0045, Accuracy: 132/179 (73.74%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0046, Accuracy: 130/179 (72.63%)


Test set: Average loss: 0.0044, Accuracy: 133/179 (74.30%)


Test set: Average loss: 0.0047, Accuracy: 128/179 (71.51%)


Test set: Average loss: