# Topic Modeling

## Load data

In [205]:
import os
import re

In [206]:
data_dir = 'emails'

In [207]:
messages = []

for file in os.listdir(data_dir):
    file_path = os.path.join(data_dir, file)
    with open(file_path, encoding='latin-1') as f:
        messages.append(f.read())

In [208]:
len(messages)

11313

In [209]:
messages[0]

'From: spl2@po.cwru.edu (Sam Lubchansky)\nSubject: Re: Joe Robbie Stadium "NOT FOR BASEBALL"\nArticle-I.D.: po.spl2.114.734131045\nOrganization: Case Western Reserve University\nLines: 27\nNNTP-Posting-Host: b61644.student.cwru.edu\n\nIn article <1993Apr6.025027.4846@oswego.Oswego.EDU> iacs3650@Oswego.EDU (Kevin Mundstock) writes:\n>From: iacs3650@Oswego.EDU (Kevin Mundstock)\n>Subject: Joe Robbie Stadium "NOT FOR BASEBALL"\n>Date: 6 Apr 93 02:50:27 GMT\n>Did anyone notice the words "NOT FOR BASEBALL" printed on the picture\n>of Joe Robbie Stadium in the Opening Day season preview section in USA\n>Today? Any reason given for this?\n>\n\nI would assume that the words (I saw the picture) indicated that those \nSEATS will not be available for baseball games.  If you look at the picture \nof the diamond in the stadium, in relation to the areas marked "NOT FOR \nBASEBALL", those seats just look terrible for watching baseball.   Now, if \nthey should happen to reach the post-season, I would 

In [210]:
def strip_newsgroup_header(text):
    _before, _blankline, after = text.partition('\n\n')
    return after


_QUOTE_RE = re.compile(r'(writes in|writes:|wrote:|says:|said:'
                       r'|^In article|^Quoted from|^\||^>)')


def strip_newsgroup_quoting(text):
    good_lines = [line for line in text.split('\n')
                  if not _QUOTE_RE.search(line)]
    return '\n'.join(good_lines)


def strip_newsgroup_footer(text):
    lines = text.strip().split('\n')
    for line_num in range(len(lines) - 1, -1, -1):
        line = lines[line_num]
        if line.strip().strip('-') == '':
            break

    if line_num > 0:
        return '\n'.join(lines[:line_num])
    else:
        return text

In [211]:
def clean_email(text):
    text = strip_newsgroup_header(text)
    text = strip_newsgroup_quoting(text)
    return strip_newsgroup_footer(text)

In [212]:
cleaned_messages = [clean_email(text) for text in messages]

In [263]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.pipeline import Pipeline

In [264]:
from nltk import word_tokenize
from nltk.stem import WordNetLemmatizer

class LemmaTokenizer:
    def __init__(self):
        self.wnl = WordNetLemmatizer()
    def __call__(self, doc):
        return [self.wnl.lemmatize(t) for t in word_tokenize(doc)]

In [265]:
num_components = 5

In [266]:
lda_pipeline = Pipeline([
    ('vect', CountVectorizer(stop_words='english', tokenizer=LemmaTokenizer())),
    ('lda', LatentDirichletAllocation(n_components=num_components, random_state=0, max_iter=10))
])

In [None]:
lda_pipeline = lda_pipeline.fit(cleaned_messages)

  'stop_words.' % sorted(inconsistent))


In [252]:
lda_pipeline.transform([cleaned_messages[0]])

array([[0.13028119, 0.00304363, 0.00319958, 0.00306087, 0.86041472]])

Test out on already categorized data.

In [253]:
train_dir = os.path.join('20news-bydate', '20news-bydate-train')

In [254]:
topic_groups = [
    ['comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x'],
    ['rec.sport.baseball', 'rec.sport.hockey'],
    ['rec.autos', 'rec.motorcycles'],
    ['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space'],
    ['alt.atheism', 'soc.religion.christian', 'talk.religion.misc'],
    ['talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc']
]

In [255]:
for topics in topic_groups:
    pred_topics = [0] * num_components
    
    for topic in topics:
        topic_dir = os.path.join(train_dir, topic)

        for file in os.listdir(topic_dir):
            file_name = os.path.join(topic_dir, file)

            with open(file_name, encoding='latin-1') as f:
                text = f.read()

            topic_dist = lda_pipeline.transform([clean_email(text)])[0]
            for idx, prob in enumerate(topic_dist):
                if prob > 0.5:   
                    pred_topics[idx] += 1
    
    print(topics, pred_topics)

['comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x'] [2471, 26, 15, 17, 295]
['rec.sport.baseball', 'rec.sport.hockey'] [321, 26, 4, 2, 786]
['rec.autos', 'rec.motorcycles'] [265, 6, 5, 1, 838]
['sci.crypt', 'sci.electronics', 'sci.med', 'sci.space'] [595, 10, 6, 8, 1662]
['alt.atheism', 'soc.religion.christian', 'talk.religion.misc'] [30, 9, 3, 1, 1396]
['talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc'] [46, 7, 5, 3, 1489]
