In [1]:
import sys, os
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import pydub
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.nn as nn
from datasets import *
from torch.utils.data import DataLoader
from loading import load_model
import foundation as fd
from foundation.util import NS
#import torchaudio
#from torchaudio import transforms
import h5py as hf

os.environ["CUDA_VISIBLE_DEVICES"]="1"

  from ._conv import register_converters as _register_converters


In [2]:
def get_wav_pred(model, path):
    with hf.File(path, 'r') as f:
        y = f.attrs[model.load_args.lbl_name]
        x = f['wav'].value#[k:(k + traindata.seq_len)]
    x, y = x, torch.tensor(int(y)).long().to(model.load_args.device)
    
    return get_pred(model, torch.from_numpy(x).view(1,1,-1)), y

def get_pred(model, x):
    
    B = x.size(0)
    x = x.float().to(model.load_args.device)#.view(B,1,-1)
    
    with torch.no_grad():
        out = model(x).detach()
        pred = F.softmax(out,-1)#.mean(1).sort(-1)[1][0, -3:]
        
    return pred

def test_naive(model, loaders):
    
    success = []
    failed_top = []
    failed_3 = []
    failed_5 = []
    
    N = len(loaders)
    
    print_freq = max(1, len(loaders)//5)
    for itr, loader in enumerate(loaders):
        
        assert len(loader) == 1
        sample = next(iter(loader))
        
        x, y = sample
        path = loader.dataset.waves.paths[0]
        
        #print(x.size(), y.size())
        
        pred = get_pred(model,x)
        
        #print(y, pred.mean(1))
        
        pred, y = pred.mean(0).sort(-1)[1].detach().cpu(), y[0].detach().cpu().item()
        
        #print(pred.size())
        
        top5 = set(pred[-5:].numpy())
        top3 = set(pred[-3:].numpy())
        top = pred[-1].item()

        if top == y:
            success.append(path)
        else:
            failed_top.append(path)
        if y not in top3:
            failed_3.append(path)
        if y not in top5:
            failed_5.append(path)

        if itr % print_freq == 0:
            print('{}/{}: {:.4f} {:.4f} {:.4f} {:.4f}'.format(itr+1, N, len(success)/(itr+1), len(failed_top)/(itr+1), 
                                                              len(failed_3)/(itr+1), len(failed_5)/(itr+1)))
            

    results = len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N
    
    print('* Results: (errors) {:.4f} {:.4f} {:.4f} {:.4f}'.format(len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N))
    
    return success, failed_top, failed_3, failed_5



def test_all(model, loader):
    
    success = []
    failed_top = []
    failed_3 = []
    failed_5 = []
    
    N = len(loader)
    
    print_freq = max(1, len(loader)//5)
    for itr, sample in enumerate(loader):
        
        x, y = sample['x'], sample['y']
        path = sample['path']
        
        #print(x.size(), y.size())
        
        pred = get_pred(model,x)
        
        #print(y, pred.mean(1))
        
        pred, y = pred.mean(1).sort(-1)[1].detach().cpu(), y.detach().cpu().item()
        
        #print(pred.size())
        
        top5 = set(pred[0, -5:].numpy())
        top3 = set(pred[0, -3:].numpy())
        top = pred[0,-1].item()

        if top == y:
            success.append(path)
        else:
            failed_top.append(path)
        if y not in top3:
            failed_3.append(path)
        if y not in top5:
            failed_5.append(path)

        if itr % print_freq == 0:
            print('{}/{}: {:.4f} {:.4f} {:.4f} {:.4f}'.format(itr+1, N, len(success)/(itr+1), len(failed_top)/(itr+1), 
                                                              len(failed_3)/(itr+1), len(failed_5)/(itr+1)))
            

    results = len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N
    
    print('* Results: (errors) {:.4f} {:.4f} {:.4f} {:.4f}'.format(len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N))
    
    return success, failed_top, failed_3, failed_5
    

In [3]:
root = 'trained_nets/' #crnn-9000-bn_18-12-05-032535/'
temp = 'checkpoint_{}.pth.tar'

In [4]:
meta = torch.load('/home/fleeb/workspace/ml_datasets/audio/yt/meta/full_meta.pth.tar')
split = torch.load('/home/fleeb/workspace/ml_datasets/audio/yt/split.pth.tar')
gid2cat = {v:k for k,v in meta['gids'].items()}
mid2cat = {v:k for k,v in meta['mids'].items()}
gcats = [gid2cat[i] for i in range(len(gid2cat))]
mcats = [mid2cat[i] for i in range(len(mid2cat))]
DS_names = ['train', 'val', 'test']
meta.keys(), split.keys(), gcats, mcats

(dict_keys(['gids', 'mids', 'tracks']),
 dict_keys(['paths', 'lens', 'genre-loss-weights', 'genre-lbls', 'mood-lbls', 'mood-loss-weights']),
 ['jazz-blues',
  'rock',
  'country',
  'dance-elec',
  'reggae',
  'pop',
  'alt-punk',
  'hiphop-rap',
  'holiday',
  'classical',
  'ambient',
  'rnb-soul',
  'cinematic',
  'childrens'],
 ['dark',
  'bright',
  'romantic',
  'angry',
  'sad',
  'happy',
  'calm',
  'funky',
  'dramatic',
  'inspirational'])

In [5]:
model_paths = [p for p in os.listdir(root) if 'new' in p and '-crnn' in p]
names = [p[:p.find('_')] for p in model_paths]
model_paths = [os.path.join(root, p) for p in model_paths]
names

['new-crnn-6000-final',
 'new-mood-crnn-9000-final',
 'new-crnn-9000-final-bn',
 'new-mood-crnn-3000-final',
 'new-crnn-9000-final',
 'new-crnn-3000-final',
 'new-crnn-3000-bn',
 'new-mood-crnn-9000-bn']

In [4]:
results = []
for i in range(20):
    path = os.path.join(root, temp.format(i+1))
    ckpt = torch.load(path)
    t, v = ckpt['train_stats']['accuracy'].avg.item(), ckpt['val_stats']['accuracy'].avg.item()
    results.append((t,v))
    print('ckpt {}: t={:.3f}, v={:.3f}'.format(i+1, t,v))
results = np.array(results)
results.shape

ckpt 1: t=0.572, v=0.517
ckpt 2: t=0.857, v=0.633
ckpt 3: t=0.939, v=0.564
ckpt 4: t=0.965, v=0.658
ckpt 5: t=0.976, v=0.650
ckpt 6: t=0.981, v=0.672
ckpt 7: t=0.985, v=0.627
ckpt 8: t=0.987, v=0.641
ckpt 9: t=0.988, v=0.704
ckpt 10: t=0.989, v=0.611
ckpt 11: t=0.990, v=0.671
ckpt 12: t=0.991, v=0.668
ckpt 13: t=0.998, v=0.663
ckpt 14: t=0.999, v=0.683
ckpt 15: t=0.999, v=0.679
ckpt 16: t=0.999, v=0.669
ckpt 17: t=0.999, v=0.667
ckpt 18: t=0.999, v=0.666
ckpt 19: t=0.999, v=0.698
ckpt 20: t=0.999, v=0.641


(20, 2)

In [5]:
plt.figure()
plt.plot(np.arange(20)+1, results[:,0], label='train')
plt.plot(np.arange(20)+1, results[:,1], label='val')
plt.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f694cc7fb70>

In [13]:
def parameter_count(module):
    s = sum([np.prod(d.size()) for d in module.parameters()])
    return s

In [14]:
parameter_count(model)

1358351

In [6]:
#model.parameter_count()

In [29]:
#traindata, valdata, testdata = ckpt['train_data'], ckpt['val_data'], ckpt['test_data']
#DS = traindata, valdata, testdata#NS(train=traindata, val=valdata, test=testdata)
#DS_names = ['train', 'val', 'test']


(dict_keys(['gids', 'mids', 'tracks']),
 dict_keys(['paths', 'lens', 'genre-loss-weights', 'genre-lbls', 'mood-lbls', 'mood-loss-weights']),
 ['jazz-blues',
  'rock',
  'country',
  'dance-elec',
  'reggae',
  'pop',
  'alt-punk',
  'hiphop-rap',
  'holiday',
  'classical',
  'ambient',
  'rnb-soul',
  'cinematic',
  'childrens'])

In [66]:
mode = 'genre'

cats = gcats if mode == 'genre' else mcats

lbls = split['{}-lbls'.format(mode)]
lens = split['lens']
print([len(l) for l in lbls])
records = {}
for lbs, lns, paths in zip(lbls, lens, split['paths']):
    records.update({p:(lb,ln) for p,lb,ln in zip(paths,lbs,lns)})
len(records)

[1056, 126, 201]


1383

In [67]:
hists = np.vstack([np.histogram(lb, weights=ln, bins=lb.max()+1)[0] for name, ln, lb in zip(DS_names, lens, lbls)])
nhists = hists/hists.sum()

In [68]:
ax = ax1
plt.axes(ax)
#fig, ax = plt.subplots(1)
x = np.zeros(len(cats))
for name, hist in zip(DS_names, nhists):
    #hist, edges = np.histogram(lb, weights=ln, bins=lb.max()+1)
    #hist = hist / hist.sum()
    
    ax.bar(range(len(cats)), hist, bottom=x, align='center', label=name)
    
    x = x + hist
plt.xticks(range(len(cats)), cats, rotation=90)
plt.legend()
plt.title('Genre Data Distribution')
fig.tight_layout()

In [69]:
mode = 'mood'

cats = gcats if mode == 'genre' else mcats

lbls = split['{}-lbls'.format(mode)]
lens = split['lens']
print([len(l) for l in lbls])
records = {}
for lbs, lns, paths in zip(lbls, lens, split['paths']):
    records.update({p:(lb,ln) for p,lb,ln in zip(paths,lbs,lns)})
len(records)

[1056, 126, 201]


1383

In [70]:
hists = np.vstack([np.histogram(lb, weights=ln, bins=lb.max()+1)[0] for name, ln, lb in zip(DS_names, lens, lbls)])
nhists = hists/hists.sum()

In [71]:
ax = ax2
plt.axes(ax)
#fig, ax = plt.subplots(1)
x = np.zeros(len(cats))
for name, hist in zip(DS_names, nhists):
    #hist, edges = np.histogram(lb, weights=ln, bins=lb.max()+1)
    #hist = hist / hist.sum()
    
    ax.bar(range(len(cats)), hist, bottom=x, align='center', label=name)
    
    x = x + hist
plt.xticks(range(len(cats)), cats, rotation=90)
#plt.legend()
plt.title('Mood Data Distribution')
fig.tight_layout()

In [65]:
fig, (ax1, ax2) = plt.subplots(2, figsize=(6,5))

<IPython.core.display.Javascript object>

In [72]:
#fig.savefig('results/data.png')
#fig.savefig('results/data.pdf')

In [25]:
data = testdata
sample_num = 0

NameError: name 'testdata' is not defined

In [26]:
with hf.File(data.paths[sample_num], 'r') as f:
    print(f.attrs['name'])
    y = f.attrs[traindata.lbl_name]
    x = f['wav'].value#[k:(k + traindata.seq_len)]
x, y = torch.from_numpy(x).float().to(args.device).view(1,1,-1), torch.tensor(int(y)).long().to(args.device)
#print(x.size(), y.size())
sample_num += 1

with torch.no_grad():
    out = model(x).detach()
    
#print(out.size())
pred = F.softmax(out,-1).mean(1).sort(-1)[1][0, -3:]
i = 0
N = 40
step = out.size(1) // (N-1)
plt.figure(figsize=(4,8))
plt.imshow(out[0].detach().cpu()[::step][i*N:(i+1)*N])
plt.axis('off')
plt.title('{} vs {}'.format(y.cpu().numpy(), pred.cpu().numpy()[::-1]))
plt.tight_layout()
i += 1

AttributeError: module 'foundation.data' has no attribute 'paths'

In [148]:
success = []
failed_top = []
failed_3 = []
failed_5 = []

In [149]:
print_freq = max(1, len(data.paths)//100)
N = len(data.paths)
for itr, path in enumerate(data.paths):

    with hf.File(path, 'r') as f:
        #print(f.attrs['name'])
        y = f.attrs[traindata.lbl_name]
        x = f['wav'].value#[k:(k + traindata.seq_len)]
    x, y = torch.from_numpy(x).float().to(args.device).view(1,1,-1), y
    #print(x.size(), y.size())
    #sample_num += 1

    with torch.no_grad():
        out = model(x).detach()

        #print(out.size())
        pred = F.softmax(out,-1).mean(1).sort(-1)[1]

    top5 = set(pred[0, -5:].cpu().numpy())
    top3 = set(pred[0, -3:].cpu().numpy())
    top = pred[0,-1].item()
    
    if top == y:
        success.append(path)
    else:
        failed_top.append(path)
    if y not in top3:
        failed_3.append(path)
    if y not in top5:
        failed_5.append(path)
        
    if itr % print_freq == 0:
        print('{}/{}: {:.4f} {:.4f} {:.4f} {:.4f}'.format(itr+1, N, len(success)/(itr+1), len(failed_top)/(itr+1), 
                                                          len(failed_3)/(itr+1), len(failed_5)/(itr+1)))

results = len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N
results

1/201: 1.0000 0.0000 0.0000 0.0000
3/201: 1.0000 0.0000 0.0000 0.0000
5/201: 1.0000 0.0000 0.0000 0.0000
7/201: 0.7143 0.2857 0.2857 0.2857
9/201: 0.7778 0.2222 0.2222 0.2222
11/201: 0.7273 0.2727 0.1818 0.1818
13/201: 0.6923 0.3077 0.1538 0.1538
15/201: 0.6667 0.3333 0.2000 0.2000
17/201: 0.5882 0.4118 0.2353 0.2353
19/201: 0.5789 0.4211 0.2105 0.2105
21/201: 0.5238 0.4762 0.1905 0.1905
23/201: 0.5652 0.4348 0.1739 0.1739
25/201: 0.6000 0.4000 0.1600 0.1600
27/201: 0.6296 0.3704 0.1481 0.1481
29/201: 0.6207 0.3793 0.1379 0.1379
31/201: 0.6129 0.3871 0.1613 0.1613
33/201: 0.6364 0.3636 0.1515 0.1515
35/201: 0.6571 0.3429 0.1429 0.1429
37/201: 0.6757 0.3243 0.1351 0.1351
39/201: 0.6923 0.3077 0.1282 0.1282
41/201: 0.6829 0.3171 0.1220 0.1220
43/201: 0.6512 0.3488 0.1395 0.1395
45/201: 0.6444 0.3556 0.1556 0.1333
47/201: 0.6383 0.3617 0.1489 0.1277
49/201: 0.6531 0.3469 0.1429 0.1224
51/201: 0.6471 0.3529 0.1373 0.1176
53/201: 0.6415 0.3585 0.1321 0.1132
55/201: 0.6364 0.3636 0.1455 0.12

(0.7114427860696517,
 0.2885572139303483,
 0.1044776119402985,
 0.0845771144278607)

In [110]:
result_paths = success, failed_top, failed_3, failed_5
res_colors = ['#146b3a', '#f8b229', '#ea4630', '#bb2528']
res_names = ['Top', 'Top3', 'Top5', 'Failed']

In [143]:
G, F1, F3, F5 = map(set, result_paths)
F5 = F3 - F5
F3 = F1 - F3
F1 = F1 - F3.union(F5)
#G, F1, F3, F5 = map(list, [G, F1, F3, F5])
len(G), len(F1), len(F3), len(F5)

(87, 10, 18, 11)

In [144]:
hists = np.vstack([np.histogram([records[p][0] for p in st], weights=[records[p][1] for p in st], bins=14, range=(0,13))[0] for st in [G, F1, F3, F5]])
nhists = hists/hists.sum()

In [145]:
fig, ax = plt.subplots()
x = np.zeros(14)
for name, c, hist in zip(res_names, res_colors, nhists):
    #hist, edges = np.histogram(lb, weights=ln, bins=lb.max()+1)
    #hist = hist / hist.sum()
    
    ax.bar(range(14), hist, bottom=x, align='center', color=c, label=name)
    
    x = x + hist
plt.xticks(range(14), cats, rotation=90)
plt.legend()
plt.ylim(0, .37)
fig.tight_layout()

<IPython.core.display.Javascript object>

In [5]:
weights = model.conv[0].weight.clone().detach().cpu()
weights.size()

torch.Size([256, 1, 2205])

In [19]:
i = 0
N = 10

In [32]:
fig, axes = plt.subplots(N, figsize=(4,8))
for j, (ax, w) in enumerate(zip(axes.flat, weights[i:].squeeze())):
    ax.plot(w.numpy())
    ax.set_title(str(i+j))
    ax.axis('off')
#fig.tight_layout()
i += N

<IPython.core.display.Javascript object>