In [1]:
from typing import Dict, List, Any, Set
import os

os.chdir("..")
from src.utils import *

import nltk
#nltk.download('punkt')
import json
import argparse
import unidecode

EN_TOKENIZER = nltk.data.load("tokenizers/punkt/english.pickle")
ES_TOKENIZER = nltk.data.load("tokenizers/punkt/spanish.pickle")

In [2]:
def format_sents_for_output(sents: List[str], doc_id: str) -> Dict[str, Dict[str, Any]]:
    """
    Transform a list of sentences into a dict of format:
        {
            "sent_id": {"text": "sentence text", "label": []}
        }
    """
    formatted_sents = {}

    for i, sent in enumerate(sents):
        formatted_sents.update({f"{doc_id}_sent_{i}": {"text": sent, "label": []}})

    return formatted_sents


def preprocess_text(txt: str, remove_new_lines: bool = False) -> str:
    """
    Steps in the preprocessing of text:
        1. Remove HTML tags
        2. Replace URLS by a tag [URL]
        3. Replace new lines and tabs by normal spaces - sometimes sentences have new lines in the middle
        4. Remove excessive spaces (more than 1 occurrence)
        5. Parse emails and abreviations
    """
    txt = replace_links(remove_html_tags(txt)).strip()
    if remove_new_lines:
        txt = txt.replace("\n", " ").replace("\t", " ").strip()

    txt = remove_multiple_spaces(txt)
    txt = parse_emails(txt)
    txt = parse_acronyms(txt)

    new_txt = ""
    all_period_idx = set([indices.start() for indices in re.finditer("\.", txt)])

    for i, char in enumerate(txt):
        if i in all_period_idx:
            # Any char following a period that is NOT a space means that we should not add that period
            if i + 1 < len(txt) and txt[i + 1] != " ":
                continue

            # NOTE: Any char that is a number following a period will not count.
            # For enumerations, we're counting on docs being enumerated as "(a)" or "(ii)", and if not,
            # they will be separated by the "." after the number:
            # "Before bullet point. 3. Bullet point text" will just be "Before bullet point 3." and "Bullet point text" as the sentences
            if i + 2 < len(txt) and txt[i + 2].isnumeric():
                continue

            # If we wanted to have all numbered lists together, uncomment this, and comment out the previous condition
            # if i + 2 < len(txt) and not txt[i + 2].isalpha():
            #     continue

        new_txt += char

    return new_txt


def preprocess_english_text(txt: str, remove_new_lines: bool = False) -> str:
    return preprocess_text(txt, remove_new_lines)


def preprocess_spanish_text(txt: str, remove_new_lines: bool = False) -> str:
    return unidecode.unidecode(preprocess_text(txt, remove_new_lines))


def remove_short_sents(sents: List[str], min_num_words: int = 4) -> List[str]:
    """
    Remove sentences that are made of less than a given number of words. Default is 4
    """
    return [sent for sent in sents if len(sent.split()) >= min_num_words]


def get_nltk_sents(txt: str, tokenizer: nltk.PunktSentenceTokenizer, extra_abbreviations: Set[str] = None) -> List[str]:
    if extra_abbreviations is not None:
        tokenizer._params.abbrev_types.update(extra_abbreviations)

    return tokenizer.tokenize(txt)

In [27]:
language='spanish'
abbrevs= None
tokenizer = ES_TOKENIZER
min_num_words = 5

with open("../extract_text/output/pdf_files.json", "r", encoding="utf-8") as f:
    pdf_conv = json.load(f)


In [28]:
file_lst = []
for key in pdf_conv:
    file_lst.append((key,pdf_conv[key]['Text']))

In [35]:
error_files={}
i = 0
for file_id, text in file_lst:
    try:
        preprocessed_text = preprocess_spanish_text(text)
        sents = get_nltk_sents(preprocessed_text, tokenizer, abbrevs)
        postprocessed_sents = format_sents_for_output(remove_short_sents(sents, min_num_words), file_id)
        sents_json = {file_id: {"metadata":
            {"n_sentences": len(postprocessed_sents),
            "language": language},
            "sentences": postprocessed_sents}}
        with open(f'./output/new/{file_id}_sents.json', 'w') as f:
            json.dump(sents_json, f)

    except Exception as e:
        error_files[str(file_id)]= str(e)

    i += 1

    if i % 10 == 0:
        print("----------------------------------------------")
        print(f"Processing {i} documents...")
        print(f"Number of errors so far: {len(error_files)}")
        print("----------------------------------------------")

with open("./errors.json", "w") as x:
    json.dump(error_files, x)

print("=============================================================")
print(f"Total documents processed: {i}")
print(f"Total number of documents with errors: {len(error_files)}. Stored file in ../output/sentence_splitting_errors.json")
print("=============================================================")

----------------------------------------------
Processing 10 documents...
Number of errors so far: 0
----------------------------------------------
----------------------------------------------
Processing 20 documents...
Number of errors so far: 0
----------------------------------------------
----------------------------------------------
Processing 30 documents...
Number of errors so far: 0
----------------------------------------------
----------------------------------------------
Processing 40 documents...
Number of errors so far: 0
----------------------------------------------
----------------------------------------------
Processing 50 documents...
Number of errors so far: 0
----------------------------------------------
----------------------------------------------
Processing 60 documents...
Number of errors so far: 0
----------------------------------------------
Total documents processed: 61
Total number of documents with errors: 0. Stored file in ../output/sentence_splitt