# Stratified 3-fold artist-filtered cross-validation split for the GTZAN Dataset

In [2]:
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
import numpy as np

In [5]:
#USAGE

#Just point to your GTZAN meta file (song name followed by a \t and then one of the labels)
meta = 'gtzan_labels.txt'
#This points to the folder where the folds will be written to.
out_dir = './gtzan_folds_filter_3f_strat/'

def parse_filelist(filename):
    with open(filename, 'r') as f:
        c = f.readlines()
    
    c = [s.strip().split('\t') for s in c]
    
    files = np.array([ s[0] for s in c])
    classes = np.array([ s[1] for s in c])

    return files, classes

In [3]:
###BLUES
blues_groups = np.arange(100)
groups = [range(0,12), range(12,29), range(29, 40), range(40,50), range(50,61), range(61,73),
             range(73,85), range(85,98), range(98,100)]
for g in groups:
    blues_groups[g]=g[0]
    
###CLASSICAL
classical_groups = np.arange(100)
groups = [ range(0,11), range(11,30), range(30,34), range(34,37), range(38,42), range(63,68), range(68,76),
             range(82,89), range(89,100)]
for g in groups:
    classical_groups[g]=g[0]
    
###COUNTRY
country_groups = np.arange(100)
groups = [np.concatenate(([19,26], range(65,81))), range(50,65), range(81,94), range(94,100)]
for g in groups:
    country_groups[g]=g[0]

###DISCO
disco_groups = np.arange(100)
groups = [[1,21,38,78], [17,24,44,45], [22,28,30,35,37], [49,50,51,69,70,71,73,74], [67,68,72]]
for g in groups:
    disco_groups[g]=g[0]

###HIPHOP
hiphop_groups = np.arange(100)
groups = [[1,7,41,42], range(8,26), np.concatenate((range(46,52),[62,76])), range(55,62), range(81,99)]
for g in groups:
    hiphop_groups[g]=g[0]

###JAZZ
jazz_groups = np.arange(100)
groups = [ range(2,11), range(11, 25), range(25, 33), np.concatenate((range(33, 47), [51, 53, 55, 57, 58, 60, 62, 65], range(67,73)))]
for g in groups:
    jazz_groups[g]=g[0]

###METAL
metal_groups = np.arange(100)
groups = [ range(12,16), np.concatenate((range(40,46), range(61,67))), range(46,58), [58, 59, 60], [33, 38, 73, 75, 78, 83, 87],
             np.concatenate(([2,5,34], range(92,95))), range(95,100)]
for g in groups:
    metal_groups[g]=g[0]

###POP
pop_groups = np.arange(100)
groups = [ np.concatenate(([0], range(87,97))), np.concatenate(([2], range(97,100))), range(3,10), [11, 39, 40],
             range(15,39), np.concatenate((range(44,52), [80])), range(52,63), range(67, 74), np.concatenate((range(74,79), [82])),
             range(84,87)]
for g in groups:
    pop_groups[g]=g[0]

###REGGAE
reggae_groups = np.arange(100)
groups = [ np.concatenate((range(0,28), range(56,61))), np.concatenate((range(46,49), range(64,69),[71])),
             np.concatenate(([85], range(94,100))), [33,42,44,50,63], np.concatenate(([70], range(76,79)))]
for g in groups:
    reggae_groups[g]=g[0]

###ROCK
rock_groups = np.arange(100)
groups = [range(0,10), range(10, 16), range(16,27), np.concatenate((range(28,32), [33, 35, 37])),
             np.concatenate(([32], range(39,49))), range(49,57), range(57,64), range(64,71), range(71,79),
             range(79,86), range(91,100)]
for g in groups:
    rock_groups[g]=g[0]


In [None]:
#REMOVE REPEATED SONGS AND DISTORTIONS

#Sturm(2013) states that removing "repeated recording" songs, whatever that may be, does
#not affect the result much, since the artist filter takes care of most problems. Thus,
#we only removed repeated songs and the two most distorted ones (pop37 and reggae86)

X, Y = parse_filelist(meta)

X = X.reshape((10, -1))
Y = Y.reshape((10, -1))

print X.shape

X_filt = [X[0], X[1], X[2]]
Y_filt = [Y[0], Y[1], Y[2]]

###DISCO
idxs = [51, 70, 60, 89, 74, 99]
m = np.ones(len(disco_groups), dtype=bool)
m[idxs]=False
disco_groups = disco_groups[m]
X_filt.append(X[3,m])
Y_filt.append(Y[3,m])

###HIPHOP
idxs = [45, 78]
m = np.ones(len(hiphop_groups), dtype=bool)
m[idxs]=False
hiphop_groups = hiphop_groups[m]
X_filt.append(X[4,m])
Y_filt.append(Y[4,m])

###JAZZ
idxs = [51, 53, 55,58,60,62,65,67,68,69,70,71,72]
m = np.ones(len(jazz_groups), dtype=bool)
m[idxs]=False
jazz_groups = jazz_groups[m]
X_filt.append(X[5,m])
Y_filt.append(Y[5,m])

###METAL (tirei a 58 do metal e deixei no rock)
idxs = [13,94,61,62,63,64,65,66,58]
m = np.ones(len(metal_groups), dtype=bool)
m[idxs]=False
metal_groups = metal_groups[m]
X_filt.append(X[6,m])
Y_filt.append(Y[6,m])

###POP
idxs = [22,31,46,80,57,60,59,71,90,37]
m = np.ones(len(pop_groups), dtype=bool)
m[idxs]=False
pop_groups = pop_groups[m]
X_filt.append(X[7,m])
Y_filt.append(Y[7,m])

###REGGAE
idxs=[65,56,57,60,58,69,74,81,82,91,92,86]
m = np.ones(len(reggae_groups), dtype=bool)
m[idxs]=False
reggae_groups = reggae_groups[m]
X_filt.append(X[8,m])
Y_filt.append(Y[8,m])

###ROCK
X_filt.append(X[9])
Y_filt.append(Y[9])

all_groups = [ blues_groups, classical_groups, country_groups, disco_groups, hiphop_groups, jazz_groups,
                 metal_groups, pop_groups, reggae_groups, rock_groups]



In [7]:
classes = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
for i in range(len(X_filt)):
    print classes[i], len(X_filt[i])

blues 100
classical 100
country 100
disco 94
hiphop 98
jazz 87
metal 91
pop 90
reggae 88
rock 100


In [None]:
X = np.concatenate(X_filt)
Y = np.concatenate(Y_filt)

print len(X), len(Y)

print len(all_groups)

all_ids = np.arange(1000)

for g in all_groups:
    all_ids[g] = 0

groups = []
    
for g in all_groups:
    #print g
    for i in np.unique(g):
        #print i
        n = np.nonzero(all_ids)[0][0]
        all_ids[n] = 0
        #print i,n
        g[g==i] = n
        #print g
    #print g
    groups.extend(g)

groups = np.array(groups)-100
    
#print groups
    
#print all_ids

    

In [9]:
import collections
def flatten(x):
    if isinstance(x, collections.Iterable):
        return [a for i in x for a in flatten(i)]
    else:
        return [x]
            
fold1_groups = [1, 7, 8, #blues
                9, 10, 11, #classical
                range(46,65), range(66,81), #country
                98,99,range(100,115), range(116,123), #disco
                range(176,189), #hiphop
                227,229, #jazz
                range(269,293), #metal
                range(327,336), 350, #pop
                352, #reggae
                range(392,396), range(406,411)#rock
               ]
fold1_groups = np.array(flatten(fold1_groups))
fold1_X = np.concatenate([ X[np.where(groups==i)] for i in fold1_groups ])
fold1_Y = np.concatenate([ Y[np.where(groups==i)] for i in fold1_groups ])
fold1_G = np.array(flatten([[i] * len(np.where(groups==i)[0]) for i in fold1_groups]))
print 'fold 1 len x, y = ', len(fold1_X), len(fold1_Y), len(fold1_G)

fold2_groups = [0,6,3, #blues
                range(12,37), #classical
                65, range(81,95), #country
                range(123,151), #disco
                range(189,209), #hiphop
                range(224,227), 228, range(230,240), #jazz
                range(293,304), range(305,307), 326, #metal
                range(336,340), 347, #pop
                range(353,371), #reggae
                397, 402, 403, 401#rock
               ]

fold2_groups = np.array(flatten(fold2_groups))
fold2_X = np.concatenate([ X[np.where(groups==i)] for i in fold2_groups ])
fold2_Y = np.concatenate([ Y[np.where(groups==i)] for i in fold2_groups ])
fold2_G = np.array(flatten([[i] * len(np.where(groups==i)[0]) for i in fold2_groups]))
print 'fold 2 len x, y = ', len(fold2_X), len(fold2_Y), len(fold2_G)

fold3_groups = [2,4,5, #blues
                range(37,46), #classical
                95,96,97, #country
                115, range(151,176), #disco
                range(209,224), #hiphop
                range(240,269), #jazz
                range(307,326), 304, #metal
                range(340,347),348,349,351, #pop
                range(371,392), #reggae
                396, 398, 399, 400, 404, 405, 411#rock
               ]

fold3_groups = np.array(flatten(fold3_groups))
fold3_X = np.concatenate([ X[np.where(groups==i)] for i in fold3_groups ])
fold3_Y = np.concatenate([ Y[np.where(groups==i)] for i in fold3_groups ])
fold3_G = np.array(flatten([[i] * len(np.where(groups==i)[0]) for i in fold3_groups]))

print 'fold 3 len x, y = ', len(fold3_X), len(fold3_Y), len(fold3_G)

all_folds =[(fold1_X, fold1_Y, fold1_G), (fold2_X, fold2_Y, fold2_G), (fold3_X,fold3_Y, fold3_G)]

fn = 1
for i in range(len(all_folds)):
    trainX = []
    trainY = []
    trainG = []
    for j in range(len(all_folds)):
        if i == j:
            continue
        fX, fY, fG = all_folds[j]
        trainX = np.concatenate((trainX, fX))
        trainY = np.concatenate((trainY, fY))
        trainG = np.concatenate((trainG, fG)).astype(int)
    print(len(trainX))

    o_tr = [ '%s\t%s\n' % (trainX[k], trainY[k]) for k in range(len(trainX))]
    o_ts = [ '%s\n' % (k) for k in all_folds[i][0]]
    o_ev = [ '%s\t%s\n' % (all_folds[i][0][k], all_folds[i][1][k]) for k in range(len(all_folds[i][0]))]    
    o_grp_tr = [ '%s\t%d\n' % (trainX[k], trainG[k]) for k in range(len(trainX))]
    o_grp_ev = ['%s\t%d\n' % (all_folds[i][0][k], all_folds[i][2][k]) for k in range(len(all_folds[i][0]))]
    
    assert set([trainG[k] for k in range(len(trainG))]) != set([all_folds[i][2][k] for k in range(len(all_folds[i][2]))])
    
    with open(out_dir + 'f%d_train.txt' % fn, 'w') as f:
        f.writelines(o_tr)
    
    with open(out_dir + 'f%d_test.txt' % fn, 'w') as f:
        f.writelines(o_ts)    
    
    with open(out_dir + 'f%d_evaluate.txt' % fn, 'w') as f:
        f.writelines(o_ev)   
    
    with open(out_dir + 'f%d_train_groups.txt' %fn, 'w') as f:
        f.writelines(o_grp_tr)
    
    with open(out_dir + 'f%d_evaluate_groups.txt' %fn, 'w') as f:
        f.writelines(o_grp_ev)
        
    fn+=1
    


fold 1 len x, y =  315 315 315
fold 2 len x, y =  316 316 316
fold 3 len x, y =  317 317 317
633
632
631


## stratified_group_split

In [None]:
import itertools
from itertools import chain, combinations
import numpy as np

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def stratified_group_split(Y, groups, test_size=0.2):
    all_classes = sorted(list(set(Y)))
    
    assert len(Y) == len(groups)
    
    #get the number of samples of each class
    fn = [len(np.where(Y==l)[0]) for l in all_classes]
    
    #compute the average number of samples
    avg_fn = np.floor(np.mean(fn)).astype(int)
    
    #keep approximately this number of samples from each class
    sn = np.floor(avg_fn * test_size).astype(int)
    
    #This will store the indexes of the test samples
    ts_idx = []
    
    for i in all_classes:
        #get all the unique groups in this class
        g = np.unique(groups[np.where(Y==i)])
        #commpute the histograms for each group
        hg = np.unique(groups[np.where(Y==i)], return_counts=True)[1]
        #sort the histograms in reverse order. Keep only 16 groups so it won't take forever.
        #do argsort to allow sorting both g and hg
        hg_s = np.fliplr([np.argsort(hg)])[0][:16]
        g = g[hg_s]
        hg = hg[hg_s]
        
        #make sure that there are at least two groups per class.
        assert len(g) >= 2
        
        #computer the powerset of the histogram. This may take a while. For up to 16 groups this is ok.
        p = list(powerset(hg[1:]))
        
        #compute the sums of the powerset. This is a naive implementation of the
        #Sums NP-Complete problem. I'm not too worried about it because I know the use
        #cases are quite small in practice. Perhaps I should think of a 
        #dynamic programming solution for this, or use some heuristics.
        s = np.array(map(sum,p))
        
        #Find the sum that's closest to the number of samples I want from each class
        s = np.argmin(np.abs(s-sn))
        
        #Get the set that's equivalent to the closest sum
        c = p[s]
        grs = []
        
        #build the array of the groups equivalent
        #Remember that c is a view of the HISTOGRAM! We need to
        #map this back into the original groups!
        for k in c:
            j = np.where(hg==k)[0][0]
            grs.append(g[j])
            #this quick and dirty hack handles groups with repeated number of elements.
            hg=np.delete(hg,j)
            g=np.delete(g,j)
            
        #print i, c, sum(c), grs
        
        #Compute the posititions in the input groups that belong to the selected groups
        for k in grs:
            ts_idx.extend(np.where(groups==k)[0])
    
    #The train samples are the complement of the test samples with respect to all samples
    tr_idx = set(range(len(Y))) - set(ts_idx)
    
    #print len(ts_idx), np.unique(Y[ts_idx], return_counts=True)
    #print len(Y), len(tr_idx)
    
    return sorted(list(tr_idx)), sorted(ts_idx)
    

for i in range(1,3):
    
    with open('./gtzan_folds_filter_3f_strat/f%d_train_groups.txt' % i) as f:
        c = f.readlines()

    y = np.array([ s.split('\t')[0].split('/')[4].split('_')[0] for s in c])

    g = np.array([ s.split('\t')[1].strip() for s in c]).astype(int)

    tr, ts = stratified_group_shuffle_split(y,g,0.2)
    
    print np.unique(y[ts], return_counts=True)
