# Pre-process data for reference classification

This script's pipeline is as follows:

1. Recover text segments that surround certificate ID for all references in CC dataset
2. Create a DataFrame `(dgst, cert_id, label, text_segments)` out of the objects
3. Clean and dump into csv

In [85]:
from __future__ import annotations

from dataclasses import dataclass
from sec_certs.dataset import CCDataset
from sec_certs.sample import CCCertificate
import spacy
from sec_certs.utils.parallel_processing import process_parallel
import pandas as pd
from tqdm import tqdm
import json

nlp = spacy.load("en_core_web_sm")
from pathlib import Path

REPO_ROOT = Path("../../../").resolve()

@dataclass
class ReferenceRecord:
    """
    Intermediate object to hold references for a given certificate together with sensible attributes to be extracted
    for labeling.
    """
    certificate: CCCertificate | None
    dgst: str
    cert_id: str
    location: str
    label: str | None = None
    sentences: set[str] | None = None

    @staticmethod
    def get_reference_sentences(doc, cert_id: str) -> set[str]:
        """
        Return a set of sentences corresponding to the given cert_id for the record
        """
        return {sent.text for sent in doc.sents if cert_id in sent.text}

    @staticmethod
    def get_cert_references_with_sentences(record: ReferenceRecord) -> set[tuple[str, str, str]]:
        pth_to_read = (
            record.certificate.state.st_txt_path
            if record.location == "target"
            else record.certificate.state.report_txt_path
        )

        with pth_to_read.open("r") as handle:
            data = handle.read()

        result = ReferenceRecord.get_reference_sentences(nlp(data), record.cert_id)
        record.sentences = result if result else None

        return record

    def to_pandas_tuple(self) -> tuple[str, str, str, str, set[str] | None]:
        return self.dgst, self.cert_id, self.location, self.label, self.sentences

def get_df_from_records(records: list[ReferenceRecord]):
    """
    Builds dataframe with [dgst,cert_id,location,reason,sentences] with references from list of ReferenceRecords.
    Reason set to None if not defined. 
    """
    results =  process_parallel(ReferenceRecord.get_cert_references_with_sentences, records, use_threading=False, progress_bar=True)
    return pd.DataFrame.from_records([x.to_pandas_tuple() for x in results], columns=["dgst", "cert_id", "location", "label", "sentences"])

def preprocess_segment(segment):
    segment = segment.replace("\n", " ")
    return segment

## Extract sentences from text files and populate dataframes

In [111]:
# Load annotated references from CSV
annotations_df = pd.read_csv(REPO_ROOT / "data/cert_id_eval/random_references.csv")
annotations_df = annotations_df.rename(columns={"id": "dgst", "reason": "label"})
annotations_df = annotations_df.loc[annotations_df.label != "self"]
annotations_df.label = annotations_df.label.map(lambda x: x.upper().replace(" ", "_"))

# Load dataset
# dset = CCDataset.from_web_latest()
dset = CCDataset.from_json(REPO_ROOT / "datasets/cc/cc_dataset.json")

annotated_records = [
    ReferenceRecord(dset[x.dgst], x.dgst, x.cert_id, x.location, x.label)
    for x in annotations_df.itertuples(index=False)
]

# Reference records without annotations
target_certs = [x for x in dset if x.heuristics.st_references.directly_referencing and x.state.st_txt_path]
report_certs = [x for x in dset if x.heuristics.report_references.directly_referencing and x.state.report_txt_path]
target_records = [
    ReferenceRecord(x, x.dgst, y, "target", None, None)
    for x in target_certs
    for y in x.heuristics.st_references.directly_referencing
]
report_records = [
    ReferenceRecord(x, x.dgst, y, "report", None, None)
    for x in report_certs
    for y in x.heuristics.report_references.directly_referencing
]

# Filter annotated_records from report_records to avoid duplicities
annotated_keys = {(x.dgst, x.cert_id) for x in annotated_records}
report_records = [x for x in report_records if (x.dgst, x.cert_id) not in annotated_keys]

df_labeled = get_df_from_records(annotated_records)
df_targets = get_df_from_records(target_records)
df_reports = get_df_from_records(report_records)

# Creates dictionary (dgst, cert_id): label to populate instances that have NULL label (on location==target) but were annotated in location==report and could adopt that label
# This helps to avoid duplicities and extends the number of annotated sentences.
dgst_cert_id_to_label_mapping = (
    df_labeled.loc[df_labeled.label.notnull(), ["dgst", "cert_id", "label"]]
    .drop_duplicates(subset=["dgst", "cert_id"])
    .set_index(["dgst", "cert_id"])
    .label.to_dict()
)

df = pd.concat([df_labeled, df_targets, df_reports])

# Check for label noise
dgst_cert_id_tuples = (
    df.drop_duplicates(subset=["dgst", "cert_id"])
    .loc[:, ["dgst", "cert_id"]]
    .set_index(["dgst", "cert_id"])
    .index.tolist()
)
duplicate_df = pd.DataFrame()
for dgst, cert_id in tqdm(dgst_cert_id_tuples):
    possible_duplicates = df.loc[(df.dgst == dgst) & (df.cert_id == cert_id) & (df.label.notnull())]
    if (
        possible_duplicates.shape[0] > 1
        and not possible_duplicates.drop_duplicates(subset=["dgst", "cert_id", "label"], keep=False).empty
    ):
        duplicate_df = pd.concat([duplicate_df, possible_duplicates])

if not duplicate_df.empty:
    print(
        "Warning, label noise detected, see `duplicate_df` for instances that have inconsistent label for `(dgst, cert_id)` key."
    )

# With no label noise, we should be safe to fill in labels for sentences found in targets such that the corresponding report was annotated
df.label = df.copy().apply(
    lambda row: dgst_cert_id_to_label_mapping.get(
        (row["dgst"], row["cert_id"]) if pd.isnull(row["label"]) else row["label"]
    ),
    axis=1,
)


100%|██████████| 58/58 [00:07<00:00,  8.26it/s]
100%|██████████| 944/944 [01:08<00:00, 13.84it/s]
100%|██████████| 2259/2259 [00:33<00:00, 68.33it/s]
100%|██████████| 2551/2551 [00:02<00:00, 1089.81it/s]


## Process Dataframes and dump to csv

1. Version with `dgst, cert_id, location, single_sentence` as `*_exploded.csv`
2. Version where all sentences tied to `(dgst, cert_id)` key are merged into `sentences`. Saved as `*_grouped.csv`

*Note*: So far don't work with test dataset

In [112]:
# Load split labels
with (REPO_ROOT / "data/reference_annotations_split/train.json").open("r") as handle:
    train_digests = json.load(handle)

with (REPO_ROOT / "data/reference_annotations_split/valid.json").open("r") as handle:
    valid_digests = json.load(handle)

split_dct = {**dict.fromkeys(train_digests, "train"), **dict.fromkeys(valid_digests, "valid")}

# Apply filtering
# TODO: We should investigate the cases when we match no sentence
df = df.loc[df.sentences.notnull()] 
df["split"] = df.dgst.map(split_dct)
df = df.loc[df.split.notnull()]  # Discard test samples

# TODO: Add language detection

# # Aggregate sentences from different sources (target, report) into one row
df = df.groupby(["dgst", "cert_id", "label", "split"], as_index=False, dropna=False)["sentences"].agg({"sentences": lambda x: set.union(*x)})
df.to_csv(REPO_ROOT / "datasets/reference_classification_dataset.csv", sep=';', index=False)

In [113]:
duplicate_df