In [None]:
#| default_exp markdown.obsidian.personal.machine_learning.tokenize

# markdown.obsidian.personal.machine_learning.tokenize
> Functions for gathering and processing tokenization data and for using ML models trained with such data.

Previous, `trouver` just had functionalities for using ML models to identify newly introduced notations in text and for gathering data to train such models. Moreover, such models were merely classification models, and using these models to identify newly introduced notations had a lot of computational redundancies.

This module aims to provide the same functionalities for both definitions and notations by training and using token classification models instead.

In [None]:
# TODO: Create a new module dedicated to definition and notation identification and move approparite functions over there. 

In [None]:
#| export
import os 
from os import PathLike
from pathlib import Path
from typing import Union
import warnings

import bs4
import pandas as pd
from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerFast

from trouver.helper import add_HTML_tag_data_to_raw_text, add_space_to_lt_symbols_without_space, double_asterisk_indices, notation_asterisk_indices, replace_string_by_indices, remove_html_tags_in_text
from trouver.markdown.markdown.file import MarkdownFile, MarkdownLineEnum
from trouver.markdown.obsidian.personal.note_processing import process_standard_information_note
from trouver.markdown.obsidian.vault import VaultNote

In [None]:
from unittest import mock

from datasets import ClassLabel, Dataset, Features, Sequence, Value
from transformers import AutoTokenizer
from fastcore.test import *


## Gather ML data from information notes

In [None]:
#| export
def convert_double_asterisks_to_html_tags(
        text: str
        ) -> str:
    """
    Replace the double asterisks, which signify definitions and notations,
    in `text` with HTML tags.
    """
    double_asts = double_asterisk_indices(text)
    replacement_html_tags = [
        _html_tag_from_double_ast(text[start:end])
        for start, end in double_asts]
    return replace_string_by_indices(
        text, double_asts, replacement_html_tags)


def _html_tag_from_double_ast(
        double_ast_string: str # Starts and ends with double asts
        ) -> str:
    """
    Get the HTML tag representing definition or notation data from
    a string surrounded by double asterisks.

    This is used in the `_convert_double_asterisks_to_html_tags` function.
    """
    no_asts = double_ast_string[2:-2]
    if notation_asterisk_indices(double_ast_string):
        return f'<span notation="">{no_asts}</span>'
    else:
        return f'<b definition="">{no_asts}</b>'


In [None]:
print(convert_double_asterisks_to_html_tags("**hi**. Here is a notation **$asdf$**"))
test_eq(convert_double_asterisks_to_html_tags("**hi**. Here is a notation **$asdf$**"), '<b definition="">hi</b>. Here is a notation <span notation="">$asdf$</span>')


<b definition="">hi</b>. Here is a notation <span notation="">$asdf$</span>


In [None]:
#| export
def raw_text_with_html_tags_from_markdownfile(
        mf: MarkdownFile,
        vault: PathLike
        ) -> str:
    """
    Process the `MarkdownFile`, replacing the double asterisk surrounded
    text indicating definitions and notations to be HTML tags instead.
    """
    mf = process_standard_information_note(
        mf, vault, remove_double_asterisks=False,
        remove_html_tags=False)
    return convert_double_asterisks_to_html_tags(str(mf))



In [None]:
#| hide

# TODO: 
# I want to make sure that footnotes are getting properly removed.
mf = MarkdownFile.from_string(
    r"""---
aliases: []
tags: []
---
# Something  

Some kind of potato[^2]

[^2]: Some footnote

[[link_to_note|Some link]]


# See Also
# Meta
## References and Citations
""") 
raw_text_with_html_tags_from_markdownfile(mf, None)


'Some kind of potato[^2]\n\n[^2]: Some footnote\n\nSome link\n'

In [None]:
#| hide
mf = MarkdownFile.from_string(
    r"""---
aliases: []
tags: []
---
# Galois group of a separable and normal finite field extension

Let $L/K$ be a separable and normal finite field extension. Its <b definition="Galois group of a separable and normal finite field extension">Galois group</b> <span notation="">$\operatorname{Gal}(L/K)$</span> is...

# Galois group of a separable and normal profinite field extension

In fact, the notion of a Galois group can be defined for profinite field extensions. Given a separable and normal profinite field extension $L/K$, say that
$L = \varinjlim_i L_i$ where $L_i/K$ are finite extensions. Its **Galois group** **$\operatorname{Gal}(L/K)$**

# See Also
# Meta
## References and Citations
""")

In the following example, let `mf` be the following `MarkdownFile`:

In [None]:
print(str(mf))

---
aliases: []
tags: []
---
# Galois group of a separable and normal finite field extension

Let $L/K$ be a separable and normal finite field extension. Its <b definition="Galois group of a separable and normal finite field extension">Galois group</b> <span notation="">$\operatorname{Gal}(L/K)$</span> is...

# Galois group of a separable and normal profinite field extension

In fact, the notion of a Galois group can be defined for profinite field extensions. Given a separable and normal profinite field extension $L/K$, say that
$L = \varinjlim_i L_i$ where $L_i/K$ are finite extensions. Its **Galois group** **$\operatorname{Gal}(L/K)$**

# See Also
# Meta
## References and Citations


The `raw_text_with_html_tags_from_markdownfile` function processes the `MarkdownFile` much in the same way as the `process_standard_information_note` function, except it 1. preserves HTML tags, and 2. replaces text surrounded by double asterisks `**` with HTML tags signifiying whether the text displays a definition or a notation.

In the below example, note that the `vault` parameter is set to `None`; this is fine for this example becaues the `process_standard_information_note` function only needs a `vault` argument when embedded links need to be replaced with text (via the `MarkdownFile.replace_embedded_links_with_text` function), but `mf` has no embedded links.

In [None]:
print(raw_text_with_html_tags_from_markdownfile(mf, None))

Let $L/K$ be a separable and normal finite field extension. Its <b definition="Galois group of a separable and normal finite field extension">Galois group</b> <span notation="">$\operatorname{Gal}(L/K)$</span> is...

In fact, the notion of a Galois group can be defined for profinite field extensions. Given a separable and normal profinite field extension $L/K$, say that
$L = \varinjlim_i L_i$ where $L_i/K$ are finite extensions. Its <b definition="">Galois group</b> <span notation="">$\operatorname{Gal}(L/K)$</span>



In [None]:
#| hide
assert '**' not in raw_text_with_html_tags_from_markdownfile(mf, None)

In [None]:
#| export
# TODO: implement a measure to not get the definition identification data, e.g. by 
# detecting a `_auto/definition_identification` tag.
def html_data_from_note(
        note: VaultNote,
        vault: PathLike
        ) -> Union[dict, None]: # The keys to the dict are "Note name", "Raw text", "Tag data". However, `None` is returned if `note` does not exist or the note is marked with auto-generated, unverified data.
    # TODO: implement obtaining multiple datapoints from a single note
    # Via typos for example.
    """Obtain html data for token classification from the information note.

    Currently, the token types mainly revolve around definitions and
    notations.

    If `note` has the tag `_auto/def_and_notat_identified`, then the data
    in the note is assumed to be auto-generated and not verified and
    `None` is returned.

    **Returns**
    - Union[dict, None]
        - The keys-value pairs are 
            - `"Note name"` - The name of the note
            - `"Raw text"` - The raw text to include in the data.
            - `"Tag data"` - The list with HTML tags carrying definition/notation
              data and their locations in the Raw text. See the second output to
              the function `remove_html_tags_in_text`.
                - Each element of the list is a tuple consisting of a ``bs4.element.Tag``
                  and two ints.
    """
    if not note.exists():
        return None
    mf = MarkdownFile.from_vault_note(note)
    if mf.has_tag('_auto/def_and_notat_identified'):
        return None
    raw_text_with_tags = raw_text_with_html_tags_from_markdownfile(mf, vault)
    raw_text, tags_and_locations = remove_html_tags_in_text(raw_text_with_tags)
    return {
        "Note name": note.name,
        "Raw text": raw_text,
        "Tag data": tags_and_locations}

In the following example, we mock a `VaultNote` whose content is that of `mf` in the example for the `raw_text_with_html_tags_from_markdownfile` function.

In [None]:
with (mock.patch('__main__.VaultNote') as mock_VaultNote,
      mock.patch('__main__.MarkdownFile.from_vault_note') as mock_from_vault_note):
    mock_VaultNote.exists.return_value = True
    mock_VaultNote.name = "Note's name"
    mock_from_vault_note.return_value = mf

    html_data = html_data_from_note(mock_VaultNote, None)
    print(html_data)


{'Note name': "Note's name", 'Raw text': 'Let $L/K$ be a separable and normal finite field extension. Its Galois group $\\operatorname{Gal}(L/K)$ is...\n\nIn fact, the notion of a Galois group can be defined for profinite field extensions. Given a separable and normal profinite field extension $L/K$, say that\n$L = \\varinjlim_i L_i$ where $L_i/K$ are finite extensions. Its Galois group $\\operatorname{Gal}(L/K)$\n', 'Tag data': [(<b definition="Galois group of a separable and normal finite field extension">Galois group</b>, 64, 76), (<span notation="">$\operatorname{Gal}(L/K)$</span>, 77, 102), (<b definition="">Galois group</b>, 330, 342), (<span notation="">$\operatorname{Gal}(L/K)$</span>, 343, 368)]}


In [None]:
#| export
def tokenize_html_data(
        html_locus: dict, # An output of `html_data_from_note`
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
        max_length: int, # Max length for each sequence of tokens
        ner_tag_from_html_tag: callable, # takes in a bs4.element.Tag and outputs the ner_tag (as a string or `None`)
        label2id: dict[str, int], # The keys ner_tag's of the form f"I-{output}" or f"B-{output}" where `output` is an output of `ner_tag_from_html_tag`.
        default_label: str = "O", # The default label for the NER tagging.
        ) -> tuple[list[list[str]], list[list[int]]]: # The first list consists of the tokens and the second list consists of the named entity recognition tags.
    """Actually tokenize the html data outputted by `html_data_from_note`.

    To account for the possibility that the raw text is long,
    this function uses the `tokenizer.batch_encode_plus` function
    to tokenize the text into sequences. 
    """
    tokenized = tokenizer.batch_encode_plus(
        [html_locus["Raw text"]], max_length=max_length, return_overflowing_tokens=True,
        return_offsets_mapping=True, truncation=True)

    default_id = label2id[default_label]        
    ner_ids = [[default_id for _ in seq_input_ids]
               for seq_input_ids in tokenized['input_ids']]
    for tag, start, end in html_locus["Tag data"]:
        ner_tag = ner_tag_from_html_tag(tag)
        if ner_tag is None:
            continue  # `ner_tag` is not of relevant data.
        tuppy = _start_end_seqs_indices_for_html_tag(tokenized, start, end - 1)
        (start_seq, start_index_in_seq), (end_seq, end_index_in_seq) = tuppy
        _set_ner_ids_for_tag(
            ner_ids, start_seq, start_index_in_seq, end_seq, end_index_in_seq,
            label2id, ner_tag)
    # return tokenized["input_ids"], ner_ids
    tokens = [tokenizer.convert_ids_to_tokens(tokens_for_seq)
              for tokens_for_seq in tokenized["input_ids"]]
    return tokens, ner_ids


def _start_end_seqs_indices_for_html_tag(
        tokenized: BatchEncoding,
        tag_start_ind: int,
        tag_end_ind: int
        ) -> tuple[tuple[int, int], tuple[int, int]]: # The first tuple is `(a, b)` where `tokenized['input_ids'][a][b]` is the token corresponding to the start of the HTML tag's (raw) text. The second tuple is `(c, d)` where `tokenized['input_ids'][c][d]` is the token corresponding to the end of the HTML tag's (raw) text.
    start_seq = _search_seq_ind_for_char(tokenized['offset_mapping'], tag_start_ind)
    # start_index_in_seq = tokenized.char_to_token(batch_or_char_index=start_seq, char_index=tag_start_ind)
    start_index_in_seq = _search_within_seq_for_char(tokenized['offset_mapping'][start_seq], tag_start_ind)
    end_seq = _search_seq_ind_for_char(tokenized['offset_mapping'], tag_end_ind)
    # end_index_in_seq = tokenized.char_to_token(batch_or_char_index=end_seq, char_index=tag_end_ind)
    end_index_in_seq = _search_within_seq_for_char(tokenized['offset_mapping'][end_seq], tag_end_ind)
    return (start_seq, start_index_in_seq), (end_seq, end_index_in_seq)


def _min_max_char_ind_for_seq(
        offset_for_seq: list[tuple[int,int]] # An item in tokenized['offset_mapping']
        ):
    min_char_ind, max_char_ind = 0, 0
    for inds in offset_for_seq:
        if inds != (0,0):
            min_char_ind = inds[0]
            break
    for inds in reversed(offset_for_seq):
        if inds != (0,0):
            max_char_ind = inds[1]
            break
    return min_char_ind, max_char_ind

def _char_is_in_seq(
        offset_for_seq: list[int], # An item in tokenized['offset_mapping']
        char: int # The index of a character in the original raw text
        ) -> bool:
    min_char_ind, max_char_ind = _min_max_char_ind_for_seq(offset_for_seq)
    return min_char_ind <= char and char < max_char_ind

def _search_seq_ind_for_char(
        offsets: list[tuple[int, int]], # tokenized['offset_mapping']
        char: int # The index of a character in the original raw text
        ) -> int:
    """
    Binary search the index of the sequence containing the token at the 
    location of the index `char` within the original (raw) text.

    Based on pseudocode from https://pseudoeditor.com/guides/binary-search
    """
    left = 0
    right = len(offsets) - 1
    while left <= right:
        mid = (left + right) // 2
        min_char_ind, max_char_ind = _min_max_char_ind_for_seq(offsets[mid])
        if min_char_ind <= char and char < max_char_ind:
            return mid
        elif max_char_ind <= char:
            left = mid + 1
        else:
            right = mid - 1
    return -1  # This should not be returned under normal use.


def _search_within_seq_for_char(
        seq_offset: list[tuple[int, int]],
        char: int
    ) -> int:
    """
    Binary search for the index within the sequence corresponding
    to the token at the location of the index `char` within the
    original (raw) text.

    Based on pseudocode from https://pseudoeditor.com/guides/binary-search
    """
    left = 0
    right = len(seq_offset) - 1
    while left <= right:
        mid = (left + right) // 2
        min_char_ind, max_char_ind = seq_offset[mid] 
        if min_char_ind <= char and char < max_char_ind:
            return mid
        elif max_char_ind <= char:
            left = mid + 1
        else:
            right = mid - 1
    return -1  # This should not be returned under normal use.


def _set_ner_ids_for_tag(
        ner_ids: list[list[int]],
        start_seq: int, 
        start_index_in_seq: int,
        end_seq: int,
        end_index_in_seq: int,
        label2id: dict[str, int],
        ner_tag: str
        ) -> None:
    """
    After the locations of the tokens corresponding to a HTML tag have been found, 
    mark within `ner_ids` the appropriate NER tags at the locations corresponding
    to the tokens' locations.
    """
    ner_ids[start_seq][start_index_in_seq] = label2id[f"B-{ner_tag}"]
    i_ner_id = label2id[f"I-{ner_tag}"]
    seq, ind = start_seq, start_index_in_seq + 1
    while seq < end_seq or ind <= end_index_in_seq:
        if len(ner_ids[seq]) <= ind:
            seq += 1
            ind = 0
        else:
            ner_ids[seq][ind] = i_ner_id 
            ind += 1
    


def def_or_notat_from_html_tag(
        tag: bs4.element.Tag
        ) -> Union[str, None]:
    """
    Can be passed as the `ner_tag_from_html_tag` argument in `tokenize_html_data`
    for the purposes of compiling a dataset for definition and notation
    identification.

    The strings f"I-{output}" and f"B-{output}" are valid ner_tags. To use for 
    """
    if "definition" in tag.attrs:
        return "definition"
    elif "notation" in tag.attrs:
        return "notation"
    return None  # If the HTML tag carries neither definition nor notation data.

In [None]:
#| hide
test_eq(_min_max_char_ind_for_seq([(0,0), (1,3), (3,4), (4,7), (7,15), (0,0)]), (1,15))

offsets = [[(0,0), (0,3), (4,5), (5,6), (6,7), (7,8), (8,9),],
           [(10,12), (13,14), (15,18), (18,24)],
           [(25,28), (29,35), (36,42), ]]
test_eq(_search_seq_ind_for_char(offsets, 0), 0)
test_eq(_search_seq_ind_for_char(offsets, 1), 0)
test_eq(_search_seq_ind_for_char(offsets, 5), 0)
test_eq(_search_seq_ind_for_char(offsets, 8), 0)
# I don't think that character index 9 is something that I need to worry about.
test_eq(_search_seq_ind_for_char(offsets, 10), 1)
test_eq(_search_seq_ind_for_char(offsets, 23), 1)
test_eq(_search_seq_ind_for_char(offsets, 25), 2)
test_eq(_search_seq_ind_for_char(offsets, 41), 2)

We continue with an example using the HTML data from the example for the `html_data_from_note` function.

In [None]:
html_data['Raw text']
html_data["Tag data"]

[(<b definition="Galois group of a separable and normal finite field extension">Galois group</b>,
  64,
  76),
 (<span notation="">$\operatorname{Gal}(L/K)$</span>, 77, 102),
 (<b definition="">Galois group</b>, 330, 342),
 (<span notation="">$\operatorname{Gal}(L/K)$</span>, 343, 368)]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
label2id = {
    "O": 0,
    "B-definition": 1,
    "I-definition": 2,
    "B-notation": 3,
    "I-notation": 4
}
tokens, ner_tag_ids = tokenize_html_data(html_data, tokenizer, 510, def_or_notat_from_html_tag, label2id)

For this example, `max_length` is set to 510 (tokens). The string ("Raw text") is not very long, so only one sequence should be present.

In [None]:
test_eq(len(tokens), 1)
test_eq(len(ner_tag_ids), 1)

Now let us see what has been tagged:

In [None]:
id2label = {value: key for key, value in label2id.items()}
id2label

{0: 'O',
 1: 'B-definition',
 2: 'I-definition',
 3: 'B-notation',
 4: 'I-notation'}

In [None]:
for token, ner_tag in zip(tokens[0], ner_tag_ids[0]):
    if ner_tag != 0:
        print(f"{token}\t\t{id2label[ner_tag]}")

gal		B-definition
##ois		I-definition
group		I-definition
$		B-notation
\		I-notation
operator		I-notation
##name		I-notation
{		I-notation
gal		I-notation
}		I-notation
(		I-notation
l		I-notation
/		I-notation
k		I-notation
)		I-notation
$		I-notation
gal		B-definition
##ois		I-definition
group		I-definition
$		B-notation
\		I-notation
operator		I-notation
##name		I-notation
{		I-notation
gal		I-notation
}		I-notation
(		I-notation
l		I-notation
/		I-notation
k		I-notation
)		I-notation
$		I-notation


Let us set `max_length` to be shorter to observe an example of a tokenization of a single text across multiple sequences (Of course, in practice, the max token length would be set to be longer, say around 512 or 1024.):

In [None]:
token_ids, ner_tag_ids = tokenize_html_data(html_data, tokenizer, 20, def_or_notat_from_html_tag, label2id)

In [None]:
print(len(token_ids))
print(len(ner_tag_ids))

7
7


In [None]:
ner_tag_ids

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2],
 [2, 2, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 3, 4, 4],
 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0]]

## Gathering data 

The following is sample code to then gather data for definition/notation identification

In [None]:
#| notest

# TODO: test

notes = [] # Replace with actual notes
vault = '' # Replace with actual vault

html_data = [html_data_from_note(note, vault) for note in notes]
max_length = 1022

tokenized_html_data = [tokenize_html_data(html_locus, tokenizer, max_length, def_or_notat_from_html_tag, label2id) for html_locus in html_data]
token_id_data = [token_ids for token_ids, _ in tokenized_html_data]
ner_tag_data = [ner_tag_ids for _, ner_tag_ids in tokenized_html_data]
token_seqs = [token_seq for token_seq in token_ids for token_ids in token_id_data]
ner_tag_seqs = [ner_tag_seq for ner_tag_seq in ner_tag_ids for ner_tag_ids in ner_tag_data]

In [None]:
#| notest
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
max_length = 1022
label2id = {
    "O": 0,
    "B-definition": 1,
    "I-definition": 2,
    "B-notation": 3,
    "I-notation": 4
} 
id2label = {value: key for key, value in label2id.items()}

In [None]:
#| notest
note_names, token_seqs, ner_tag_seqs = [], [], []
for html_locus, (token_ids, ner_tag_ids) in zip(html_data, tokenized_html_data):
    note_names.extend([html_locus["Note name"]] * len(token_ids))
    token_seqs.extend(token_ids)
    ner_tag_seqs.extend(ner_tag_ids)

In [None]:
#| notest
# ner_tags = ClassLabel(names=list(label2id))

# ds = Dataset.from_dict(
#         {"note_name": note_names,
#         "tokens": token_ids,
#         "ner_tags": ner_tag_ids},
#         features=Features(
#             {
#              "note_name": Value(dtype='string'),
#              "tokens": Sequence(Value(dtype='string')),
#              "ner_tags": Sequence(ner_tags)}
#         ))

# ds.save_to_disk(".")

# ds.load_from_disk(".")
    

## Use the trained model

See https://huggingface.co/docs/transformers/tasks/token_classification for training a token classification model.

In [None]:
# Helper functions

In [None]:
#| export
def _html_tag_data_from_part(
        main_text: str,
        part: list[dict[str]]) -> tuple[bs4.element.Tag, int, int]:
    """
    Helper function to `_html_tags_from_token_preds`
    """
    start_token = part[0]
    end_token = part[-1]
    start_char = start_token['start']
    end_char = end_token['end']
    # the `'entity'` is either 'I-definition', 'B-definition', 'I-notation',
    # or 'B-notation'
    entity_type = start_token['entity'][2:]
    html_text = main_text[start_char:end_char]
    if entity_type == 'definition':
        tag = bs4.BeautifulSoup(
            f'<b definition="">{html_text}</b>', "html.parser")
    else:
        tag = bs4.BeautifulSoup(
            f'<span style="border-width:1px;border-style:solid;'
            f'padding:3px" notation="">{html_text}</span>',
            "html.parser")
    return (tag, start_char, end_char)


In [None]:
#| hide

main_text = "Let $I \subset A$ be an ideal. Define its radical by $\sqrt{I}$"

sample_output_1 = _html_tag_data_from_part(
    main_text, [{
        'entity': 'B-definition',
        'score': 0.37319255,
        'index': 25,  # This is moot for the purposes of this test.
        'word': 'radical',
        'start': 42,
        'end': 49
    }])
test_eq(str(sample_output_1[0]), '<b definition="">radical</b>')

sample_output_2 = _html_tag_data_from_part(
    main_text, [{
        'entity': 'B-notation',
        'score': 0.67021805,
        'index': 27,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 53,
        'end': 54},
    {
        'entity': 'I-notation',
        'score': 0.9748327,
        'index': 28,  # This is moot for the purposes of this test.
        'word': '\\',
        'start': 54,
        'end': 55},
    {
        'entity': 'I-notation',
        'score': 0.9754836,
        'index': 29,  # This is moot for the purposes of this test.
        'word': 'sq',
        'start': 55,
        'end': 57},
    {
        'entity': 'I-notation',
        'score': 0.9750675,
        'index': 30,  # This is moot for the purposes of this test.
        'word': '##rt',
        'start': 57,
        'end': 59},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 31,  # This is moot for the purposes of this test.
        'word': '{',
        'start': 59,
        'end': 60},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 32,  # This is moot for the purposes of this test.
        'word': 'i',
        'start': 60,
        'end': 61},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 33,  # This is moot for the purposes of this test.
        'word': '}',
        'start': 61,
        'end': 62},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 34,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 62,
        'end': 63},
    ])
test_eq(str(sample_output_2[0]), '<span notation="" style="border-width:1px;border-style:solid;padding:3px">$\sqrt{I}$</span>')
# main_text.find('radical')

In [None]:
#| export
def _current_token_continues_the_previous_token(
        current_token: dict, previous_token: dict, note: VaultNote
        ) -> bool:
    """
    Helper function to `_divide_token_preds_into_parts`.
    """
    if current_token['entity'].startswith('I-'):
        if current_token['entity'][2:] == previous_token['entity'][2:]:
            return True
        else:
            warnings.warn(rf"""
                In the note {note.name} at {note.path()},
                The token '{previous_token['word']}' is marked as '{previous_token['entity']}'
                and the subsequent token '{current_token['word']}' is marked as '{current_token['entity']}',
                which is unusual because the two consecutive tokens seem to be of different
                entities, and yet the latter token does not start with a 'B-'.

                The latter token will be treated like the beginning of a new entity."""
                    )
            return False
    else:
        return False
        

In [None]:
#| hide
previous_token_1 = {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 33,  # This is moot for the purposes of this test.
        'word': '}',
        'start': 61,
        'end': 62}
current_token_1 = {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 34,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 62,
        'end': 63}
assert _current_token_continues_the_previous_token(current_token_1, previous_token_1, note=VaultNote('', rel_path='hi'))

# Something like below should hopefully not happen, but it should still give a warning message
previous_token_2 = {
        'entity': 'I-definition',
        'score': 0.97785944,
        'index': 33,  # This is moot for the purposes of this test.
        'word': '}',
        'start': 61,
        'end': 62}
current_token_2 = {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 34,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 62,
        'end': 63}

with warnings.catch_warnings(record=True) as w:
    sample_output = _current_token_continues_the_previous_token(current_token_2, previous_token_2, note=VaultNote('', rel_path='hi'))
    assert w
    assert not sample_output

previous_token_3 = {
        'entity': 'I-definition',
        'score': 0.97785944,
        'index': 33,  # This is moot for the purposes of this test.
        'word': '##tion',
        'start': 58,
        'end': 62}
current_token_3 = {
        'entity': 'B-notation',
        'score': 0.97785944,
        'index': 34,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 62,
        'end': 63}

assert not _current_token_continues_the_previous_token(current_token_3, previous_token_3, note=VaultNote('', rel_path='hi'))



In [None]:
#| export
def _divide_token_preds_into_parts(
        token_preds: list[dict[str]],
        note: VaultNote,
        excessive_space_threshold: int
        ) -> list[list[dict[str]]]:
    """
    Divide `token_preds` into parts so that each part
    represents a single definition/notation marking.

    Helper function to `_html_tags_from_token_preds`.
    """
    token_preds_parts = []
    for current_token in token_preds:
        if not token_preds_parts:
            token_preds_parts.append([current_token])
            continue
        prev_token = token_preds_parts[-1][-1]
        if _current_token_continues_the_previous_token(
                current_token, prev_token, note):
            prev_token_end = prev_token['end']
            cur_token_start = current_token['start']
            if prev_token_end + excessive_space_threshold >= cur_token_start:
                Warning(rf"""
                    In the note {note.name} at {note.path()},
                    There seems to be excessive space between the token
                    {prev_token['word']} and {current_token['word']}, which
                    seem to be part of the same entity"""
                        )
            token_preds_parts[-1].append(current_token)
        else:
            token_preds_parts.append([current_token])
    return token_preds_parts


In [None]:
#| hide

main_text = "Let $I \subset A$ be an ideal. Define its radical by $\sqrt{I}$"

preds = [
    {
        'entity': 'B-definition',
        'score': 0.37319255,
        'index': 25,  # This is moot for the purposes of this test.
        'word': 'radical',
        'start': 42,
        'end': 49
    },
    {
        'entity': 'B-notation',
        'score': 0.67021805,
        'index': 27,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 53,
        'end': 54},
    {
        'entity': 'I-notation',
        'score': 0.9748327,
        'index': 28,  # This is moot for the purposes of this test.
        'word': '\\',
        'start': 54,
        'end': 55},
    {
        'entity': 'I-notation',
        'score': 0.9754836,
        'index': 29,  # This is moot for the purposes of this test.
        'word': 'sq',
        'start': 55,
        'end': 57},
    {
        'entity': 'I-notation',
        'score': 0.9750675,
        'index': 30,  # This is moot for the purposes of this test.
        'word': '##rt',
        'start': 57,
        'end': 59},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 31,  # This is moot for the purposes of this test.
        'word': '{',
        'start': 59,
        'end': 60},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 32,  # This is moot for the purposes of this test.
        'word': 'i',
        'start': 60,
        'end': 61},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 33,  # This is moot for the purposes of this test.
        'word': '}',
        'start': 61,
        'end': 62},
    {
        'entity': 'I-notation',
        'score': 0.97785944,
        'index': 34,  # This is moot for the purposes of this test.
        'word': '$',
        'start': 62,
        'end': 63},
    ]

output = _divide_token_preds_into_parts(
    preds, VaultNote('', rel_path='hi'), excessive_space_threshold=2
)

# Test that the list finds two parts, one for the definition, and the other for the notation.
test_eq(len(output), 2)
test_eq(len(output[0]), 1)
test_eq(len(output[1]), len(preds) - 1)


In [None]:
#| export
def _html_tags_from_token_preds(
        main_text: str,
        token_preds: list[dict[str]],
        note: VaultNote,
        excessive_space_threshold: int
        ) -> list[tuple[bs4.element.Tag, int, int]]:  # Tag element, start, end, where main_text[start:end] needs to be replaced by the tag element.
    """
    Return HTML tags for definition and notation classification.

    Helper function to `auto_mark_def_and_notats`.
    """
    parts = _divide_token_preds_into_parts(
        token_preds, note, excessive_space_threshold)
    return [_html_tag_data_from_part(main_text, part) for part in parts]


In [None]:
#| export
def _collate_html_tags(
        tag_data_1: list[tuple[bs4.element.Tag, int, int]],
        tag_data_2: list[tuple[bs4.element.Tag, int, int]],
    ) -> list[tuple[bs4.element.Tag], int, int]:
    """
    Collates the lists of HTML tags and the indices within a certain text
    (which is not-needed for this function and hence not included)
    that the HTML tags need to replace.

    If there are entries in `tag_data_1` and `tag_data_2` with overlapping
    ranges, then the entry from `tag_data_1` is prioritized and the entry
    from `tag_data_2` is discarded.

    Helper function to `auto_mark_def_and_notats`
    """
    collated_list = []
    i, j = 0, 0
    while i < len(tag_data_1) and j < len(tag_data_2):
        current_1 = tag_data_1[i]
        current_2 = tag_data_2[j]
        if _ranges_overlap(current_1, current_2): # Ignore current_2
            j += 1
            continue
        if current_1[1] > current_2[1]:
            collated_list.append(current_2)
            j += 1
        else:
            collated_list.append(current_1)
            i += 1
    while i < len(tag_data_1):
        collated_list.append(tag_data_1[i])
        i += 1
    while j < len(tag_data_2):
        collated_list.append(tag_data_2[j])
        j += 1
    return collated_list



def _ranges_overlap(
        current_1: tuple[bs4.element.Tag, int, int],
        current_2: tuple[bs4.element.Tag, int, int]
        ) -> bool:
    """
    Based on https://stackoverflow.com/a/64745177

    Helper function to `_collate_html_tags`.
    """
    return max(current_1[1], current_2[1]) < min(current_1[2], current_2[2])


In [None]:
#| hide

# In actuality, there should be bs4.element.Tag objects in place of ''.
assert _ranges_overlap(('', 3, 8), ('', 6, 12))
assert _ranges_overlap(('', 3, 8), ('', 3, 4))
assert not _ranges_overlap(('', 3, 8), ('', 8, 9))
assert _ranges_overlap(('', 6, 12), ('', 3, 8))

tag_data_1 = [
    ('', 0, 1),
    ('', 9, 12),
    ('', 20, 21)
]

tag_data_2 = [
    ('', 2, 4),
    ('', 6, 7),
    ('', 8, 10), # This should be discarded
    ('', 10, 13), # This should be discarded
    ('', 17, 20),
    ('', 21, 24)
]
output = _collate_html_tags(tag_data_1, tag_data_2)
test_eq(output, [('', 0, 1), ('', 2, 4), ('', 6, 7), ('', 9, 12), ('', 17, 20), ('', 20, 21), ('', 21, 24)])

In [None]:
#| export
def _add_nice_boxing_attrs_to_notation_tags(
        html_tag_data: list[tuple[bs4.element.Tag, int, int]]
        ) -> list[tuple[bs4.element.Tag, int, int]]:
    """
    Add HTML tag attributes to draw boxes around notation data

    Helper function to `auto_mark_def_and_notats`.
    """
    listy = []
    for tag, start, end in html_tag_data:
        if 'notation' in tag.attrs and 'style' not in tag.attrs:
            tag.attrs['style'] = "border-width:1px;border-style:solid;padding:3px"
        listy.append((tag, start, end)) 
    return listy



In [None]:
#| hide
soup = bs4.BeautifulSoup('', "html.parser")
tag = soup.new_tag("span", notation="")
tag.string = 'hi'
tag_data = [
    (tag, 0, 2),
]
output = _add_nice_boxing_attrs_to_notation_tags(tag_data)
assert "style" in output[0][0].attrs

In [None]:
#| export
def auto_mark_def_and_notats(
        note: VaultNote,
        pipeline,
        # remove_existing_def_and_notat_markings: bool = False,  # If `True`, remove definition and notation markings (both via surrounding by double asterisks `**` as per the legacy method and via HTML tags)
        excessive_space_threshold: int = 2,
        add_boxing_attr_to_existing_notat_markings: bool = True # If `True`, then nice attributes are added to the existing notation HTML tags, if not already present.
    ) -> None:
    """
    Predict and mark where definitions and notation occur in a note.

    Assumes that the note is a standard information note that does not
    have a lot of "user modifications", such as footnotes, links,
    and HTML tags. If
    there are many modifications, then these might be deleted.

    Existing markings for definition and notation data (i.e. by
    surrounding with double asterisks or by HTML tags) are preserved
    (and turned into HTML tags), unless the markings overlap with 
    predictions, in which case the original is preserved (and still
    turned into an HTML tag if possible)

    
    **Raises**
    Warning messages (`UserWarning`) are printed in the following situations:

    - There are two consecutive tokens within the `pipeline`'s predictions
      of different entity types (e.g. one is predicted to belong within a
      definition and the other within a notation), but the latter token's
      predicted `'entity'` more specifically begins with `'I-'` (i.e. is
      `'I-definition'` or `'I-notation'`) as opposed to `'B-'`.
        - `note`'s name, and path are included in the warning message in
          this case.
    - There are two consecutive tokens within the `pipeline`'s predictions
      which the pipeline predicts to belong to the same entity, and yet
      there is excessive space (specified by `excessive_space_threshold`)
      between the end of the first token and the start of the second.

    """
    mf = MarkdownFile.from_vault_note(note)
    tuppy = mf.metadata_lines()
    if tuppy is not None:
        first_non_metadata_line = tuppy[1] + 1
    else:
        first_non_metadata_line = 0 
    see_also_line = mf.get_line_number_of_heading('See Also')
     
    main_text = mf.text_of_lines(first_non_metadata_line, see_also_line)
    main_text = add_space_to_lt_symbols_without_space(main_text)
    main_text = convert_double_asterisks_to_html_tags(main_text)
    main_text, existing_html_tag_data = remove_html_tags_in_text(main_text)
    if add_boxing_attr_to_existing_notat_markings:
        existing_html_tag_data = _add_nice_boxing_attrs_to_notation_tags(
            existing_html_tag_data)
    html_tags_to_add = _html_tags_from_token_preds(
        main_text, pipeline(main_text), note, excessive_space_threshold)

    html_tags_to_add_back = _collate_html_tags(
        existing_html_tag_data, html_tags_to_add)
    main_text = add_HTML_tag_data_to_raw_text(main_text, html_tags_to_add_back)
    mf.remove_lines(first_non_metadata_line, see_also_line)
    mf.insert_line(first_non_metadata_line,
                   {'type': MarkdownLineEnum.DEFAULT, 'line': main_text})
    # mf.insert_line(first_non_metadata_line,
    #                {'type': MarkdownLineEnum.HEADING, 'line': '# Topic[^1]'})
    mf.add_tags('_auto/def_and_notat_identified')
    mf.write(note)

