In [1]:
import numpy as np
from collections import defaultdict, Counter
import os.path
import os
import preprocessing as pp
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import datetime
from operator import itemgetter

In [2]:
def weight_normalized_cnb(complement_probs_normalized, vectorized_text, prior_probs):
    '''
    :param complement_probs: dictionary where key = label and values = dictionary where
                            keys = words and values = (# of times word w appears in docs
                            NOT labeled l)/(# of words in documents NOT labeled l)
    :param idf: dictionary where keys = words and values = (total # docs)/(# of docs in 
                which we see that word)
    :param vectorized_text: words from text that are in valid_words
    :param prior_probs: dictionary where keys = labels and values = the probability
                        of seeing that label in the dataset
    '''
    labels = []
    freq = Counter(vectorized_text)
    for label in prior_probs.keys():
        conditional = 0.0
        for word in freq.keys():
            conditional += (freq[word] * complement_probs_normalized[label][word])
        labels.append((label, np.exp(conditional)))
    return sorted(labels, key=itemgetter(1))

In [3]:
def complement_naive_bayes(complement_probs, vectorized_text, prior_probs):
    '''
    :param complement_probs: dictionary where key = label and values = dictionary where
                            keys = words and values = (# of times word w appears in docs
                            NOT labeled l)/(# of words in documents NOT labeled l)
    :param vectorized_text: words from text that are in valid_words
    :param prior_probs: dictionary where keys = labels and values = the probability
                        of seeing that label in the dataset
    '''
    labels = []
    doc_denom = 0
    freq = Counter(vectorized_text)
    '''
    for word in freq.keys():
        for label in prior_probs.keys():
            doc_denom += (np.log(prior_probs[label]) + (freq[word]/len(vectorized_text) * complement_probs[label][word]))
    print(doc_denom)
    '''
    for label in prior_probs.keys():
        conditional = 0.0
        for word in freq.keys():
            conditional += (freq[word] * complement_probs[label][word])
        labels.append((label, conditional))
    return sorted(labels, key=itemgetter(1))

In [4]:
def multinomial_naive_bayes(conditional_probs, vectorized_text, prior_probs):
    '''
    :param conditional_probs: dictionary where keys = labels and values = dictionary where
                    keys = words and values = P(x|Y)
    :param vectorized_text: words from text that are in valid_words
    :param prior_probs: dictionary where keys = labels and values = the probability
                        of seeing that label in the dataset
    '''
    labels = []
    freq = Counter(vectorized_text)
    for label in prior_probs.keys():
        conditional = 0.0
        for word in vectorized_text:
            if conditional_probs[label][word] != 0.0:
                conditional += (freq[word] * conditional_probs[label][word])
        labels.append((label, np.exp(conditional)))
    return sorted(labels, key=itemgetter(1), reverse=True)

In [5]:
def compute_prior_probabilities(number_labels):
    '''
    This function will compute the prior probabilities
    P(y) = probability of seeing a label with a sample. 
    Note: since many samples have multiple labels, these prior
    probabilites will sum to > 1
    :param number_labels: dictionary where keys = number of training sample
                            and value = the list of labels associated with it
    :return: a dictionary where keys = the label and value = probability of seeing
            that label in the document list
    '''
    prior_probs = defaultdict(float)
    i = 0
    for num, labels in number_labels.items():
        for l in labels:
            if not prior_probs[l]:
                prior_probs[l] = 1
            else:
                prior_probs[l] += 1
        i += 1
    for label, freq in prior_probs.items():
        prior_probs[label] /= i
    return prior_probs

In [6]:
def rename_files(dir_path):
    '''
    Utility function designed to rename all files in any directory
    to a .txt file so they can be read from
    :param dir_path: directory of the files to be renamed
    '''
    for file in os.listdir(dir_path):
        filepath = dir_path + '\\' + file 
        os.rename(filepath, filepath+".txt")

In [7]:
def get_parameters(dir_path, valid_words, number_labels, label_list):
    '''
    This function will iterate over the documents and compute the frequencies of 
    the words by label and in total
    :param dir_path: a path to the directory containing all the training samples
    :param valid_words: a dictionary where the keys are all the unique, valid
                        terms are present in the text file
    :param number_labels: dictionary where keys = document # and values = the set of 
                            labels associated with those labels
    :param label_list: list of all the unique labels
    :return: a dictionary where keys = labels and values = dictionary where keys 
            = words and values = the frequencies of that word in documents with that label
            AND 
            a dictionary where keys = words and values = the total # of occurrences of
            that word
            AND
            a dictionary where keys = words and values = the idf score for that word
            AND 
            a dictionary where keys = labels and values = the total # of words associated with that label
            AND
            the total # of valid words in the entire corpus
    '''
    words_by_doc_num = defaultdict()
    idf = {word: 0.0 for word in valid_words}
    total_num_words = 0
    total_word_count_by_label = {label: 0 for label in label_list}    
    i = 0
    for file in os.listdir(dir_path):
        with open(dir_path + '\\' + file, "r") as f:
            content = f.read()
            num = int(file[0:len(file) - 4])
            labels = number_labels[num]
            words = nltk.word_tokenize(content)
            new_words = [word.lower() for word in words]
            new_words = [word.lower() for word in new_words if word in valid_words]
            total_num_words += len(new_words)
            freq = Counter(new_words)
            words_by_doc_num[num] = freq
            for word in freq.keys():
                idf[word] += 1
            for l in labels:
                total_word_count_by_label[l] += len(new_words)
            i += 1
    for word in idf.keys():
        idf[word] = 1 + np.log(i/(idf[word]+1))      
    frequencies = {label: {word: 0.0 for word in valid_words} for label in label_list} 
    total_frequencies = {word: 0 for word in valid_words}
    for num in words_by_doc_num.keys():
        freq = words_by_doc_num[num]
        labels = number_labels[num]
        normalization_term = np.sqrt(sum([score**2 for word, score in freq.items()]))
        for l in labels:
            total_word_count_by_label[l] += len(new_words)
            for word in freq.keys():
                term_to_add = freq[word] * idf[word]
                frequencies[l][word] += (term_to_add/normalization_term)
                total_frequencies[word] += (term_to_add/normalization_term)
    return [frequencies, total_frequencies, idf, total_word_count_by_label, total_num_words]

In [8]:
if __name__ == '__main__':
    dir_path = "C:\\Users\\ksing\\OneDrive\\Documents\\TextClassifiers\\training"
    stop_words = set(stopwords.words('english'))
    valid_words = pp.get_valid_words(dir_path, stop_words)
    number_labels_training, number_labels_test = pp.add_labels_to_samples("cats2.txt")
    prior_probs = compute_prior_probabilities(number_labels_training)
    
    parameters = get_parameters(dir_path, valid_words, number_labels_training, prior_probs.keys())
    
    frequencies = parameters[0]
    total_frequencies = parameters[1]
    idf = parameters[2] 
    total_word_count_by_label = parameters[3]
    total_num_words = parameters[4]

In [9]:
    conditional_probs = {label: {word: 0.0 for word in valid_words} for label in prior_probs.keys()}
    complement_probs = {label: {word: 0.0 for word in valid_words} for label in prior_probs.keys()}
    for label, vector in conditional_probs.items():
        denom = total_num_words - total_word_count_by_label[label] + len(valid_words.keys())
        for word in vector.keys():
            mod_cond_freq = frequencies[label][word] + 1
            mod_comp_freq = (total_frequencies[word] - frequencies[label][word]) + 1
            conditional_probs[label][word] = np.log(mod_cond_freq/(total_word_count_by_label[label] + len(valid_words.keys())))
            complement_probs[label][word] = np.log(mod_comp_freq/denom)

In [10]:
    complement_probs_normalized = {label: {word: complement_probs[label][word] for word in valid_words} 
                                   for label in prior_probs.keys()}
    conditional_probs_normalized = {label :{word: 0.0 for word in valid_words} for label in prior_probs.keys()}
    normalize_terms = {label: 0.0 for label in prior_probs.items()}
    for label, vector in complement_probs.items():
        normalize_term_1 = np.sqrt(sum([(complement_probs_normalized[label][word]**2) for word in valid_words]))
        normalize_term_2 = np.sqrt(sum([(conditional_probs[label][word]**2) for word in valid_words]))
        normalize_terms[label] = normalize_term_1
        for word in vector.keys():
            complement_probs_normalized[label][word] /= normalize_term_1
            conditional_probs_normalized[label][word] = conditional_probs[label][word] / normalize_term_2

In [13]:
    # Removing the stemmer actually improves accuracy on test set, who knew
    successes, earned, bottom_5,i = 0, 0, 0, 0
    computed_label_set = defaultdict(list)
    dir_path = "C:\\Users\\ksing\\OneDrive\\Documents\\TextClassifiers\\test"
    for file in os.listdir(dir_path):
        filepath = dir_path + '\\' + file 
        num = int(file[0:len(file) - 4])
        text = pp.vectorize_text(valid_words, filepath)
        computed_labels = complement_naive_bayes(complement_probs, text, prior_probs)
        # computed_labels = multinomial_naive_bayes(conditional_probs_normalized, text, prior_probs)
        # computed_labels = weight_normalized_cnb(complement_probs_normalized, text, prior_probs)
        suc, e, b5 = pp.accuracy_model(num, number_labels_test, computed_labels)
        computed_label_set[num] = [x for x,y in computed_labels]
        # MNB with doc length normalization, IDF: 86.10% accuracy (2599.288708513709), 1548 "Earn" labels
        # CNB with doc length normalization, IDF: 90.02% accuracy(2717.798340548341), 1130 "Earn" labels
        # However, this approach results in conditional terms that don't make much sense for precision or recall
        # WCNB with doc length normalization, IDF: 87.72% accuracy (2648.141955266955), 1542 "Earn" labels
        
        # Perhaps the reason that TF doesn't lead to improvements with this is because we already stripped out the 
        # stop words, which would be affected the most by this technique

        successes += suc
        earned += e
        bottom_5 += b5
        i += 1
    print(successes, earned, bottom_5, i)

14828
['grain'] [('cpi', -547.915172220033), ('grain', -547.8270632903467), ('barley', -547.7815549340163), ('cotton', -547.7576181453837), ('orange', -547.7463093455947), ('rice', -547.7317820380794), ('tea', -547.6685358667639), ('coconut-oil', -547.6676793243474), ('ipi', -547.6557951010985), ('rapeseed', -547.6444534392339)]
14832
['rubber', 'tin', 'sugar', 'corn', 'rice', 'grain', 'trade'] [('bop', -793.0357725400696), ('tin', -792.4608009482948), ('rubber', -792.4011846893133), ('cpi', -791.9316475459617), ('sugar', -791.8930977503874), ('money-supply', -791.6954189681139), ('rice', -791.6023850085486), ('trade', -791.5542576638858), ('retail', -791.5095871321384), ('orange', -791.5059822570007)]
14843
['acq'] [('nickel', -3255.6208783136117), ('iron-steel', -3252.6679580714103), ('castor-oil', -3252.3909189310602), ('nkr', -3252.3759721930533), ('palladium', -3252.3660853818938), ('cotton-oil', -3252.3633024871574), ('jet', -3252.3617817897316), ('sun-meal', -3252.353654810688),

['interest'] [('money-fx', -118.80116394751182), ('interest', -118.04029291263329), ('money-supply', -116.35305924990976), ('rapeseed', -115.7242209861712), ('dlr', -115.57391422870352), ('dmk', -115.5052048695269), ('instal-debt', -115.49978947583308), ('cpu', -115.49586585921766), ('nkr', -115.49311157442969), ('castor-oil', -115.49304583258385)]
15106
['soybean', 'oilseed', 'corn', 'grain'] [('grain', -1253.8795442107248), ('ship', -1253.5283012073437), ('cotton', -1252.5450437806558), ('orange', -1252.2555292536656), ('gas', -1252.2473289465922), ('iron-steel', -1252.168796837838), ('alum', -1252.0601691683687), ('soybean', -1252.0517434330884), ('coffee', -1251.9797646862608), ('corn', -1251.964567577317)]
15107
['earn'] [('acq', -652.3225055334261), ('earn', -648.8735746582719), ('cpu', -643.0028108139642), ('nkr', -643.0009091833617), ('castor-oil', -642.9979386562234), ('cotton-oil', -642.9969375157159), ('housing', -642.9951678952491), ('sun-meal', -642.9938257959193), ('groun

15400
['yen'] [('money-fx', -928.5471809683773), ('dlr', -923.5118917734427), ('yen', -922.0691874194506), ('jobs', -920.6753455651709), ('ipi', -920.2509174814853), ('lei', -920.1846586573856), ('dmk', -920.16364484113), ('tin', -920.1586018032737), ('strategic-metal', -920.1449131458351), ('cotton-oil', -920.1116821375356)]
15410
['interest'] [('money-fx', -674.2300187487473), ('interest', -673.0714397596213), ('money-supply', -669.7888127930815), ('dlr', -669.3730348891995), ('lead', -669.2510022131905), ('zinc', -669.1734331025024), ('dmk', -669.1585185237248), ('palladium', -669.1432552369088), ('nickel', -669.1152814215172), ('lei', -669.1103809776705)]
15416
['nat-gas'] [('crude', -720.8600083225422), ('nat-gas', -720.2775959649186), ('jobs', -719.6173071695018), ('yen', -719.6165772685092), ('lead', -719.5459410009178), ('zinc', -719.5160068889005), ('rubber', -719.4012507682781), ('dmk', -719.3303252139199), ('rand', -719.2942294068314), ('propane', -719.2816296818349)]
15417


15617
['interest'] [('money-fx', -2134.7026962657474), ('dlr', -2124.092752177632), ('interest', -2123.88469723117), ('yen', -2121.589475203633), ('dmk', -2120.575205034378), ('trade', -2120.4809696761295), ('cotton-oil', -2120.4450010110686), ('castor-oil', -2120.440253176531), ('nkr', -2120.437686937687), ('coconut', -2120.4352296474817)]
15620
['potato'] [('carcass', -1433.1907045222151), ('potato', -1433.168420085469), ('livestock', -1432.7995320785578), ('soy-meal', -1432.7311109776033), ('meal-feed', -1432.5112822010353), ('hog', -1432.0321190926109), ('heat', -1431.7691372202437), ('barley', -1431.607580510223), ('l-cattle', -1431.540792803597), ('oat', -1431.5388049889075)]
15623
['iron-steel'] [('acq', -278.09237949319225), ('iron-steel', -276.8773675546887), ('gold', -275.9730545704464), ('pet-chem', -275.9112780009531), ('lead', -275.8458091225653), ('alum', -275.8431473378871), ('zinc', -275.8431281522279), ('nickel', -275.80834511389446), ('jet', -275.80144059539816), ('pa

15875
['oat', 'barley', 'sorghum', 'rice', 'cotton', 'soy-meal', 'soy-oil', 'soybean', 'wheat', 'corn', 'veg-oil', 'meal-feed', 'oilseed', 'grain'] [('grain', -5424.191663524967), ('corn', -5415.517741020027), ('oilseed', -5408.818841866369), ('soybean', -5407.668164732658), ('wheat', -5407.620861492053), ('barley', -5406.401009349206), ('sorghum', -5405.904866190712), ('rice', -5405.665951668645), ('cotton', -5403.958078962129), ('meal-feed', -5403.561084631985)]
15877
['corn', 'grain'] [('earn', -357.1158609215777), ('grain', -354.7664483043973), ('corn', -354.132257921751), ('wheat', -353.2248485167278), ('oilseed', -352.7145622563176), ('barley', -352.6111964959041), ('soybean', -352.5988268223852), ('strategic-metal', -352.4870699182584), ('cotton', -352.4788508632874), ('rice', -352.4780310338962)]
15890
['cotton', 'soy-meal', 'soy-oil', 'soybean', 'corn', 'wheat', 'veg-oil', 'meal-feed', 'oilseed', 'grain'] [('grain', -319.1076062663357), ('corn', -318.91894449639017), ('oilseed

16080
['rape-oil', 'soy-oil', 'veg-oil'] [('veg-oil', -422.5935479886771), ('oilseed', -421.8501479313691), ('rapeseed', -421.6525823985716), ('palm-oil', -421.64985710747305), ('soybean', -421.45384288513264), ('rape-oil', -421.38467344779554), ('meal-feed', -421.36772479816557), ('rice', -421.3227916080968), ('sun-oil', -421.3079080447079), ('soy-oil', -421.30467010712414)]
16083
['money-supply', 'interest'] [('money-fx', -447.18419830211155), ('interest', -445.20389255361886), ('dlr', -444.6933799519684), ('money-supply', -444.35486162459955), ('cocoa', -444.22240330397443), ('rapeseed', -444.12382607073886), ('yen', -444.1049885486133), ('dmk', -444.03592548329664), ('cpi', -444.0200158549225), ('castor-oil', -443.9967309594948)]
16099
['rice', 'corn', 'grain'] [('grain', -1119.7132569905154), ('corn', -1116.854894893374), ('oilseed', -1116.8468967349431), ('wheat', -1116.8298870239632), ('rice', -1116.6954003714682), ('coconut', -1116.3693705831463), ('sugar', -1115.9981208690876)

16289
['lead'] [('copper', -218.0195832599956), ('zinc', -214.28655520795039), ('lead', -214.28132246853787), ('silver', -214.1814863537856), ('alum', -214.13875657564392), ('veg-oil', -214.1281452488685), ('palm-oil', -214.10984428931332), ('heat', -214.07558119766944), ('sun-oil', -214.03826651751822), ('rape-oil', -214.00374156534744)]
16301
['earn'] [('cocoa', 0.0), ('sorghum', 0.0), ('oat', 0.0), ('barley', 0.0), ('corn', 0.0), ('wheat', 0.0), ('grain', 0.0), ('sunseed', 0.0), ('oilseed', 0.0), ('soybean', 0.0)]
16320
['alum', 'acq'] [('alum', -494.7711335050912), ('jobs', -490.86512474048754), ('zinc', -489.5020181570905), ('lead', -489.44478517438426), ('rape-oil', -489.41693647001676), ('lin-oil', -489.4137025510497), ('nkr', -489.4121822268928), ('castor-oil', -489.4118880365148), ('cotton-oil', -489.4112963777254), ('sun-meal', -489.4097989492871)]
16331
['money-fx'] [('interest', -405.2861615417331), ('money-fx', -404.2044248161763), ('money-supply', -398.47148423358186), ('

16852
['interest'] [('money-fx', -1457.4081988671219), ('trade', -1457.2773050922224), ('dlr', -1456.5540207527781), ('interest', -1456.0507688470364), ('cpi', -1455.1054668328422), ('lei', -1454.9542499306006), ('jobs', -1454.870399382885), ('dmk', -1454.8675976872073), ('nkr', -1454.8546604953287), ('castor-oil', -1454.8468797972528)]
16854
['wpi'] [('cpi', -304.46750852748676), ('wpi', -303.91024268599426), ('ipi', -302.96176415108323), ('jobs', -302.4654829360582), ('lei', -302.1242645685364), ('money-supply', -302.00044500889106), ('retail', -301.9924456770571), ('bop', -301.9056394564697), ('gnp', -301.88508998715133), ('housing', -301.84221222721203)]
17036
['money-supply'] [('interest', -3024.4810483752935), ('money-supply', -3021.7367711946777), ('money-fx', -3018.2313469741325), ('reserves', -3006.600513665394), ('housing', -3005.775180696128), ('instal-debt', -3005.6647857216562), ('rye', -3005.6484132345786), ('nkr', -3005.6461750232916), ('castor-oil', -3005.646078908498),

17767
['barley', 'grain'] [('grain', -116.43067964887467), ('crude', -116.18310018061597), ('barley', -115.82019347645482), ('corn', -115.65944142511579), ('wheat', -115.57890611424006), ('palm-oil', -115.07495541484616), ('veg-oil', -115.01309245190072), ('rapeseed', -115.01151846669066), ('soy-oil', -114.9395770782763), ('meal-feed', -114.93813854211076)]
17783
['strategic-metal', 'zinc', 'lead'] [('lead', -1342.2302979686692), ('alum', -1341.881084475198), ('gold', -1338.0119952229957), ('ship', -1337.4333099284586), ('copper', -1337.207040457919), ('zinc', -1337.1522003261884), ('cpi', -1336.8870790104543), ('silver', -1336.3619589400237), ('iron-steel', -1336.228812945438), ('meal-feed', -1335.8948269727548)]
17805
['strategic-metal', 'zinc', 'lead'] [('lead', -1337.333928924605), ('alum', -1337.0013237778373), ('gold', -1333.1005776493298), ('ship', -1332.5931314394634), ('copper', -1332.31288503809), ('zinc', -1332.2567766792167), ('cpi', -1331.99620380908), ('silver', -1331.465

18253
['nat-gas'] [('acq', -404.4028195277669), ('nat-gas', -399.6170733822655), ('heat', -398.85769701877047), ('oat', -398.8328942178743), ('rye', -398.8198014306063), ('gas', -398.80253206087093), ('fuel', -398.78470150858817), ('propane', -398.77735037421894), ('dmk', -398.70109569054176), ('coconut', -398.69912396102615)]
18263
['barley', 'grain'] [('grain', -551.3251909934802), ('wheat', -549.3965544782794), ('barley', -547.3981428161032), ('livestock', -545.3969438879641), ('corn', -545.298393045693), ('meal-feed', -544.8442841487717), ('veg-oil', -544.7740630350402), ('palm-oil', -544.7629934963635), ('rice', -544.6612796038909), ('rapeseed', -544.6360123719179)]
18276
['carcass'] [('grain', -459.87932081928506), ('livestock', -458.848306246848), ('wheat', -458.78154405078095), ('barley', -456.7427146989171), ('orange', -456.65486347244473), ('carcass', -456.6476810243897), ('meal-feed', -456.63236073223595), ('oilseed', -456.0765041852371), ('lumber', -456.01423699783436), ('g

18614
['corn', 'grain'] [('grain', -2519.996088574547), ('wheat', -2517.5950416064934), ('corn', -2515.106028503529), ('cotton', -2514.8930661443055), ('oilseed', -2514.2874462710492), ('soybean', -2514.248046846884), ('rice', -2514.196525645602), ('carcass', -2514.0476392259548), ('groundnut', -2513.756763205458), ('oat', -2513.730011941176)]
18625
['nzdlr', 'money-fx'] [('interest', -2005.1393667323784), ('bop', -2004.91654866043), ('gnp', -2004.2394675331352), ('jobs', -2003.8250777592054), ('dlr', -2003.8127794839631), ('rand', -2003.7933488485935), ('nzdlr', -2003.7735831749135), ('castor-oil', -2003.6436040623707), ('nkr', -2003.6320537525207), ('naphtha', -2003.623294003662)]
18642
['sugar', 'rice', 'grain'] [('sugar', -1617.6831123573081), ('rice', -1609.7464090072867), ('palm-oil', -1608.871421400763), ('veg-oil', -1608.7266787010467), ('copra-cake', -1608.586357929358), ('cotton', -1608.5626075145199), ('palmkernel', -1608.4084798365984), ('dmk', -1608.398828634652), ('rape-o

19055
['grain', 'ship'] [('ship', -787.062557474597), ('tin', -767.8169342624965), ('soy-oil', -767.2119651448792), ('rape-oil', -767.1401712812759), ('sun-oil', -767.1301000687811), ('copper', -767.129476142481), ('zinc', -767.1211835265783), ('castor-oil', -767.1184561835065), ('lin-oil', -767.1127559865856), ('rye', -767.1116957979422)]
19061
['gnp', 'bop', 'yen', 'dlr', 'money-fx'] [('gnp', -998.0494658666621), ('dlr', -993.7651546166876), ('money-fx', -993.1485293092784), ('trade', -992.6488372194673), ('yen', -991.9872325542657), ('bop', -991.3778236898434), ('lei', -991.2996342936249), ('dmk', -991.1813397465322), ('income', -991.1547489910138), ('castor-oil', -991.1529763994645)]
19075
['bop', 'gnp'] [('reserves', -1059.4046094249293), ('gnp', -1059.3238549055736), ('gold', -1059.1382033428806), ('bop', -1059.1324436586488), ('money-supply', -1058.8217691585414), ('rand', -1057.5898443627323), ('lei', -1057.5436818839448), ('strategic-metal', -1057.413985851048), ('housing', -1

19537
['nkr', 'money-fx'] [('money-fx', -178.40991313004687), ('interest', -177.19251169160933), ('dlr', -176.59807403838911), ('reserves', -176.05433753232813), ('wpi', -175.8791769157699), ('dmk', -175.86814730032518), ('yen', -175.7762727388721), ('rape-oil', -175.76516943855066), ('lei', -175.7592076672676), ('nkr', -175.75529971876432)]
19541
['nkr', 'money-fx'] [('money-fx', -1018.6838248137478), ('dlr', -1013.0494253046425), ('interest', -1011.6915073017265), ('dmk', -1011.16586825142), ('lei', -1011.1573809914844), ('nkr', -1010.7546654198015), ('yen', -1010.2838629095139), ('reserves', -1010.1117707821857), ('wpi', -1010.08827030511), ('cpi', -1010.0864349219551)]
19549
['rapeseed', 'oilseed', 'wheat', 'grain'] [('cpi', -526.3905041964374), ('livestock', -526.3288468795945), ('carcass', -525.9416613677893), ('veg-oil', -524.9209510406122), ('jet', -524.888516056077), ('cotton', -524.8589924625016), ('barley', -524.8507113799662), ('wpi', -524.8474556780792), ('naphtha', -524.8

20127
['reserves'] [('earn', -134.68678689576768), ('reserves', -133.750236537287), ('money-fx', -133.58717538647866), ('money-supply', -133.46277360551915), ('dlr', -132.9295521414783), ('dmk', -132.92594426973457), ('interest', -132.8785511931584), ('bop', -132.85601954725374), ('nkr', -132.85033284319658), ('lei', -132.83699970487714)]
20208
['rapeseed', 'sunseed', 'soybean', 'oilseed', 'soy-meal', 'meal-feed'] [('meal-feed', -996.1661211913206), ('soy-meal', -995.4963288150274), ('rapeseed', -995.1112746741017), ('oilseed', -995.069137957235), ('veg-oil', -994.172836490822), ('soy-oil', -994.1155193526312), ('barley', -994.062678634529), ('corn', -994.0075836158738), ('rape-oil', -993.8798323622023), ('soybean', -993.8752296614992)]
20214
['earn'] [('cocoa', 0.0), ('sorghum', 0.0), ('oat', 0.0), ('barley', 0.0), ('corn', 0.0), ('wheat', 0.0), ('grain', 0.0), ('sunseed', 0.0), ('oilseed', 0.0), ('soybean', 0.0)]
20232
['veg-oil'] [('wheat', -1477.869667505192), ('grain', -1477.77210

20757
['trade'] [('orange', -945.9254809007226), ('trade', -945.8899980953064), ('bop', -945.8547985672996), ('coffee', -944.7878465418186), ('reserves', -944.5364137627471), ('lumber', -944.1845327926817), ('money-supply', -944.177359820488), ('instal-debt', -944.1122850397554), ('castor-oil', -944.110746466786), ('nkr', -944.1051835633973)]
20764
['interest', 'money-fx'] [('money-fx', -3734.4399906562894), ('dlr', -3728.5832986619125), ('coconut', -3726.7435945835164), ('yen', -3726.7410815884105), ('nkr', -3726.646152788629), ('jet', -3726.6301045599344), ('castor-oil', -3726.5917177298998), ('palladium', -3726.5845467930426), ('nzdlr', -3726.5815057582663), ('rye', -3726.580326020777)]
20773
['earn'] [('acq', -963.5950411896754), ('zinc', -940.0330738608941), ('rye', -939.4964687971067), ('cotton-oil', -939.4788640288432), ('nkr', -939.4632686496666), ('castor-oil', -939.4592529098287), ('sun-meal', -939.4554314316543), ('palladium', -939.4525208666043), ('lin-oil', -939.4465035099

21542
['dmk', 'yen', 'dlr', 'money-fx'] [('money-fx', -3225.235944286968), ('dlr', -3209.067211907503), ('interest', -3202.1451941297782), ('yen', -3201.753625746401), ('dmk', -3199.9008331047676), ('nkr', -3199.4539372944027), ('rand', -3199.412280317357), ('castor-oil', -3199.407842875593), ('cotton-oil', -3199.3983230103404), ('sun-meal', -3199.388948775206)]
21561
['nzdlr', 'money-fx'] [('gnp', -1430.5173925545014), ('cpi', -1430.1978000600127), ('dlr', -1429.3673937946269), ('bop', -1429.0339882535043), ('nzdlr', -1429.0116403999039), ('interest', -1428.9571814104374), ('lumber', -1428.9149477044853), ('rand', -1428.9028411403892), ('dmk', -1428.899839262267), ('nkr', -1428.8988864870958)]
21565
['sugar'] [('trade', -1052.8036089128584), ('sugar', -1050.2558101564346), ('cpi', -1049.3194605727801), ('copper', -1049.2498159275979), ('coconut', -1049.072314808006), ('nzdlr', -1048.6270617084053), ('strategic-metal', -1048.6194494937463), ('jet', -1048.5794838916397), ('sun-oil', -10

In [14]:
    precision, recall = pp.compute_precision_recall(computed_label_set, number_labels_test, prior_probs.keys())
    print(precision, recall)

0.8704939919893191 0.7139728427507666
