In [None]:
import logging
from datetime import datetime
from pathlib import Path

from sqlalchemy import create_engine, func
from sqlalchemy.orm import sessionmaker

from tqdm import tqdm

from helpers.logging import OutputWidgetHandler
from libratom.cli.subcommands import entities
from libratom.lib.database import db_session
from libratom.lib.entities import (
    OUTPUT_FILENAME_TEMPLATE,
    count_messages_in_files,
    extract_entities,
    load_spacy_model,
)
from libratom.models.entity import Entity

In [None]:
logger = logging.getLogger(__name__)
handler = OutputWidgetHandler()
handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
logger.addHandler(handler)
logger.setLevel(logging.INFO)

### Location of input PST files

In [None]:
# Edit as appropriate
src = Path("../RevisedEDRMv1_Complete")

### Location of output database file

In [None]:
out = Path("/tmp/libratom")
out.mkdir(parents=True, exist_ok=True)

### Input variables

In [None]:
spacy_model_name = 'en_core_web_sm'
concurrent_jobs=4

In [None]:
# Resolve output file based on src parameter
if out.is_dir():
    out = out / OUTPUT_FILENAME_TEMPLATE.format(
        src.name, datetime.now().isoformat(timespec="seconds")
    )

# Get list of PST files from the source
if src.is_dir():
    files = set(src.glob("**/*.pst"))
else:
    files = {src}

# Get the total number of messages
with tqdm(
    total=len(files),
    desc="Initial file scan",
    unit="files",
    leave=False,
) as file_bar:
    msg_count, files = count_messages_in_files(
        files, progress_callback=file_bar.update
    )

# Get spaCy model
logger.info(f"Loading spacy model: {spacy_model_name}")
spacy_model = load_spacy_model(spacy_model_name)
assert spacy_model

# Get messages and extract entities
if not files:
    logger.warning(f"No PST file found in {src}; nothing to do")
else:
    with tqdm(
        total=msg_count, desc="Processing messages", unit="msg"
    ) as msg_bar:
        status = extract_entities(
            files=files,
            destination=out,
            spacy_model=spacy_model,
            jobs=concurrent_jobs,
            progress_callback=msg_bar.update,
        )


### Post Extraction Queries

In [None]:
engine = create_engine(f"sqlite:///{out}")
session = sessionmaker(bind=engine)()

##### Total entity count

In [None]:
session.query(Entity).count()

##### View the first 10 entities

In [None]:
for entity in session.query(Entity)[:10]:
    print(entity)

##### Entity count by type

In [None]:
results = session.query(Entity.label_, func.count(Entity.label_)).group_by(Entity.label_).all()

for entity_type, count in results:
    print(f'{entity_type}: {count}')

In [None]:
session.close()