In [1]:
import sys, os
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import pydub
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
import torch.nn as nn
from datasets import *
from torch.utils.data import DataLoader
from loading import load_model
import foundation as fd
from foundation.util import NS
#import torchaudio
#from torchaudio import transforms
import h5py as hf

os.environ["CUDA_VISIBLE_DEVICES"]="1"

  from ._conv import register_converters as _register_converters


In [53]:
def get_wav_pred(model, path):
    with hf.File(path, 'r') as f:
        y = f.attrs[model.load_args.lbl_name]
        x = f['wav'].value#[k:(k + traindata.seq_len)]
    x, y = x, torch.tensor(int(y)).long().to(model.load_args.device)
    
    return get_pred(model, torch.from_numpy(x).view(1,1,-1)), y

def get_pred(model, x):
    
    B = x.size(0)
    x = x.float().to(model.load_args.device)#.view(B,1,-1)
    
    with torch.no_grad():
        out = model(x).detach()
        pred = F.softmax(out,-1)#.mean(1).sort(-1)[1][0, -3:]
        
    return pred

def test_naive(model, loaders):
    
    success = []
    failed_top = []
    failed_3 = []
    failed_5 = []
    
    N = len(loaders)
    
    print_freq = max(1, len(loaders)//5)
    for itr, loader in enumerate(loaders):
        
        assert len(loader) == 1
        sample = next(iter(loader))
        
        x, y = sample
        path = loader.dataset.waves.paths[0]
        
        #print(x.size(), y.size())
        
        pred = get_pred(model,x)
        
        #print(y, pred.mean(1))
        
        pred, y = pred.mean(0).sort(-1)[1].detach().cpu(), y[0].detach().cpu().item()
        
        #print(pred.size())
        
        top5 = set(pred[-5:].numpy())
        top3 = set(pred[-3:].numpy())
        top = pred[-1].item()

        if top == y:
            success.append(path)
        else:
            failed_top.append(path)
        if y not in top3:
            failed_3.append(path)
        if y not in top5:
            failed_5.append(path)

        if itr % print_freq == 0:
            print('{}/{}: {:.4f} {:.4f} {:.4f} {:.4f}'.format(itr+1, N, len(success)/(itr+1), len(failed_top)/(itr+1), 
                                                              len(failed_3)/(itr+1), len(failed_5)/(itr+1)))
            

    results = len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N
    
    print('* Results: (errors) {:.4f} {:.4f} {:.4f} {:.4f}'.format(len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N))
    
    return success, failed_top, failed_3, failed_5



def test_all(model, loader):
    
    success = []
    failed_top = []
    failed_3 = []
    failed_5 = []
    
    N = len(loader)
    
    print_freq = max(1, len(loader)//5)
    for itr, sample in enumerate(loader):
        
        x, y = sample['x'], sample['y']
        path = sample['path']
        
        #print(x.size(), y.size())
        
        pred = get_pred(model,x)
        
        #print(y, pred.mean(1))
        
        pred, y = pred.mean(1).sort(-1)[1].detach().cpu(), y.detach().cpu().item()
        
        #print(pred.size())
        
        top5 = set(pred[0, -5:].numpy())
        top3 = set(pred[0, -3:].numpy())
        top = pred[0,-1].item()

        if top == y:
            success.append(path)
        else:
            failed_top.append(path)
        if y not in top3:
            failed_3.append(path)
        if y not in top5:
            failed_5.append(path)

        if itr % print_freq == 0:
            print('{}/{}: {:.4f} {:.4f} {:.4f} {:.4f}'.format(itr+1, N, len(success)/(itr+1), len(failed_top)/(itr+1), 
                                                              len(failed_3)/(itr+1), len(failed_5)/(itr+1)))
            

    results = len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N
    
    print('* Results: (errors) {:.4f} {:.4f} {:.4f} {:.4f}'.format(len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N))
    
    return success, failed_top, failed_3, failed_5
    

In [103]:
root = 'trained_nets/' #crnn-9000-bn_18-12-05-032535/'
temp = 'checkpoint_{}.pth.tar'

In [66]:
meta = torch.load('/home/fleeb/workspace/ml_datasets/audio/yt/meta/full_meta.pth.tar')
split = torch.load('/home/fleeb/workspace/ml_datasets/audio/yt/split.pth.tar')
gid2cat = {v:k for k,v in meta['gids'].items()}
mid2cat = {v:k for k,v in meta['mids'].items()}
gcats = [gid2cat[i] for i in range(len(gid2cat))]
mcats = [mid2cat[i] for i in range(len(mid2cat))]
DS_names = ['train', 'val', 'test']
meta.keys(), split.keys(), gcats, mcats

(dict_keys(['gids', 'mids', 'tracks']),
 dict_keys(['paths', 'lens', 'genre-loss-weights', 'genre-lbls', 'mood-lbls', 'mood-loss-weights']),
 ['jazz-blues',
  'rock',
  'country',
  'dance-elec',
  'reggae',
  'pop',
  'alt-punk',
  'hiphop-rap',
  'holiday',
  'classical',
  'ambient',
  'rnb-soul',
  'cinematic',
  'childrens'],
 ['dark',
  'bright',
  'romantic',
  'angry',
  'sad',
  'happy',
  'calm',
  'funky',
  'dramatic',
  'inspirational'])

(1383, 1383)

In [104]:
model_paths = [p for p in os.listdir(root) if 'new' in p and '-crnn' in p]
names = [p[:p.find('_')] for p in model_paths]
model_paths = [os.path.join(root, p) for p in model_paths]
names

['new-crnn-6000-final',
 'new-mood-crnn-9000-final',
 'new-crnn-9000-final-bn',
 'new-mood-crnn-3000-final',
 'new-crnn-9000-final',
 'new-crnn-3000-final',
 'new-crnn-3000-bn',
 'new-mood-crnn-9000-bn']

In [40]:
test_paths = split['paths'][-1]
len(test_paths)

201

In [41]:
# model = final_models['new-rnn-3000-final']
# dataset = 
# len(dataset)

In [51]:
final_models = {}
final_results = {}

In [54]:
for name, mpath in zip(names, model_paths):
    
    if name not in final_models:
        print('Loaded {}'.format(name))
        final_models[name] = load_model(mpath)
    
    if name not in final_results:
        
        model = final_models[name]
        
        dataset = Full_Yt_Dataset(test_paths, lbl_name=model.load_args.lbl_name, )
        
        if '-rnn' in name:
            dataset = MEL_Dataset(dataset, 
                          hop=model.load_args.mel_hop, ws=model.load_args.mel_ws, n_mels=model.load_args.mel_n)
        
        if 'naive' in name:
            
            loader = [MEL_Dataset(Yt_Dataset([test_path], seq_len=model.load_args.seq_len, hop=model.load_args.hop, lbl_name=model.load_args.lbl_name,),
                                 hop=model.load_args.mel_hop, ws=model.load_args.mel_ws, n_mels=model.load_args.mel_n)
                      for test_path in test_paths]
            
            loader = [DataLoader(ds, batch_size=len(ds)) for ds in loader]
            
            
            test_fn = test_naive
            
        else:
            test_fn = test_all
        
            loader = DataLoader(dataset, batch_size=1)
        
        final_results[name] = test_fn(final_models[name], loader)
    else:
        print('Skipping {}'.format(name))
        

1/201: 0.0000 1.0000 0.0000 0.0000
41/201: 0.5854 0.4146 0.1707 0.0488
81/201: 0.5926 0.4074 0.1481 0.0617
121/201: 0.5620 0.4380 0.1901 0.0826
161/201: 0.5652 0.4348 0.2050 0.0994
201/201: 0.5721 0.4279 0.1990 0.0945
* Results: (errors) 0.5721 0.4279 0.1990 0.0945
Loaded new-naive-3000-final-bn
Loaded trained_nets/new-naive-3000-final-bn_18-12-05-215932/best.pth.tar
(1, 296, 150) (64, 4, 2)
Saved params loaded
1/201: 0.0000 1.0000 0.0000 0.0000
41/201: 0.3171 0.6829 0.3171 0.2195
81/201: 0.4198 0.5802 0.2840 0.2099
121/201: 0.4050 0.5950 0.3140 0.2231
161/201: 0.3727 0.6273 0.3540 0.2422
201/201: 0.3731 0.6269 0.3433 0.2388
* Results: (errors) 0.3731 0.6269 0.3433 0.2388
Loaded new-mood-naive-3000-final
Loaded trained_nets/new-mood-naive-3000-final_18-12-06-134718/best.pth.tar
(1, 296, 150) (64, 4, 2)
Saved params loaded
1/201: 1.0000 0.0000 0.0000 0.0000
41/201: 0.4390 0.5610 0.2927 0.0976
81/201: 0.4691 0.5309 0.2222 0.0494
121/201: 0.4711 0.5289 0.2231 0.0826
161/201: 0.4596 0.5404

In [55]:
#torch.save(final_results,'results/naive.pth.tar')

In [69]:
final_results = torch.load('results/crnn.pth.tar')

In [112]:
# path = os.path.join(root, temp.format(20))
# ckpt = torch.load(path)
# ckpt.keys()
temp = 'checkpoint_{}.pth.tar'
model_i = 0

In [113]:
for model_i, (mroot, name) in enumerate(zip(model_paths, names)):
    results = []
    print('loading {}'.format(mroot))
    for i in range(50):
        try:
            path = os.path.join(mroot, temp.format(i+1))
            ckpt = torch.load(path)
        except:
            continue
        t, v = ckpt['train_stats']['accuracy'].avg.item(), ckpt['val_stats']['accuracy'].avg.item()
        results.append((t,v))
        print('ckpt {}: t={:.3f}, v={:.3f}'.format(i+1, t,v))
    results = np.array(results)
    ckpt = torch.load(os.path.join(mroot, 'best.pth.tar'))
    print('** res:{}, best: epoch={}, t={:.3f}, v={:.3f}'.format(results[:, 1].argmax() + 1, ckpt['best_epoch'], ckpt['train_stats']['accuracy'].avg.item(), ckpt['val_stats']['accuracy'].avg.item()))

loading trained_nets/new-crnn-6000-final_18-12-06-005836
ckpt 1: t=0.451, v=0.357
ckpt 2: t=0.751, v=0.423
ckpt 3: t=0.867, v=0.467
ckpt 4: t=0.918, v=0.448
ckpt 5: t=0.941, v=0.477
ckpt 6: t=0.953, v=0.508
ckpt 7: t=0.963, v=0.483
ckpt 8: t=0.968, v=0.531
ckpt 9: t=0.971, v=0.483
ckpt 10: t=0.975, v=0.502
ckpt 11: t=0.976, v=0.500
ckpt 12: t=0.977, v=0.484
ckpt 13: t=0.995, v=0.538
ckpt 14: t=0.997, v=0.543
ckpt 15: t=0.997, v=0.552
ckpt 16: t=0.997, v=0.535
ckpt 17: t=0.997, v=0.536
ckpt 18: t=0.997, v=0.543
ckpt 19: t=0.997, v=0.544
ckpt 20: t=0.998, v=0.531
ckpt 21: t=0.997, v=0.520
ckpt 22: t=0.997, v=0.526
ckpt 23: t=0.997, v=0.524
ckpt 24: t=0.998, v=0.529
ckpt 25: t=0.999, v=0.532
** res:15, best: epoch=15, t=0.997, v=0.552
loading trained_nets/new-mood-crnn-9000-final_18-12-05-193723
ckpt 1: t=0.431, v=0.276
ckpt 2: t=0.840, v=0.339
ckpt 3: t=0.952, v=0.335
ckpt 4: t=0.977, v=0.359
ckpt 5: t=0.984, v=0.373
ckpt 6: t=0.988, v=0.360
ckpt 7: t=0.989, v=0.368
ckpt 8: t=0.992, v=0.

In [114]:
plt.figure()
plt.plot(np.arange(20)+1, results[:,0], label='train')
plt.plot(np.arange(20)+1, results[:,1], label='val')
plt.legend()

<IPython.core.display.Javascript object>

ValueError: x and y must have same first dimension, but have shapes (20,) and (25,)

In [4]:
model = load_model(path)
args = model.load_args
print(model)

Loaded trained_nets/crnn-9000-bn_18-12-05-032535/checkpoint_20.pth.tar
Saved params loaded
Conv_RNN(
  (conv): Sequential(
    (0): Conv1d(1, 256, kernel_size=(2205,), stride=(2205,))
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (rec): RecNet(
    (rec): GRU(256, 256, num_layers=2, batch_first=True)
    (out_layer): Linear(in_features=256, out_features=14, bias=True)
  )
)


In [115]:
def parameter_count(module):
    s = sum([np.prod(d.size()) for d in module.parameters()])
    return s

In [116]:
for name, path in zip(names, model_paths):
    model = load_model(path)
    num = parameter_count(model)
    print('{} : {}'.format(name, num))

Loaded trained_nets/new-crnn-6000-final_18-12-06-005836/best.pth.tar
Saved params loaded
new-crnn-6000-final : 1357839
Loaded trained_nets/new-mood-crnn-9000-final_18-12-05-193723/best.pth.tar
Saved params loaded
new-mood-crnn-9000-final : 1356811
Loaded trained_nets/new-crnn-9000-final-bn_18-12-05-122649/best.pth.tar
Saved params loaded
new-crnn-9000-final-bn : 1358351
Loaded trained_nets/new-mood-crnn-3000-final_18-12-05-193723/best.pth.tar
Saved params loaded
new-mood-crnn-3000-final : 1356811
Loaded trained_nets/new-crnn-9000-final_18-12-06-135023/best.pth.tar
Saved params loaded
new-crnn-9000-final : 1357839
Loaded trained_nets/new-crnn-3000-final_18-12-05-124145/best.pth.tar
Saved params loaded
new-crnn-3000-final : 1357839
Loaded trained_nets/new-crnn-3000-bn_18-12-05-125153/best.pth.tar
Saved params loaded
new-crnn-3000-bn : 1358351
Loaded trained_nets/new-mood-crnn-9000-bn_18-12-06-010006/best.pth.tar
Saved params loaded
new-mood-crnn-9000-bn : 1357323


In [6]:
#model.parameter_count()

In [25]:
data = testdata
sample_num = 0

NameError: name 'testdata' is not defined

In [26]:
with hf.File(data.paths[sample_num], 'r') as f:
    print(f.attrs['name'])
    y = f.attrs[traindata.lbl_name]
    x = f['wav'].value#[k:(k + traindata.seq_len)]
x, y = torch.from_numpy(x).float().to(args.device).view(1,1,-1), torch.tensor(int(y)).long().to(args.device)
#print(x.size(), y.size())
sample_num += 1

with torch.no_grad():
    out = model(x).detach()
    
#print(out.size())
pred = F.softmax(out,-1).mean(1).sort(-1)[1][0, -3:]
i = 0
N = 40
step = out.size(1) // (N-1)
plt.figure(figsize=(4,8))
plt.imshow(out[0].detach().cpu()[::step][i*N:(i+1)*N])
plt.axis('off')
plt.title('{} vs {}'.format(y.cpu().numpy(), pred.cpu().numpy()[::-1]))
plt.tight_layout()
i += 1

AttributeError: module 'foundation.data' has no attribute 'paths'

In [148]:
success = []
failed_top = []
failed_3 = []
failed_5 = []

In [149]:
print_freq = max(1, len(data.paths)//100)
N = len(data.paths)
for itr, path in enumerate(data.paths):

    with hf.File(path, 'r') as f:
        #print(f.attrs['name'])
        y = f.attrs[traindata.lbl_name]
        x = f['wav'].value#[k:(k + traindata.seq_len)]
    x, y = torch.from_numpy(x).float().to(args.device).view(1,1,-1), y
    #print(x.size(), y.size())
    #sample_num += 1

    with torch.no_grad():
        out = model(x).detach()

        #print(out.size())
        pred = F.softmax(out,-1).mean(1).sort(-1)[1]

    top5 = set(pred[0, -5:].cpu().numpy())
    top3 = set(pred[0, -3:].cpu().numpy())
    top = pred[0,-1].item()
    
    if top == y:
        success.append(path)
    else:
        failed_top.append(path)
    if y not in top3:
        failed_3.append(path)
    if y not in top5:
        failed_5.append(path)
        
    if itr % print_freq == 0:
        print('{}/{}: {:.4f} {:.4f} {:.4f} {:.4f}'.format(itr+1, N, len(success)/(itr+1), len(failed_top)/(itr+1), 
                                                          len(failed_3)/(itr+1), len(failed_5)/(itr+1)))

results = len(success)/N, len(failed_top)/N, len(failed_3)/N, len(failed_5)/N
results

1/201: 1.0000 0.0000 0.0000 0.0000
3/201: 1.0000 0.0000 0.0000 0.0000
5/201: 1.0000 0.0000 0.0000 0.0000
7/201: 0.7143 0.2857 0.2857 0.2857
9/201: 0.7778 0.2222 0.2222 0.2222
11/201: 0.7273 0.2727 0.1818 0.1818
13/201: 0.6923 0.3077 0.1538 0.1538
15/201: 0.6667 0.3333 0.2000 0.2000
17/201: 0.5882 0.4118 0.2353 0.2353
19/201: 0.5789 0.4211 0.2105 0.2105
21/201: 0.5238 0.4762 0.1905 0.1905
23/201: 0.5652 0.4348 0.1739 0.1739
25/201: 0.6000 0.4000 0.1600 0.1600
27/201: 0.6296 0.3704 0.1481 0.1481
29/201: 0.6207 0.3793 0.1379 0.1379
31/201: 0.6129 0.3871 0.1613 0.1613
33/201: 0.6364 0.3636 0.1515 0.1515
35/201: 0.6571 0.3429 0.1429 0.1429
37/201: 0.6757 0.3243 0.1351 0.1351
39/201: 0.6923 0.3077 0.1282 0.1282
41/201: 0.6829 0.3171 0.1220 0.1220
43/201: 0.6512 0.3488 0.1395 0.1395
45/201: 0.6444 0.3556 0.1556 0.1333
47/201: 0.6383 0.3617 0.1489 0.1277
49/201: 0.6531 0.3469 0.1429 0.1224
51/201: 0.6471 0.3529 0.1373 0.1176
53/201: 0.6415 0.3585 0.1321 0.1132
55/201: 0.6364 0.3636 0.1455 0.12

(0.7114427860696517,
 0.2885572139303483,
 0.1044776119402985,
 0.0845771144278607)

In [110]:
result_paths = success, failed_top, failed_3, failed_5
res_colors = ['#146b3a', '#f8b229', '#ea4630', '#bb2528']
res_names = ['Top', 'Top3', 'Top5', 'Failed']

In [126]:
N = 201.

In [136]:
tys = ['crnn', 'rnn', 'naive']
tnames = ['Conv-RNN', 'MEL-RNN', 'MEL-Conv']
for tname, t in zip(tnames, tys):
    
    #print(t)
    results = torch.load('results/{}.pth.tar'.format(t))
    for name, (_, F1, F3, F5) in results.items():
        
        if 'mood' not in name:
            continue
            
        bn = 'bn' in name
        l = 3 if '3' in name else (6 if '6' in name else 9)
        
        bn = '\\cmark' if bn else '\\xmark'
        
        print('{} & {} & {} & {:2.3} & {:2.3} & {:2.3} \\\\'.format(tname, l, bn,  (1-len(F1)/N)*100, (1-len(F3)/N)*100, (1-len(F5)/N)*100, ))
        
        #print('{:<25} : {:2.3}   {:2.3}   {:2.3}'.format(name, (1-len(F1)/N)*100, (1-len(F3)/N)*100, (1-len(F5)/N)*100, ))

Conv-RNN & 9 & \xmark & 34.3 & 62.2 & 83.1 \\
Conv-RNN & 3 & \xmark & 35.8 & 71.6 & 85.6 \\
Conv-RNN & 9 & \cmark & 31.8 & 59.2 & 77.1 \\
MEL-RNN & 9 & \xmark & 33.8 & 66.7 & 86.1 \\
MEL-RNN & 3 & \xmark & 38.3 & 67.2 & 86.6 \\
MEL-Conv & 3 & \xmark & 42.8 & 72.1 & 90.0 \\


In [202]:
results = torch.load('results/naive.pth.tar'.format(t))
results.keys()

dict_keys(['new-naive-3000-final', 'new-naive-3000-final-bn', 'new-mood-naive-3000-final'])

In [203]:
grecords = {}
mrecords = {}
for info in zip(split['paths'], split['lens'], split['mood-lbls'], split['genre-lbls']):
    for p, ln, mlb, glb in zip(*info):
        grecords[p] = (glb, ln)
        mrecords[p] = (mlb, ln)
len(grecords), len(mrecords)

(1383, 1383)

In [222]:
np.hstack(split['lens']).sum() / 44100 / 3600 # hrs

64.1959588750315

In [213]:
#results['new-crnn-6000-final']
res_colors = ['#146b3a', '#f8b229', '#ea4630', '#bb2528']
res_names = ['Top', 'Top3', 'Top5', 'Failed']


name = 'new-crnn-9000-final'
name = 'new-mood-crnn-3000-final'
name = 'new-rnn-9000-final'
name = 'new-rnn-9000-final'
name = 'new-mood-rnn-3000'
name = 'new-naive-3000-final'
#name = 'new-mood-naive-3000-final'

result_paths = results[name]
cats = mcats if 'mood' in name else gcats
records = mrecords if 'mood' in name else grecords

In [214]:
#result_paths = [ [r[0] for r in rps] for rps in result_paths]

In [215]:
G, F1, F3, F5 = map(set, result_paths)
F5 = F3 - F5
F3 = F1 - F3
F1 = F1 - F3.union(F5)
#G, F1, F3, F5 = map(list, [G, F1, F3, F5])
len(G), len(F1), len(F3), len(F5)

(115, 19, 46, 21)

In [216]:
hists = np.vstack([np.histogram([records[p][0] for p in st], weights=[records[p][1] for p in st], bins=len(cats), range=(0,len(cats)-1))[0] for st in [G, F1, F3, F5]])
nhists = hists/hists.sum()

In [217]:
fig, ax = plt.subplots(figsize=(6,4))
x = np.zeros(len(cats))
for name, c, hist in zip(res_names, res_colors, nhists):
    #hist, edges = np.histogram(lb, weights=ln, bins=lb.max()+1)
    #hist = hist / hist.sum()
    
    ax.bar(range(len(cats)), hist, bottom=x, align='center', color=c, label=name)
    
    x = x + hist
plt.xticks(range(len(cats)), cats, rotation=90)
plt.legend()
#plt.title('')
plt.ylim(0, .37)
fig.tight_layout()

<IPython.core.display.Javascript object>

In [218]:
#fig.savefig('results/errs-naive-3000.png')
#fig.savefig('results/errs-naive-3000.pdf')

In [5]:
weights = model.conv[0].weight.clone().detach().cpu()
weights.size()

torch.Size([256, 1, 2205])

In [19]:
i = 0
N = 10

In [32]:
fig, axes = plt.subplots(N, figsize=(4,8))
for j, (ax, w) in enumerate(zip(axes.flat, weights[i:].squeeze())):
    ax.plot(w.numpy())
    ax.set_title(str(i+j))
    ax.axis('off')
#fig.tight_layout()
i += N

<IPython.core.display.Javascript object>