In [None]:
import logging
import gzip
import os
import re
import shutil
import sqlite3
from email import policy
from email.parser import HeaderParser
import numpy as np
from pathlib import Path
from typing import Dict, List

import pandas as pd
from matplotlib import pyplot
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer, ENGLISH_STOP_WORDS
from tqdm.notebook import tqdm

#### Things that won't change

In [None]:
# To be removed from message bodies
LICENSE = '''EDRM Enron Email Data Set has been produced in EML, PST and NSF format by ZL Technologies, Inc. This Data Set is licensed under a Creative Commons Attribution 3.0 United States License <http://creativecommons.org/licenses/by/3.0/us/> . To provide attribution, please cite to "ZL Technologies, Inc. (http://www.zlti.com)."'''

#### Things that can change

In [None]:
# Study sample size
TOTAL_MESSAGES = 10_000

# Batch size when reading messages from the database
MESSAGE_BATCH_SIZE = 100

In [None]:
# Input file for this study
# Generated from PST files, e.g: `ratom report -pvm /tmp/libratom/test_data/RevisedEDRMv1_Complete/ -o edrm_subset.sqlite3`
DB_FILE = Path("data/edrm_subset.sqlite3.gz")

#### Utility functions

In [None]:
def cleanup_message_body(body: str) -> str:
    """
    Removes non content from message body
    """
    
    license = re.escape(LICENSE)
    
    # Also remove separators around license
    separator = re.escape('***********')

    body = re.sub(f"{separator}[\s]*{license}[\s]*{separator}", "", body, flags=re.MULTILINE)

    return body.strip()

In [None]:
def read_messages_from_db(db_file: str) -> Dict[str, List]:
    """
    Returns a dict of lists to feed to a dataframe
    """

    header_parser = HeaderParser(policy=policy.default)

    # The keys here will become dataframe column names
    # Use bracket notation with 'from' and other python reserved keywords, e.g: df['from']
    messages = {
        'from': [],
        'to': [],
        'subject': [],
        'body': [],
    }

    with sqlite3.connect(db_file) as conn:
        cursor = conn.cursor()

        # Our base DB query
        cursor.execute(f"select headers, body from message")

        with tqdm(desc='messages', unit='msg', total=TOTAL_MESSAGES) as pbar:

            # Fetch batches of messages
            for batch in iter(lambda: cursor.fetchmany(MESSAGE_BATCH_SIZE), []):
            
                # Iterate over messages until we have the right amount of useful messages
                for raw_headers, body in batch:
                    try:
                        # Parse the headers
                        headers = header_parser.parsestr(raw_headers)
                        sender = headers['from'] or headers['sender'] or headers['return-path'] or headers['reply-to']
                        recipient = headers['to']
                        
                        # Skip this message if we don't have both a sender and a recipient
                        if not (sender and recipient):
                            continue

                        # Sanitize the body
                        body = cleanup_message_body(body)

                        # Skip message if body is empty
                        if not body:
                            continue

                        # Add message contents to partial results
                        messages['from'].append(sender)
                        messages['to'].append(recipient)
                        messages['subject'].append(headers['subject'] or '')  # Possibly blank
                        messages['body'].append(body)

                        # Update progress
                        pbar.update()
                        
                        # Have we reached our sample size?
                        if pbar.n >= pbar.total:
                            return messages

                    except Exception as exc:
                        print(exc)
                    

In [None]:
def top_tfidf_feats(row, features, top_n=20):
    topn_ids = np.argsort(row)[::-1][:top_n]
    top_feats = [(features[i], row[i]) for i in topn_ids]
    df = pd.DataFrame(top_feats, columns=['features', 'score'])
    return df

def top_feats_in_doc(X, features, row_id, top_n=25):
    row = np.squeeze(X[row_id].toarray())
    return top_tfidf_feats(row, features, top_n)

def top_mean_feats(X, features, grp_ids=None, min_tfidf=0.1, top_n=25):
    if grp_ids:
        D = X[grp_ids].toarray()
    else:
        D = X.toarray()

    D[D < min_tfidf] = 0
    tfidf_means = np.mean(D, axis=0)
    return top_tfidf_feats(tfidf_means, features, top_n)




def top_feats_per_cluster(X, y, features, min_tfidf=0.1, top_n=25):
    dfs = []

    labels = np.unique(y)
    for label in labels:
        ids = np.where(y==label) 
        feats_df = top_mean_feats(X, features, ids, min_tfidf=min_tfidf, top_n=top_n)
        feats_df.label = label
        dfs.append(feats_df)
    return dfs


def plot_tfidf_classfeats_h(dfs):
    fig = pyplot.figure(figsize=(12, 9), facecolor="w")
    x = np.arange(len(dfs[0]))
    for i, df in enumerate(dfs):
        ax = fig.add_subplot(1, len(dfs), i+1)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.set_frame_on(False)
        ax.get_xaxis().tick_bottom()
        ax.get_yaxis().tick_left()
        ax.set_xlabel("Tf-Idf Score", labelpad=16, fontsize=14)
        ax.set_title("cluster = " + str(df.label), fontsize=16)
        ax.ticklabel_format(axis='x', style='sci', scilimits=(-2,2))
        ax.barh(x, df.score, align='center', color='#7530FF')
        ax.set_yticks(x)
        ax.set_ylim([-1, x[-1]+1])
        yticks = ax.set_yticklabels(df.features)
        pyplot.subplots_adjust(bottom=0.09, right=0.97, left=0.15, top=0.95, wspace=0.52)
    pyplot.show()

#### Study

In [None]:
# Decompress DB file
if DB_FILE.suffix == '.gz':
    out_path = DB_FILE.parent / DB_FILE.stem

    with gzip.open(str(DB_FILE), 'rb') as f_in, open(out_path, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)

    unzipped_db_file = str(out_path)
else:
    unzipped_db_file = str(DB_FILE)

In [None]:
# Extract messages from DB file
messages = read_messages_from_db(unzipped_db_file)

In [None]:
# Load messages into dataframe
df = pd.DataFrame(messages)
df

In [None]:
# Quick visual check on the message bodies
for body in df.body[:100]:
    print('=' * 80)
    print(body)

In [None]:
# Here max_df is 50% of total documents, min_df is 2 documents
vect = TfidfVectorizer(stop_words='english', max_df=0.50, min_df=2)

# Get document-term matrix
X = vect.fit_transform(df.body)

# Number of unique terms
# print(len(vect.get_feature_names()))

# Visualize document-term matrix
X_dense = X.todense()
coords = PCA(n_components=2).fit_transform(X_dense)
pyplot.scatter(coords[:, 0], coords[:, 1], c='m')
pyplot.show()

In [None]:
# Pick a message
msg_index = 100

In [None]:
features = vect.get_feature_names()
print(top_feats_in_doc(X, features, msg_index, 25))

In [None]:
print(df.body[msg_index])

In [None]:
print(top_mean_feats(X, features, None, 0.1, 25))

In [None]:
n_clusters = 3
clf = KMeans(n_clusters=n_clusters, 
            max_iter=100, 
            init='k-means++', 
            n_init=1)
labels = clf.fit_predict(X)

In [None]:
X_dense = X.todense()
pca = PCA(n_components=2).fit(X_dense)
coords = pca.transform(X_dense)

In [None]:
# Lets plot it again, but this time we add some color to it.
# This array needs to be at least the length of the n_clusters.
label_colors = ["#2AB0E9", "#2BAF74", "#D7665E", "#CCCCCC", 
                "#D2CA0D", "#522A64", "#A3DB05", "#FC6514"]
colors = [label_colors[i] for i in labels]

pyplot.scatter(coords[:, 0], coords[:, 1], c=colors)

# Plot the cluster centers
centroids = clf.cluster_centers_
centroid_coords = pca.transform(centroids)
pyplot.scatter(centroid_coords[:, 0], centroid_coords[:, 1], marker='X', s=200, linewidths=2, c='#444d60')


pyplot.show()

In [None]:
plot_tfidf_classfeats_h(top_feats_per_cluster(X, labels, features, 0.1, 25))