In [14]:
import string
import numpy as np

In [18]:
vocab = list(string.ascii_lowercase)

In [221]:
def make_lyrics(vocab, low, high, num_songs, genre, prob):
    return [np.random.choice(vocab, 
                             np.random.randint(low, high), 
                             replace = True,
                             p = prob) for _ in range(num_songs)]

In [222]:
popprob = np.array(list(range(1, 27))) / sum(list(range(1, 27)))
rapprob = popprob[::-1]
pop = make_lyrics(vocab, 100, 400, 100, "pop", popprob)
rap = make_lyrics(vocab, 100, 400, 100, "rap", rapprob)

In [223]:
def hellinger_distance(dist1, dist2):
    
    num = 0
    
    for element in dist1.keys():
        
        num += (np.sqrt(dist1[element]) - np.sqrt(dist2[element])) ** 2
        
    num = (1 / np.sqrt(2)) * np.sqrt(num)
    
    return 1 - num

In [195]:
def kl_divergence(dist1, dist2):
    
    num = 0
    
    for element in dist1.keys():
        if (dist2[element] == 0) and (dist1[element] == 0):
            continue
        num -= dist1[element] * np.log(dist2[element] / dist1[element])
        
    return num

In [196]:
def get_p_genre(x):
    return {genre:len(x[genre]) for genre in x}

In [256]:
def get_word_distribution(d, type_ = "train"):
    """
    if train: return word distribution
    if test: return each testing example's word distribution.
    """
    
    from collections import Counter
    from itertools import chain
    
    if type_ == "train":
        ans = Counter(list(chain.from_iterable(d)))
        norm = sum(ans.values())
        for element in ans:
            ans[element] /= norm

    elif type_ == "test":
        
        ans = []
        for song in d:
            ans1 = Counter(list(chain.from_iterable(song)))
            norm = sum(ans1.values())
            for element in ans1:
                ans1[element] /= norm 
            ans.append(ans1)
            
            
    return ans

In [258]:
popdist = get_word_distribution(pop)
rapdist = get_word_distribution(rap)

In [264]:
dists = {"pop": popdist, "rap": rapdist}

### Make testing data

In [259]:
poptest = make_lyrics(vocab, 100, 400, 100, "pop", popprob)
raptest = make_lyrics(vocab, 100, 400, 100, "rap", rapprob)

In [262]:
poptest2 = get_word_distribution(poptest, "test")
raptest2 = get_word_distribution(raptest, "test")

In [284]:
def classify(data, dists, popprop, rapprop):
    p_genre = {"pop": popprop, "rap": rapprop}
    results = []
    for song in data:
        distance = {}
        for dist in dists:
            distance.update({dist: hellinger_distance(song, dists[dist]) * p_genre[dist]})
        #print(distance)
        results.append(max(distance.items(), key = lambda x: x[1]))
    return results

In [285]:
classify(raptest2, dists, .5, .5)

[('rap', 0.4351242125069986),
 ('rap', 0.4490287380253664),
 ('rap', 0.45484155288819605),
 ('rap', 0.44066608941916113),
 ('rap', 0.46113144976683973),
 ('rap', 0.45405666600937045),
 ('rap', 0.4482152290004212),
 ('rap', 0.45424254340173115),
 ('rap', 0.4491323467620527),
 ('rap', 0.4438252259269708),
 ('rap', 0.4485650780617269),
 ('rap', 0.4432961710394042),
 ('rap', 0.45487781934769167),
 ('rap', 0.41828193434798877),
 ('rap', 0.43582176716342536),
 ('rap', 0.42950211493431384),
 ('rap', 0.43304547993131165),
 ('rap', 0.4253836283675927),
 ('rap', 0.455586482420802),
 ('rap', 0.43830533988510134),
 ('rap', 0.44723808482438765),
 ('rap', 0.46350439795806353),
 ('rap', 0.42905616612187286),
 ('rap', 0.4571593469722843),
 ('rap', 0.44517502757134897),
 ('rap', 0.4515025899614402),
 ('rap', 0.45650833333164986),
 ('rap', 0.44632357252118116),
 ('rap', 0.43881655592728275),
 ('rap', 0.45449840945118214),
 ('rap', 0.42609449528347665),
 ('rap', 0.46243180877508017),
 ('rap', 0.440780339

In [286]:
classify(poptest2, dists, .5, .5)

[('pop', 0.4323473959990164),
 ('pop', 0.4439754774436474),
 ('pop', 0.4503772662437322),
 ('pop', 0.4499180225940682),
 ('pop', 0.4515195707936297),
 ('pop', 0.45382328546653794),
 ('pop', 0.4352443617056324),
 ('pop', 0.42940367460459494),
 ('pop', 0.44667828104420604),
 ('pop', 0.44031758593713),
 ('pop', 0.42902918351977426),
 ('pop', 0.4236125422207174),
 ('pop', 0.45400431326752083),
 ('pop', 0.44698513998651834),
 ('pop', 0.45177615802295606),
 ('pop', 0.45019886913320833),
 ('pop', 0.4585059085127261),
 ('pop', 0.42175285463710765),
 ('pop', 0.42554366221083406),
 ('pop', 0.4483736671321149),
 ('pop', 0.43158344920400726),
 ('pop', 0.44175531735793905),
 ('pop', 0.46346698473083436),
 ('pop', 0.46277985180358727),
 ('pop', 0.4327407543061326),
 ('pop', 0.45088290992817504),
 ('pop', 0.4521496419581152),
 ('pop', 0.45543424742493926),
 ('pop', 0.45682822410292706),
 ('pop', 0.45240417791766613),
 ('pop', 0.44691582373120076),
 ('pop', 0.42070970199182667),
 ('pop', 0.45926552994