Skip to content
Browse files

fix this travesty

  • Loading branch information...
2 parents a0bcf97 + 466bfe6 commit 024aed2e3d647289ce2f3b444b013ed571898bc2 Daniel Erenrich committed May 31, 2012
View
20 pa4/python-starter/message_features.py
@@ -7,17 +7,19 @@
SUBJECT_TAG = "Subject: "
+
+
MAX_TOKEN_LEN = 20
MIN_WORD_LEN = 3;
-NUM_RE = re.compile(r"([0-9]+)")
-WORD_RE = re.compile(r"([a-zA-Z'\-]+)")
-ALPHANUM_RE = re.compile(r"(\w+)")
-HYPERLINK_RE = re.compile(r"(http\:\/\/(\w+\.)+\w+)")
-EMAIL_RE = re.compile(r"([\w\-\.]+@[\w\-\.]+)")
-DELIMS_RE = re.compile(r"[\s\.()\"',-:;/\\?!@]+")
-
-class MessageFeatures:
-
+NUM_RE = re.compile(r"^([0-9]+)$")
+WORD_RE = re.compile(r"^([a-zA-Z'\-]+)$")
+ALPHANUM_RE = re.compile(r"^(\w+)$")
+HYPERLINK_RE = re.compile(r"^(http\:\/\/(\w+\.)+\w+)$")
+EMAIL_RE = re.compile(r"^([\w\-\.]+@[\w\-\.]+)$")
+DELIMS_RE = re.compile(r"[\s\.()\"',-:;/\\?!@]+")
+
+
+class MessageFeatures:
def __init__(self, newsgroupnum, filename, stemmer, stopwords):
self.newsgroupnum = newsgroupnum
self.filename = filename
View
3 pa4/python-starter/message_iterators.py
@@ -1,12 +1,9 @@
-
from __future__ import print_function
-
from cPickle import Unpickler
import sys
class MessageIterator(object):
-
def __init__(self, inp_filename):
self.filename = inp_filename
with open(self.filename, 'rb') as inpfile:
View
128 pa4/python-starter/naive_bayes_classifier.py
@@ -41,6 +41,8 @@ def binomial(mi):
class_words[groupnum] = Counter()
doc = counter_add(m.subject, m.body)
for w in doc:
+ if doc[w] <= 0:
+ continue
if not (groupnum in dictionaries):
dictionaries[groupnum] = set()
dictionaries[groupnum].add(w)
@@ -170,6 +172,116 @@ def multinomial(mi):
output[groupnum] += str(predicted_class)+"\t"
for i in range(len(output)):
print(output[i][:-1])
+
+
+
+def twcnb_init(mi, filter_func=None):
+ """
+ converts the mi into the needed weight format we need
+ for all of the twcnb-related functions
+ """
+ class_words = dict()
+ for m in mi:
+ groupnum = m.newsgroupnum
+ if not groupnum in class_words:
+ class_words[groupnum] = Counter()
+ doc = counter_add(m.subject, m.body)
+ # as a hack just put everything in m.body
+ m.body = doc
+ if filter_func:
+ # should modify mi in place, but whatever
+ mi = filter_func(mi)
+ for m in mi:
+ groupnum = m.newsgroupnum
+ for w in m.body:
+ if m.body[w] <= 0:
+ continue
+ class_words[groupnum][w] += 1
+ return class_words
+
+def cnb_filter(class_words):
+ # we need to know how big the dictionary is
+ dictionary = set()
+ for c in class_words:
+ dictionary = dictionary.union(class_words[c].keys())
+ alpha = len(dictionary)
+ cc_words = dict()
+ # LINE 4 PAGE 7
+ for c in class_words:
+ cc_words[c] = Counter()
+ denom = len(class_words[c]) # alpha
+ for j in class_words:
+ if j == c:
+ continue
+ for k in class_words[j]:
+ denom += class_words[j][k]
+ for i in class_words[c]:
+ num = 1
+ for j in class_words:
+ if j == c:
+ continue
+ num += class_words[j][i]
+ cc_words[c][i] = log(num / float(denom))
+ return cc_words
+
+def wcnb_filter(class_words):
+ for c in class_words:
+ weight_sum = 0.0
+ for w in class_words[c]:
+ weight_sum += abs(class_words[c][w])
+ for w in class_words[c]:
+ class_words[c][w] /= weight_sum
+ return class_words
+
+def twcnb_filter(mi):
+ # EQUATION 1
+ for m in mi:
+ for w in m.body:
+ m.body[w] = log(m.body[w] + 1)
+ # EQUATION 2
+ for m in mi:
+ for i in m:
+ num = len(mi)
+ denom = 0
+ for j in mi:
+ if mi[j][i] > 0:
+ denom += 1
+ return mi
+
+
+def advanced_nb(mi, weight_filter_func, message_filter_func=None):
+ if message_filter_func != None:
+ mi = message_filter_func(mi)
+ class_words = twcnb_init(mi)
+ class_words = weight_filter_func(class_words)
+ class_count = Counter()
+ correct = 0
+ total = 0
+ for m in mi:
+ groupnum = m.newsgroupnum
+ if class_count[groupnum] >= 20:
+ continue
+ class_count[groupnum] += 1
+ doc = counter_add(m.body,m.subject)
+ scores = []
+ for c in class_words:
+ score = 0
+ for w in doc:
+ score += doc[w] * class_words[c][w]
+ scores.append((score, c))
+ total += 1
+ predicted_class = min(scores)[1]
+ if predicted_class == groupnum:
+ correct += 1
+ print("actual class : " + str(groupnum), file=sys.stderr)
+ print("predicted class : " + str(predicted_class), file=sys.stderr)
+ print("accuracy : " + str(float(correct) / total), file=sys.stderr)
+
+def cnb(mi):
+ advanced_nb(mi, cnb_filter)
+def wcnb(mi):
+ advanced_nb(mi, lambda cw : wcnb_filter(cnb_filter(cw)))
+
def twcnb(mi):
pass
@@ -186,7 +298,9 @@ def output_probability(probs):
'binomial': binomial,
'binomial-chi2': binomial_chi2,
'multinomial': multinomial,
- 'twcnb': twcnb
+ 'twcnb': twcnb,
+ 'cnb' : cnb,
+ 'wcnb' : wcnb
# Add others here if you want
}
@@ -199,12 +313,12 @@ def main():
mi = MessageIterator(train)
- try:
- MODES[mode](mi)
- except KeyError:
- print("Unknown mode: {0}".format(mode),file=sys.stderr)
- print("Accepted modes are: {0}".format(MODES.keys()), file=sys.stderr)
- sys.exit(-1)
+ #try:
+ MODES[mode](mi)
+ #except KeyError:
+ # print("Unknown mode: {0}".format(mode),file=sys.stderr)
+ # print("Accepted modes are: {0}".format(MODES.keys()), file=sys.stderr)
+ # sys.exit(-1)
if __name__ == '__main__':
main()

0 comments on commit 024aed2

Please sign in to comment.
Something went wrong with that request. Please try again.