# Pre-process data for reference classification

This script's pipeline is as follows:

1. Recover text segments that surround certificate ID for all references in CC dataset
2. Create a DataFrame `(dgst, cert_id, label, text_segments)` out of the objects
    - two versions, in fact: one with single row per segment, second with all segments from all sources (target, report) merged into single row
3. Clean and dump into csv
4. Check for label noise

In [6]:
from __future__ import annotations

from dataclasses import dataclass
from sec_certs.dataset import CCDataset
from sec_certs.sample import CCCertificate
import spacy
from sec_certs.utils.parallel_processing import process_parallel
import pandas as pd
import json
from tqdm import tqdm


nlp = spacy.load("en_core_web_sm")
from pathlib import Path

REPO_ROOT = Path("../../../").resolve()

@dataclass
class ReferenceRecord:
    """
    Intermediate object to hold references for a given certificate together with sensible attributes to be extracted
    for labeling.
    """
    certificate: CCCertificate | None
    dgst: str
    cert_id: str
    location: str
    label: str | None = None
    sentences: set[str] | None = None

    @staticmethod
    def get_reference_sentences(doc, cert_id: str) -> set[str]:
        """
        Return a set of sentences corresponding to the given cert_id for the record
        """
        return {sent.text for sent in doc.sents if cert_id in sent.text}

    @staticmethod
    def get_cert_references_with_sentences(record: ReferenceRecord) -> set[tuple[str, str, str]]:
        pth_to_read = (
            record.certificate.state.st_txt_path
            if record.location == "target"
            else record.certificate.state.report_txt_path
        )

        with pth_to_read.open("r") as handle:
            data = handle.read()

        result = ReferenceRecord.get_reference_sentences(nlp(data), record.cert_id)
        record.sentences = result if result else None

        return record

    def to_pandas_tuple(self) -> tuple[str, str, str, str, set[str] | None]:
        return self.dgst, self.cert_id, self.location, self.label, self.sentences

def get_df_from_records(records: list[ReferenceRecord]):
    """
    Builds dataframe with [dgst,cert_id,location,reason,sentences] with references from list of ReferenceRecords.
    Reason set to None if not defined. 
    """
    results =  process_parallel(ReferenceRecord.get_cert_references_with_sentences, records, max_workers=200, use_threading=False, progress_bar=True)
    return pd.DataFrame.from_records([x.to_pandas_tuple() for x in results], columns=["dgst", "cert_id", "location", "label", "sentences"])



## Extract sentences from text files and populate dataframes

In [20]:
# Load annotated references from CSV
annotations_df = pd.read_csv(REPO_ROOT / "data/cert_id_eval/random_references.csv")
annotations_df = annotations_df.rename(columns={"id": "dgst", "reason": "label"})
annotations_df = annotations_df.loc[annotations_df.label != "self"]
annotations_df.label = annotations_df.label.map(lambda x: x.upper().replace(" ", "_"))

# Load dataset
# dset = CCDataset.from_web_latest()
dset = CCDataset.from_json(REPO_ROOT / "datasets/cc/cc_dataset.json")

annotated_records = [ReferenceRecord(dset[x.dgst], x.dgst, x.cert_id, x.location, x.label) for x in annotations_df.itertuples(index=False)]

# Reference records without annotations
target_certs = [x for x in dset if x.heuristics.st_references.directly_referencing and x.state.st_txt_path]
report_certs = [x for x in dset if x.heuristics.report_references.directly_referencing and x.state.report_txt_path]
target_records = [ReferenceRecord(x, x.dgst, y, "target", None, None) for x in target_certs for y in x.heuristics.st_references.directly_referencing]
report_records = [ReferenceRecord(x, x.dgst, y, "report", None, None) for x in report_certs for y in x.heuristics.report_references.directly_referencing]

# Filter annotated_records from report_records to avoid duplicities
annotated_keys = {(x.dgst, x.cert_id) for x in annotated_records}
report_records = [x for x in report_records if (x.dgst, x.cert_id) not in annotated_keys]

df_labeled = get_df_from_records(annotated_records)
df_targets = get_df_from_records(target_records)
df_reports = get_df_from_records(report_records)
df = pd.concat([df_labeled, df_targets, df_reports])

 69%|██████▉   | 654/944 [22:33<10:00,  2.07s/it]
100%|██████████| 58/58 [00:07<00:00,  8.26it/s]
100%|██████████| 944/944 [01:04<00:00, 14.59it/s]
100%|██████████| 2259/2259 [00:30<00:00, 75.10it/s]


In [49]:
df = df_labeled.copy()

In [21]:
# Load split labels
with (REPO_ROOT / "data/reference_annotations_split/train.json").open("r") as handle:
    train_digests = json.load(handle)

with (REPO_ROOT / "data/reference_annotations_split/valid.json").open("r") as handle:
    valid_digests = json.load(handle)

split_dct = {**dict.fromkeys(train_digests, "train"), **dict.fromkeys(valid_digests, "valid")}


# Apply filtering
df = df.loc[df.sentences.notnull()] # TODO: We should investigate the cases when we match no sentence
df["split"] = df.dgst.map(split_dct)  # Annotate with splits
df = df.loc[df.split.notnull()]  # Discard test samples
df.explode("sentences").to_csv(REPO_ROOT / "datasets/reference_classification_dataset_exploded.csv", sep=';', index=False)

# TODO: Add language detection

# Aggregate sentences from different sources (target, report) into one row
df_grouped = df.groupby(["dgst", "cert_id", "label", "split"], as_index=False)["sentences"].agg({"sentences": lambda x: set.union(*x)})
df_grouped.to_csv(REPO_ROOT / "datasets/reference_classification_dataset_merrged.csv", sep=';', index=False)

In [26]:
# Check for label noise
duplicates_df = df_grouped[df_grouped.duplicated(subset=["dgst", "cert_id"], keep=False)]
if not duplicates_df.empty:
    print("Warning, label noise in dataset. I.e. tuples (dgst, cert_id) with inconsistent reason. See `duplicates_df` frame.")