# Pre-process data for reference classification

This script's pipeline is as follows:

1. Recover text segments that surround certificate ID for all references in CC dataset
2. Create a DataFrame `(dgst, cert_id, label, text_segments)` out of the objects
3. Clean and dump into csv

In [1]:
from __future__ import annotations

from dataclasses import dataclass
from sec_certs.sample import CCCertificate
from sec_certs.dataset import CCDataset
import spacy
from sec_certs.utils.parallel_processing import process_parallel
import pandas as pd
import json

nlp = spacy.load("en_core_web_sm")
from pathlib import Path

REPO_ROOT = Path("../../../").resolve()


@dataclass
class ReferenceRecord:
    """
    Intermediate object to hold references for a given certificate together with sensible attributes to be extracted
    for labeling.
    """

    certificate: CCCertificate | None
    referenced_cert_id: str
    source: str
    label: str | None = None
    sentences: set[str] | None = None

    @staticmethod
    def get_reference_sentences(doc, referenced_cert_id: str) -> set[str]:
        """
        Return a set of sentences corresponding to the given cert_id for the record
        """
        return {sent.text for sent in doc.sents if referenced_cert_id in sent.text}

    @staticmethod
    def get_cert_references_with_sentences(record: ReferenceRecord) -> ReferenceRecord:
        pth_to_read = (
            record.certificate.state.st_txt_path
            if record.source == "target"
            else record.certificate.state.report_txt_path
        )

        with pth_to_read.open("r") as handle:
            data = handle.read()

        result = ReferenceRecord.get_reference_sentences(nlp(data), record.referenced_cert_id)
        record.sentences = result if result else None

        return record

    def to_pandas_tuple(self) -> tuple[str, str, str, str, set[str] | None]:
        return self.certificate.dgst, self.referenced_cert_id, self.source, self.label, self.sentences


def get_df_from_records(records: list[ReferenceRecord]):
    """
    Builds dataframe with [dgst,cert_id,location,reason,sentences] with references from list of ReferenceRecords.
    Reason set to None if not defined.
    """
    results = process_parallel(
        ReferenceRecord.get_cert_references_with_sentences, records, use_threading=False, progress_bar=True
    )
    return pd.DataFrame.from_records(
        [x.to_pandas_tuple() for x in results], columns=["dgst", "referenced_cert_id", "source", "label", "sentences"]
    )


def preprocess_segment(segment):
    segment = segment.replace("\n", " ")
    return segment


def get_split_dict(
    train_path: Path | None = None, valid_path: Path | None = None, test_path: Path | None = None
) -> dict[str, str]:
    """
    Returns dictionary that maps dgst: split, where split in `train`, `valid`, `test`. Expects path to list of dgsts for each split.
    """

    def get_single_dct(pth: Path | None, split_name: str) -> dict[str, str]:
        if not pth:
            return dict()
        with pth.open("r") as handle:
            return dict.fromkeys(json.load(handle), split_name)

    return {
        **get_single_dct(train_path, "train"),
        **get_single_dct(valid_path, "valid"),
        **get_single_dct(test_path, "test"),
    }

def load_annotated_samples(
    train_path: Path | None = None, valid_path: Path | None = None, test_path: Path | None = None
):
    def load_single_df(pth: Path | None, split_name: str) -> pd.DataFrame:
        if not pth:
            return pd.DataFrame()
        return (
            pd.read_csv(pth)
            .assign(label=lambda df_: df_.label.str.replace(" ", "_").str.upper(), split=split_name)
            .replace("NONE", None)
            .dropna(subset="label")
        )

    return pd.concat(
        [load_single_df(train_path, "train"), load_single_df(valid_path, "valid"), load_single_df(test_path, "test")]
    )[["dgst", "referenced_cert_id", "source", "label", "comment"]]

## Extract sentences from text files and populate dataframes

In [2]:
# Load annotated references from CSV
annotations_df = load_annotated_samples(REPO_ROOT / "data/reference_annotations/manual_annotations/train.csv", REPO_ROOT / "data/reference_annotations/manual_annotations/valid.csv")

# Load dataset
# dset = CCDataset.from_web_latest()
dset = CCDataset.from_json(REPO_ROOT / "datasets/cc/cc_dataset.json")

target_certs = [x for x in dset if x.heuristics.st_references.directly_referencing and x.state.st_txt_path]
report_certs = [x for x in dset if x.heuristics.report_references.directly_referencing and x.state.report_txt_path]
target_records = [
    ReferenceRecord(x, y, "target", None, None)
    for x in target_certs
    for y in x.heuristics.st_references.directly_referencing
]
report_records = [
    ReferenceRecord(x, y, "report", None, None)
    for x in report_certs
    for y in x.heuristics.report_references.directly_referencing
]

# df_labeled = get_df_from_records(annotated_records)
df_targets = get_df_from_records(target_records)
df_reports = get_df_from_records(report_records)
df = pd.concat([df_targets, df_reports])

100%|██████████| 944/944 [01:05<00:00, 14.48it/s]
100%|██████████| 2288/2288 [00:32<00:00, 69.50it/s]


## Process Dataframes and dump to csv

*Note*: So far don't work with test dataset

In [5]:
# Load split labels
split_dct = get_split_dict(
    REPO_ROOT / "data/reference_annotations/split/train.json",
    REPO_ROOT / "data/reference_annotations/split/valid.json",
    REPO_ROOT / "data/reference_annotations/split/test.json",
)

# Creates dictionary `(dgst, cert_id): label`` to populate instances with manually assigned annotations.
annotations_dict = (
    annotations_df[["dgst", "referenced_cert_id", "label"]].set_index(["dgst", "referenced_cert_id"]).label.to_dict()
)

# TODO: We should investigate the cases when we match no sentence, they may be new-lines and stuff
# TODO: Add language detection
# Process
df = (
    df.assign(
        split=df.dgst.map(split_dct),
        label=lambda df_: [annotations_dict.get(x) for x in zip(df_["dgst"], df_["referenced_cert_id"])],
    )
    .loc[lambda df_: (df_["sentences"].notnull()) & (df_["split"] != "test")]
    .groupby(["dgst", "referenced_cert_id", "label", "split"], as_index=False, dropna=False)["sentences"]
    .agg({"sentences": lambda x: set.union(*x)})
)
df.to_csv(REPO_ROOT / "datasets/reference_classification_dataset.csv", index=False)