In [None]:
import logging
from datetime import datetime
from pathlib import Path

from sqlalchemy import create_engine, func
from sqlalchemy.orm import sessionmaker

from tqdm import tqdm

from helpers.logging import OutputWidgetHandler
from libratom.lib.database import db_init, db_session
from libratom.lib.entities import (
    OUTPUT_FILENAME_TEMPLATE,
    count_messages_in_files,
    extract_entities,
    load_spacy_model,
)
from libratom.lib.report import store_file_reports_in_db
from libratom.models.entity import Entity
from libratom.models.file_report import FileReport

In [None]:
logger = logging.getLogger()
handler = OutputWidgetHandler()
handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)

In [None]:
def print_orm_object(obj, exclude=None):
    exclude =  exclude or []
    
    columns = [col.name for col in obj.__table__.columns]
    for column in columns:
        if not column in exclude:
            print(f'{column}: {getattr(obj, column)}')

### Location of input PST files

In [None]:
# Edit as appropriate
src = Path("RevisedEDRMv1_Complete")

### Location of output database file

In [None]:
out = Path.cwd()

### Input variables

In [None]:
spacy_model_name = 'en_core_web_sm'

### Entity extraction

In [None]:
# Resolve output file based on src parameter
if out.is_dir():
    out = out / OUTPUT_FILENAME_TEMPLATE.format(
        src.name, datetime.now().isoformat(timespec="seconds")
    )

# DB setup
Session = db_init(out)
    
# Get list of PST files from the source
if src.is_dir():
    files = set(src.glob("**/*.pst"))
else:
    files = {src}

# Compute and store file information
with tqdm(
    total=len(files),
    desc="Retrieving file information",
    unit="files",
    leave=False,
) as file_bar, db_session(Session) as session:
    store_file_reports_in_db(files, session, jobs=4, progress_callback=file_bar.update)

# Get the total number of messages
with tqdm(
    total=len(files),
    desc="Retrieving total message count",
    unit="files",
    leave=False,
) as file_bar:
    msg_count, files = count_messages_in_files(
        files, progress_callback=file_bar.update
    )

# Get spaCy model
logger.info(f"Loading spacy model: {spacy_model_name}")
spacy_model = load_spacy_model(spacy_model_name)
assert spacy_model

# Get messages and extract entities
if not files:
    logger.warning(f"No PST file found in {src}; nothing to do")
else:
    with tqdm(
        total=msg_count, desc="Processing messages", unit="msg"
    ) as msg_bar, db_session(Session) as session:
        status = extract_entities(
            files=files,
            session=session,
            spacy_model=spacy_model,
            jobs=4,
            progress_callback=msg_bar.update,
        )


### Post Extraction Queries

In [None]:
engine = create_engine(f"sqlite:///{out}")
session = sessionmaker(bind=engine)()

##### Total entity count

In [None]:
session.query(Entity).count()

##### View the first 100 entities

In [None]:
for entity in session.query(Entity)[:100]:
    print_orm_object(entity, exclude=['id', 'file_report_id', 'message_id', 'filepath'])
    print(f'file: {entity.file_report.name}')
    print('---')

##### Entity count by type

In [None]:
results = session.query(Entity.label_, func.count(Entity.label_)).group_by(Entity.label_).all()

for entity_type, count in results:
    print(f'{entity_type}: {count}')

##### Per file reports

In [None]:
file_reports = session.query(FileReport).all()
for file_report in file_reports:
    print_orm_object(file_report)
    print(f'number of messages: {len(file_report.messages)}')
    print(f'number of entities: {len(file_report.entities)}')
    print(f'processing start time: {file_report.processing_start_time}')
    print(f'processing end time: {file_report.processing_end_time}')
    print(f'processing wall time: {file_report.processing_wall_time}')
    print('---')

In [None]:
session.close()

### Log details

In [None]:
handler.show_logs()