In [None]:
import logging
from pathlib import Path
from contextlib import contextmanager

from collections import defaultdict

import multiprocessing

from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

import ipywidgets as widgets
from ipywidgets import Layout, Box
from IPython.display import display

import humanfriendly
import pandas as pd

from libratom.utils.pff import PffArchive

### Set up spacy

In [None]:
import spacy
nlp = spacy.load("en_core_web_sm")

### Log settings

In [None]:
# https://ipywidgets.readthedocs.io/en/stable/examples/Output%20Widget.html#Integrating-output-widgets-with-the-logging-module
class OutputWidgetHandler(logging.Handler):
    """ Custom logging handler sending logs to an output widget """

    def __init__(self, *args, **kwargs):
        super(OutputWidgetHandler, self).__init__(*args, **kwargs)
        layout = {
            'display': 'flex',
            'border': '1px solid lightgray',
        }
        self.out = widgets.Output(layout=layout)

    def emit(self, record):
        """ Overload of logging.Handler method """
        new_output = {
            'name': 'stdout',
            'output_type': 'stream',
            'text': f'{self.format(record)}\n'
        }
        self.out.outputs = (new_output, ) + self.out.outputs

    def show_logs(self):
        """ Show the logs """
        display(self.out)

    def clear_logs(self):
        """ Clear the current logs """
        self.out.clear_output()

logger = logging.getLogger(__name__)
handler = OutputWidgetHandler()
handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
logger.addHandler(handler)
logger.setLevel(logging.INFO)

### Location of PST files

In [None]:
# Edit as appropriate
CACHED_ENRON_DATA_DIR = Path("/tmp/libratom/test_data/RevisedEDRMv1_Complete/andy_zipper")

### DB file

In [None]:
DB_FILE = Path("/tmp/libratom/ner.db")

### Session variables

In [None]:
# Whether to update report widgets as we extract entities
SHOW_PROGRESS = False
SHOW_ENTITIES = False

# Generate the list of files to know how many there are
FILES = list(CACHED_ENRON_DATA_DIR.glob('**/*.pst'))

### Database setup

In [None]:
# Remove previous DB file
try:
    DB_FILE.unlink()
    logger.info(f'Removed existing database file: {DB_FILE}')
except FileNotFoundError:
    pass

In [None]:
engine = create_engine(f'sqlite:///{DB_FILE}')
Session = sessionmaker(bind=engine)

In [None]:
Base = declarative_base()

class Entity(Base):
    __tablename__ = 'entities'
    
    id = Column(Integer, primary_key=True)
    text = Column(String)
    label_ = Column(String)
    filename = Column(String)

Base.metadata.create_all(engine)

In [None]:
Entity.__table__

### Rendering

In [None]:
# Layouts
report_box_layout = Layout(
    display='flex',
    flex_flow='column nowrap',
    width='50%',
    margin='0px 0px 4px 0px',
    border='1px solid lightblue',
    justify_content='center',
    align_items='center'
)

entities_box_layout = Layout(
    width='50%',
    height='16em',
    margin='0px 0px 4px 0px',
    border='1px solid lightblue',
)

### Utility functions

##### Widget update functions

In [None]:
def update_report(out, data):
    """Refreshes the report output widget
    """

    out.clear_output(wait=True)
    
    df_data = {key: [value] for key, value in data.items()}
    df_data['Size'] = [humanfriendly.format_size(data['Size'])]
    
    with out:
        display(pd.DataFrame(df_data, index=['Total']))


def update_entities(out, data):
    """Refreshes the entities output widget
    """

    out.clear_output(wait=True)
    
    with out:
        print('Sample of entities found')
        print('------------------------')

        for ent in data:
            print(' '.join(ent.values()))

##### Message generator

In [None]:
def get_messages(files, report):
    # Iterate over files
    for pst_file in files:
        try:            
            with PffArchive(pst_file) as archive:                
                # Iterate over messages
                for message in archive.messages():
                    try:
                        
                        yield pst_file.name, archive.format_message(message)
                        
                        # Update report per message
                        report['Messages'] += 1
                        
                    except Exception as exc:
                        # Log and move on to the next message
                        logger.exception(exc)

            # Update report per file
            report['Files'] += 1    
            report['Size'] += pst_file.stat().st_size
            
            # Update progress bar
            progress.value += 1
                        
        except Exception as exc:
            # Log and move on to the next file
            logger.exception(exc)

##### Job function for the worker processes

In [None]:
def process_message(filename: str, message: str):
    # Return basic types to avoid serialization issues

    try:
        # Extract entities from the message
        doc = nlp(message)
        
        entities = [{'text': ent.text, 'label_': ent.label_, 'filename': filename} for ent in doc.ents]

        return entities, None

    except Exception as exc:
        return None, str(exc)

##### DB session context manager

In [None]:
@contextmanager
def open_db_session():

    session = Session()
    try:
        yield session
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()

### Initialize progress and report widgets

In [None]:
# Progress bar for number of files processed
progress = widgets.IntProgress(
    value=0,
    min=0,
    max=len(FILES),
    step=1,
    description='Completed:',
    bar_style='',
    orientation='horizontal'
)

# Container for the report
report_out = widgets.Output()

# Container for the entities sample
ents_out = widgets.Output()

### Extract entities per message

In [None]:
%%time

# handler.clear_logs()

# Overall report
report = defaultdict(int)

# Start displaying results
display(Box(children=[report_out, progress], layout=report_box_layout))

# Entities sample widget
if SHOW_ENTITIES:
    display(Box(children=[ents_out], layout=entities_box_layout))

# Can't pickle lambdas
def job(args):
    return process_message(*args)

if __name__ == '__main__':
    with multiprocessing.Pool() as pool, open_db_session() as session:
        for entities, exc in pool.imap(job, get_messages(FILES, report), chunksize=100):
            if exc:
                report['Errors'] += 1
                logger.error(exc)
            
            for entity in entities:
                new_ent = Entity(**entity)
                session.add(new_ent)
            
            report['Entities'] += len(entities)
            
            # Update entities sample
            if SHOW_ENTITIES:
                update_entities(ents_out, entities[:10])
            
            # Refresh report widgets
            if SHOW_PROGRESS:
                update_report(report_out, report)

# Final report
if not SHOW_PROGRESS:
    update_report(report_out, report)

In [None]:
# Print out errors, if any 
handler.show_logs()

In [None]:
# Quick check
session.query(Entity).count()