In [None]:
#| default_exp markdown.obsidian.personal.machine_learning.note_linking

# markdown.obsidian.personal.machine_learning.note_linking
> Functions for gathering note linking data and to use models trained with said data. 

In [None]:
#| export

import ast
from abc import ABC, abstractmethod
import copy
from datasets import Dataset
from enum import Enum
from itertools import combinations
from pathlib import Path
from os import PathLike
import random
import re
from typing import Literal, Optional, TypedDict, TypeVar, Union

from fastcore.basics import patch
import torch
from transformers import Pipeline


from trouver.helper import latex_str_is_likely_in_latex_str, latex_str_in_latex_str_fuzz_metric
from trouver.helper.numbers import modify_int_by_at_most_at_most_offset, modify_int_by_at_most_at_most_value
from trouver.helper.regex import find_regex_in_text, latex_indices
from trouver.markdown.markdown.file import MarkdownFile
from trouver.helper.latex import (
    augment_text, choose_modification_methods_at_random, remove_font_styles_at_random, change_font_styles_at_random, change_greek_letters_at_random, remove_math_keywords, random_latex_command_removal, random_word_removal, dollar_sign_manipulation, random_char_modification
    )
from trouver.markdown.obsidian.footnotes import identify_available_footnote_numbers
from trouver.markdown.obsidian.links import links_from_text, LinkType, ObsidianLink, MARKDOWNLINK_CAPTURE_PATTERN
from trouver.markdown.obsidian.personal.information_notes import index_note_of_note
from trouver.markdown.obsidian.personal.machine_learning.note_data import (
    NoteLinkEnum, NoteData, note_data_order_cmp, randomly_modify, InfoNoteData, NotatNoteData, note_data_from_index_note, note_data_from_reference, find_reverse_links, get_main_note_content_of_notat_note_data, _note_data_from_vault_note_on_the_fly
    )
from trouver.markdown.obsidian.personal.notation.in_standard_information_note import notation_notes_linked_in_see_also_section
from trouver.markdown.obsidian.personal.notation.parse import NotationNoteParsed, parse_notation_note, notation_in_note, main_of_notation
from trouver.markdown.obsidian.personal.note_processing import process_standard_information_note, ProcessNoteError
from trouver.markdown.obsidian.personal.note_type import (
    PersonalNoteTypeEnum, assert_note_is_of_type, note_is_of_type, type_of_note
)


from trouver.markdown.obsidian.personal.notes import (
    notes_linked_in_note,  notes_linked_in_notes_linked_in_note)
from trouver.markdown.obsidian.personal.reference import index_note_for_reference, all_paths_to_notes_in_reference_folder
from trouver.markdown.obsidian.vault import VaultNote



In [None]:

from unittest.mock import MagicMock
from unittest.mock import patch as mock_patch

from fastcore.test import *

from nbdev.showdoc import show_doc


## Sieve instances of pairs for building a dataset

In practice, it is difficult to manually create all links that ought to be linked. In particular, while it can be easy to extract "positive" instances of links (by virtue of simply finding explicit links), it is more difficult to obtain "negative" instances of links with certainty. The general method for obtaining "negative" instances is nevertheless to randomly sample pairs of notes and consider such a pair as "negative" if there is no link between them; some notes are "well focused" on in practice (and in particular has many links to other notes); if there are no links between two such notes, then it is likely that there are not supposed to be links between them.

In [None]:
#| export
class NotePairData(TypedDict):
    origin_note: NoteData
    relied_note: NoteData
    # linked_type: NoteLinkEnum

In [None]:
#| export
def link_types_for_note_pair_data(
        pair_data: NotePairData
        ) -> set[NoteLinkEnum]:
    relied_note_name: str = pair_data['relied_note'].note_name
    directly_linked_notes_from_origin = pair_data['origin_note'].directly_linked_notes
    if relied_note_name in directly_linked_notes_from_origin:
        return set(directly_linked_notes_from_origin[relied_note_name])
    else:
        return set([NoteLinkEnum.NO_LINK])

In [None]:
#| export
def _high_count_note_data(
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, NotatNoteData],
        ) -> tuple[set[str], set[str]]:
    """
    Helper function to `sieve_note_data_pairs`.
    """

    high_count_info_notes: set[str] = set([
        name for name, data_point in info_note_data.items()
        if len(data_point.reverse_linked_notes) > 4
        and len(info_note_data[name].directly_linked_notes) > 2])
    high_count_notat_notes: set[str] = set([
        name for name, data_point in notat_note_data.items()
        if len(data_point.reverse_linked_notes) > 4])

    return (high_count_info_notes, high_count_notat_notes)


def _mid_count_note_data(
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, NotatNoteData],
        high_count_info_notes: set[str],
        high_count_notat_notes: set[str],
        ) -> tuple[set[str], set[str]]:
    """
    Helper function to `sieve_note_data_pairs`.
    """

    mid_count_info_notes: set[str] = set([
        name for name, data_point in info_note_data.items()
        if len(data_point.reverse_linked_notes) > 2
        and len(info_note_data[name].directly_linked_notes) > 1
        and name not in high_count_info_notes])
    mid_count_notat_notes: list[str] = set([
        name for name, data_point in notat_note_data.items()
        if len(data_point.reverse_linked_notes) > 2
        and name not in high_count_notat_notes])

    return (mid_count_info_notes, mid_count_notat_notes)


In [None]:
#| export
def _positive_instances_from_high_or_mid_count_notes(
        high_count_notes: set[str],
        mid_count_notes: set[str],
        note_data: dict[str, NoteData],
        ) -> list[tuple[str, str]]:
    """
    Helper function to `sieve_note_data_pairs`.
    """
    chosen_pairs: list[tuple[str, str]] = []
    # Get all "positive" note links from high count notes to high or mid count notes.
    for high_count_note_name in list(high_count_notes):
    # for high_count_note_name, high_count_data_point in high_count_notes.items():
        high_count_data_point = note_data[high_count_note_name]
        for other_note, _ in high_count_data_point.directly_linked_notes.items():
            if other_note in high_count_notes or other_note in mid_count_notes:
                chosen_pairs.append((high_count_note_name, other_note))
    # Get all "positive" note links from mid count notes to high count notes.
    for mid_count_note_name in list(mid_count_notes):
        mid_count_data_point = note_data[mid_count_note_name]
    # for mid_count_note_name, mid_count_data_point in mid_count_notes.items():
        for other_note, _ in mid_count_data_point.directly_linked_notes.items():
            if other_note in high_count_notes:
                chosen_pairs.append((mid_count_note_name, other_note))
    return chosen_pairs

In [None]:
#| export
def _negative_instances_from_high_or_mid_count_notes(
        high_count_notes: set[str],
        mid_count_notes: set[str],
        note_data: dict[str, NoteData],
        num_pairs: int # The approximate number of pairs to sample.
        ) -> list[tuple[str, str]]:
    """
    Get "negative" pair instances from high or mid count notes, i.e. pairs where
    the origin note seem to not link to relied note.
    """
    high_count_weights = [
        (len(note_data[note_name].reverse_linked_notes)**0.5)
        for note_name in list(high_count_notes)]
    mid_count_weights = [
        (len(note_data[note_name].reverse_linked_notes)**0.5)
        for note_name in list(mid_count_notes)]
    high_to_high_samples = int(0.5 * num_pairs)
    high_to_mid_samples = int(0.25 * num_pairs)
    mid_to_high_samples = int(0.25 * num_pairs)
    high_count_notes_list = list(high_count_notes)
    mid_count_notes_list = list(mid_count_notes)

    sample_pairs: set[tuple[str, str]] = set()

    origin_notes = random.choices(
        high_count_notes_list, weights=high_count_weights, k=high_to_high_samples
        ) if high_count_notes_list else []
    relied_notes = random.choices(
        high_count_notes_list, weights=high_count_weights, k=high_to_high_samples
        ) if high_count_notes_list else []
    for origin_note, relied_note in zip(origin_notes, relied_notes):
        if (origin_note == relied_note
                or relied_note in note_data[origin_note].directly_linked_notes):
            continue
        else:
            sample_pairs.add((origin_note, relied_note))

    origin_notes = random.choices(
        high_count_notes_list, weights=high_count_weights, k=high_to_mid_samples
        ) if high_count_notes_list else []
    relied_notes = random.choices(
        mid_count_notes_list, weights=mid_count_weights, k=high_to_mid_samples
        ) if mid_count_notes_list else []
    for origin_note, relied_note in zip(origin_notes, relied_notes):
        if (origin_note == relied_note
                or relied_note in note_data[origin_note].directly_linked_notes):
            continue
        else:
            sample_pairs.add((origin_note, relied_note))

    origin_notes = random.choices(
        mid_count_notes_list, weights=mid_count_weights, k=mid_to_high_samples
        ) if mid_count_notes_list else []
    relied_notes = random.choices(
        high_count_notes_list, weights=high_count_weights, k=mid_to_high_samples
        ) if high_count_notes_list else []
    for origin_note, relied_note in zip(origin_notes, relied_notes):
        if (origin_note == relied_note
                or relied_note in note_data[origin_note].directly_linked_notes):
            continue
        else:
            sample_pairs.add((origin_note, relied_note))

    return list(sample_pairs)

In [None]:
#| export
def _similar_notation_pairs(
        notat_note_data: dict[str, NotatNoteData],
        # ) -> list[tuple[str, str]]:
        ) -> dict[str, set[str]]: # The keys are names of notation notes and the values are sets of names of notation notes whose notations are similar to the one explained in the key notation note.
    """
    Identify pairs of names of notation notes whose notations are similar.

    Helper function to sieve_note_data_pairs.

    The similarity is measured by Jaro-Winkler, which works well on short
    strings.
    """
    jarowinkler = JaroWinkler()
    # similar_notation_pairs: list[tuple[str, str]] = []
    similar_notation_dict: dict[str, set[str]] = {}
    for notat_name_1, notat_name_2 in combinations(notat_note_data, 2):
        notat_data_1, notat_data_2 = notat_note_data[notat_name_1], notat_note_data[notat_name_2]
        notat_str_1 = notat_data_1.parsed.notation_str
        notat_str_2 = notat_data_2.parsed.notation_str
        similarity = jarowinkler.similarity(notat_str_1, notat_str_2)
        reverse_similarity = jarowinkler.similarity(notat_str_1[::-1], notat_str_2[::-1]) 
        if similarity > 0.9 or reverse_similarity > 0.9:
            _update_dict(similar_notation_dict, notat_name_1, notat_name_2)
            _update_dict(similar_notation_dict, notat_name_2, notat_name_1)
    return similar_notation_dict


In [None]:
#| export
def _random_pair_replacing_notation_notes_with_similar_notation_notes(
        original_pair: tuple[str, str],
        similar_notation_dict: set[str, set[str]]
        ) -> tuple[str, str]:
    """
    Helper function to `_random_pair_replacing_notation_notes_with_similar_notation_notes`.
    """
    origin_note_name = original_pair[0]
    relied_note_name = original_pair[1]
    if random.random() > 0.5:
        if origin_note_name in similar_notation_dict:
            origin_note_name = random.choice(list(similar_notation_dict[origin_note_name]))
    if random.random() > 0.5:
        if relied_note_name in similar_notation_dict:
            relied_note_name = random.choice(list(similar_notation_dict[relied_note_name]))
    return (origin_note_name, relied_note_name)
    

    
def _pairs_with_notation_notes_replaced_with_similar_notation_notes(
        sieved_pairs: set[tuple[str, str]],
        count: int, # The approximate number of pairs to attempt to obtain.
        similar_notation_dict: set[str, set[str]], # An output of `_similar_notation_pairs`
        # notat_note_data: dict[str, NotatNoteData],
        ) -> list[tuple[str, str]]:
    """
    Return modified versions of entries of `sieved_pairs` drawn at random
    where notation note names are replaced by names of notation notes whose 
    introduced notations are similar, in accordance to `similar_notation_dict`.

    Helper function to `sieve_note_data_pairs`.
    """
    sieved_pairs_list = list(sieved_pairs)
    new_pairs: list[tuple[str, str]] = []
    for _ in range(count):
        original_pair = random.choice(sieved_pairs_list)
        new_pair = _random_pair_replacing_notation_notes_with_similar_notation_notes(
            original_pair, similar_notation_dict)
        new_pairs.append(new_pair)
    return new_pairs

In [None]:
#| export
def _pair_is_admissible(
        origin_note: str,
        relied_note: str,
        note_data: dict[str, NoteData],
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, InfoNoteData],
        ) -> bool:
    origin_note_has_tags = note_data[origin_note].tags is not None
    if not origin_note_has_tags:
        return True
    if (('_auto/links_added' in note_data[origin_note].tags and relied_note in info_note_data)
            or ('_auto/notations_added' in note_data[origin_note].tags and relied_note in notat_note_data)):
        return False
    return True

In [None]:
#| export
def sieve_note_data_pairs(
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, NotatNoteData],
        ) -> list[NotePairData]:
    note_data: dict[str, NoteData] = {}
    note_data.update(info_note_data)
    note_data.update(notat_note_data)
    high_count_info_notes, high_count_notat_notes = _high_count_note_data(
        info_note_data, notat_note_data) # set[str]
    high_count_notes: set[str] = high_count_info_notes.union(high_count_notat_notes)
    mid_count_info_notes, mid_count_notat_notes = _mid_count_note_data(
        info_note_data, notat_note_data, high_count_info_notes, high_count_notat_notes)
    mid_count_notes: set[str] = mid_count_info_notes.union(mid_count_notat_notes)

    positive_pairs: list[tuple[str, str]] = _positive_instances_from_high_or_mid_count_notes(
        high_count_notes, mid_count_notes, note_data)
    negative_pairs: list[tuple[str, str]] = _negative_instances_from_high_or_mid_count_notes(
        high_count_notes, mid_count_notes, note_data, len(positive_pairs)*3)
    sieved_pairs: set[tuple[str, str]] = set(positive_pairs)
    sieved_pairs.update(negative_pairs)

    similar_notation_dict: dict[str, set[str]] = _similar_notation_pairs(notat_note_data)
    # similar_notation_names: list[tuple[str, str]] = _similar_notation_pairs(notat_note_data)
    sieved_pairs.update(_pairs_with_notation_notes_replaced_with_similar_notation_notes(
        sieved_pairs, len(positive_pairs), similar_notation_dict))
    sieved_pairs_list = list(sieved_pairs)
    note_pair_data_list: list[NotePairData] = []
    for origin_note, relied_note in sieved_pairs_list:
        if not _pair_is_admissible(
                origin_note, relied_note, note_data, info_note_data, notat_note_data):
            continue
        note_pair_data_list.append(
            NotePairData(
                origin_note=note_data[origin_note],
                relied_note=note_data[relied_note]))
    return note_pair_data_list

In [None]:
# data_pairs = sieve_note_data_pairs(info_note_data, notat_note_data)

In [None]:
# count = 0
# for data_pair in data_pairs:
#     if data_pair['relied_note'].note_name in data_pair['origin_note'].directly_linked_notes:
#         if data_pair['relied_note'].note_name == 'hotta_takeuchi_tanisaki_dmpsrt_notation_X_smooth_non_singular_algebraic_variety_over_C':
#             print(data_pair['origin_note'].note_name)
#             print(data_pair['relied_note'].note_name)

In [None]:
# index = 2
# print(data_pairs[index]['origin_note'].note_name)
# print(data_pairs[index]['relied_note'].note_name)
# print(data_pairs[index]['origin_note'].directly_linked_notes)

In [None]:
# data_pairs[8]['origin_note'].note_name

## Converting a note pair into a string and data augmentation 

In [None]:
#| export
def _erase_position_metadata(
        augmentation: Literal['high', 'mid',' low'] | None,
        ) -> bool:
    """

    """
    rand_value = random.random()
    if augmentation == 'high':
        return rand_value < 0.3
    elif augmentation == 'mid':
        return rand_value < 0.2
    elif augmentation == 'low':
        return rand_value < 0.1
    return False
    
    
def string_from_note_pair(
            pair_data: NotePairData,
            format: Literal['bert', 't5'],
            # note_data: dict[str, NoteData],
        ) -> str:
    origin_data = pair_data['origin_note']
    relied_data = pair_data['relied_note']
    origin_data_string = origin_data.data_string(format)
    relied_data_string = relied_data.data_string(format)
    if format == 'bert':
        return f'{origin_data_string}\n\n[SEP]\n\n{relied_data_string}'
    else:
        return f'{origin_data_string}\n\n</s>\n\n{relied_data_string}'


def augment_note_pair(
        pair_data: NotePairData,
        augmentation: Optional[Literal['high', 'mid', 'low']] = None,
        include_position_data_for_origin: bool = True,
        include_position_data_for_relied: bool = True,
        ) -> NotePairData:
    """
    Return an augmented copy of `pair_data`.
    """
    origin_data = pair_data['origin_note'].deepcopy()
    relied_data = pair_data['relied_note'].deepcopy()
    erase_position_data_for_origin_data = _erase_position_metadata(augmentation) or not include_position_data_for_origin
    erase_position_data_for_relied_data = _erase_position_metadata(augmentation) or not include_position_data_for_relied
    if augmentation is not None:
        origin_data.randomly_modify(augmentation, erase_position_data_for_origin_data)
        relied_data.randomly_modify(augmentation, erase_position_data_for_relied_data)
    return NotePairData(
        origin_note=origin_data, relied_note=relied_data)


In [None]:
# from transformers import AutoModelForSeq2SeqLM, AutoModelForTokenClassification, AutoTokenizer, pipeline
# model = AutoModelForSeq2SeqLM.from_pretrained('hyunjongkimmath/notation_summarizations_model')
# tokenizer = AutoTokenizer.from_pretrained('hyunjongkimmath/notation_summarizations_model')
# summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)

In [None]:
#| export
class NoteLinkingDataPoint(TypedDict):
    """
    A dict object that is 
    """
    origin_note_name: str
    relied_note_name: str
    input_text: str
    # Keys derived from NoteLinkEnum (excluding NO_LINK)
    link_types: list[str] # The str are the names of `NoteLinkEnum`.
    # info_to_info_in_content: bool
    # info_to_info_in_see_also: bool
    # info_to_info_via_notat: bool
    # info_to_notat_via_embedding: bool
    # notat_to_info: bool
    # notat_to_info_via_notat: bool
    # notat_to_notat: bool


In [None]:
#| export
def dict_data_point_from_pair(
        pair_data: NotePairData,
        format: Literal['bert', 't5'],
        # note_data: dict[str, NoteData],
        ) -> NoteLinkingDataPoint:
    """
    Obtain a `NoteLinkingDataPoint` object from a `NotePairData` object.
    """
    origin_note_name = pair_data['origin_note'].note_name
    relied_note_name = pair_data['relied_note'].note_name
    input_text = string_from_note_pair(pair_data, format)    
    
    link_types: list[NoteLinkEnum] = list(
        link_types_for_note_pair_data(pair_data))
    link_types: list[str] = [value.name for value in link_types]
    return NoteLinkingDataPoint(
        origin_note_name=origin_note_name,
        relied_note_name=relied_note_name,
        input_text=input_text,
        link_types=link_types,)


In [None]:
#| export
def dataset_from_note_data(
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, NotatNoteData],
        augment: bool,
        format: Literal['bert', 't5'],
        ) -> Dataset:
    note_data_pairs: list[NotePairData] = sieve_note_data_pairs(
        info_note_data, notat_note_data)
    dict_data: list[NoteLinkingDataPoint] = []
    for pair_data in note_data_pairs:
        dict_data.append(dict_data_point_from_pair(
            pair_data, format))
        if not augment:
            continue
        for augmentation in ['low', 'mid', 'high']:
            augmented_pair_data = augment_note_pair(
                pair_data, augmentation)
            dict_data.append(dict_data_point_from_pair(
                augmented_pair_data, format))
    return Dataset.from_list(dict_data)
    

## Using the model

### Get predictions from model pipeline

In [None]:
#| export
class MultiLabelPipeline(Pipeline):
    """
    Implementing this `Pipeline` class is necessary because HuggingFAce's standard
    `text-classification` pipeline uses softmax, which is suitable for single-label or
    multi-class classification; a sigmoid activation function is more suitable for
    multi-label classification.
    """
    def __init__(self, model, tokenizer, **kwargs):
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, inputs, **kwargs):
        return self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)

    def _forward(self, model_inputs, **kwargs):
        return self.model(**model_inputs)

    def postprocess(self, model_outputs, **kwargs):
        logits = model_outputs.logits
        probabilities = torch.sigmoid(logits)  # Use sigmoid for multi-label
        return probabilities.tolist()

In [None]:
#| export
def prediction_by_note_linking_model(
        origin_data: NoteData, # The `NoteData` object representing the "origin note", i.e.  the note from which a link to the "relied note" is considered.
        relied_data: NoteData, # The `NoteData` object representing the "relied note", i.e.  the note to which a link from the "origin note" is considered.
        predictor: MultiLabelPipeline,
        format: Literal['bert', 't5'] = 'bert', # Specifies how to format the input to `predictor`.
        as_floats: bool = True, # If `True`, then return the predictions as floats indicating how likely it is that there should be a linking from the origin note to the relied note of each type.
        threshold: float | dict[str, float] = 0.5, # Either a float value or a dictionary whose keys are the possible labels and whose values are floats. If a label is not one of the dictionary's key, then the default threshold value of 0.5 is used for that label. A float value exceeding this threshold corresponds to a prediction that a link of the given type should exist. This is only used if `as_floats` is `True`.
        ) -> Union[dict[str, float], dict[str, bool]]: # A `dict` whose keys are the `labels` and whose values are either `float`s between 0.0 and 1.0 indicating how likely it is that there should be a linking from the origin note to the relied note of the type corresponding to the label. 
    r"""
    Predict how likely/whether a note to should to another note for a specified reason.
    """
    pair_data = NotePairData(origin_note=origin_data, relied_note=relied_data)
    input_text = string_from_note_pair(pair_data, format)
    preds: list[float] = predictor(input_text)[0]
    id2label: dict[int, str] = predictor.model.config.id2label
    output: Union[dict[str, float], dict[str, bool]] = {}
    for id, label in id2label.items():
        if as_floats:
            output[label] = preds[id]
        else:
            if isinstance(threshold, float):
                label_threshold = threshold
            elif label in threshold:
                label_threshold = threshold[label]
            else:
                label_threshold = 0.5
            output[label] = preds[id] > label_threshold
    return output

In [None]:
with (mock_patch('__main__.string_from_note_pair') as mock_string_from_note_pair):
    mock_origin_data = MagicMock()
    mock_relied_data = MagicMock()
    mock_predictor = MagicMock()
    mock_predictor.model = MagicMock()
    mock_predictor.model.config = MagicMock()
    mock_predictor.model.config.id2label = {
        0: 'NO_LINK',
        1: 'INFO_TO_INFO_IN_CONTENT',
        2: 'INFO_TO_INFO_IN_SEE_ALSO',
        3: 'INFO_TO_INFO_VIA_NOTAT',
        4: 'INFO_TO_NOTAT_VIA_EMBEDDING',
        5: 'NOTAT_TO_INFO',
        6: 'NOTAT_TO_INFO_VIA_NOTAT',
        7: 'NOTAT_TO_NOTAT'}
    
    mock_predictor.return_value = [[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.6, 0.0]]
    output = prediction_by_note_linking_model(
        mock_origin_data, mock_relied_data, mock_predictor, as_floats=False, threshold=0.5)
    print(output)
    test_is(output['NO_LINK'], False)
    test_is(output['INFO_TO_INFO_IN_CONTENT'], False)
    test_is(output['INFO_TO_NOTAT_VIA_EMBEDDING'], True)
    test_is(output['NOTAT_TO_INFO_VIA_NOTAT'], True)

    mock_predictor.return_value = [[0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]
    output = prediction_by_note_linking_model(
        mock_origin_data, mock_relied_data, mock_predictor, as_floats=False, threshold={
            'INFO_TO_INFO_VIA_NOTAT': 0.65,
            'INFO_TO_INFO_IN_CONTENT': 0.10
        })
    print(output)
    test_is(output['INFO_TO_INFO_VIA_NOTAT'], False)
    test_is(output['INFO_TO_INFO_IN_CONTENT'], True)
    



{'NO_LINK': False, 'INFO_TO_INFO_IN_CONTENT': False, 'INFO_TO_INFO_IN_SEE_ALSO': False, 'INFO_TO_INFO_VIA_NOTAT': False, 'INFO_TO_NOTAT_VIA_EMBEDDING': True, 'NOTAT_TO_INFO': False, 'NOTAT_TO_INFO_VIA_NOTAT': True, 'NOTAT_TO_NOTAT': False}
{'NO_LINK': False, 'INFO_TO_INFO_IN_CONTENT': True, 'INFO_TO_INFO_IN_SEE_ALSO': False, 'INFO_TO_INFO_VIA_NOTAT': False, 'INFO_TO_NOTAT_VIA_EMBEDDING': False, 'NOTAT_TO_INFO': True, 'NOTAT_TO_INFO_VIA_NOTAT': True, 'NOTAT_TO_NOTAT': True}


In [None]:
#| export
def predict_note_linking(
        origin_note: VaultNote, 
        relied_notes: VaultNote| list[VaultNote],
        predictor: MultiLabelPipeline,
        format: Literal['bert', 't5'] = 'bert', # Specifies how to format the input to `predictor`.
        note_data: Optional[dict[str, NoteData]] = None, # For the purposes of predicting note linking, the note data only requires the positional data, so getting the note data via `note_data_from_index_note` should suffice (without having to use `find_reverse_links`, although `get_main_note_content_of_notat_note_data` should still be necessary).
        omit_no_link_predictions: bool = True, # if `True` omit predictions of `NoteLinkEnum.NO_LINK`
        threshold: float | dict[float]= 0.5, # See also `prediction_by_note_linking_model`. Either a float value or a dictionary whose keys are the possible labels and whose values are floats. If a label is not one of the dictionary's key, then the default threshold value of 0.5 is used for that label. A float value exceeding this threshold corresponds to a prediction that a link of the given type should exist. This is only used if `as_floats` is `True`.
        ) -> dict[str, list[NoteLinkEnum]]: # The keys are the names of relied notes. The values are lists of `NoteLinkEnum` that specify the linking types from origin note to the relied note.
    # TODO: add threshold parameter
    if isinstance(relied_notes, VaultNote):
        relied_notes: list[VaultNote] = [relied_notes]
    if note_data and origin_note.name in note_data:
        origin_note_data = note_data[origin_note.name]
    else:
        try:
            origin_note_data = _note_data_from_vault_note_on_the_fly(
                origin_note, reference='', note_data=note_data)
        except Exception as e:
            print(f"An error ocurred while trying to get data for  `origin_note`: {origin_note}")
            print(e)
            return
    output_dict: dict[str, list[NoteLinkEnum]] = {}
    for relied_note in relied_notes:
        if relied_note.name == origin_note.name:
            continue
        if note_data and relied_note.name in note_data:
            relied_note_data = note_data[relied_note.name]
        else:
            try:
                relied_note_data = _note_data_from_vault_note_on_the_fly(
                    relied_note, reference='', note_data=note_data)
            except Exception as e:
                print(f"An error ocurred while trying to get data for  `relied_note`: {relied_note}")
                print(e)
                continue
        if relied_note_data == None:
            print(relied_note)
        preds: dict[str, bool] = prediction_by_note_linking_model(
            origin_note_data, relied_note_data, predictor, format,
            as_floats=False,
            threshold=threshold)
        output_dict[relied_note.name] = []
        for enum_name, link_flag in preds.items():
            if omit_no_link_predictions and enum_name == "NO_LINK":
                continue
            elif link_flag:
                output_dict[relied_note.name].append(NoteLinkEnum[enum_name])
    return output_dict
    

In [None]:
#| hide
with (mock_patch('__main__.prediction_by_note_linking_model') as mock_prediction_by_note_linking_model, \
          mock_patch('__main__._note_data_from_vault_note_on_the_fly') as mock_note_data_from_vault_note_on_the_fly):
     mock_origin_note = MagicMock()
     mock_relied_note = MagicMock()
     mock_origin_note.name = 'origin_note_name'
     mock_relied_note.name = 'relied_note_name'
     relied_notes = [mock_relied_note]
     mock_predictor = MagicMock()

     mock_prediction_by_note_linking_model.return_value = {
          'INFO_TO_INFO_IN_CONTENT': False,
          'INFO_TO_INFO_IN_SEE_ALSO': False,
          'INFO_TO_INFO_VIA_NOTAT': True,
          'INFO_TO_NOTAT_VIA_EMBEDDING': False,
          'NOTAT_TO_INFO': False,
          'NOTAT_TO_INFO_VIA_NOTAT': False,
          'NOTAT_TO_NOTAT': False,
          'NO_LINK': False}
     mock_note_data_from_vault_note_on_the_fly.side_effect = [MagicMock(), MagicMock()]
     output = predict_note_linking(
          mock_origin_note, relied_notes, mock_predictor, omit_no_link_predictions=True)
     test_eq(
          output,
          {'relied_note_name': [NoteLinkEnum.INFO_TO_INFO_VIA_NOTAT]})


     mock_prediction_by_note_linking_model.return_value = {
          'INFO_TO_INFO_IN_CONTENT': False,
          'INFO_TO_INFO_IN_SEE_ALSO': False,
          'INFO_TO_INFO_VIA_NOTAT': False,
          'INFO_TO_NOTAT_VIA_EMBEDDING': False,
          'NOTAT_TO_INFO': False,
          'NOTAT_TO_INFO_VIA_NOTAT': False,
          'NOTAT_TO_NOTAT': False,
          'NO_LINK': True}
     mock_note_data_from_vault_note_on_the_fly.side_effect = [MagicMock(), MagicMock()]
     output = predict_note_linking(
          mock_origin_note, relied_notes, mock_predictor, omit_no_link_predictions=True)
     test_eq(
          output,
          {'relied_note_name': []})
     print(output)

{'relied_note_name': []}


### Link cache note

Using the model will take a lot of time --- not only does each prediction take about a few seconds, but also the predictions need to be made on pairs of notes and hence the total time needed for predictions grows quadratically with the number of notes. As such, "link cache notes" will be made to record predictions.

The link cache note will be saved in the root directory of its reference folder.


In [None]:
#| export
def link_cache_note(
        vault: PathLike,
        reference: str,
        create_if_does_not_exist: bool = True,
        ) -> VaultNote: # The `VaultNote` object representing the link cache note.
    """
    Return a `VaultNote` object representing the link cache note in a reference of a vault.
    """
    ind_note: VaultNote = index_note_for_reference(vault, reference, update_cache=True)
    reference_folder: Path = ind_note.path(relative=True).parent
    vn = VaultNote(vault, rel_path=reference_folder / f'_link_cache_{reference}.md')
    if create_if_does_not_exist and not vn.exists():
        vn.create()
    return vn

The link cache note will be formatted as follows:

```
- [[origin_note_name_1]]
    - [[relied_note_name_1]]: [<comma_separated_link_types_1>]
    - [[relied_note_name_2]]: [<comma_separated_link_types_2>]
    ...
<blank space for separation>
- [[origin_note_name_2]]
    - ...
```

In [None]:
#| export
def separate_blocks(
        text: str) -> list[str]:
    """
    Splits text into blocks separated by one or more blank lines.
    Returns a list of blocks (strings) with whitespace stripped.
    """
    blocks = []
    current_block = []
    
    for line in text.splitlines():
        if line.strip() == '':  # Blank line
            if current_block:  # Only add if we have content
                blocks.append('\n'.join(current_block))
                current_block = []
        else:
            current_block.append(line)
    
    # Add the last block if there's content remaining
    if current_block:
        blocks.append('\n'.join(current_block))
    
    return blocks


In [None]:
text = """First line
Second line

Third block starts here
With multiple lines

Final block"""

blocks = separate_blocks(text)
print(blocks)
# for i, block in enumerate(blocks, 1):
#     print(f"Block {i}:\n{block}\n{'-'*20}")

['First line\nSecond line', 'Third block starts here\nWith multiple lines', 'Final block']


In [None]:
#| export
def parse_link_cache_note(
        link_cache_note: VaultNote,
        ) -> dict[str, dict[str, list[NoteLinkEnum]]]: # The first key is the name of an "origin note". The second key is the name of a "relied note" with respect to the origin note. The value is a list of the link types from the origin note to the relied note.
    """
    See also `write_link_cache_note`, which is essentially the opposite of this function.
    """
    text = link_cache_note.text()
    blocks = separate_blocks(text)
    link_types: dict[str, dict[str, list[NoteLinkEnum]]] = {}
    for block in blocks:
        lines: list[str] = block.splitlines()
        first_line_link: ObsidianLink = links_from_text(lines[0])[0]
        origin_note_name = first_line_link.file_name
        link_types[origin_note_name] = {}
        for line in lines[1:]:
            link: ObsidianLink = links_from_text(line)[0]
            relied_note_name = link.file_name
            ind = line.index(':')
            note_type_list = ast.literal_eval(line[ind+2:])
            link_types[origin_note_name][relied_note_name] = [
                NoteLinkEnum[note_type_str] for note_type_str in note_type_list]
    return link_types


In [None]:
with mock_patch('__main__.VaultNote') as mock_vault_note:
    mock_link_cache_note = mock_vault_note.return_value
    mock_link_cache_note.text.return_value = '''
- [[origin_note_1]]
    - [[relied_note_1]]: ['INFO_TO_INFO_IN_CONTENT', 'INFO_TO_INFO_VIA_NOTAT']
    - [[relied_note_2]]: ['INFO_TO_NOTAT_VIA_EMBEDDING']

- [[origin_note_2]]
    - [[relied_note_3]]: ['NOTAT_TO_NOTAT']
    - [[relied_note_4]]: ['NOTAT_TO_INFO', 'NOTAT_TO_INFO_VIA_NOTAT']
'''
    parse_link_cache_note(mock_link_cache_note)

In [None]:
#| export
def write_link_cache_note(
        link_types: dict[str, dict[str, list[NoteLinkEnum]]],
        cache_note: VaultNote,
        ) -> None:
    """
    Overwrite the contents of the note represented by `link_cache_note` using the data
    from `link_types`.

    `link_cache_notes` is assumed to exist.

    See also `parse_link_cache_note`, which is essentially the opposite of this function.
    """
    chunks: list[str] = []
    for origin_note_name, relied_dict in link_types.items():
        chunk_text = f"- [[{origin_note_name}]]\n"
        for relied_note_name, link_type_list in relied_dict.items():
            chunk_text = f'{chunk_text}    - [[{relied_note_name}]]: {str([link_type.name for link_type in link_type_list])}\n'
        chunks.append(chunk_text)
    cache_note.write('\n\n'.join(chunks))

In [None]:
with mock_patch('__main__.VaultNote') as mock_vault_note:
    mock_link_cache_note = mock_vault_note.return_value
    link_types = {
        'origin_note_1': {
            'relied_note_1': [
                NoteLinkEnum.INFO_TO_INFO_IN_CONTENT, NoteLinkEnum.INFO_TO_INFO_VIA_NOTAT],
            'relied_note_2': [
                NoteLinkEnum.INFO_TO_NOTAT_VIA_EMBEDDING]},
        'origin_note_2': {
            'relied_note_3': [
                NoteLinkEnum.NOTAT_TO_NOTAT],
            'relied_note_4': [
                NoteLinkEnum.NOTAT_TO_INFO, NoteLinkEnum.NOTAT_TO_INFO_VIA_NOTAT] }
    }
    write_link_cache_note(link_types, mock_link_cache_note)
    args, _ = mock_link_cache_note.write.call_args
    written_content = args[0]
    print(written_content)
    test_eq(
        written_content,
        '''- [[origin_note_1]]
    - [[relied_note_1]]: ['INFO_TO_INFO_IN_CONTENT', 'INFO_TO_INFO_VIA_NOTAT']
    - [[relied_note_2]]: ['INFO_TO_NOTAT_VIA_EMBEDDING']


- [[origin_note_2]]
    - [[relied_note_3]]: ['NOTAT_TO_NOTAT']
    - [[relied_note_4]]: ['NOTAT_TO_INFO', 'NOTAT_TO_INFO_VIA_NOTAT']
'''
        )

- [[origin_note_1]]
    - [[relied_note_1]]: ['INFO_TO_INFO_IN_CONTENT', 'INFO_TO_INFO_VIA_NOTAT']
    - [[relied_note_2]]: ['INFO_TO_NOTAT_VIA_EMBEDDING']


- [[origin_note_2]]
    - [[relied_note_3]]: ['NOTAT_TO_NOTAT']
    - [[relied_note_4]]: ['NOTAT_TO_INFO', 'NOTAT_TO_INFO_VIA_NOTAT']



In [None]:
#| export
def consolidate_note_linking_predictions_into_cache(
        origin_note: VaultNote | str,
        predictions: dict[str, list[NoteLinkEnum]], # An output of `predict_note_linking``
        cache: dict[str, dict[str, list[NoteLinkEnum]]], # See `parse_link_cache_note`. The first key is the name of an "origin note". The second key is the name of a "relied note" with respect to the origin note. The value is a list of the link types from the origin note to the relied note.
        ):
    """
    Consolidate the outputs of `predict_note_linking` into a link cache.
    """
    if isinstance(origin_note, VaultNote):
        origin_note = origin_note.name
    if origin_note not in cache:
        cache[origin_note] = {}
    for relied_note_name, predicted_link_enums in predictions.items():
        if origin_note == relied_note_name:
            continue 
        predicted_link_enums = set(predicted_link_enums)
        predicted_link_enums = predicted_link_enums - {NoteLinkEnum.NO_LINK}
        if relied_note_name not in cache[origin_note]:
            cache[origin_note][relied_note_name] = []
        cached_link_enums = set(cache[origin_note][relied_note_name])
        link_enums = cached_link_enums | predicted_link_enums
        cache[origin_note][relied_note_name] = list(link_enums)


In [None]:
predictions = {
    'relied_note_name_1': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT, NoteLinkEnum.INFO_TO_INFO_VIA_NOTAT],
    'relied_note_name_2': [],
    'relied_note_name_3': [NoteLinkEnum.INFO_TO_NOTAT_VIA_EMBEDDING]}
cache = {'origin_note_name': {'relied_note_name_1': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT, NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO]}}

consolidate_note_linking_predictions_into_cache('origin_note_name', predictions, cache)

print(cache)
test_eq(
    set(cache['origin_note_name']['relied_note_name_1']), 
    set([NoteLinkEnum.INFO_TO_INFO_IN_CONTENT, NoteLinkEnum.INFO_TO_INFO_VIA_NOTAT, NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO]))

test_eq(
    set(cache['origin_note_name']['relied_note_name_2']), 
    set([]))

test_eq(
    set(cache['origin_note_name']['relied_note_name_3']), 
    set([NoteLinkEnum.INFO_TO_NOTAT_VIA_EMBEDDING]))

{'origin_note_name': {'relied_note_name_1': [<NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO: 2>, <NoteLinkEnum.INFO_TO_INFO_IN_CONTENT: 1>, <NoteLinkEnum.INFO_TO_INFO_VIA_NOTAT: 3>], 'relied_note_name_2': [], 'relied_note_name_3': [<NoteLinkEnum.INFO_TO_NOTAT_VIA_EMBEDDING: 4>]}}


In [None]:
#| export
def consolidate_caches(
        cache_1: dict[str, dict[str, list[NoteLinkEnum]]], # See `parse_link_cache_note`. The first key is the name of an "origin note". The second key is the name of a "relied note" with respect to the origin note. The value is a list of the link types from the origin note to the relied note.
        cache_2: dict[str, dict[str, list[NoteLinkEnum]]],
        ) -> dict[str, dict[str, list[NoteLinkEnum]]]:
    new_cache: dict[str, dict[str, list[NoteLinkEnum]]] = copy.deepcopy(cache_1)
    for origin_note_name, origin_note_dict in cache_2.items():
        consolidate_note_linking_predictions_into_cache(
            origin_note_name, origin_note_dict, new_cache)
    return new_cache


In [None]:
# Create two caches with some overlapping data
cache_a = {
    'Note_A': {'Note_B': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT]}
}
cache_b = {
    'Note_A': {'Note_B': [NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO]},
    'Note_C': {'Note_D': [NoteLinkEnum.NOTAT_TO_INFO]}
}

# Consolidate them
merged = consolidate_caches(cache_a, cache_b)

# Verify the merge happened
print(merged['Note_A']['Note_B']) 
# Output: [<NoteLinkEnum...CONTENT>, <NoteLinkEnum...SEE_ALSO>]


[<NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO: 2>, <NoteLinkEnum.INFO_TO_INFO_IN_CONTENT: 1>]


In [None]:
#| hide
from fastcore.test import *

# --- Basic Edge Cases ---
# Empty Inputs
test_eq(consolidate_caches({}, {}), {})

# Idempotency (Duplicates)
dup_cache = {'A': {'B': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT]}}
res = consolidate_caches(dup_cache, dup_cache)
test_eq(len(res['A']['B']), 1)

# --- Complex Scenario Tests ---
# Base setup for the main test scenario
cache_1 = {
    'origin_note_1': {
        'relied_note_1': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT],
    },
    'origin_note_2': { # Only exists in cache_1
        'relied_note_1': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT],
    },
    'origin_note_3': { # Exists in both, but relied_note_2 is unique to cache_1
        'relied_note_2': [NoteLinkEnum.NOTAT_TO_INFO],
    }
}
cache_2 = {
    'origin_note_1': { # Exists in both, relied_note_1 exists in both
        'relied_note_1': [NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO]
    },
    'origin_note_3': { # Exists in both, but relied_note_1 is unique to cache_2
        'relied_note_1': [NoteLinkEnum.NOTAT_TO_INFO_VIA_NOTAT]
    }
}
new_cache = consolidate_caches(cache_1, cache_2)

# Case 1: Deep Merge (Union of Lists)
test_eq(
    set(new_cache['origin_note_1']['relied_note_1']), 
    {NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO, NoteLinkEnum.INFO_TO_INFO_IN_CONTENT}
)

# Case 2: Preservation of Left-Only Data
test_eq(
    new_cache['origin_note_2']['relied_note_1'], [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT]
)

# Case 3: Partial Merge (Left Unique Key in Shared Parent)
test_eq(
    new_cache['origin_note_3']['relied_note_2'], [NoteLinkEnum.NOTAT_TO_INFO]
)

# Case 4: Partial Merge (Right Unique Key in Shared Parent)
test_eq(
    new_cache['origin_note_3']['relied_note_1'], [NoteLinkEnum.NOTAT_TO_INFO_VIA_NOTAT]
)

# Case 6: Completely New Origin Key (Right-Only Top Level)
cache_new_origin = {'Z': {'Y': [NoteLinkEnum.NOTAT_TO_INFO]}}
res_new = consolidate_caches(cache_1, cache_new_origin)
test_eq(res_new['Z']['Y'], [NoteLinkEnum.NOTAT_TO_INFO])
test_eq(len(res_new), 4) # origin_1, origin_2, origin_3 + Z

# Case 7: Empty Inputs (Identity)
test_eq(consolidate_caches(cache_1, {}), cache_1)
test_eq(consolidate_caches({}, cache_2), cache_2)


In [None]:
#| export
def remove_blank_or_no_link_data_from_cache(
        cache: dict[str, dict[str, list[NoteLinkEnum]]], # See `parse_link_cache_note`. The first key is the name of an "origin note". The second key is the name of a "relied note" with respect to the origin note. The value is a list of the link types from the origin note to the relied note.
        ) -> dict[str, dict[str, list[NoteLinkEnum]]]: # A new cache, with lists that are either blank or which only contain `NoteLinkEnum.NO_LINK` are removed and with blank dict values are also removed..
    new_cache: dict[str, dict[str, list[NoteLinkEnum]]] = {} 
    for origin_note_name, origin_dict in cache.items():
        cleaned_dict: dict[str, list[NoteLinkEnum]] = {}
        for relied_note_name, listy in origin_dict.items():
            if not listy or (len(set(listy)) == 1 and listy[0] == NoteLinkEnum.NO_LINK):
                continue
            cleaned_dict[relied_note_name] = listy
        if cleaned_dict:
            new_cache[origin_note_name] = cleaned_dict
    return new_cache


In [None]:
cache = {
    'origin_note_1': {},
    'origin_note_2': {
        'relied_note_1': [],
        'relied_note_2': [NoteLinkEnum.NO_LINK] 
    },
    'origin_note_3': {
        'relied_note_1': [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT, NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO]
    }}
output = remove_blank_or_no_link_data_from_cache(cache)
test_eq(
    output, 
    {'origin_note_3':
     {'relied_note_1':
      [NoteLinkEnum.INFO_TO_INFO_IN_CONTENT, NoteLinkEnum.INFO_TO_INFO_IN_SEE_ALSO]}}
)


In [None]:
#| export
def remove_nonexistent_note_names_from_cache(
        cache: dict[str, dict[str, list[NoteLinkEnum]]], # See `parse_link_cache_note`. The first key is the name of an "origin note". The second key is the name of a "relied note" with respect to the origin note. The value is a list of the link types from the origin note to the relied note.
        vault: PathLike
        ) -> dict[str, dict[str, list[NoteLinkEnum]]]:
    """
    Remove names of nonexistent notes in `cache`.
    """
    cache_copy = copy.deepcopy(cache)
    keys = cache_copy.keys()
    for origin_note_name in list(keys):
        origin_note = VaultNote(vault, name=origin_note_name)
        if not origin_note.exists():
            cache_copy.pop(origin_note_name)
    for origin_note_name, origin_dict in cache_copy.items():
        keys = origin_dict.keys()
        for relied_note_name in list(keys):
            relied_note = VaultNote(vault, name=relied_note_name)
            if not relied_note.exists():
                origin_dict.pop(relied_note_name)
    return cache_copy

## Sieve note pairs to predict on

Since the number of pairs of notes grows quadratically in the number of notes, it takes too much time to make predictions one-by-one. It should be useful to prioritize certain pairs over others.


In [None]:
#| export
def sieve_potential_relied_notes(
        vault: PathLike,
        reference: str,
        origin_note: VaultNote, # an info note
        note_data: dict[str, NoteData],
        # potential_relied_notes: list[VaultNote],
        appendix_notes: list[VaultNote], # notes whose index notes are appendix notes
        cache: dict[str, dict[str, list[NoteLinkEnum]]],
        notation_similarity_threshold: float = 0.8, # The threshold that the similarity metric of a notion must exceed for the name of a notation note to be included in the output.
        skip_already_made_predictions: bool = True,
        ) -> set[str]: # Names of potential relied notes that may be good to predict note linking from `origin_note` for.`
    if origin_note.name not in note_data:
        print(f'`origin_note` was not in `note_data`. Perhaps a `origin_note` has been renamed at some point and it may be necessary to reload `note_data`. `origin_note`: {origin_note}.')
        return set()
    index_note: VaultNote = index_note_for_reference(vault, reference, update_cache=True)
    info_notes: list[VaultNote] = notes_linked_in_notes_linked_in_note(index_note, as_dict=False)
    appendix_note_names: set[str] = set([appendix_note.name for appendix_note in appendix_notes])

    relied_note_names = set()

    # Add an info not if it 
    # 1. is in the appendix or precedes `origin_note`, is a definition/notation note
    # 2. is in the same section and precedes `origin_note` and is a context note.
    # TODO: Automatically add a def/notat note if it precedes `origin_note` in a section by a little.
    # Add a notation note if it 
    # 1. looks similar to a substr in a latex str in the origin_note.
    for info_note in info_notes + appendix_notes:
        if not info_note.exists():
            continue
        if info_note.name not in note_data:
            # If this happens, it may be the case that `info_note` has been
            # renamed, but this has not been reflected in `note_data`.
            continue
        # Ignore `info_note` if it was already predicted on or it precedes `origin_note` and is not an appendix note. 
        if (skip_already_made_predictions
                and origin_note.name in cache
                and info_note.name in cache[origin_note.name]):
            continue
        if (note_data_order_cmp(note_data[origin_note.name], note_data[info_note.name]) <= 0
                and info_note.name not in appendix_note_names):
            continue

        mf = MarkdownFile.from_vault_note(info_note)
        # ignore non-definition/notation notes.
        if not (mf.has_tag('_auto/_meta/definition') or mf.has_tag('_auto/_meta/notation') or mf.has_tag('_meta/definition') or mf.has_tag('_meta/notation')):
            continue
        # admit context notes in the same section as `origin_note` that also precede `origin_note`.
        elif (mf.has_tag("_auto/_meta/context") or mf.has_tag('_meta/context')
                and note_data_order_cmp(note_data[origin_note.name], note_data[info_note.name]) >= 0
                and note_data[origin_note.name].section_num == note_data[info_note.name].section_num):
            relied_note_names.add(info_note.name)
            continue
        relied_note_names.add(info_note.name)
        # For each info note with notations, try to see if the notations resemble notations used in `origin_note`.
        notat_notes: list[VaultNote] = notation_notes_linked_in_see_also_section(
            info_note, vault, as_vault_notes=True)
        if skip_already_made_predictions:
            notat_notes = [
                notat_note for notat_note in notat_notes
                if not (origin_note.name in cache and notat_note.name in cache[origin_note.name])]
        notat_note_candidates: list[VaultNote] = []
        for notat_note in notat_notes:
            individual_notat_note: list[VaultNote] = similar_notat_notes_in_note(
                origin_note, notat_note, threshold=notation_similarity_threshold)
            if not individual_notat_note:
                continue
            relied_note_names.add(info_note.name)
            relied_note_names.add(notat_note.name)

        # For each info note with definitions, try to see if the definitions resemble phrases used in `origin_note`. 


    
    # # 2. find all context notes in the same section as origin_note
    # origin_index_note = index_note_of_note(origin_note)
    # section_notes: dict[str, VaultNote] = notes_linked_in_note(
    #     origin_index_note, as_dict=True)
    # relied_note_names.update(section_notes.keys())

    return relied_note_names



In [None]:
#| export
def _predict_one_direction_and_consolidate_cache(
        origin_note: VaultNote,
        relied_note: VaultNote,
        cache: dict[str, dict[str, list[NoteLinkEnum]]],
        predictor: MultiLabelPipeline, 
        format: Literal['bert', 't5'],
        note_data: dict[str, NoteData] | None,
        skip_already_made_predictions: bool,
        threshold: float | dict[str, float],
        ) -> None:
    """
    """
    if (skip_already_made_predictions
            and origin_note.name in cache and relied_note.name in cache[origin_note.name]):
        return
    outputs: dict[str, list[NoteLinkEnum]] = predict_note_linking(
        origin_note, relied_note, predictor, format, note_data, threshold=threshold)
    consolidate_note_linking_predictions_into_cache(origin_note, outputs, cache)

In [None]:
#| export
def predict_on_relied_note_and_related_notat_notes(
        origin_note: VaultNote,
        relied_note: VaultNote,
        cache: dict[str, dict[str, list[NoteLinkEnum]]], # The current cache of predictions, see `parse_link_cache_note` for example; this is used to skip predictions that have already been made. Moreover, the cache is updated based on the predictions made. 
        predictor: MultiLabelPipeline,
        format: Literal['bert', 't5'] = 'bert',
        note_data: Optional[dict[str, NoteData]] = None,
        skip_already_made_predictions: bool = True,
        predict_reverse_too: bool = False,
        threshold: float | dict[str, float] = 0.5,
        ) -> None:
    """
    Update `cache` by making predictions from `origin_note` to `relied_note` (and vice versa).
    Moreover, 
    """
    # predict `origin_note` to `relied_note``

    _predict_one_direction_and_consolidate_cache(
        origin_note, relied_note, cache, predictor,
        format, note_data, skip_already_made_predictions, threshold)
    if predict_reverse_too:
        _predict_one_direction_and_consolidate_cache(
            relied_note, origin_note, cache, predictor, format, note_data,
            skip_already_made_predictions, threshold)

    # For each relied note that is 1. an info note, 2. got predicted to be a relied note via info_to_info_via_notat, and 3. has a notation, predict whether the relevant notation notes ought to be linked.
    if not origin_note.name in cache:
        return
    relied_note_names: list[str] = list(cache[origin_note.name])
    for relied_note_name in relied_note_names:
        relied_note_link_types = cache[origin_note.name][relied_note_name]
        if not relied_note_link_types:
            continue
        if not NoteLinkEnum.INFO_TO_INFO_VIA_NOTAT in relied_note_link_types:
            continue
        relied_note = VaultNote(origin_note.vault, name=relied_note_name)
        notat_notes: list[VaultNote] = notation_notes_linked_in_see_also_section(
            relied_note, origin_note.vault, as_vault_notes=True)
        for notat_note in notat_notes:
            _predict_one_direction_and_consolidate_cache(
                origin_note, notat_note, cache, predictor, format, note_data,
                skip_already_made_predictions, threshold)


## Identify notation notes that should be embedded as footnotes in information notes or linked in other notation notes

In [None]:
#| export
def similar_notat_notes_in_note(
        origin_note: VaultNote, # Either an info or a notat note
        notation_notes: VaultNote | list[VaultNote], # The notation notes that are considered to be 
        threshold: float = 0.8, 
        ) -> list[VaultNote]: # The notation notes whose notations are determined to be similar to notations used in `origin_note`.
    """
    Determine which notation notes introduce notations which resemble notations used
    in `origin_note`.

    This is a fuzzy function purely based on the str value of the notation and the text of `origin_note` and does not use ML predictions. 
    """
    if isinstance(notation_notes, VaultNote):
        notation_notes = [notation_notes]

    text = origin_note.text()
    indices = latex_indices(text)
    latex_texts_in_origin_note: list[str] = []
    for start, end in indices:
        latex_text = text[start:end]
        latex_text = latex_text.strip('$ ')
        latex_texts_in_origin_note.append(latex_text)
    matching_notat_notes: list[VaultNote] = []
    for notation_note in notation_notes:
        if notation_note.name == origin_note.name:
            continue
        notation: str = notation_in_note(notation_note, include_dollar_signs=False)
        for latex_text in latex_texts_in_origin_note:
            if latex_str_is_likely_in_latex_str(notation, latex_text, threshold=threshold):
                matching_notat_notes.append(notation_note)
                break
    return matching_notat_notes
        

In [None]:
#| TODO: test

In [None]:
#| export
def locate_footnote_embedded_notation_link(
        origin_note: VaultNote, # An info note 
        notation_note: VaultNote, # The notation notes that are considered to be 
        locate_by: Literal['first', 'best'] = 'best', # If `'first'`, then the first latex string for which the `latex_str_in_latex_str_fuzz_metric` score exceed threshold is used as the location. If `'best'` or if no such latex string exists, then the latex string giving the greatest score is used as the location.
        threshold: float = 0.8,
        ) -> int | None: # The index in `origin_note.text()` at which the footnote to an embedded link to `notation_note` should be added. If the main note of `notation_note` is `origin_note`, then `None`.
    """
    Determine where in `origin_note` a footnote to an embedded link to `notation_note` should be added.

    Such a location would be at the end of the closing of a latex string in the text. 

    This is a fuzzy function purely based on the str value of the notation and the text of `origin_note` and does not use ML predictions. 
    """
    main_note = main_of_notation(notation_note, as_note=False)
    if main_note and origin_note.name == main_note:
        return None
    notation: str = notation_in_note(notation_note, include_dollar_signs=False)
    text = origin_note.text()
    indices = latex_indices(text)
    scores: dict[int, float] = {} # Keys are end indices and values are scores of how likely it seems that the latex str seems to use the notation.
    for start, end in indices:
        latex_text = text[start:end]
        latex_text = latex_text.strip('$ ')
        score: float = latex_str_in_latex_str_fuzz_metric(notation, latex_text)
        if score > threshold and locate_by == 'first':
            return end
        else:
            scores[end] = score
    max_key = max(scores, key=scores.get)
    return max_key

In [None]:
mock_origin_note = MagicMock()
mock_notation_note = MagicMock()
mock_origin_note.text.return_value = r"""For each integer $m$ and each transitive $G \leq S_m$, there are constants $C(G), Q(G)$, and $e(G)$ such that, for all $q>Q(G)$ coprime to $\#G$ and all $X>0$, 

$$N_G(\mathbb{F}_q(t),X) \leq C(G) X^{a(G)} \log(X)^{e(G)}$$


"""

mock_notation_note.text.return_value = r"""---
detect_regex: 
latex_in_original: ["a(G)"]
tags: [_meta/notation_note_named]
---
$a(G)$ [[ellenberg_tran_westerland_fnfcqsamcff_1. Introduction_ellenberg_tran_westerland_fnfcqsamcff|denotes]] $[\min_{G \setminus \{1 \}} ind(g)]^{-1}$ where $G$ is a transitive subgroup of $S_m$ and 

![[ellenberg_tran_westerland_fnfcqsamcff_1. Introduction_ellenberg_tran_westerland_fnfcqsamcff#^38959b]]

for $g \in S_m$.

For instance, if $G = S_m$, the minimal index is $1$, realized by transpositions, and so $a(S_m) = 1$.
- [$ind(g)$](ellenberg_tran_westerland_fnfcqsamcff_notation_ind_g_index_of_element_of_S_m.md)"""

with (mock_patch('__main__.latex_str_in_latex_str_fuzz_metric')
        as mock_latex_str_in_latex_str_fuzz_metric,
        mock_patch('__main__.main_of_notation') as mock_main_of_notation,
        mock_patch('__main__.notation_in_note') as mock_notation_in_note,
        ):
    mock_latex_str_in_latex_str_fuzz_metric.side_effect = [0, 0, 0, 0, 0, 0, 0, 1]
    mock_main_note = MagicMock()
    mock_main_of_notation.return_value = mock_main_note
    mock_notation_in_note.return_value = '$a(G)$'
    output = locate_footnote_embedded_notation_link(
        mock_origin_note, mock_notation_note, locate_by='best', threshold=0.8)
    print(output)
    print(mock_origin_note.text.return_value[:output])

222
For each integer $m$ and each transitive $G \leq S_m$, there are constants $C(G), Q(G)$, and $e(G)$ such that, for all $q>Q(G)$ coprime to $\#G$ and all $X>0$, 

$$N_G(\mathbb{F}_q(t),X) \leq C(G) X^{a(G)} \log(X)^{e(G)}$$


In [None]:
#| export
def _where_to_add_notation_links(
        origin_note: VaultNote,
        relied_notes: list[VaultNote],
        locate_by: Literal['first', 'best'] = 'best',
        threshold: float = 0.8,
        ) -> dict[int, list[VaultNote]]:
    """
    Helper function to `add_notation_note_embedded_footnotes_to_info_note`.
    """
    where_to_add: dict[int, list[VaultNote]] = {} # Keys are end indices of latex str in `origin_note` and values are lists of VaultNote objects representing notation notes for which the embedded links should be added.
    for relied_note in relied_notes:
        location: int | None = locate_footnote_embedded_notation_link(
            origin_note, relied_note, locate_by, threshold)
        if location is None:
            continue
        if not location in where_to_add:
            where_to_add[location] = []
        where_to_add[location].append(relied_note)
    return where_to_add

In [None]:
#| export
def _add_notation_note_embedded_footnotes(
        text: str,
        where_to_add: dict[int, list[VaultNote]],
        ) -> str:
    """
    Helper function to `add_notation_note_embedded_footnotes_to_info_note`.
    """
    reverse_sorted_locations: list[int] = sorted(where_to_add.keys(), reverse=True)
    # iterate in reverse to make sure that the modifications made along the way do not change
    # the indices of the locations.
    for location in reverse_sorted_locations:
        notation_notes_to_link = where_to_add[location]
        available_footnote_numbers: list[int] = identify_available_footnote_numbers(
            text, count=len(notation_notes_to_link))
        footnote_text = ''.join([f'[^{num}]' for num in available_footnote_numbers])
        footnote_mentions = '\n'.join(
            [f'[^{num}]: ![[{notat_note.name}]]'
             for num, notat_note in zip(available_footnote_numbers, notation_notes_to_link)])
        new_line_index = text.find('\n', location+1)
        if new_line_index == -1:
            new_line_index = len(text)
        pieces = [text[0:location], text[location:new_line_index], text[new_line_index:]]

        if location > 1 and text[location-2] == '$': # latex str ends with '$$'
            # pieces.append(f'\n\n{footnote_text}\n\n{footnote_mentions}\n\n')
            pieces.insert(2, f'\n\n{footnote_text}\n\n{footnote_mentions}\n\n')
            text = ''.join(pieces)
            # start a new line to add the footnotes and then start another
            # to add the footnote mentions.
        else: # latex str ends with '$'
            pieces.insert(2, f'\n\n{footnote_mentions}\n\n')
            pieces.insert(1, footnote_text)
            text = ''.join(pieces)
            # new_line_index = 
    return text
    

In [None]:
#| hide
text = r"""$$asdf$$"""
mock_notat_note = MagicMock()
mock_notat_note.name = 'notat_note_name'
where_to_add = {8: [mock_notat_note]}
print(_add_notation_note_embedded_footnotes(text, where_to_add))

text = r"""asdf asdf $asdf$ asdf asdf 

fjfjfjfj
"""
mock_notat_note = MagicMock()
mock_notat_note.name = 'notat_note_name'
where_to_add = {16: [mock_notat_note]}
output = _add_notation_note_embedded_footnotes(text, where_to_add)
print(output)


text = r"""---
cssclass: clean-embeds
aliases: []
tags: [_meta/literature_note, _reference/18785, _meta/concept, _meta/proof]
---
# Topic[^1]

Theorem 2.1. The map $\mathrm{q} \mapsto \mathrm{q} \cap A$ defines a bijection from the set of prime ideals of $S^{-1} A$ and the set of prime ideals of A that do not intersect $S .$ The inverse map is $\mathfrak{p} \mapsto \mathfrak{p} S^{-1} A$.

Proof. See [1, Cor.11.20] or [2, Prop. 3.11.iv].

# See Also

# Meta
## References
![[_reference_18785]]

## Citations and Footnotes
[^1]: Sutherland, Theorem 2.1, Page 11"""

mock_notat_note.name = '18785_notation_S_minus_1_A_localization_of_a_commutative_ring_with_respect_to_a_multiplicative_subset'
where_to_add = {254: [mock_notat_note]}
output = _add_notation_note_embedded_footnotes(text, where_to_add)
print(output)


text = r"""

For each integer $m$ and each transitive $G \leq S_m$, there are constants $C(G), Q(G)$, and $e(G)$ such that, for all $q>Q(G)$ coprime to $\#G$ and all $X>0$, 

$$N_G(\mathbb{F}_q(t),X) \leq C(G) X^{a(G)} \log(X)^{e(G)}$$

blah blah

"""
mock_notat_note = MagicMock()
mock_notat_note.name = 'notat_note_name'
where_to_add = {224: [mock_notat_note]}
output = _add_notation_note_embedded_footnotes(text, where_to_add)
print(output)

$$asdf$$

[^1]

[^1]: ![[notat_note_name]]


asdf asdf $asdf$[^1] asdf asdf 

[^1]: ![[notat_note_name]]



fjfjfjfj

---
cssclass: clean-embeds
aliases: []
tags: [_meta/literature_note, _reference/18785, _meta/concept, _meta/proof]
---
# Topic[^1]

Theorem 2.1. The map $\mathrm{q} \mapsto \mathrm{q} \cap A$ defines a bijection from the set of prime ideals of $S^{-1} A$[^2] and the set of prime ideals of A that do not intersect $S .$ The inverse map is $\mathfrak{p} \mapsto \mathfrak{p} S^{-1} A$.

[^2]: ![[18785_notation_S_minus_1_A_localization_of_a_commutative_ring_with_respect_to_a_multiplicative_subset]]



Proof. See [1, Cor.11.20] or [2, Prop. 3.11.iv].

# See Also

# Meta
## References
![[_reference_18785]]

## Citations and Footnotes
[^1]: Sutherland, Theorem 2.1, Page 11


For each integer $m$ and each transitive $G \leq S_m$, there are constants $C(G), Q(G)$, and $e(G)$ such that, for all $q>Q(G)$ coprime to $\#G$ and all $X>0$, 

$$N_G(\mathbb{F}_q(t),X) \leq C(G) X^{a(G)} 

In [None]:
#| export
def add_notation_note_embedded_footnotes_to_info_note(
        origin_note: VaultNote, # An info note
        relied_notes: Optional[VaultNote | list[VaultNote]] = None, # notation notes to add embedded footnotes for.
        cache: Optional[dict[str, dict[str, list[NoteLinkEnum]]]] = None, # The cache from which to identify the notation notes to add embedded footnotes for.
        locate_by: Literal['first', 'best'] = 'best',
        threshold: float = 0.8,
        ):
    """
    Modify the contents of `origin_note` to add footnotes to embedded links to `relied_notes`

    One of `relied_notes` or `cache` must be passed.
    """
    if relied_notes is None and cache is None:
        raise ValueError("Expected `relied_note` or `cache` to be specified, but both were `None`.")
    if relied_notes is None:
        if origin_note.name not in cache:
            print(f'`origin_note.name` is not in `cache`. `origin_note` is {origin_note}.')
            return
        cache = remove_nonexistent_note_names_from_cache(cache, origin_note.vault)
        relied_notes: list[VaultNote] = []
        for relied_note_name, link_enums in cache[origin_note.name].items():
            relied_note = VaultNote(origin_note.vault, name=relied_note_name)
            if not note_is_of_type(relied_note, PersonalNoteTypeEnum.NOTATION_NOTE):
                continue
            if NoteLinkEnum.INFO_TO_NOTAT_VIA_EMBEDDING in link_enums:
                relied_notes.append(relied_note)
    if isinstance(relied_notes, VaultNote):
        relied_notes = [relied_notes]

    # Try to only add embedded links to notation notes that do not already exist in `origin_note`
    origin_note_text = origin_note.text()
    embedded_links_in_text: list[ObsidianLink] = links_from_text(
        origin_note_text, ObsidianLink(
            is_embedded=True, file_name=-1, anchor=-1, custom_text=-1, link_type=LinkType.WIKILINK))
    embedded_note_names_in_text: set[str] = set([link.file_name for link in embedded_links_in_text])
    new_relied_notes: list[VaultNote] = [
        relied_note for relied_note in relied_notes if relied_note.name not in embedded_note_names_in_text]

    where_to_add: dict[int, list[VaultNote]] = _where_to_add_notation_links(
        origin_note, new_relied_notes, locate_by, threshold)
    new_text = _add_notation_note_embedded_footnotes(origin_note.text(), where_to_add)
    origin_note.write(new_text)


In [None]:
# origin_note = VaultNote(vault, name='18785_Theorem 2.1')
# notat_note = VaultNote(vault, name='18785_notation_S_minus_1_A_localization_of_a_commutative_ring_with_respect_to_a_multiplicative_subset')
# add_notation_note_embedded_footnotes_to_info_note(
#     origin_note, notat_note,
#     )

In [None]:
# print(origin_note.text())

# Notation Summarization using `NoteData` classes 

The functions above surrounding the `NoteData` classes should actually be close to providing the means for gathering data for other ML tasks, such as the summarization task (thus far provided by `25_markdown.obisidian.personal.machine_learning.notation.summarization.ipynb`) and the definition naming task (thus far provided by `35_markdown.obsidian.personal.machine_learning.definition_and_notation_naming.ipynb`), and even improve upon them by providing contextual data (given by the notes linked by a given note).

In [None]:
#| export
class SummarizationDataPoint(TypedDict):
    input: str
    output: str
    notat_note_name: str

In [None]:
#| export
def summarization_data(
        notat_note_data_point: NotatNoteData,
        info_note_data: dict[str, InfoNoteData], # For getting data from the linked notes.
        notat_note_data: dict[str, NotatNoteData], # For getting data from the linked notes.
        format: Literal['bert', 't5'],
        augmentation: Optional[Literal['high', 'mid', 'low']] = None,
        ) -> SummarizationDataPoint:
    """
    Compile the summarization data from a `NotatNoteData` . 

    `notat_note_data_point` must have a nonblank value for its `note_content` attribute.

    The summarization data consists of the notation note's main note content and (optionally)
    the position data (which is the position data of the main note) along with (also optionally)
    the content and/or position data of info notes that either the notation note or its
    main note depend on (via the `NoteLinkEnum.NOTAT_TO_INFO_VIA_NOTAT` or 
    `NoteLinkEnum.INFO_TO_INFO_IN_CONTENT` enum items).

    The augmentations are not applied to `notat_note_data_point` but are rather applied to
    (copies of) the info notes that the notation note or its main note depends on. Use
    `augment_notat_note_data_for_summarization` to augment that `NotatNoteData` object.

    **Raises**
    - `ValueError`
        - if `notat_note_data_point.note_content` is not a nonblank `str`.
    """
    content = notat_note_data_point.note_content
    if content is None or content.strip() == '':
        raise ValueError(f"Expected `notat_note_data_point.content` to be a non blank string but was {notat_note_data_point.content}. The relevant notation note name is {notat_note_data_point.note_name}.")
    # Temporarily blank out the `note_content` attribute for the `.data_string` method
    # To exclude whatever is in the `note_content` attribute.
    notat_note_data_point.note_content = None
    notat_note_string = notat_note_data_point.data_string(format)
    notat_note_data_point.note_content = content

    info_note_names_to_consider: set[str] = set()
    for relied_note_name, link_types in notat_note_data_point.directly_linked_notes.items():
        if NoteLinkEnum.NOTAT_TO_INFO_VIA_NOTAT in link_types:
            info_note_names_to_consider.add(relied_note_name)
    if (notat_note_data_point.main_note
            and notat_note_data_point.main_note in info_note_data):
        main_note_data_point = info_note_data[notat_note_data_point.main_note]
        for relied_note_name, link_types in main_note_data_point.directly_linked_notes.items():
            if NoteLinkEnum.INFO_TO_INFO_IN_CONTENT in link_types:
                info_note_names_to_consider.add(relied_note_name)

    parts: list[str] = [notat_note_string]
    for relied_note_name in list(info_note_names_to_consider):
        if not relied_note_name in info_note_data:
            continue
        relied_note_data_point = info_note_data[relied_note_name]
        relied_note_data_point = relied_note_data_point.deepcopy()
        if augmentation:
            erase_position_metadata = _erase_position_metadata(augmentation)
            relied_note_data_point.randomly_modify(augmentation, erase_position_metadata)
        parts.append(relied_note_data_point.data_string(format))

    sep_str = '\n\n[SEP]\n\n' if format == 'bert' else '\n\n</s>\n\n'
    input = sep_str.join(parts)
    output = notat_note_data_point.note_content
    return SummarizationDataPoint(
        input=input, output=output,
        notat_note_name=notat_note_data_point.note_name)



In [None]:
#| export
def augment_notat_note_data_for_summarization(
        notat_note_data_point: NotatNoteData,
        augmentation: Literal['high', 'mid' ,'low'],
        info_note_data: dict[str, InfoNoteData], # For getting data from the linked notes.
        modify_links: bool = True, # If `True`, randomly modify the linking data from that of `notat_note_data_point`
        ) -> NotatNoteData:
    """
    Return a modified copy of `notat_note_data` augmented for providing
    notation summarization data.

    The `note_content` attribute of the outputted copy should not be modified
    as it serves as the output of the training data.
    """
    notat_note_data_copy = notat_note_data_point.deepcopy()
    erase_position_metadata = _erase_position_metadata(augmentation)
    content = notat_note_data_copy.note_content
    notat_note_data_copy.randomly_modify(augmentation, erase_position_metadata)
    notat_note_data_copy.note_content = content

    if modify_links:
        if augmentation == 'high':
            num_rand_info_note_data_to_add = 3
            key_deletion_prob = 0.10
        elif augmentation == 'mid':
            num_rand_info_note_data_to_add = 2
            key_deletion_prob = 0.05
        elif augmentation == 'low':
            num_rand_info_note_data_to_add = 1
            key_deletion_prob = 0.02
        else:
            num_rand_info_note_data_to_add = 0
            key_deletion_prob = 0
        # Delete "links" at random.
        keys = list(notat_note_data_copy.directly_linked_notes)
        for key in keys:
            if random.random() < key_deletion_prob:
                notat_note_data_copy.directly_linked_notes.pop(key)
        # Add "links" to info notes at random.
        random_info_note_names_to_add = random.choices(
            list(info_note_data), k=min(len(info_note_data), num_rand_info_note_data_to_add))
        for random_info_note_name in random_info_note_names_to_add:
            _update_dict(
                notat_note_data_copy.directly_linked_notes, 
                random_info_note_name,
                NoteLinkEnum.NOTAT_TO_INFO_VIA_NOTAT)
    return notat_note_data_copy

In [None]:
#| export
def notat_note_data_admissible_for_summarization_data(
        notat_note_data_point: NotatNoteData
        ) -> bool:  # `True` if the notation note data does not have the `_auto/notation_summary` tag, and the content of the notation note is essentially note blank.
    if notat_note_data_point.tags and '_auto/notation_summary' in notat_note_data_point.tags:
        return False
    if notat_note_data_point.note_content is None:
        return False 
    return bool(notat_note_data_point.note_content.strip())

In [None]:

# for name, data_point in notat_note_data.items():
#     if notat_note_data_admissible_for_summarization_data(data_point):
#         print(name)
#         break

# summ_data = summarization_data(data_point, info_note_data, notat_note_data, 'bert')
# print(summ_data['input'])


In [None]:
# notat_note_data_admissible_for_summarization_data(notat_note_data['achter_pries_imht_notation_T_bar_gamma_smooth_trielliptic_curves_with_inertia_type'])

In [None]:
#| export
def _add_augmented_data_points(
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, NotatNoteData],
        format: Literal['bert', 't5'],
        data_point: NotatNoteData,
        dict_data: list[SummarizationDataPoint],
        ):
    """
    Helper function to `summarization_dataset_from_note_data`.
    """
    data_point_without_links = data_point.deepcopy()
    data_point_without_links.directly_linked_notes = {}
    for augmentation in ['low', 'mid', 'high']:
        for modify_links, base_data_point in zip([True, False], [data_point, data_point_without_links]):
            aug_data_point: NotatNoteData = augment_notat_note_data_for_summarization(
                base_data_point, augmentation, info_note_data, modify_links)
            dict_data.append(summarization_data(
                aug_data_point, info_note_data, notat_note_data, format, augmentation))
    

In [None]:
#| export
def summarization_dataset_from_note_data(
        info_note_data: dict[str, InfoNoteData],
        notat_note_data: dict[str, NotatNoteData],
        augment: bool,
        format: Literal['bert', 't5'],
        ) -> Dataset:
    admissible_notat_note_names: list[str] = []
    for notat_note_name, data_point in notat_note_data.items():
        if notat_note_data_admissible_for_summarization_data(data_point):
            admissible_notat_note_names.append(notat_note_name)
    dict_data: list[SummarizationDataPoint] = []
    for admissible_notat_note_name in admissible_notat_note_names:
        data_point: NotatNoteData = notat_note_data[admissible_notat_note_name]
        dict_data.append(summarization_data(
            data_point, info_note_data, notat_note_data, format, augmentation=None))
        if not augment:
            continue
        _add_augmented_data_points(info_note_data, notat_note_data, format, data_point, dict_data)
    return Dataset.from_list(dict_data)
