In [6]:
!python -m spacy download en_core_web_lg


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 189, in _run_module_as_main
  File "<frozen runpy>", line 148, in _get_module_details
  File "<frozen runpy>", line 112, in _get_module_details
  File "/Users/omar/opt/anaconda3/envs/triplets/lib/python3.11/site-packages/spacy/__init__.py", line 6, in <module>
  File "/Users/omar/opt/anaconda3/envs/triplets/lib/python3.11/site-packages/spacy/errors.py", line 3, in <module>
    from .compat import Literal
  File "/Users/omar/opt/anaconda3/envs/triplets/lib/python3.11/site-packages/spacy/compat.

In [None]:
import logging


def get_logger(log_name, log_file="processing_log.txt"):
    """Create and return a logger that writes to a log file.

    Parameters:
        log_name (str): Name of the logger (e.g., "citations", "abbreviations", "coref").
        log_file (str): Path to the log file. Default is "processing_log.txt".

    Returns:
        logging.Logger: Configured logger for writing logs.
    """
    logger = logging.getLogger(log_name)
    logger.setLevel(logging.DEBUG)

    # Create file handler to log to a file
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.DEBUG)

    # Create formatter and add it to the handler
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    file_handler.setFormatter(formatter)

    # Add the file handler to the logger
    logger.addHandler(file_handler)

    return logger


def write_log(
    old_text, new_text, line_count, start_positions, end_positions, action, logger
):
    """Write a log entry indicating a specific text transformation.

    Parameters:
        old_text (str): Original text before transformation.
        new_text (str): Transformed text.
        line_count (int): Line number where the change occurred.
        start_positions (list): List of start positions for the changes.
        end_positions (list): List of end positions for the changes.
        action (str): Description of the action performed (e.g., "Removing citations").
        logger (logging.Logger): The logger to write the log entry.

    """
    logger.info(f"Action: {action}")
    logger.info(f"Line: {line_count}")
    logger.info(f"Old Text: {old_text.strip()}")
    logger.info(f"New Text: {new_text.strip()}")
    logger.info(f"Start Positions: {start_positions}")
    logger.info(f"End Positions: {end_positions}")
    logger.info("-" * 50)

In [7]:
from fastcoref import spacy_component
from abbreviations import schwartz_hearst

import re
import time
import tqdm
import spacy
import pandas as pd
from spacy.tokens import Doc

spacy.prefer_gpu()

# Load spacy pipeline
nlp = spacy.load("en_core_web_lg", exclude=["parser", "ner", "lemmatizer", "textcat"])
nlp.add_pipe("fastcoref")


def remove_citations(texts, logger_citations):
    """Remove citations through a rule based heuristic

    Parameters:
        texts (list[str]): list of texts
        log_file (str): path to the log file
    Return:
        new_texts (list[str]): list of texts with the citations removed

    """
    new_texts = []
    for text in texts:
        new_text = ""
        line_count = 0
        for line in text.split("\n"):
            # Brackets with only a number inside are removed
            # Brackets with a year inside are removed
            # Brackets with a number inside and other text, e.g. [llm2], are not removed
            re_expression = "\[[0-9]{4}[a-zA-Z0-9 .,!/\-\"']*\]|\[[0-9]+\]|\[[a-zA-Z0-9 .,!/\-\"']*[0-9]{4}\]|\([a-zA-Z0-9 .,!/\-\"']*[0-9]{4}\)|\([0-9]{4}[a-zA-Z0-9 .,!/\-\"']*\)|\([0-9]+\)"
            if re.search(re_expression, line):
                # get starting and ending position of citation. If there are multiple citations in one line, store starting and ending position of each in a list
                new_line = re.sub(re_expression, "", line)
                start_pos, end_pos = [], []
                for match in re.finditer(re_expression, line):
                    start_pos.append(match.start())
                    end_pos.append(match.end())

                write_log(
                    line,
                    new_line,
                    line_count,
                    start_pos,
                    end_pos,
                    "Removing citations",
                    logger_citations,
                )
            else:
                new_line = line
            line_count += 1
            new_text += new_line + "\n"
        new_texts.append(new_text)
    return new_texts


def expand_abbreviations(texts, logger_abbr):
    """Expand the abbreviations using the Schwartz-Hearst algorithm

    Parameters:
        texts (list[str]): list of texts
        log_file (str): path to the log file

    Return:
        new_texts (list[str]): list of texts with the abbreviations expanded
        pairs (dict): dictionary with the abbreviations as keys and the definitions as values
    """

    new_texts = []
    errors_with_abbreviations = set()
    for text in texts:
        pairs = schwartz_hearst.extract_abbreviation_definition_pairs(doc_text=text)
        # Add the fully lowercased versions of the abbreviations as keys
        pairs_copy = pairs.copy()
        for abbrev, definition in pairs_copy.items():
            if abbrev.lower() != abbrev:
                pairs[abbrev.lower()] = definition
        # iterate over the lines in the text file and replace the abbreviations
        # split by \n to get the lines

        sentences = text.split("\n")
        new_sentences = []
        for i, sentence in enumerate(sentences):
            old_sentence = sentence
            start_pos, end_pos = [], []
            replacements = []
            for abbrev, definition in pairs.items():
                # check whether the abbreviation is in the sentence
                if abbrev in sentence:
                    # we have to make sure that the abbreviation is not inside a word, e.g. "in" in "within". It is allowed to have punctuation before and after the abbreviation, e.g. AI, or AI.
                    # We add a "try" since the abbreviation might contain a backslash, which would cause an error. If there is an error, we skip the abbreviation
                    try:
                        for m in re.finditer(abbrev, old_sentence):
                            # check whether there is a letter before and after the abbreviation
                            if m.start() > 0:
                                if sentence[m.start() - 1].isalpha():
                                    continue
                            if m.end() < len(sentence):
                                if sentence[m.end()].isalpha():
                                    continue
                            replacements.append(((m.start(), m.end()), definition))
                    except:
                        errors_with_abbreviations.add(abbrev)
                        continue
            # Now we want to make sure that the replacements do not overlap. We do this by sorting the replacements by their start index and then iterating over them and only keeping the first replacement that does not overlap with the previous replacements
            replacements = sorted(replacements, key=lambda x: x[0][0])
            replacements_to_keep = []
            for replacement in replacements:
                if len(replacements_to_keep) == 0:
                    replacements_to_keep.append(replacement)
                else:
                    # check whether the replacement overlaps with the previous replacements
                    overlap = False
                    for replacement_to_keep in replacements_to_keep:
                        if replacement[0][0] <= replacement_to_keep[0][1]:
                            overlap = True
                            break
                    if not overlap:
                        replacements_to_keep.append(replacement)
            # Now we can replace the abbreviations with their definitions
            sorted_replacements_to_keep = sorted(
                replacements_to_keep, key=lambda x: x[0][0], reverse=True
            )
            for replacement in sorted_replacements_to_keep:
                sentence = (
                    sentence[: replacement[0][0]]
                    + replacement[1]
                    + sentence[replacement[0][1] :]
                )
                start_pos.append(replacement[0][0])
                end_pos.append(replacement[0][1])
            new_sentences.append(sentence)
            if len(replacements_to_keep) > 0:
                write_log(
                    old_sentence,
                    sentence,
                    i,
                    start_pos,
                    end_pos,
                    "Abbreviation replacement",
                    logger_abbr,
                )
        # Get new_text by joining the sentences
        new_text = "\n".join(new_sentences)
        new_texts.append(new_text)
    return new_texts, errors_with_abbreviations


def get_span_noun_indices(doc, cluster):
    """Get the indices of the spans that contain a noun

    Parameters:
        doc (Doc): spacy document
        cluster (list[tuple]): list of tuples with the start and end position of the spans

    Return:
        span_noun_indices (list[int]): list of indices of the spans that contain a noun
    """

    spans = [doc.text[span[0] : span[1] + 1] for span in cluster]
    # We now want to know which tokens are in the spans and whether they are nouns
    span_noun_indices = []
    for idx, span in enumerate(spans):
        has_noun = False
        for token in doc:
            if token.text in span and token.pos_ in ["NOUN", "PROPN"]:
                has_noun = True
                break
        if has_noun:
            span_noun_indices.append(idx)
    return span_noun_indices


def is_containing_other_spans(span, all_spans):
    """Check whether a span is containing other spans

    Parameters:
        span (tuple): tuple with the start and end position of the span
        all_spans (list[tuple]): list of tuples with the start and end position of the spans

    Return:
        bool: whether the span is containing other spans
    """
    return any([s[0] >= span[0] and s[1] <= span[1] and s != span for s in all_spans])


def get_cluster_head(doc: Doc, cluster, noun_indices):
    """Get the head of the cluster

    Parameters:
        doc (Doc): spacy document
        cluster (list[tuple]): list of tuples with the start and end position of the spans
        noun_indices (list[int]): list of indices of the spans that contain a noun

    Return:
        head_span (str): head of the cluster
        head_start_end (tuple): tuple with the start and end position of the head
    """
    head_idx = noun_indices[0]
    head_start, head_end = cluster[head_idx]
    head_span = doc.text[head_start : head_end + 1]
    return head_span, (head_start, head_end)


def replace_corefs(doc, logger_coref, clusters):
    """Replace the coreferences in the text

    Parameters:
        doc (Doc): spacy document
        PATH_LOG (str): path to the log file
        clusters (list[list[tuple]]): list of clusters, where each cluster is a list of tuples with the start and end position of the spans

    Return:
        new_text (str): text with the coreferences replaced
    """

    all_spans = [span for cluster in clusters for span in cluster]
    # initialize new text being equal to old text
    new_text = doc.text
    start_positions = []
    end_positions = []
    all_replacements = []
    for cluster in clusters:
        noun_indices = get_span_noun_indices(doc, cluster)
        if len(noun_indices) > 0:
            mention_span, mention = get_cluster_head(doc, cluster, noun_indices)
            for coref in cluster:
                if coref != mention and not is_containing_other_spans(coref, all_spans):
                    # Execute the replacement
                    start_pos, end_pos = coref
                    # Replace the coref
                    start_positions.append(coref[0])
                    end_positions.append(coref[1])
                    # Store the replacement in a way that we can do it later
                    all_replacements.append((coref, mention_span))


def process_dataframe(df: pd.DataFrame):
    """Process the rows of the DataFrame containing the titles, abstracts, and bodies

    Parameters:
        df (pd.DataFrame): DataFrame with 'id', 'title', 'abstract', 'body' columns

    Return:
        df (pd.DataFrame): DataFrame with processed text columns
    """
    texts = []
    logger_citations = get_logger("citations")
    logger_abbr = get_logger("abbreviations")
    logger_coref = get_logger("coref")

    for index, row in df.iterrows():
        title, abstract, body = row["title"], row["abstract"], row["body"]

        # Remove citations
        abstract, body = (
            remove_citations([abstract], logger_citations)[0],
            remove_citations([body], logger_citations)[0],
        )

        # Expand abbreviations
        abstract, errors_abbr = expand_abbreviations([abstract], logger_abbr)
        body, errors_abbr_body = expand_abbreviations([body], logger_abbr)
        abstract, body = abstract[0], body[0]

        # Apply coreference resolution
        doc_abstract = nlp(abstract)
        doc_body = nlp(body)
        clusters_abstract = doc_abstract._.coref_clusters
        clusters_body = doc_body._.coref_clusters

        abstract = replace_corefs(doc_abstract, logger_coref, clusters_abstract)
        body = replace_corefs(doc_body, logger_coref, clusters_body)

        # Store the processed text
        texts.append(
            {"id": row["id"], "title": title, "abstract": abstract, "body": body}
        )

    # Return the DataFrame with processed text
    return pd.DataFrame(texts)


# Now, `processed_df` will contain the processed titles, abstracts, and bodies with citations removed, abbreviations expanded, and coreferences resolved.

11/09/2024 20:34:36 - INFO - 	 missing_keys: []
11/09/2024 20:34:36 - INFO - 	 unexpected_keys: []
11/09/2024 20:34:36 - INFO - 	 mismatched_keys: []
11/09/2024 20:34:36 - INFO - 	 error_msgs: []
11/09/2024 20:34:36 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M


In [8]:
arxiv_df = pd.read_parquet('filtered_df_arxiv.parquet')

In [9]:
processed_df = process_dataframe(arxiv_df)

11/09/2024 20:36:18 - INFO - 	 Action: Removing citations
11/09/2024 20:36:18 - INFO - 	 Line: 0
11/09/2024 20:36:18 - INFO - 	 Old Text: Introduction When a glass-forming liquid is cooled rapidly, its viscosity increases dramatically and it eventually transforms into an amorphous solid, called a glass, whose physical properties are profoundly different from those of ordered crystalline solids [1]. At even lower temperature, around 1K, the specific heat of a disordered solid is much larger than that of its crystalline counterpart as it scales linearly rather than cubically with temperature. Similarly, the temperature evolution of the thermal conductivity in glasses is quadratic, rather than cubic [2], [3], [4], [5], [6], [7], [8], [9], [10], [11]. A theoretical framework rationalizing such anomalous behavior was provided by Anderson, Halperin and Varma [12] and by Phillips [13], [14]. They argued that the energy landscape of amorphous solids contains many nearly-degenerate minima, conn

RuntimeError: Numpy is not available