In [None]:
#| default_exp markdown.obsidian.personal.machine_learning.notation_linking

# markdown.obsidian.personal.machine_learning.notation_linking
> Functions for gathering and processing data to train and for using  ML models to link notation notes with one another. 

In a `trouver`-styled `Obsidian.md` vault, notation notes summarize the definitions of various notations introduced in excerpts of mathematical text. They are also written to quickly indicate where each notation is introduced. 

Stating the definition of a notation often depends on other notations, usually defined before the notation itself. In practice, a notation note lists links to other notation notes that it depends on. Some of the links are made within the "content" of the notation note. Some others are made in a bulleted list of links.


The following is an example of an notation note:

TODO: insert example

This module contains functions to train and use models to predict whether one notation note depends on another. 

In [None]:
#| export
from os import PathLike
from pathlib import Path
import random
import re
from typing import Literal, NamedTuple, Optional, TypedDict, Union

from scipy.spatial.distance import cosine
from sentence_transformers import SentenceTransformer
import transformers
from transformers.pipelines.text_classification import TextClassificationPipeline


from trouver.markdown.markdown.file import MarkdownFile, MarkdownLineEnum
from trouver.markdown.obsidian.links import MARKDOWNLINK_CAPTURE_PATTERN, LinkFormatError
from trouver.markdown.obsidian.personal.information_notes import reference_of_information_note
from trouver.markdown.obsidian.personal.notation.in_standard_information_note import notation_notes_linked_in_see_also_section
from trouver.markdown.obsidian.personal.notation.parse import NotationNoteParsed, parse_notation_note, _notat_str
from trouver.markdown.obsidian.personal.note_processing import process_standard_information_note
from trouver.markdown.obsidian.personal.note_type import note_is_of_type, type_of_note, PersonalNoteTypeEnum
from trouver.markdown.obsidian.personal.notes import notes_linked_in_notes_linked_in_note

from trouver.helper.latex import (
        dollar_sign_manipulation, random_char_modification, remove_math_keywords, random_word_removal,
        random_latex_command_removal, augment_text, change_font_styles_at_random,
        remove_font_styles_at_random, change_greek_letters_at_random
)
from trouver.markdown.obsidian.personal.machine_learning.notation import (
    NotationNoteData, NotationLinkingDataPoint, data_point_to_notation_note_data_pair, 
    notation_note_data_pair_to_data_point)
from trouver.markdown.obsidian.personal.machine_learning.notation_summarization import _notation_note_has_auto_summary_tag
from trouver.markdown.obsidian.vault import VaultNote

In [None]:
from fastcore.test import *
from unittest.mock import patch, Mock, MagicMock
from unittest.mock import Mock, patch, MagicMock, call
import numpy as np

## Gather ML data from notation notes

Gathering data is done for each "reference folder", which is a group of notes belonging to a single mathematical text. Each data point consists of 1. An ordered pair of notation notes, 2. Some miscellaneous, but precisely quantifiable information about the relationship between the notation notes, and 3. Whether or not the first notation note depends on the second (`True`/`False`). 

Each data point should be processed in such a way that training/predictions are 

When gathering such data, `False` data points are only partially gathered --- this is to prevent bias against `True` data points, which are relatively rare.

In [None]:
#| export
class NotationNoteInitData(NamedTuple):
    """
    A type of cached data for datapoints involving a notation note.
    """
    notation_note: VaultNote # The notation note
    parsed: Union[NotationNoteParsed, None] # The output of `parse_notation_note` applied to `notation_note`
    main_note_content: Union[str, None] # The content of the main note of `notation_note`, i.e. the output of `process_standard_information_note`(MarkdownFile.from_vault_note(main_of_notation_note))`. If `None`, then this needs to be computed "on-the-fly".`
    linked_notat_notes: set[str] # The names of the notation notes that `notation_note` links to.


In [None]:
#| export
def _linked_note_names_from_content(
        content: str) -> list[str]:
    linked_note_names = []
    for match in re.findall(MARKDOWNLINK_CAPTURE_PATTERN, content):
        link_name = match[1]
        if link_name.endswith('.md'):
            link_name = link_name[:-3]
        linked_note_names.append(link_name)
    return linked_note_names
        


In [None]:
# Test basic link extraction
content = "[[MathNote]] and [[PhysicsNote.md]] with [[Bio|BiologyAlias]]"
result = _linked_note_names_from_content(content)
test_eq(sorted(result), [])

content = "[text](MathNote.md) and [more text](PhysicsNote.md) with [BiologyAlias](Bio.md)"
result = _linked_note_names_from_content(content)
test_eq(sorted(result), ['Bio', 'MathNote', 'PhysicsNote'])

# Test empty content
result = _linked_note_names_from_content("")
test_eq(result, [])

# Test mixed formats
content = "See [link](url) and ![img](src) but keep [[ValidNote]]"
result = _linked_note_names_from_content(content)
test_eq(result, ['url', 'src'])

content = "See [text link](url) ![image](image.png) [ValidNote](ValidNote.md)"
result = _linked_note_names_from_content(content)
test_eq(sorted(result), ['ValidNote', 'image.png', 'url'])


In [None]:
#| export
def _linked_notat_note_names_from_content(
        content: str, vault: PathLike) -> list[str]:
    linked_note_names = _linked_note_names_from_content(content)
    linked_notation_note_names = []
    for linked_note_name in linked_note_names:
        note = VaultNote(vault, name=linked_note_name, update_cache=False)
        if note_is_of_type(note, PersonalNoteTypeEnum.NOTATION_NOTE):
            linked_notation_note_names.append(note.name)
    return linked_notation_note_names



In [None]:
with (patch('__main__.VaultNote') as mock_vault_note,
      patch('__main__.note_is_of_type') as mock_note_check):

    content = r"[$\operatorname{Gal}$](notation_Gal_galois_group.md) [[PhysicsNote]] [$$](blah)"
    mock_note_check.side_effect = [True, True]  # Third link ("blah") is not a notation note
    
    # Create Mock objects with specific name attributes
    mock_note1 = Mock(name='VaultNote')
    mock_note1.name = 'notation_Gal_galois_group'
    mock_note2 = Mock(name='VaultNote')
    mock_note2.name = 'blah'
    
    # Configure side effect to return different mocks for different inputs
    mock_vault_note.side_effect = lambda vault, name, **kw: {
        'notation_Gal_galois_group': mock_note1,
        'blah': mock_note2
    }[name]

    result = _linked_notat_note_names_from_content(content, "/fake/vault")
    test_eq(sorted(result), ['blah', 'notation_Gal_galois_group'])

In [None]:
#| export
def _linked_notat_note_names_from_parsed(
        notat_note_parsed: NotationNoteParsed,
        vault: PathLike
        ) -> set[str]:
    """
    Helper function to `_init_args_for_data_from_notation_notes`

    Since some notation note links are within the content of a notation note
    rather than at the end as an unordered list, we need to obtain them
    separately.

    """
    # TODO: use this function for initializing notation note linking data in the `data_for_reference` function.
    linked_notat_note_names = []
    for _, notat_note in notat_note_parsed.linked_notation_notes:
        if notat_note.endswith('.md'):
            notat_note = notat_note[:-3]
        linked_notat_note_names.append(notat_note)
    
    linked_notat_note_names = set(linked_notat_note_names)
    linked_notat_note_names.update(
        _linked_notat_note_names_from_content(
            str(notat_note_parsed.main_content_markdown_file), vault))
    return linked_notat_note_names

In [None]:
#| export
# TODO: test
def _init_args_for_data_from_notation_notes(
        vault: PathLike,
        origin_init_data: NotationNoteInitData, # The cached data for the notation note which potentially uses the notation introduced by the `relied_notation_note`. In particular, there potentially ought to be a link to `relied_notation_note` in `origin_notation_note`.
        relied_init_data: NotationNoteInitData, # The cached data for the notation note `relied_notation_note`. 
        reference_name: Optional[str] = None, # The name of the reference folder in the vault from which the two notation notes comes from. If `None`, this is computed "on-the-fly" based on the reference of the main note of main_of_origin, see `reference_of_information_note`
        information_notes_of_reference: Optional[list[VaultNote]] = None, # The standard information notes for the reference folder in order (as arranged in the index notes of the reference folder)
    ) -> tuple[NotationNoteParsed, NotationNoteParsed, str, str, str, set[str], list[VaultNote]]:
    origin_notation_note, origin_parsed, main_of_origin_content, linked_notat_notes_in_origin = origin_init_data
    relied_notation_note, relied_parsed, main_of_relied_content, linked_notat_notes_in_relied = relied_init_data

    if origin_parsed is None:
        origin_parsed: NotationNoteParsed = parse_notation_note(origin_init_data.notation_note)
    if relied_parsed is None:
        relied_parsed: NotationNoteParsed = parse_notation_note(relied_init_data.notation_note)

    main_of_origin, main_of_relied = origin_parsed.name_of_main_note, relied_parsed.name_of_main_note
    if reference_name is None and main_of_origin is not None:
        reference_name = reference_of_information_note(main_of_origin)

    if main_of_origin_content is None and main_of_origin is not None:
        main_of_origin_content = process_standard_information_note(
            MarkdownFile.from_vault_note(main_of_origin))
    if main_of_relied_content is None and main_of_relied is not None:
        main_of_relied_content = process_standard_information_note(
            MarkdownFile.from_vault_note(main_of_relied))

    if linked_notat_notes_in_origin is None:
        linked_notat_notes_in_origin = _linked_notat_note_names_from_parsed(origin_parsed, vault)

    if information_notes_of_reference is None:
        # TODO Implement the initialization of `information_notes_of_reference`.
        information_notes_of_reference = information_notes_of_reference
    return (
        origin_parsed, relied_parsed, reference_name, main_of_origin_content,
        main_of_relied_content,
        linked_notat_notes_in_origin,
        information_notes_of_reference)



In [None]:
#| export
# TODO: test
def _origin_links_to_relied(
        linked_notat_notes: list[tuple], # One of the outputs of `parse_notation_note`
        origin_content: MarkdownFile,  # One of the outputs of `parse_notation_note`
        relied_notation_note: VaultNote
        ) -> bool:
    r"""Returns `True` if the origin notation note links to `relied_notation_note`.

    Helper function to `notat_linking_data_from_notation_notes`.
    
    This function checks both the trailing list of links as well as any links within
    the content of the origin notation note.
    """

    linked_note_names = []
    for _, notat_note_name in linked_notat_notes:
        if notat_note_name.endswith('.md'):
            notat_note_name = notat_note_name[:-3]
        linked_note_names.append(notat_note_name)
        # if notat_note_name == relied_notation_note.name:
        #     return True
    linked_note_names.extend(
        _linked_note_names_from_content(str(origin_content)))
        
    for note_name in linked_note_names:
        if note_name == relied_notation_note.name:
            return True
    return False


def _adjust_content(
        include_content: bool, 
        content: str,
        meta: Union[dict, None],
        ) -> str:
    r"""
    Helper function to `notat_linking_data_from_notation_notes`.

    Sometimes, the content of a notation note is autogenerated by an ML model
    and hence can be unreliable. This function is used to exclude
    autogenerated content.
    """
    if not include_content or (
            meta is not None and 'tags' in meta and '_auto/notation_summary' in meta['tags']):
        return ""
    return content



In [None]:
#| export
# TODO: test

def notat_linking_data_from_notation_notes(
        origin_notation_note_data: NotationNoteInitData,
        relied_notation_note_data: NotationNoteInitData,
        include_origin_content: bool, # If `True`, include the content of `origin_notation_note`, i.e. a summary of what the notation introduced by this note means.
        include_relied_content: bool, # If `True`, include the content of `relied_notation_note`, i.e. a summary of what the notation introduced by this note means.
        reference_name: Optional[str] = None, # The name of the reference folder in the vault from which the two notation notes comes from. If `None`, this is computed "on-the-fly" based on the reference of the main note of main_of_origin, see `reference_of_information_note`
        information_notes_of_reference: Optional[list[VaultNote]] = None, # The standard information notes for the reference folder in order (as arranged in the index notes of the reference folder)
        vault: PathLike = None, # If `None`, the vault of `origin_notation_note_data` is used.
    ) -> NotationLinkingDataPoint:
    """
    Obtain data for a single pair of notation notes.

    Assumes that

    - `origin_notation_note` and `relied_notation_note` have the same `vault`
      attribute.
    - `origin_parsed` and `relied_parsed` are respectively the outputs of
      `parse_notation_note` applied to `reference_name` if specified.
    - `reference_name` is the correct output of `reference_of_information_note`
      applied to `main_of_origin` and that this output is the same as that when
      applied to `main_of_relied`.
    - `main_of_origin_content` and `main_of_relied_content` are the outputs of
      `process_standard_information_note(MarkdownFile.from_vault_note(main_of_origin))`
      and
      `process_standard_information_note(MarkdownFile.from_vault_note(main_of_relied))`
      respectively if they are specified.
    - `information_notes_of_reference` correctly lists the standard information
      notes from the reference in the vault of name `reference_name`.
    """
    if not vault:
        vault = origin_notation_note_data.notation_note.vault

    (origin_parsed, relied_parsed, reference_name, main_of_origin_content,
     main_of_relied_content, linked_notat_notes_in_origin, information_notes_of_reference
     ) = _init_args_for_data_from_notation_notes(
         vault,
         origin_notation_note_data,
         relied_notation_note_data,
         reference_name,
         information_notes_of_reference)
     
    (origin_meta, origin_notat_str, main_of_origin_name,
     origin_content, linked_notat_notes
     ) = origin_parsed
    (relied_meta, relied_notat_str, main_of_relied_name, relied_content, _
     ) = relied_parsed

    origin_content, relied_content = str(origin_content), str(relied_content)
    origin_content = _adjust_content(include_origin_content, origin_content, origin_meta)
    relied_content = _adjust_content(include_relied_content, relied_content, relied_meta)

    origin_notation_note_name=origin_notation_note_data.notation_note.name
    relied_notation_note_name=relied_notation_note_data.notation_note.name
    return NotationLinkingDataPoint(
        # origin_notation_note_name=origin_notation_note.name,
        origin_notation_note_name=origin_notation_note_name,
        main_of_origin_notation_note_name=main_of_origin_name,
        origin_notation_note_content=origin_content,
        processed_main_of_origin_content=main_of_origin_content,
        latex_in_original_or_summarized_in_origin=_notat_str(origin_meta, origin_notat_str),
        summarized_in_origin=origin_notat_str,
        reference_of_origin=reference_name,

        # relied_notation_note_name=relied_notation_note.name,
        relied_notation_note_name=relied_notation_note_name,
        main_of_relied_notation_note_name=main_of_relied_name,
        relied_notation_note_content=relied_content,
        processed_main_of_relied_content=main_of_relied_content,
        latex_in_original_or_summarized_in_relied=_notat_str(relied_meta, relied_notat_str),
        summarized_in_relied=relied_notat_str,
        reference_of_relied=reference_name,

        origin_links_to_relied=relied_notation_note_name in linked_notat_notes_in_origin
        # origin_links_to_relied=_origin_links_to_relied(linked_notat_notes, origin_content, relied_notation_note)
        )



In [None]:
#| export
def _positive_pairs_of_notation_notes(
        confirmed_summary_notat_note_names: list[str],
        notat_notes_and_linked_notat_notes: dict[str, set[str]],
        # notat_notes_and_parsed: dict[str, NotationNoteParsed],
        vault: PathLike
        ) -> list[tuple[str, str]]:
    r"""Return the pairs `(<origin_notat_note_name>, <linked_notat_note_name>)`
    where `origin_notat_note_name` is the name of a notation note whose notation
    summary is "confirmed" (i.e. written and not autogenerated) and where
    `linked_notat_note_name` is the name of a notation note linked by the
    notation note with name `origin_notat_note_name`.
    """
    positive_linked_notat_note_pairs = []
    
    for notat_note_name in confirmed_summary_notat_note_names:
        if not notat_note_name in notat_notes_and_linked_notat_notes:
            continue
        for linked_notat_note_name in notat_notes_and_linked_notat_notes[notat_note_name]:
            positive_linked_notat_note_pairs.append((notat_note_name, linked_notat_note_name))

    return positive_linked_notat_note_pairs
    


In [None]:
#| export
def _positive_data_points(
        reference_index_note: VaultNote,
        positive_linked_notat_note_pairs: list[tuple[str, str]],
        notat_notes_and_parsed: dict[str, NotationNoteParsed],
        notat_notes_and_linked_notat_notes: dict[str, set[str]],
        info_notes_and_processed_content: dict[str, str]
        # ) -> list[tuple]:
        ) -> list[NotationLinkingDataPoint]:
    r"""Gather data points for the pairs `(<origin_notat_note_name>, <linked_notat_note_name>)`
    where `origin_notat_note_name` is the name of a notation note whose notation
    summary is "confirmed" (i.e. written and not autogenerated) and where
    `linked_notat_note_name` is the name of a notation note linked by the
    notation note with name `origin_notat_note_name`.
    """
    vault = reference_index_note.vault
    data_points: list[NotationLinkingDataPoint] = []

    # for name, _ in notat_notes_and_parsed.items():
    #     print(name)
    
    for origin_notat_note_name, linked_notat_note_name in positive_linked_notat_note_pairs:
        try:
            origin_notat_note = VaultNote(vault, name=origin_notat_note_name, update_cache=False)
            linked_notat_note = VaultNote(vault, name=linked_notat_note_name, update_cache=False)
            main_of_origin_name = notat_notes_and_parsed[origin_notat_note_name].name_of_main_note
            main_of_linked_name = notat_notes_and_parsed[linked_notat_note_name].name_of_main_note
            if (main_of_origin_name is None or main_of_linked_name is None
                    or main_of_origin_name not in info_notes_and_processed_content
                    or main_of_linked_name not in info_notes_and_processed_content):
                continue

            for include_origin_content, include_linked_content in [(True, True), (True, False), (False, True), (False, False)]:
                data_points.append(notat_linking_data_from_notation_notes(
                    # TODO: modify these NotationNoteInitData constructions to 
                    # include `linked_notat_notes`
                    NotationNoteInitData(
                        origin_notat_note,
                        notat_notes_and_parsed[origin_notat_note_name],
                        info_notes_and_processed_content[main_of_origin_name],
                        notat_notes_and_linked_notat_notes.get(origin_notat_note_name, set()),
                        ),
                    NotationNoteInitData(
                        linked_notat_note,
                        notat_notes_and_parsed[linked_notat_note_name],
                        info_notes_and_processed_content[main_of_linked_name],
                        notat_notes_and_linked_notat_notes.get(linked_notat_note_name, set()),
                        ),
                    include_origin_content,
                    include_linked_content, 
                    reference_index_note.name[7:], # name starts with '_index_',
                    None # TODO: Pass proper argument for information_notes_of_reference
                    ))
        except Exception as e:
            print('An error has occurred while gathering data:')
            print(e)
    return data_points



In [None]:
#| export
def _sample_data_points(
        reference_index_note: VaultNote,
        num_samples: int, 
        # info_notes: list[VaultNote],
        notat_notes: list[VaultNote],
        confirmed_summary_notat_note_names: list[VaultNote],
        notat_notes_and_parsed: dict[str, NotationNoteParsed],
        notat_notes_and_linked_notat_notes: dict[str, set[str]],
        info_notes_and_processed_content: dict[str, str]
        # ) -> list[tuple]:
        ) -> list[NotationLinkingDataPoint]:
    r"""Randomly sample pairs `(<origin_notat_note_name>, <relied_notat_note_name>)`
    """
    # num_samples = min(num_samples, len(confirmed_summary_notat_note_names))
    origins = random.choices(confirmed_summary_notat_note_names, k=num_samples)
    origins = [VaultNote(reference_index_note.vault, name=name, update_cache=False) for name in origins]
    relieds = random.choices(notat_notes, k=num_samples)
    data_points: list[NotationLinkingDataPoint] = []
    for origin, relied in zip(origins, relieds):
        main_of_origin_name = notat_notes_and_parsed[origin.name].name_of_main_note
        main_of_relied_name = notat_notes_and_parsed[relied.name].name_of_main_note
        if (main_of_origin_name is None or main_of_relied_name is None
                or main_of_origin_name not in info_notes_and_processed_content
                or main_of_relied_name not in info_notes_and_processed_content):
            continue
        for include_origin_content, include_relied_content in [(True, True), (True, False), (False, True), (False, False)]:
            data_points.append(notat_linking_data_from_notation_notes(
                # TODO: modify these NotationNoteInitData constructions to 
                # include `linked_notat_notes`
                NotationNoteInitData(
                    origin,
                    notat_notes_and_parsed[origin.name],
                    info_notes_and_processed_content[main_of_origin_name],
                    notat_notes_and_linked_notat_notes.get(origin.name, set()),
                    ),
                NotationNoteInitData(
                    relied,
                    notat_notes_and_parsed[relied.name],
                    info_notes_and_processed_content[main_of_relied_name],
                    notat_notes_and_linked_notat_notes.get(relied.name, set()),
                    ),
                include_origin_content,
                include_relied_content, 
                reference_index_note.name[7:], 
                None  # TODO: Pass proper argument for information_notes_of_reference
                ))
    return data_points

In [None]:
#| export
def data_points_for_reference(
        reference_index_note: VaultNote, # The index note for the reference from which to draw the data.
        return_notation_note_parsings: bool = False, # If `True`, return the outputs of `parse_notation_note` applied to the notation notes in the reference folder 
        ) -> Union[list[NotationLinkingDataPoint], tuple[list[NotationLinkingDataPoint], dict[str, NotationNoteParsed]]]:
        # ) -> Union[list[tuple], tuple[list[tuple], dict[str, NotationNoteParsed]]]:
    r"""Compile data points for notation note linking based on the information
    notes and notation notes in a reference folder.

    "Positive" linking data points are relatively rare in comparison to "Negative"
    data points, so "Negative" data points are randomly sampled (although the random
    samples will redundantly include "Positive" data as well).

    Note that it makes sense to draw data exclusively within each "reference" ---
    notations tend to have dependencies within the same mathematical text.

    **Returns**
    - `Union[list[tuple], tuple[list[tuple], dict[str, tuple]]]`
        - Either 1. a list of tuples --- in this case, each tuple is a
          "data point" and is an output of `notat_linking_data_from_notation_notes` ---
          or 2. the list of tuples along with a dict whose keys are
          the names of the notation notes and whose values are the
          outputs of `parse_notation_note` applied to these notation notes.
    """
    # TODO: initialize the data of the notation notes that each notation note links to.
    info_notes = notes_linked_in_notes_linked_in_note(reference_index_note, as_dict=False)
    info_notes = [note for note in info_notes if note.exists() and note_is_of_type(note, PersonalNoteTypeEnum.STANDARD_INFORMATION_NOTE)]
    notat_notes: list[VaultNote] = []
    for info_note in info_notes:
        notat_notes.extend(notation_notes_linked_in_see_also_section(info_note, info_note.vault))

    vault = reference_index_note.vault
    info_notes_and_processed_content = {}
    for note in info_notes:
        try:
            info_notes_and_processed_content[note.name] = str(process_standard_information_note(MarkdownFile.from_vault_note(note), vault))
        except TypeError as e:
            print(f"An error occurred while trying to process the following note: {note.name}")
            print(e)
            info_notes_and_processed_content[note.name] = note.text()
        except LinkFormatError as e:
            print(r"A Link formatting error occurred while trying to process the following note: {note.name}")
            print(e)
            info_notes_and_processed_content[note.name] = note.text()

    notat_notes_and_parsed: dict[str, NotationNoteParsed] = {
        notat_note.name: parse_notation_note(notat_note, process_notation_note_content=True) for notat_note in notat_notes}
    notat_notes_and_parsed = _filter_notat_notes_with_auto_generated_notat_links(
        notat_notes_and_parsed)
    
    notat_notes_and_linked_notat_notes: dict[str, set[str]] = {
        notat_note.name: _linked_notat_note_names_from_parsed(
            notat_notes_and_parsed[notat_note.name], vault)
        for notat_note in notat_notes}

    # Notation notes whose summaries are written and not auto-generated
    confirmed_summary_notat_note_names = [
        notat_note for notat_note, parsed in notat_notes_and_parsed.items()
        if not _notation_note_has_auto_summary_tag(parsed[0])]

    # Get all positive pairs of notation notes
    positive_linked_notat_note_pairs = _positive_pairs_of_notation_notes(
        confirmed_summary_notat_note_names,
        notat_notes_and_linked_notat_notes,
        vault)
    data_points: list[NotationLinkingDataPoint] = _positive_data_points(
        reference_index_note, positive_linked_notat_note_pairs,
        notat_notes_and_parsed, notat_notes_and_linked_notat_notes,
        info_notes_and_processed_content)
    data_points.extend(_sample_data_points(
        reference_index_note, len(data_points)*4,
        notat_notes, confirmed_summary_notat_note_names,
        notat_notes_and_parsed, notat_notes_and_linked_notat_notes,
        info_notes_and_processed_content))
    if return_notation_note_parsings:
        return data_points, notat_notes_and_parsed
    else:
        return data_points


def _filter_notat_notes_with_auto_generated_notat_links(
        notat_notes_and_parsed: dict[str, NotationNoteParsed]
        ) -> dict[str, NotationNoteParsed]:
    return {
        name: parsed for name, parsed in notat_notes_and_parsed.items()
        if not parsed.yaml_frontmatter_meta or 'tags' not in parsed.yaml_frontmatter_meta or '_auto/notation_notes_linked' not in parsed.yaml_frontmatter_meta['tags']}




In [None]:
#| export
def text_from_note_data(
        note_data: NotationNoteData,
        separation_token: str = '[SEP]',
        ) -> str:
    text = (
        f"latex_in_original_or_summarized: {note_data['latex_in_original_or_summarized']}\n\n{separation_token}\n\n"
        f"summarized: {note_data['summarized']}\n\n{separation_token}\n\n"
        f"main_note_content: {note_data['main_note_content']}\n\n{separation_token}\n\n"
        f"processed_content: {note_data['processed_content']}"
        )
    return text

def text_from_data_point(
        data_point: NotationLinkingDataPoint,  # An output of `notat_linking_data_from_notation_notes`.
        separation_token: str = '[SEP]',
        ) -> str:
    r"""
    Format a data point to present it as a str.
    """
    origin_data, relied_data = data_point_to_notation_note_data_pair(data_point)
    origin_text = text_from_note_data(origin_data, separation_token)
    relied_text = text_from_note_data(relied_data, separation_token)
    return f"origin_data:\n\n{separation_token}\n\n{origin_text}\n\n{separation_token}\n\nrelied_data:\n\n{separation_token}\n\n{relied_text}"

def _content_relied(
        main_of_origin_content: str,
        main_of_relied_content: str,
        ) -> str:
    if main_of_origin_content == main_of_relied_content:
        return f"Content for main note of relied_notation_note: same as that of main note of origin_notation_note" 
    else:
        return f"Content for main note of relied_notation_note: {main_of_relied_content}"

## Data augmentation

In [None]:
#| export
def augment_notation_linking_data(
        datapoint: NotationLinkingDataPoint,
        num_augmentation_sets: int = 1, # Each augmentation set consists of an augmentation with low, medium, and high probability modifications.
        seed: Optional[int] = None
        ) -> list[NotationLinkingDataPoint]:
    r"""
    Augment a given datapoint for Notation Linking
    """
    augmented_datapoints: list[NotationLinkingDataPoint] = []
    if seed is not None:
        random.seed(seed)
    for _ in range(num_augmentation_sets):
        augmented_datapoints.append(
            _augment_notation_linking_data_once(datapoint, 'low'))
        augmented_datapoints.append(
            _augment_notation_linking_data_once(datapoint, 'mid'))
        augmented_datapoints.append(
            _augment_notation_linking_data_once(datapoint, 'high'))
    return augmented_datapoints


def _augment_notation_linking_data_once(
        datapoint: NotationLinkingDataPoint,
        modification: Literal['low', 'mid', 'high'],
        ) -> NotationLinkingDataPoint:
    methods = [
        # (push_dollar_signs,0.2),
        (remove_font_styles_at_random, 0.1), (change_font_styles_at_random, 0.2), (change_greek_letters_at_random, 0.1), 
        (remove_math_keywords,0.1), (random_latex_command_removal,0.2),
        (random_word_removal,0.1), (dollar_sign_manipulation,0.05),
        (random_char_modification,0.001)]
    if modification == 'low':
        method_inclusion_chance = 0.3
        scale = 0.5
    elif modification == 'mid':
        method_inclusion_chance = 0.5
        scale = 1.0
    else:
        method_inclusion_chance = 0.8
        scale = 1.5
    
    random_methods = []
    def create_method(method, p, scale):
        return lambda x: method(x, p=p*scale)
    for method, p in methods:
        if random.random() < method_inclusion_chance:
            random_methods.append(create_method(method, p, scale))

    augmented_datapoint = datapoint.copy()
    for key in ['origin_notation_note_content', 'processed_main_of_origin_content', 'latex_in_original_or_summarized_in_origin', 'summarized_in_origin', 'processed_main_of_relied_content', 'relied_notation_note_content', 'latex_in_original_or_summarized_in_relied', 'summarized_in_relied']:
        augmented_datapoint[key] = augment_text(datapoint[key], random_methods)

    return augmented_datapoint
    

## Use the trained model 

In [None]:
#| export
def prediction_by_model(
        origin_data: NotationNoteData,
        relied_data: NotationNoteData,
        pipeline: Union[TextClassificationPipeline, SentenceTransformer],
        as_single_float: bool = True, # If `True`, return a float score of how likely it is that origin should link to relied.
        threshold: float = 0.5, # The threshold for determining whether origin should link to relied; should be a value between 0.0 and 1.0. A prediction exceeding this threshold should correspond to origin linking to relied. Ideally, this argument should be specified if `pipeline` is a `SentenceTransformer`
        ) -> Union[float, dict[Literal['label', 'score'], Union[bool, float]]]: # A float score between 0.0 and 1.0 of how likely it is that origin should link to relied (0.0 means unlikely, 1.0 means likely) or a dict consisting of whether origin should link to relied, given `threshold`, as well as the score of how likely the model thinks that origin should link to relied (0.0 means unlikely, 1.0 means likely).
    r"""
    Predict whether a notation note depends on the notation
    summarized by another notation note.

    See also `prediction_by_model_via_datapoint` for an alternative
    function for predictions.
    """
    if isinstance(pipeline, TextClassificationPipeline):
        data_point: NotationLinkingDataPoint = notation_note_data_pair_to_data_point(
            origin_data, relied_data)

        input = text_from_data_point(data_point)
        # Of the form {'label': 'LABEL_0' or 'LABEL_1', 'score': 0.5000}
        pred = pipeline(input)[0] 
        # LABEL_1 is for `True`, i.e. when the origin note ought to link to the relied note
        # and LABEL_0 is for `False`, i.e. when the origin note should not link to the relied note
        if pred['label'] == 'LABEL_1':
            pos_score = pred['score']
        else:
            pos_score = 1.0 - pred['score']

        if as_single_float:
            return pos_score
        else:
            # TODO: fix the following to account for threshold
            return {'label': pos_score >= threshold, 'score': pos_score}
    else:
        origin_text = text_from_note_data(origin_data)
        relied_text = text_from_note_data(relied_data)
        origin_embedding = pipeline.encode(origin_text)
        relied_embedding = pipeline.encode(relied_text)
        similarity = 1 - cosine(origin_embedding, relied_embedding)
        if as_single_float:
            return similarity
        else:
            return {'label': similarity >= threshold, 'score': similarity}


In [None]:
# Mock data structures with all required keys
mock_origin_data = {
    'notation_note_name': 'OriginNotationNote',
    'main_info_note': 'OriginMainInfoNote',
    'processed_content': 'OriginProcessedContent',
    'main_note_content': 'OriginMainNoteContent',
    'latex_in_original_or_summarized': 'OriginLatex',
    'summarized': 'OriginSummarized',
    'reference': 'OriginReference'
}

mock_relied_data = {
    'notation_note_name': 'ReliedNotationNote',
    'main_info_note': 'ReliedMainInfoNote',
    'processed_content': 'ReliedProcessedContent',
    'main_note_content': 'ReliedMainNoteContent',
    'latex_in_original_or_summarized': 'ReliedLatex',
    'summarized': 'ReliedSummarized',
    'reference': 'ReliedReference'
}

# Mock TextClassificationPipeline class
class MockTextClassificationPipeline:
    def __init__(self):
        self.return_value = []

    def __call__(self, input):
        return self.return_value

# Mock SentenceTransformer class
class MockSentenceTransformer:
    def __init__(self):
        self.encode_return_values = []

    def encode(self, input):
        if not self.encode_return_values:
            raise IndexError("encode_return_values is empty. Ensure it is properly initialized.")
        return self.encode_return_values.pop(0)

# Patch the external dependencies
with patch('__main__.TextClassificationPipeline', MockTextClassificationPipeline), \
     patch('__main__.SentenceTransformer', MockSentenceTransformer):

    # Test case for TextClassificationPipeline
    mock_pipeline = MockTextClassificationPipeline()
    mock_pipeline.return_value = [{'label': 'LABEL_1', 'score': 0.75}]

    # Test with as_single_float=True
    result = prediction_by_model(mock_origin_data, mock_relied_data, mock_pipeline, as_single_float=True)
    test_eq(result, 0.75)

    # Test with as_single_float=False
    result = prediction_by_model(mock_origin_data, mock_relied_data, mock_pipeline, as_single_float=False, threshold=0.6)
    test_eq(result, {'label': True, 'score': 0.75})

    # Additional test cases for LABEL_0 in TextClassificationPipeline
    mock_pipeline.return_value = [{'label': 'LABEL_0', 'score': 0.6}]

    result = prediction_by_model(mock_origin_data, mock_relied_data, mock_pipeline, as_single_float=True)
    test_eq(result, 0.4)  # 1.0 - 0.6

    result = prediction_by_model(mock_origin_data, mock_relied_data, mock_pipeline, as_single_float=False, threshold=0.5)
    test_eq(result, {'label': False, 'score': 0.4})

    # Test case for SentenceTransformer
    mock_transformer = MockSentenceTransformer()

    # Initialize encode_return_values for the first SentenceTransformer test case
    mock_transformer.encode_return_values = [
        np.array([0.1, 0.2, 0.3]),  # origin embedding
        np.array([0.2, 0.3, 0.4])   # relied embedding
    ]

    # Calculate expected similarity
    origin_embedding = np.array([0.1, 0.2, 0.3])
    relied_embedding = np.array([0.2, 0.3, 0.4])
    expected_similarity = 1 - cosine(origin_embedding, relied_embedding)

    # Test with as_single_float=True
    result = prediction_by_model(mock_origin_data, mock_relied_data, mock_transformer, as_single_float=True)
    test_eq(result, expected_similarity)

    # Reinitialize encode_return_values for the next SentenceTransformer test case
    mock_transformer.encode_return_values = [
        np.array([0.1, 0.2, 0.3]),  # origin embedding (reused)
        np.array([0.2, 0.3, 0.4])   # relied embedding (reused)
    ]

    # Test with as_single_float=False
    result = prediction_by_model(mock_origin_data, mock_relied_data, mock_transformer, as_single_float=False, threshold=0.8)
    test_eq(result, {'label': expected_similarity >= 0.8, 'score': expected_similarity})


In [None]:
#| export
# TODO: test
def rank_notat_notes_to_potentially_link_to(
        origin_data: NotationNoteData,
        relied_data_list: list[NotationNoteData],
        pipeline: Union[TextClassificationPipeline, SentenceTransformer],
        threshold: float = 0.5
        ) -> list[tuple[NotationNoteData, float]]:
    """
    Return a list of notation notes that the notation note represented by `origin_data` should
    link to in decreasing order of likelihood.
    """
    ranked_data_list: list[tuple[NotationNoteData, float]] = []
    for relied_data in relied_data_list:
        relevance_score = prediction_by_model(
            origin_data,
            relied_data,
            pipeline,
            as_single_float=True)
        if relevance_score >= threshold:
            ranked_data_list.append((relied_data, relevance_score))
    return sorted(ranked_data_list, key=lambda x: x[1], reverse=True)

In [None]:
#| export
def _origin_notation_note_already_has_link_to_relied(
        relied_notation_note: VaultNote,
        origin_parsed: Union[tuple, None],
        ) -> bool:
    r"""
    Return `True` if `relied_notation_note is determined to be linked by
    `origin_notation_note` based on the contents of `origin_parsed`, which
    is the output of `parse_notation_note` appleid to `origin_notation_note`.

    This is a helper function to 
    `automatically_add_bulleted_link_to_other_notation_note`.
    """
    for _, linked_note_name in origin_parsed[4]:
        if linked_note_name == relied_notation_note.name:
            return True
    return False


def _add_notation_link(
        origin_notation_note: VaultNote,
        relied_notation_note: VaultNote,
        relied_parsed: tuple,
        ) -> None:
    r"""
    Add a link in `origin_notation_note` to `relied_notation_note`
    and add the tag `_auto/notation_notes_linked` to `origin_notation_note`.
    """
    if not relied_parsed:
        relied_parsed = parse_notation_note(
            relied_notation_note, relied_notation_note.vault)
    mf = MarkdownFile.from_vault_note(origin_notation_note)
    mf.add_tags('_auto/notation_notes_linked',
                enquote_entries_in_metadata_fields=['latex_in_original'])
    bullet = f'- [{relied_parsed[1]}]({relied_notation_note.name}.md)'
    mf.add_line_to_end(
        {'line': bullet,
         'type': MarkdownLineEnum.UNORDERED_LIST})
    mf.write(origin_notation_note)

In [None]:
#| export
def _add_link_to_notation_note_in_mf(
        mf: MarkdownFile,
        relied_data: NotationNoteData,
        vault: Path,
        only_add_links_to_notation_notes_with_confirmed_summaries: bool,
        origin_name: str,
        ) -> bool: # `True` if the link to the relied note is added; `False` otherwise.
    # relied_parsed = parse_notation_note(relied_note)
    """Helper function to `add_links_to_notation_note_via_data_point`"""
    if only_add_links_to_notation_notes_with_confirmed_summaries: 
        relied_note = VaultNote(vault, name=relied_data['notation_note_name'])
        relied_mf = MarkdownFile.from_vault_note(relied_note)
        meta = relied_mf.metadata()
        parsed = parse_notation_note(relied_note)
        summary_is_auto_generated = 'tags' in meta and ['_auto/notation_summary'] in meta['tags']
        summary_is_empty = not bool(str(parsed.main_content_markdown_file).strip())
        if summary_is_auto_generated or summary_is_empty:
            print(f'The notation linking pipeline predicts that the notation note named {origin_name} should link to {relied_data["notation_note_name"]}, which either has an auto-generated summary or an empty summary. The link will not be added.')
            return False

    bullet = f'- [{relied_data["summarized"]}]({relied_data["notation_note_name"]}.md)'
    mf.add_line_to_end(
        {'line': bullet,
        'type': MarkdownLineEnum.UNORDERED_LIST})
    return True

In [None]:
#| export
def add_links_to_notation_note_via_data_point(
        origin_data: NotationNoteData,
        relied_data_list: list[NotationNoteData],
        pipeline: transformers.pipelines.text_classification.TextClassificationPipeline,
        vault: Path, # The vault in which the notes are.
        threshold: float = 0.5, # The threshold for `pipeline`'s prediction of how likely it is that a relied notation note should be linked in order for the linking to actually happen.
        only_add_links_to_notation_notes_with_confirmed_summaries: bool = True # If `True`, and if `pipeline` determines that the origina note should link to a notation note with an autogenerated summary or no summary, then print a messaage about this, but do not add the link.
    ) -> list[str]: # The names of the notation notes added; the names of notation notes that are not added for one reason or another are not included in this list.
    """
    Add links for the notation notes in `relied_data_list` into the notation note represented
    by `origin_data` if `pipeline` predicts that thoes notation notes should be linked.
    """
    relied_data_to_be_linked: list[tuple[NotationNoteData, float]] = rank_notat_notes_to_potentially_link_to(
        origin_data, relied_data_list, pipeline, threshold)
    origin_notat_note = VaultNote(vault, name=origin_data['notation_note_name'])
    if not relied_data_to_be_linked:
        return []
    mf = MarkdownFile.from_vault_note(origin_notat_note)
    mf.add_tags('_auto/notation_notes_linked',
                enquote_entries_in_metadata_fields=['latex_in_original'])
    names_of_notation_notes_added: list[str] = []
    for relied_data, _ in relied_data_to_be_linked:
        if relied_data['notation_note_name'] == origin_data['notation_note_name']:
            continue
        added = _add_link_to_notation_note_in_mf(
            mf, relied_data, vault,
            only_add_links_to_notation_notes_with_confirmed_summaries,
            origin_data['notation_note_name'])
        if added:
            names_of_notation_notes_added.append(relied_data['notation_note_name'])
    mf.write(origin_notat_note)
    return names_of_notation_notes_added

In [None]:

# Configure global mocks
mock_vault = Path('/test/vault')

# Create VaultNote mock factory
def mock_vault_note(*args, **kwargs):
    if len(args) == 2:
        vault, name = args
    else:
        vault = kwargs.get('vault', mock_vault)
        name = kwargs.get('name')
    
    m = MagicMock()
    m.name = name
    m.vault = vault
    m.path.return_value = vault / f"{name}.md"
    m.rel_path = f"{name}.md"
    m.md = MagicMock()
    return m

# Mock VaultNote class to use factory
vault_note_patcher = patch('__main__.VaultNote', side_effect=mock_vault_note)
mock_vault_cls = vault_note_patcher.start()

# Create a custom mock for MarkdownFile
mock_markdown_file = MagicMock()
mock_markdown_file.from_vault_note.return_value = MagicMock(
    add_tags=Mock(),
    add_line_to_end=Mock(),
    write=Mock()
)

# Patch MarkdownFile with our custom mock
markdown_file_patcher = patch('__main__.MarkdownFile', mock_markdown_file)
markdown_file_patcher.start()

# Test 1: No notation notes to link
with patch('__main__.rank_notat_notes_to_potentially_link_to', return_value=[]):
    result = add_links_to_notation_note_via_data_point(
        {'notation_note_name': 'origin'}, [], Mock(), mock_vault
    )
    test_eq(result, [])

# Test 2: Skip self-referential link
origin_data = {'notation_note_name': 'origin'}
with patch('__main__.rank_notat_notes_to_potentially_link_to', 
         return_value=[(origin_data, 0.8)]):
    result = add_links_to_notation_note_via_data_point(
        origin_data, [], Mock(), mock_vault
    )
    test_eq(result, [])

# Test 3: Valid link addition
with patch('__main__.rank_notat_notes_to_potentially_link_to',
         return_value=[({'notation_note_name': 'test'}, 0.7)]), \
     patch('__main__._add_link_to_notation_note_in_mf', return_value=True):
    result = add_links_to_notation_note_via_data_point(
        origin_data, [], Mock(), mock_vault
    )
    test_eq(result, ['test'])

In [None]:
#| hide
# Test 4: Block auto-generated summary
with patch('__main__.parse_notation_note') as mock_parse, \
     patch('builtins.print') as mock_print:
    mock_mf = MagicMock()
    mock_mf.metadata.return_value = {'tags': ['_auto/notation_summary']}
    mock_parse.return_value.main_content_markdown_file = ""
    
    result = _add_link_to_notation_note_in_mf(
        mock_mf, 
        {'notation_note_name': 'test', 'summarized': 'Test Summary'},
        mock_vault,
        True,
        'origin'
    )
    test_eq(result, False)
    mock_print.assert_called_once()

# Test 5: Valid link added
with patch('__main__.parse_notation_note') as mock_parse:
    mock_mf = MagicMock()
    mock_mf.metadata.return_value = {}
    mock_parse.return_value.main_content_markdown_file = "Real content"
    
    result = _add_link_to_notation_note_in_mf(
        mock_mf,
        {'notation_note_name': 'test', 'summarized': 'Test Summary'},
        mock_vault,
        True,
        'origin'
    )
    test_eq(result, True)
    mock_mf.add_line_to_end.assert_called_with({
        'line': '- [Test Summary](test.md)',
        'type': MarkdownLineEnum.UNORDERED_LIST
    })

In [None]:
# Clean up patchers
_ = vault_note_patcher.stop()
_ = markdown_file_patcher.stop()