In [None]:
class Config:
    def __init__(self, for_git=True):
        self.gap_too_large_threshold = 1000
        self.savetime_on_fulltext = False   # If True, operations on fulltext will be kept to a minimum

    def get_output_path(self):
        return ''

config = Config()

In [None]:
import time
import os

def time_function(func):
    def wrapper(*args, **kwargs):
        appendix = ""
        # if instances in args:
        if "instances" in kwargs:
            # append len of instances
            appendix = f"({len(kwargs['instances'])} instances"
        if "papers" in kwargs:
            if appendix:
                appendix += ", "
            appendix += f"{len(kwargs['papers'])} papers"
        if appendix:
            appendix += ")"
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} executed in {end_time - start_time} seconds" + appendix)
        return result
    return wrapper

In [None]:
def split_string(input_string, delimiters = [" ", "-", "_"]):
    for delimiter in delimiters:
        input_string = " ".join(input_string.split(delimiter))
    return input_string.split()

In [None]:
# Step 2: find occurrences of instances in full text of papers
import sys
from bisect import bisect_left
from sortedcontainers import SortedSet
import numpy as np
import json

class PosInPaper:
    def __init__(self):
        # List of paper identifiers
        self.papers = []
        # List of literals
        self.literals = []
        # Dict of unique words across all literals
        self.words = {}
        self.word_len = []
        # List of unique combinations of words across all literals
        self.word_combinations = {}
        self.word_combination_lists = []
        # 2D list mapping pairs of literals to their word combination index
        self.word_combination_index_literal_literal = []
        # 2D list of SortedSets, each containing the positions of a word in a paper
        self.word_occurrences_in_papers = []
        # 3D list containing the minimum distances between word combinations in each paper
        self.min_distances = []

    @time_function
    def populate(self, config: Config, papers: list, literals: list[str], paper_full_text, optimize=True):
        """
        Populates the internal data structures with occurrences and distances of literals in papers.

        Parameters:
        - config (Config): Configuration object containing settings.
        - papers (list): List of paper identifiers.
        - literals (list[str]): List of literals to process.
        - paper_full_text (dict): Mapping from paper identifiers to their full text file paths.
        - optimize (bool): Flag to optimize data structures after population.
        """
        self.initialize_variables(papers, literals)
        self.process_literals(literals)
        self.process_literal_combinations(literals)
        self.setup_data_structures(papers)
        self.find_occurrences_in_texts(papers, paper_full_text)
        if optimize:
            self.optimize_data()

    @time_function
    def initialize_variables(self, papers, literals):
        """
        Initializes basic variables for the class instance.

        Parameters:
        - papers (list): List of paper identifiers.
        - literals (list): List of literals to process.
        """
        self.papers = papers
        self.literals = literals
        self.word_combination_index_literal = {lit: None for lit in literals}

    @time_function
    def process_literals(self, literals):
        """
        Processes each literal to extract and store unique words and word combinations.

        Parameters:
        - literals (list): List of literals to process.
        """
        for lit in literals:
            word_list = split_string(lit)
            self.add_words(word_list)
            self.add_if_word_combination(word_list, lit)

    def add_words(self, word_list):
        """
        Adds unique words from a list to the internal list of words.

        Parameters:
        - word_list (list): List of words to add.
        """
        for word in word_list:
            if word not in self.words:
                self.words[word] = len(self.words)
                self.word_len.append(len(word))

    def add_if_word_combination(self, word_list, lit):
        """
        Adds a unique combination of words from a list to the internal list of word combinations.

        Parameters:
        - word_list (list): List of words forming a combination.
        - lit (str): The literal corresponding to the word combination.
        """
        if len(word_list) > 1:
            pos = self.word_combination_index_literal.get(lit, -1)
            if pos == -1 or pos == None:
                froz = frozenset(word_list)
                pos = len(self.word_combinations)
                self.add_word_combination(froz, pos)
            self.word_combination_index_literal[lit] = pos
    
    def add_word_combination(self, froz, pos):
        self.word_combinations[froz] = pos
        self.word_combination_lists.append([self.words[word] for word in sorted(froz, key=len, reverse=True)])

    @time_function
    def process_literal_combinations(self, literals):
        """
        Processes combinations of literals to store their indices in the internal data structure.

        Parameters:
        - literals (list): List of literals to process.
        """
        self.word_combination_index_literal_literal = [[None] * len(literals) for _ in range(len(literals))]
        # Use a dictionary for quick lookup and storage
        combination_index = len(self.word_combinations)

        for id1, literal1 in enumerate(literals):
            for id2 in range(id1 + 1, len(literals)):
                literal2 = literals[id2]
                # Use a sorted tuple for consistent ordering
                froz = frozenset(split_string(literal1) + split_string(literal2))
                # Check if the combination is already in the dictionary
                pos = self.word_combinations.get(froz, -1)
                if pos == -1:
                    pos = combination_index
                    combination_index += 1

                    self.add_word_combination(froz, pos)

                # Update the matrix with the index of the combination
                self.word_combination_index_literal_literal[id1][id2] = pos
                self.word_combination_index_literal_literal[id2][id1] = pos

    @time_function
    def setup_data_structures(self, papers):
        """
        Initializes the data structures for storing word occurrences and minimum distances.

        Parameters:
        - papers (list): List of paper identifiers.
        """
        self.word_occurrences_in_papers = [[SortedSet() for _ in self.words] for _ in papers]
        self.min_distances = np.full((len(papers), len(self.word_combinations)), -2, dtype=int)

    @time_function
    def find_occurrences_in_texts(self, papers, paper_full_text):
        """
        Finds and stores the occurrences of each word in the full text of each paper.

        Parameters:
        - papers (list): List of paper identifiers.
        - paper_full_text (dict): Mapping from paper identifiers to their full text file paths.
        """
        for paperID, paper in enumerate(papers):
            if paper in paper_full_text:
                with open(paper_full_text[paper], 'r', encoding="utf8") as f:
                    text = f.read().lower()
                    for wordID, word in enumerate(self.words):
                        self.find_and_add_word_occurrences(paperID, wordID, word, text)
            else:
                print(f"Paper {paper} has no full text available.")

    def find_and_add_word_occurrences(self, paperID, wordID, word, text):
        """
        Finds and adds the occurrences of a word in a paper's text to the internal data structure.

        Parameters:
        - paperID (int): The index of the paper in the internal list.
        - wordID (int): The index of the word in the internal list.
        - word (str): The word to find occurrences of.
        - text (str): The full text of the paper.
        """
        pos = text.find(word)
        while pos != -1:
            self.word_occurrences_in_papers[paperID][wordID].add(pos)
            pos = text.find(word, pos + 1)

    @time_function
    def optimize_data(self):
        """
        Optimizes the internal data structures for faster access and smaller memory footprint.
        """
        # self.word_combination_index_literal_literal = np.array(self.word_combination_index_literal_literal, dtype=int)
        for paperID in range(len(self.papers)):
            for wordID in range(len(self.words)):
                # self.word_occurrences_in_papers[paperID][wordID] = SortedSet(self.word_occurrences_in_papers[paperID][wordID])
                self.word_occurrences_in_papers[paperID][wordID] = [(x, wordID) for x in self.word_occurrences_in_papers[paperID][wordID]]

    @time_function
    def save_to_file(self, config, path=None, name = "pos_in_paper", check_size=False, min_distances=False):
        """
        Saves the internal data structures to files for persistence.

        Parameters:
        - path (str, optional): The base path for the output files. Defaults to "pos_in_paper".
        """
        if path is None:
            path = config.get_output_path()
        filepath = os.path.join(path, name + '.json')
        
        data = {
            "papers": self.papers,
            "literals": self.literals,
            "words": self.words,
            "word_len": self.word_len,
            "word_combinations": {"_".join(key): value for key, value in self.word_combinations.items()},
            "word_combination_lists": self.word_combination_lists,
            "word_combination_index_literal_literal": self.word_combination_index_literal_literal,
            # Convert SortedSets to lists for JSON serialization
            "word_occurrences_in_papers": [[list(occurrences) for occurrences in paper] for paper in self.word_occurrences_in_papers],
        }
        if min_distances:
            data["min_distances"] = self.min_distances.tolist()
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False)

        if check_size:
            for key, value in data.items():
                # Construct the file name for each sub-dictionary
                filepath = os.path.join(path, f"{name}_{key}.json")
                with open(filepath, 'w', encoding='utf-8') as f:
                    json.dump(value, f, ensure_ascii=False)

    @time_function
    def load_from_file(self, config, path=None, name="pos_in_paper"):
        """
        Loads the internal data structures from files.

        Parameters:
        - path (str, optional): The base path for the input files. Defaults to "pos_in_paper".
        """
        if path is None:
            path = config.get_output_path()
        filepath = os.path.join(path, name + '.json')
        
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.papers = data["papers"]
        self.literals = data["literals"]
        self.words = data["words"]
        self.word_len = data["word_len"]
        self.word_combinations = {frozenset(split_string(key)): i for i, key in enumerate(data["word_combinations"])}
        self.word_combination_lists = data["word_combination_lists"]
        self.word_combination_index_literal_literal = data["word_combination_index_literal_literal"]
        if "min_distances" in data:
            self.min_distances = np.array(data["min_distances"])
        else:
            self.setup_data_structures(self.papers)
        self.word_occurrences_in_papers = data["word_occurrences_in_papers"]
        
    # def set_min_distance(self, paper, literals, distance):
    def set_min_distance(self, paperID, word_combination_id, distance):
        """
        Sets the minimum distance between occurrences of literals in a paper.

        Parameters:
        - paper (str): The identifier for the paper.
        - literals (list): A list of literals for which the distance is calculated.
        - distance (int): The calculated minimum distance.
        """
        # key = frozenset(literals)
        # if paper not in self.min_distances:
        #     self.min_distances[paper] = {}
        self.min_distances[paperID][word_combination_id] = distance

    @time_function
    def calculate_all_possible(self, start_at=0, stop_at=None):
        """
        Calculates the minimum distances between all possible combinations of literals in all papers.
        """
        for p in range(len(self.papers)):
            if stop_at is not None and p >= stop_at:
                break
            if p < start_at:
                continue
            for i in range(len(self.literals)):
                for j in range(i + 1, len(self.literals)):
                    # get word_combination_index_literal_literal
                    self.find_min_distance_by_id(p, self.word_combination_index_literal_literal[i][j])

    def find_min_distance_by_id(self, paperID, wcID, allow_call=True):
        """
        Finds the minimum distance between occurrences of literals in a paper.

        Parameters:
        - paper (str): The identifier for the paper.
        - literals (list): A list of literals for which the distance is to be found.
        - allow_call (bool): Flag to allow recursive call to get_min_distance.

        Returns:
        - int: The minimum distance between occurrences of the literals.
        """
        distance = self.min_distances[paperID][wcID]

        if distance == -1:
            # word combination not found in paper
            return -1
        if distance == -2:
            # calculate distance
            pass
        else:
            return distance
        
        list_ids = self.word_combination_lists[wcID]
        list_ids_map = {list_ids[i]: i for i in range(len(list_ids))}
        lit_len = [self.word_len[i] for i in list_ids]
        # literals = [list(self.words)[i] for i in list_ids]
        
        for i in list_ids:
            if not self.word_occurrences_in_papers[paperID][i]:
                self.set_min_distance(paperID, wcID, -1)
                return -1
        # Outsourced to optimize
        # inputs = [[(x, i) for x in self.word_occurrences_in_papers[paperID][wordID]] for i, wordID in enumerate(list_ids)]
        inputs = [self.word_occurrences_in_papers[paperID][wordID] for wordID in list_ids]

        indices = [lst[0][0] for lst in inputs]
        best = float('inf')
        self.set_min_distance(paperID, wcID, 15)

        for item in sorted(sum(inputs, [])):
            indices[list_ids_map[item[1]]] = item[0]
            arr_min = min(indices)
            best = min(max(indices) - arr_min - lit_len[indices.index(arr_min)], best)
        self.set_min_distance(paperID, wcID, best)
        return best
    

In [None]:
class NewPosInPaper(PosInPaper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize any additional attributes for EnhancedPosInPaper here

    # # Example of redefining a specific function from PosInPaper
    @time_function
    def find_min_distance_by_id(self, paperID, wcID, allow_call=True):
        raise NotImplementedError("This function has not been implemented yet.")

In [None]:
pos_in_paper = PosInPaper()
pos_in_paper.load_from_file(config)

instances = []
with open("instances.txt", 'r', encoding="utf-8") as f:
    instances = f.read().split("\n")

In [None]:
import copy
start_at=0
stop_at=10
stop_at=None

# instances = [
#     'knowledge based engineering',
#     'engine analysis'
# ]

print("Now testing new vs. old")

old_version = PosInPaper()
old_version.load_from_file(config)

new_version = NewPosInPaper()
new_version.load_from_file(config)

old_version.calculate_all_possible(start_at=start_at, stop_at=stop_at)
new_version.calculate_all_possible(start_at=start_at, stop_at=stop_at)
error_count = 0
for paperID in range(len(pos_in_paper.papers)):
    for wcID in range(len(pos_in_paper.word_combinations)):
        if old_version.min_distances[paperID][wcID] != new_version.min_distances[paperID][wcID]:
            error_count += 1
            if error_count < 5:
                print(f"Paper {paperID}, word combination {wcID}: {old_version.min_distances[paperID][wcID]} vs. {new_version.min_distances[paperID][wcID]}")
if error_count:
    raise Exception(f"{error_count} errors found!")
else:
    print("No errors found!")
    
print("Now testing only old")
old_version = PosInPaper()
old_version.load_from_file(config)
old_version.calculate_all_possible(start_at=start_at, stop_at=stop_at)

print("Now testing the new version")
new_version = NewPosInPaper()
new_version.load_from_file(config)
new_version.calculate_all_possible(start_at=start_at, stop_at=stop_at)

# Debug barrier
raise Exception("This is the end of the script. The following code is not executed.")