Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: adding external tokenizer to python text_chunker #1388

Merged
merged 4 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 152 additions & 69 deletions python/semantic_kernel/text/text_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
For markdown, split looking at punctuation first, and so on.
"""
import os
from typing import List
import re
from typing import Callable, List, Tuple

NEWLINE = os.linesep

Expand Down Expand Up @@ -36,46 +37,95 @@
]


def split_plaintext_lines(text: str, max_token_per_line: int) -> List[str]:
def _token_counter(text: str) -> int:
"""
Count the number of tokens in a string.

TODO: chunking methods should be configurable to allow for different
tokenization strategies depending on the model to be called.
For now, we use an extremely rough estimate.
"""
return len(text) // 4


def split_plaintext_lines(
text: str, max_token_per_line: int, token_counter: Callable = _token_counter
) -> List[str]:
"""
Split plain text into lines.
it will split on new lines first, and then on punctuation.
"""
return _split_text_lines(text, max_token_per_line, True)
return _split_text_lines(
text=text,
max_token_per_line=max_token_per_line,
trim=True,
token_counter=token_counter,
)


def split_markdown_lines(text: str, max_token_per_line: int) -> List[str]:
def split_markdown_lines(
text: str, max_token_per_line: int, token_counter: Callable = _token_counter
) -> List[str]:
"""
Split markdown into lines.
It will split on punctuation first, and then on space and new lines.
"""
return _split_markdown_lines(text, max_token_per_line, True)
return _split_markdown_lines(
text=text,
max_token_per_line=max_token_per_line,
trim=True,
token_counter=token_counter,
)


def split_plaintext_paragraph(text: List[str], max_tokens: int) -> List[str]:
def split_plaintext_paragraph(
text: List[str], max_tokens: int, token_counter: Callable = _token_counter
) -> List[str]:
"""
Split plain text into paragraphs.
"""

split_lines = []
for line in text:
split_lines.extend(_split_text_lines(line, max_tokens, True))
split_lines.extend(
_split_text_lines(
text=line,
max_token_per_line=max_tokens,
trim=True,
token_counter=token_counter,
)
)

return _split_text_paragraph(split_lines, max_tokens)
return _split_text_paragraph(
text=split_lines, max_tokens=max_tokens, token_counter=token_counter
)


def split_markdown_paragraph(text: List[str], max_tokens: int) -> List[str]:
def split_markdown_paragraph(
text: List[str], max_tokens: int, token_counter: Callable = _token_counter
) -> List[str]:
"""
Split markdown into paragraphs.
"""
split_lines = []
for line in text:
split_lines.extend(_split_markdown_lines(line, max_tokens, False))
split_lines.extend(
_split_markdown_lines(
text=line,
max_token_per_line=max_tokens,
trim=False,
token_counter=token_counter,
)
)

return _split_text_paragraph(split_lines, max_tokens)
return _split_text_paragraph(
text=split_lines, max_tokens=max_tokens, token_counter=token_counter
)


def _split_text_paragraph(text: List[str], max_tokens: int) -> List[str]:
def _split_text_paragraph(
text: List[str], max_tokens: int, token_counter: Callable
) -> List[str]:
"""
Split text into paragraphs.
"""
Expand All @@ -86,8 +136,8 @@ def _split_text_paragraph(text: List[str], max_tokens: int) -> List[str]:
current_paragraph = []

for line in text:
num_tokens_line = _token_count(line)
num_tokens_paragraph = _token_count("".join(current_paragraph))
num_tokens_line = token_counter(line)
num_tokens_paragraph = token_counter("".join(current_paragraph))

if (
num_tokens_paragraph + num_tokens_line + 1 >= max_tokens
Expand All @@ -109,7 +159,7 @@ def _split_text_paragraph(text: List[str], max_tokens: int) -> List[str]:
last_para = paragraphs[-1]
sec_last_para = paragraphs[-2]

if _token_count(last_para) < max_tokens / 4:
if token_counter(last_para) < max_tokens / 4:
last_para_tokens = last_para.split(" ")
sec_last_para_tokens = sec_last_para.split(" ")
last_para_token_count = len(last_para_tokens)
Expand All @@ -125,27 +175,44 @@ def _split_text_paragraph(text: List[str], max_tokens: int) -> List[str]:
return paragraphs


def _split_markdown_lines(text: str, max_token_per_line: int, trim: bool) -> List[str]:
def _split_markdown_lines(
text: str, max_token_per_line: int, trim: bool, token_counter: Callable
) -> List[str]:
"""
Split markdown into lines.
"""

lines = _split_str_lines(text, max_token_per_line, MD_SPLIT_OPTIONS, trim)
return lines
return _split_str_lines(
text=text,
max_tokens=max_token_per_line,
separators=MD_SPLIT_OPTIONS,
trim=trim,
token_counter=token_counter,
)


def _split_text_lines(text: str, max_token_per_line: int, trim: bool) -> List[str]:
def _split_text_lines(
text: str, max_token_per_line: int, trim: bool, token_counter: Callable
) -> List[str]:
"""
Split text into lines.
"""

lines = _split_str_lines(text, max_token_per_line, TEXT_SPLIT_OPTIONS, trim)

return lines
return _split_str_lines(
text=text,
max_tokens=max_token_per_line,
separators=TEXT_SPLIT_OPTIONS,
trim=trim,
token_counter=token_counter,
)


def _split_str_lines(
text: str, max_tokens: int, separators: List[List[str]], trim: bool
text: str,
max_tokens: int,
separators: List[List[str]],
trim: bool,
token_counter: Callable,
) -> List[str]:
if not text:
return []
Expand All @@ -155,96 +222,112 @@ def _split_str_lines(
was_split = False
for split_option in separators:
if not lines:
lines, was_split = _split_str(text, max_tokens, split_option, trim)
lines, was_split = _split_str(
text=text,
max_tokens=max_tokens,
separators=split_option,
trim=trim,
token_counter=token_counter,
)
else:
lines, was_split = _split_list(lines, max_tokens, split_option, trim)
if not was_split:
lines, was_split = _split_list(
text=lines,
max_tokens=max_tokens,
separators=split_option,
trim=trim,
token_counter=token_counter,
)
if was_split:
break

return lines


def _split_str(
text: str, max_tokens: int, separators: List[str], trim: bool
) -> List[str]:
text: str,
max_tokens: int,
separators: List[str],
trim: bool,
token_counter: Callable,
) -> Tuple[List[str], bool]:
"""
Split text into lines.
"""
input_was_split = False
if not text:
return []
return [], input_was_split

input_was_split = False
text = text.strip() if trim else text
if trim:
text = text.strip()

text_as_is = [text]

if _token_count(text) <= max_tokens:
if token_counter(text) <= max_tokens:
return text_as_is, input_was_split

input_was_split = True

half = int(len(text) / 2)
half = len(text) // 2

cutpoint = -1

if not separators:
cutpoint = half

elif set(separators) & set(text) and len(text) > 2:
for index, text_char in enumerate(text):
if text_char not in separators:
continue

if abs(half - index) < abs(half - cutpoint):
cutpoint = index + 1

regex_separators = re.compile("|".join(re.escape(s) for s in separators))
min_dist = half
for match in re.finditer(regex_separators, text):
end = match.end()
dist = abs(half - end)
if dist < min_dist:
min_dist = dist
cutpoint = end
elif end > half:
# distance is increasing, so we can stop searching
break
else:
return text_as_is, input_was_split

if 0 < cutpoint < len(text):
lines = []
first_split, has_split1 = _split_str(
text[:cutpoint], max_tokens, separators, trim
)
second_split, has_split2 = _split_str(
text[cutpoint:], max_tokens, separators, trim
)

lines.extend(first_split)
lines.extend(second_split)

input_was_split = has_split1 or has_split2
for text_part in [text[:cutpoint], text[cutpoint:]]:
split, has_split = _split_str(
text=text_part,
max_tokens=max_tokens,
separators=separators,
trim=trim,
token_counter=token_counter,
)
lines.extend(split)
input_was_split = input_was_split or has_split
else:
return text_as_is, input_was_split

return lines, input_was_split


def _split_list(
text: List[str], max_tokens: int, separators: List[str], trim: bool
) -> List[str]:
text: List[str],
max_tokens: int,
separators: List[str],
trim: bool,
token_counter: Callable,
) -> Tuple[List[str], bool]:
"""
Split list of string into lines.
"""
if not text:
return []
return [], False

lines = []
input_was_split = False
for line in text:
split_str, was_split = _split_str(line, max_tokens, separators, trim)
split_str, was_split = _split_str(
text=line,
max_tokens=max_tokens,
separators=separators,
trim=trim,
token_counter=token_counter,
)
lines.extend(split_str)
input_was_split = input_was_split or was_split

return lines, input_was_split


def _token_count(text: str) -> int:
"""
Count the number of tokens in a string.

TODO: chunking methods should be configurable to allow for different
tokenization strategies depending on the model to be called.
For now, we use an extremely rough estimate.
"""
return int(len(text) / 4)
38 changes: 38 additions & 0 deletions python/tests/unit/text/test_text_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,44 @@
NEWLINE = os.linesep


def test_split_plain_text_lines_with_token_count():
"""Test split_plain_text_lines() with external token counter"""

text = "This is a test of the emergency broadcast system. This is only a test."

max_token_per_line = 8

expected = [
"This is a test of the",
"emergency",
"broadcast system.",
"This is only a test.",
]
split = split_plaintext_lines(
text=text,
max_token_per_line=max_token_per_line,
token_counter=lambda x: len(x) // 3,
)
assert expected == split


def test_split_plain_text_lines_half():
"""Test split_plain_text_lines() with external token counter"""

text_1 = "This is a test of. cutting. at the half point."
text_2 = "This is a test of . cutting. at the half point."

max_token_per_line = 10

expected_1 = ["This is a test of. cutting.", "at the half point."]
split_1 = split_plaintext_lines(text=text_1, max_token_per_line=max_token_per_line)
assert expected_1 == split_1

expected_2 = ["This is a test of .", "cutting. at the half point."]
split_2 = split_plaintext_lines(text=text_2, max_token_per_line=max_token_per_line)
assert expected_2 == split_2


def test_split_plain_text_lines():
"""Test split_plain_text_lines()"""

Expand Down