In [3]:
!pip install unidecode tiktoken transformers einops auto-gptq sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


## Time to install packages 4 mins 0 seconds
## Time started: 23:59:00

In [1]:
from bs4 import BeautifulSoup, NavigableString
from unidecode import unidecode
import re
import os
from datetime import datetime
import unicodedata
import itertools
import tiktoken
import ast
import pprint
import json

In [2]:
from transformers import AutoTokenizer, pipeline, logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

model_name_or_path = "TheBloke/StableBeluga2-70B-GPTQ"
model_basename = "gptq_model-4bit--1g"
model_name = 'StableBeluga2'
file_code = '1807.00939'

use_triton = False

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
        model_basename=model_basename,
        inject_fused_attention=False, # Required for Llama 2 70B models at this time.
        use_safetensors=True,
        trust_remote_code=False,
        device="cuda:0",
        use_triton=use_triton,
        quantize_config=None)

Downloading (…)lve/main/config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

Downloading (…)quantize_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

Downloading (…)4bit--1g.safetensors:   0%|          | 0.00/35.3G [00:00<?, ?B/s]

skip module injection for FusedLlamaMLPForQuantizedModel not support integrate without triton yet.


In [3]:
def num_tokens_from_messages(prompt_template, model="gpt-3.5-turbo"):
    return len(tokenizer(prompt_template, return_tensors='pt').input_ids.cuda().tolist()[0])

In [4]:
def split_into_messages(text):
    # This is a placeholder for your actual implementation
    return text.split("\n")

def split_into_chunks(text, max_tokens=2000):
    messages = split_into_messages(text)
    chunks = []
    current_chunk = []
    current_tokens = 0
    for message in messages:
        message_tokens = num_tokens_from_messages(message, model="gpt-3.5-turbo")
        if current_tokens + message_tokens > max_tokens:
            # If adding this message would exceed the max tokens, start a new chunk
            chunks.append('\n'.join(current_chunk))
            current_chunk = [message]
            current_tokens = message_tokens
        else:
            # Otherwise, add the message to the current chunk
            current_chunk.append(message)
            current_tokens += message_tokens
    # Don't forget the last chunk!
    if current_chunk:
        chunks.append('\n'.join(current_chunk))
    return chunks

In [5]:
def modify_string(input_str):
    if "</s>" in input_str:
        return input_str
    else:
        last_comma_index = input_str.rfind(',')
        if last_comma_index == -1:
            return input_str  # No comma found, return the string as is
        else:
            return input_str[:last_comma_index] + '}' + input_str[last_comma_index+1:]
        
def get_json_of_string(incorrect, pattern=r'\{[^\}]*\}'):
    match = re.search(r'{(.*)}', incorrect, re.DOTALL)
    if match:
        return "{" + match.group(1).replace('{', '[').replace('}', ']') + "}"
    else:
        return "{}"


def string_to_dict(my_string):
    # Load the JSON string into a list of tuples
    tuples_list = json.JSONDecoder(object_pairs_hook=list).decode(my_string)

    # Create a new dictionary to hold the final result
    final_dict = {}

    # Iterate over the list of tuples
    for key, value in tuples_list:
        # If the key is already in the final dictionary, append the value
        # to the list of values for that key
        if key in final_dict:
            # Ensure the value is in a list form
            if not isinstance(final_dict[key], list):
                final_dict[key] = [final_dict[key]]
            final_dict[key].append(value)
        else:
            # If the key is not in the final dictionary, add it with the value
            final_dict[key] = value

    return final_dict

In [6]:
def merge_dictionaries(dict1, dict2):
    union_dict = dict1.copy()

    for key, value in dict2.items():
        if key in union_dict:
            if isinstance(union_dict[key], list):
                if value not in union_dict[key]:
                    union_dict[key].append(value)
            else:
                if union_dict[key] != value:
                    union_dict[key] = [union_dict[key], value]
        else:
            union_dict[key] = value

    return union_dict

In [7]:
def get_text_from_tags(element):
    if isinstance(element, NavigableString):
        return element
    if element.name == 'mi':
        return str(element)
    return ''.join(get_text_from_tags(child) for child in element.children)

def parse_html(file_path, clean=True):
    with open(file_path, 'r', encoding='utf-8') as html_file:
        soup = BeautifulSoup(html_file, 'html.parser')

    texts = get_text_from_tags(soup)
    if clean:
        matches = re.findall(r'<mi(.*?)</mi>', texts)
        for match in matches:
            original_string = f'<mi{match}</mi>'
            replaced_string = re.sub(r'<.*?>(.*?)</.*?>', r'<|\1|>', original_string)
            texts = texts.replace(original_string, replaced_string)

    return texts

# call the function with your HTML file path
page = parse_html(f'{file_code}.html')

In [8]:
def get_prompt(prompt):
    return f'''### SYSTEM:\n{prompt[0]['content']}\n\n
### USER:\n{prompt[1]['content']}\n\n
### ASSISTANT:\n{prompt[2]['content']}\n\n
### USER:\n{prompt[3]['content']}\n\n
### ASSISTANT:\n'''

In [9]:
def contains_pattern(input_string):
    pattern = r"<\|[^<\|>]*\|>"
    if re.search(pattern, input_string):
        return True
    else:
        return False

In [10]:
def print_output(output):
    pos = output.index('Do not include the angle brackets in the dictionary')
    end_of_string = output[pos:]
    print(end_of_string)

In [11]:
text = page
chunks = split_into_chunks(text, max_tokens=512)
i = 0

while not contains_pattern(chunks[i]):
    i += 1
print(i)

question = chunks[i]

actual_total_tokens = 0
completion_tokens = 0
prompt_tokens = 0

prompt = [
        {'role': 'system',
         'content': 'You are a helpful research assistant tasked with converting long paragraphs into a JSON '
                    'dictionary. The goal is to identify and classify each individual mathematical symbol, variable,'
                    ' and identifier in the text marked between "<||>". The dictionary should store the identifiers as '
                    'keys and their corresponding definitions as values in an array format. '},
        {'role': 'system', 'name': 'example_user', 'content': '''A relational model is a triple <|M|>′=(<|X|>,<|R|>,
        <|v|>), where <|X|> is a set of states, <|R|><|⊆|><|X|><|×|><|X|> is a binary relation on <|X|>, 
        and <|v|>:<|𝖯𝗋𝗈𝗉|>→2<|X|> is a valuation. Given a relational model <|M|>′, the satisfaction relation between 
        points <|x<|∈<|X<| and formulas <|φ<|∈<|ℒ<|<|𝖪𝖠<| is defined inductively by <|M|>′,<|x|>⊨<|𝖪|><|φ|>⇔ for all 
        <|y|>∈<|X|>,<|x|><|R|><|y|> implies <|M|>′,<|y|>⊨<|φ|><|M|>′,<|x|>⊨<|𝖠|><|φ|><||>⇔ for all <|y|>∈<|X|>,
        <|M|>′,<|y|>⊨<|φ|>'''},
        {'role': 'system', 'name': 'example_assistant', 'content': '''identifiers = {
            "M": ["Model", "Expertise Model"],
            "M'": "Relational model",
            "X": "Set of states",
            "R": "Binary relation on X",
            "v": "Valuation",
            "𝖯𝗋𝗈𝗉": "Set of propositions",
            "M'": "Relational model",
            "x": "Point in X",
            "φ": "Formula in 𝖪𝖠",
            "ℒ_{𝖪𝖠}": "Set of formulas",
            "𝖪": "Modal operator K",
            "𝖠": "Modal operator A",
            "y": "Point in X",
            "⊨": "Satisfaction relation",
            "⇔": "If and only if operator",
            "∈": "Element of a set",
            "⊆": "Subset of a set",
            "×": "Cartesian product operator",
            "→": "Function or implication operator",
            "for all": "Universal quantifier"
            }'''},
        {'role': 'user', 'content': f'Generate a JSON dictionary for the following text\n```txt\n{question}```. '
                                    'Only consider the mathematical identifiers inside "<||>" for the dictionary. '
                                    'Do not consider any other identifier other than those marked. Consider all the '
                                    'identifiers individually. Do not skip any identifier, mention all the identifiers '
                                    'inside "<||>" in your dictionary. Do not include the angle brackets in the '
                                    'dictionary.'}
    ]


open_prompt = get_prompt(prompt)

prompt_size = num_tokens_from_messages(open_prompt)
print(f"{prompt_size} prompt tokens counted.")

17
1105 prompt tokens counted.


In [12]:
print(f"Using {model_name_or_path}")
start_time = datetime.now()

input_ids = tokenizer(open_prompt, return_tensors='pt').input_ids.cuda()
output = model.generate(inputs=input_ids, temperature=0.5, max_new_tokens=512)
output_string = tokenizer.decode(output[0])
print_output(output_string)

total_time_taken = datetime.now() - start_time
print(f"Time taken: {total_time_taken}")

Using TheBloke/StableBeluga2-70B-GPTQ
Do not include the angle brackets in the dictionary.


### ASSISTANT:
 identifiers = {
            "LSTM RNN": "Long Short-Term Memory Recurrent Neural Network",
            "ANOMALOUS": "Anomaly Detection Algorithm",
            "NCC": "Normalized Cross Correlation",
            "Corr": "Cross-Correlation",
            "x": "First signal",
            "y": "Second signal",
            "n": "Time step",
            "N": "Total number of time steps",
            "x[n]": "Value of x at time step n",
            "y[n]": "Value of y at time step n",
            "sum": "Summation operator",
            "1": "Constant 1",
            "-1": "Constant -1",
            "1": "Constant 1",
            "N": "Total number of time steps",
            "x[n]": "Value of x at time step n",
            "y[n]": "Value of y at time step n",
            "sum": "Summation operator",
            "1": "Constant 1",
            "-1": "Constant -1",
            "1": "Consta

In [13]:
actual_total_tokens += num_tokens_from_messages(output_string)
completion_tokens += num_tokens_from_messages(output_string) - num_tokens_from_messages(open_prompt)
prompt_tokens += num_tokens_from_messages(open_prompt)
print(actual_total_tokens, completion_tokens, prompt_tokens)

1614 509 1105


In [14]:
# Safely convert the dictionary string to a dictionary using json.loads()
try:
    ind = output_string.index('Do not include the angle brackets in the dictionary.')
    correct_output_string = modify_string(output_string[ind:])
    dic_output_string = get_json_of_string(correct_output_string)
    dictionary = [string_to_dict(dic_output_string)]
except Exception as e:
    dictionary = [{}]

In [15]:
print(dictionary)

[{'LSTM RNN': 'Long Short-Term Memory Recurrent Neural Network', 'ANOMALOUS': 'Anomaly Detection Algorithm', 'NCC': 'Normalized Cross Correlation', 'Corr': 'Cross-Correlation', 'x': 'First signal', 'y': 'Second signal', 'n': 'Time step', 'N': ['Total number of time steps', 'Total number of time steps', 'Total number of time steps', 'Total number of time steps', 'Total number of time steps'], 'x[n]': ['Value of x at time step n', 'Value of x at time step n', 'Value of x at time step n', 'Value of x at time step n', 'Value of x at time step n'], 'y[n]': ['Value of y at time step n', 'Value of y at time step n', 'Value of y at time step n', 'Value of y at time step n', 'Value of y at time step n'], 'sum': ['Summation operator', 'Summation operator', 'Summation operator', 'Summation operator'], '1': ['Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1'], '-1': ['Constant -1', 'Constant -1', 'Constant -1', 'Constant -1']}]


In [16]:
def get_prompt_loop(prompt):
    return f'''### SYSTEM:\n{prompt[0]['content']}\n\n
### USER:\n{prompt[1]['content']}\n\n
### ASSISTANT:\n{prompt[2]['content']}\n\n
### SYSTEM:\n{prompt[3]['content']}\n\n
### USER:\n{prompt[4]['content']}\n\n
### ASSISTANT:\n'''

In [17]:
start_time = datetime.now()
number_of_dictionaries = 0
for chunk in chunks:
    print(f"Iteration {i} of {len(chunks)}")
    i += 1
    if chunk == question:
        continue
    if not contains_pattern(chunk):
        continue
    question = chunk
    
    if prompt_size > 1600:
        number_of_dictionaries += 1
        print("\nNew dictionary\n")
        dictionary.append({})
    
    prompt = [
            {'role': 'system',
             'content': 'You are a helpful research assistant tasked with converting long paragraphs into a JSON '
                        'dictionary. '
                        'The goal is to identify and classify each individual mathematical symbol, variable, '
                        'and identifier in the text marked between "<||>"'
                        'The dictionary should store the identifiers as keys and their corresponding definitions as '
                        'values in an array format. '},
            {'role': 'system', 'name': 'example_user', 'content': '''A relational model is a triple <|M|>′=(<|X|>,<|R|>,
            <|v|>), where <|X|> is a set of states, <|R|><|⊆|><|X|><|×|><|X|> is a binary relation on <|X|>, 
            and <|v|>:<|𝖯𝗋𝗈𝗉|>→2<|X|> is a valuation. Given a relational model <|M|>′, the satisfaction relation 
            between points <|x<|∈<|X<| and formulas <|φ<|∈<|ℒ<|<|𝖪𝖠<| is defined inductively by <|M|>′,
            <|x|>⊨<|𝖪|><|φ|>⇔ for all <|y|>∈<|X|>,<|x|><|R|><|y|> implies <|M|>′,<|y|>⊨<|φ|><|M|>′,
            <|x|>⊨<|𝖠|><|φ|><||>⇔ for all <|y|>∈<|X|>,<|M|>′,<|y|>⊨<|φ|>'''},
            {'role': 'system', 'name': 'example_assistant', 'content': '''identifiers = {
            "M": ["Model", "Expertise Model"],
            "M'": "Relational model",
            "X": "Set of states",
            "R": "Binary relation on X",
            "v": "Valuation",
            "𝖯𝗋𝗈𝗉": "Set of propositions",
            "M'": "Relational model",
            "x": "Point in X",
            "φ": "Formula in 𝖪𝖠",
            "ℒ_{𝖪𝖠}": "Set of formulas",
            "𝖪": "Modal operator K",
            "𝖠": "Modal operator A",
            "y": "Point in X",
            "⊨": "Satisfaction relation",
            "⇔": "If and only if operator",
            "∈": "Element of a set",
            "⊆": "Subset of a set",
            "×": "Cartesian product operator",
            "→": "Function or implication operator",
            "for all": "Universal quantifier"
            }'''},
            {'role': 'system',
             'content': f'Given is already a pre existing dictionary. Your job is to extend this dictionary. Do not '
                        f'remove any pre existing definitions from this dictionary.'
                        f'\n{dictionary[number_of_dictionaries]}. If there is nothing to mention, reply with an empty '
                        f'dictionary'},
            {'role': 'user', 'content': f'Generate a JSON dictionary for the following text: {question}. '
                                        'Only consider the mathematical identifiers inside "<||>" for the dictionary. '
                                        'Do not consider any other identifier other than those marked. '
                                        'Consider all the identifiers individually. Do not skip any identifier, mention'
                                        ' all the identifiers inside "<||>" in your dictionary. '
                                        'Do not include the angle brackets in your dictionary.'}
        ]
    
    open_prompt = get_prompt_loop(prompt)
    
    prompt_size = num_tokens_from_messages(open_prompt)
    print(f"\n\n\n{prompt_size} prompt tokens counted.\n")
    
    while True:
        try:
            input_ids = tokenizer(open_prompt, return_tensors='pt').input_ids.cuda()
            output = model.generate(inputs=input_ids, temperature=0.5, max_new_tokens=512, repetition_penalty=1.05)
            output_string = tokenizer.decode(output[0])
            
            actual_total_tokens += num_tokens_from_messages(output_string)
            completion_tokens += num_tokens_from_messages(output_string) - num_tokens_from_messages(open_prompt)
            prompt_tokens += num_tokens_from_messages(open_prompt)

            ind = output_string.index('Do not include the angle brackets in your dictionary.')
            print(output_string[ind:])
            
            print(completion_tokens, prompt_tokens)
    
            print(f"Actual total tokens till now: {actual_total_tokens}")

            try:
                correct_output_string = modify_string(output_string[ind:])
                dic_output_string = get_json_of_string(correct_output_string)
                new_dictionary = string_to_dict(dic_output_string)
            except Exception as e:
                print("INCORRECT DICTIONARY")
            dictionary[number_of_dictionaries] = merge_dictionaries(dictionary[number_of_dictionaries], new_dictionary)
            
            break
        except Exception as e:
            number_of_dictionaries += 1
            dictionary.append({})
            print(f"Exception occurred: {e}")
            print("Retrying...")
total_time_taken += (datetime.now() - start_time)
print(f"Time taken: {datetime.now() - start_time }")
print(f"Total time taken: {total_time_taken}")
print(actual_total_tokens, completion_tokens, prompt_tokens)

Iteration 17 of 27
Iteration 18 of 27
Iteration 19 of 27
Iteration 20 of 27
Iteration 21 of 27
Iteration 22 of 27
Iteration 23 of 27
Iteration 24 of 27
Iteration 25 of 27
Iteration 26 of 27
Iteration 27 of 27
Iteration 28 of 27
Iteration 29 of 27
Iteration 30 of 27
Iteration 31 of 27
Iteration 32 of 27
Iteration 33 of 27
Iteration 34 of 27
Iteration 35 of 27



1586 prompt tokens counted.

Do not include the angle brackets in your dictionary.


### ASSISTANT:
 {
    "x": "Vector of time series data (stock transaction volume)",
    "y": "Vector of time series data (stock transaction volume)",
    "N": "Number of days in the series",
    "Corr": "Cross-correlation",
    "NCC": "Normalized Cross Correlation",
    "n": "Time step",
    "x[n]": "Value of x at time step n",
    "y[n]": "Value of y at time step n",
    "sum": "Summation operator",
    "1": "Constant 1",
    "-1": "Constant -1",
    "xcorr": "Built-in Matlab feature for normalized cross-correlation"
}</s>
684 2691
Actual total

In [18]:
dct = {}
for dic in dictionary:
    dct = merge_dictionaries(dct, dic)
# pprint.pprint(dct)

In [19]:
for key, value in dct.items():
    if type(value) == list:
        print(f"'{key}': '{value}'")

'Corr': '['Cross-Correlation', 'Cross-correlation']'
'x': '['First signal', 'Vector of time series data (stock transaction volume)']'
'y': '['Second signal', 'Vector of time series data (stock transaction volume)']'
'N': '['Total number of time steps', 'Total number of time steps', 'Total number of time steps', 'Total number of time steps', 'Total number of time steps', 'Number of days in the series']'
'x[n]': '['Value of x at time step n', 'Value of x at time step n', 'Value of x at time step n', 'Value of x at time step n', 'Value of x at time step n']'
'y[n]': '['Value of y at time step n', 'Value of y at time step n', 'Value of y at time step n', 'Value of y at time step n', 'Value of y at time step n']'
'sum': '['Summation operator', 'Summation operator', 'Summation operator', 'Summation operator']'
'1': '['Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1', 'Constant 1']'
'-1': '['Constant -1', 'Constant -1', 'Constant -1', 'Constant -

In [20]:
def flatten_list(input_list):
    output_list = []
    for i in input_list:
        if isinstance(i, list):
            output_list.extend(flatten_list(i))
        else:
            output_list.append(i)
    return output_list


def remove_duplicates(input_list):
    output_list = []
    for item in input_list:
        if item not in output_list:
            output_list.append(item)
    if len(output_list) == 1:
        return output_list[0]
    return output_list


def process_value(v):
    if isinstance(v, str):
        new_v = v.replace('$', '')
        while '\\\\' in new_v:
            new_v = new_v.replace('\\\\', '\\').replace('\n', '')
    else:  # Assuming it's a list
        new_v = flatten_list([process_value(val) for val in v])
        
    return remove_duplicates(new_v) if isinstance(new_v, list) else new_v


def reduce_pairs(dictionary):
    new_dict = {}
    for k, v in dictionary.items():
        # reduce key backslashes
        new_k = k.replace('$', '')
        while '\\\\' in new_k:
            new_k = new_k.replace('\\\\', '\\')

        # process value
        new_v = process_value(v)

        new_dict[new_k] = new_v

    return new_dict

In [21]:
dict_without_backslashes = reduce_pairs(dct)

In [22]:
#pprint.pprint(dict_without_backslashes)

In [23]:
parsed_json = dict_without_backslashes

In [24]:
with open(f'{file_code}_mcdict.json', 'r', encoding='utf-8') as f:
    mc_dict_original = json.loads(f.read())

In [25]:
# Function to create a hex code (a binary representation of the key)
def get_hex_code(key):
    return key.encode().hex()

mc_dict_original['_author'] = model_name_or_path

# Iterate over your dictionary and fill the new one
for key, values in parsed_json.items():
    # Determine the base key and the affix
    base_key = re.match(r"^[^*'_^,(\[]*", key).group()
    affix = key[len(base_key):]

    hex_code = get_hex_code(base_key)
    values = values if isinstance(values, list) else [values]

    if hex_code in mc_dict_original["concepts"]:
        k = list(mc_dict_original["concepts"][hex_code]["identifiers"].keys())[0]
        new_identifier = []
        for value in values:
            mc_dict_original["concepts"][hex_code]["identifiers"][k].append({
                "affixes": [affix] if affix else [],
                "arity": 0,
                "description": value
            })
    else:
        if hex_code not in mc_dict_original["concepts"]:
            mc_dict_original["concepts"][hex_code] = {
                "_surface": {
                    "text": base_key,
                    "unicode_name": base_key if len(base_key) != 1 else unicodedata.name(base_key)
                },
                "identifiers": {
                    'default': []
                }
            }

        for value in values:
            mc_dict_original["concepts"][hex_code]["identifiers"]["default"].append({
                "affixes": [affix] if affix else [],
                "arity": 0,
                "description": value
            })


# Convert new dictionary to a sorted dictionary
sorted_dict = dict(sorted(mc_dict_original["concepts"].items(), key=lambda x: (len(x[0]), x[0])))
mc_dict_original["concepts"] = sorted_dict

# Convert new dictionary to JSON
json_str = json.dumps(mc_dict_original, indent=4, ensure_ascii=False)

#print(json_str)

with open(f'{file_code}-{model_name}_mcdict.json', 'w', encoding='utf-8') as f:
    json.dump(mc_dict_original, f, ensure_ascii=False, indent=4)

In [26]:
def get_text_from_tags(element):
    if isinstance(element, NavigableString):
        return element
    if element.name == 'mi':
        return str(element)
    return ''.join(get_text_from_tags(child) for child in element.children)

def parse_html(file_path):
    with open(file_path, 'r', encoding='utf-8') as html_file:
        soup = BeautifulSoup(html_file, 'html.parser')

    texts = get_text_from_tags(soup)
    return texts

def find_mi_strings(text):
    pattern = r'(<mi.*?</mi>)'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches

# call the function with your HTML file path
page = parse_html(f'{file_code}.html')
matches = find_mi_strings(page)

In [27]:
parsed_dict = mc_dict_original
with open(f'{file_code}_anno.json', encoding='utf-8') as fp:
    parsed_annotation = json.load(fp)

In [28]:
def get_word_index_from_char_index(message, key, char_index):
    i = 0
    index = -1
    for word in message:
        if key in word:
            index = i
        i += 1
    return index

def expand_string_to_tokens(message, index, num_tokens_right=25, num_tokens_left=75):
    words = message.split()  # Split the message into words

    # Start at the index where the center word is
    left_index = right_index = index

    tokens_counter_right = num_tokens_from_messages(words[right_index])
    tokens_counter_left = num_tokens_from_messages(words[left_index])

    # Expand to the left from the center index until you reach num_tokens_left
    while tokens_counter_left < num_tokens_left and left_index > 0:
        left_index -= 1
        tokens_counter_left += num_tokens_from_messages(words[left_index])

    # Expand to the right from the center index until you reach num_tokens_right
    while tokens_counter_right < num_tokens_right and right_index < len(words) - 1:
        right_index += 1
        tokens_counter_right += num_tokens_from_messages(words[right_index])

    # Combine the words back into a string and return
    return ' '.join(words[left_index:right_index + 1])

In [29]:
def replace_text(text, replacement, exception):
    # Find all matches
    matches = re.findall(r'<mi(.*?)</mi>', text)
    
    for match in matches:
        original_string = f'<mi{match}</mi>'
        
        # Skip exception
        if original_string == exception:
            continue
        
        # Replace match
        replaced_string = re.sub(r'<.*?>(.*?)</.*?>', r'\1', original_string)
        text = text.replace(original_string, replaced_string)

    return text

def get_context(match):
    match_len = len(match)
    new_page = replace_text(page, '', match)
    char_index = new_page.index(match) + int(match_len/2)
    word_index = get_word_index_from_char_index(new_page, char_index)
    section = expand_string_to_tokens(new_page, word_index)
    section = re.sub(r'<.*?>(.*?)</.*?>', r'<<\1>>', section)
    return match, section

# Function to create a hex code (a binary representation of the key)
def get_hex_code(key):
    return key.encode('utf-8').hex()

In [30]:
def remove_trailing_tags(s):
    parts = re.split('(<mi)', s)
    for i in range(1, len(parts), 2):
        if '>' not in parts[i + 1]:
            parts[i] = ''
            parts[i + 1] = ''
    return ''.join(parts)

def get_definition_of_id(dict_id, identifier):
    
    try:
        hex_code = get_hex_code(identifier)
        index = parsed_annotation['mi_anno'][dict_id]['concept_id']
        key = list(parsed_dict['concepts'][hex_code]['identifiers'].keys())[0]
        return f"({parsed_dict['concepts'][hex_code]['identifiers'][key][index]['description']})"
    except Exception as e:
        return ""

def get_context(match):
    key_word = page.index(match) + len(match)
    last_index = min(len(page), key_word + 500)
    first_index = max(0, key_word - 3000)
    context_window = page[first_index:last_index]
    
    reg_matches = re.findall(r'<mi(.*?)</mi>', context_window)
    
    identifier = None
    
    for reg_match in reg_matches:
        original_string = f'<mi{reg_match}</mi>'
        soup = BeautifulSoup(original_string, 'html.parser')
        
        
        tags = soup.find_all('mi')
        
        if original_string == match:
            identifier = tags[0].text
            continue

        context_window = context_window.replace(original_string,
                                                f"{tags[0].text}{get_definition_of_id(tags[0].get('id'), tags[0].text)}")
    
    context_window = re.sub(r'<mi.*?>(.*?)<\/mi>', r'<<\1>>', context_window)
    
    context_window = remove_trailing_tags(context_window)
    context_window = re.sub(r'^(?!.*<mi.*).*<\/mi>', '', context_window, flags=re.DOTALL)
        
    index = 0
    for word in context_window.split():
        if f"<<{identifier}>>" in word:
            word_index = index
        index += 1
    
    if word_index == -1:
        return context_window
    else:
        context_window = expand_string_to_tokens(context_window, word_index)
    return context_window

#print(get_context('<mi id="S1.p2.1.m1.1.1.3" xref="S1.p2.1.m1.1.1.3.cmml">φ</mi>'))
#get_context('<mi id="S1.p2.1.m1.1.1.2" xref="S1.p2.1.m1.1.1.2.cmml">𝖤</mi>')

In [31]:
def get_prompt_anno(prompt):
    return f'''### SYSTEM:\n{prompt[0]['content']}\n\n
### USER:\n{prompt[1]['content']}\n\n
### ASSISTANT:\n'''

In [32]:
start_time = datetime.now()
actual_total_tokens = 0
completion_tokens = 0
prompt_tokens = 0
no_tags = 0
no_keys = 0
no_anno = 0
i = 1
for match in matches:
    print(f"Iteration {i} of {len(matches)}: ", end='')
    i += 1
    context = get_context(match)
    match_variable = re.sub(r'<.*?>(.*?)</.*?>', r'\1', match)
    context_index = context.index(f"<<{match_variable}>>") + len(match_variable)
    possible_affix = str(context[context_index+4:context_index+5]).replace("′", "'")
    soup = BeautifulSoup(match, 'html.parser')
    mi_tag = soup.find('mi')
    if mi_tag is not None and 'id' in mi_tag.attrs:
        anno_id = mi_tag['id']
    else:
        print('TAG NOT FOUND', match)
        no_tags += 1
        continue
    
    hex_code = get_hex_code(match_variable)
    if hex_code not in parsed_dict['concepts']:
        match_variable = f"{unidecode(match_variable)}"
        hex_code = get_hex_code(match_variable)
        if hex_code not in parsed_dict['concepts']:
            print("Key does not exist in the dictionary of concepts", match_variable, hex_code)
            no_keys += 1
            continue
    
    if anno_id not in parsed_annotation['mi_anno']:
        print("Annotation ID does not exist in annotation.json", anno_id)
        no_anno += 1
        continue

    k = list(parsed_dict["concepts"][hex_code]["identifiers"].keys())[0]
    mcdict = parsed_dict['concepts'][hex_code]['identifiers'][k]
    
    if len(mcdict) == 1:
        parsed_annotation['mi_anno'][anno_id]['concept_id'] = 0
        print('0')
    elif len(mcdict) > 1:
        prompt_mcdict = []

        index = 0
        for val in mcdict:
            prompt_mcdict.append({'index': f"{index}", 'identifier': f"{match_variable}{'' if len(val['affixes']) == 0 else val['affixes'][0]}", 'description': val['description']})
            index += 1
            
        prompt = [
            {'role': 'system', 'content': 'You are a professional annotater API. Your job is to select a fitting annotation from a dictionary for a mathematical identifier.'},
            {'role': 'user', 'content': f'''Given the following possible annotations:\n```json\n{prompt_mcdict}```.
             Select the index for the most fitting description for the identifier <<{match_variable}>> from the following text.
             The potential affix of the indentifier could be <<{possible_affix}>>. Take the affixes of the possible annotations into account.
             Only return the value of the index and nothing else.
             Do not add any explanation otherwise the API breaks.
             The identifier has been marked with <<>>.
             If you can't come up with an index, write 'None'
             ```txt
             {context}
             ```'''}
        ]

        while True:
            try:
                open_prompt = get_prompt_anno(prompt)
                
                input_ids = tokenizer(open_prompt, return_tensors='pt').input_ids.cuda()
                output = model.generate(inputs=input_ids, temperature=0.5, max_new_tokens=512, repetition_penalty=1.05)
                output_string = tokenizer.decode(output[0])
                
                actual_total_tokens += num_tokens_from_messages(output_string)
                completion_tokens += num_tokens_from_messages(output_string) - num_tokens_from_messages(open_prompt)
                prompt_tokens += num_tokens_from_messages(open_prompt)

                ind = output_string.index('ASSISTANT:')
                value = output_string[ind:]
                print(value)
                
                print(completion_tokens, prompt_tokens)

                try:
                    index = int(int(re.search('\d+', value).group()))
                    print(index)
                    parsed_annotation['mi_anno'][anno_id]['concept_id'] = index
                except Exception as f:
                    print(f)

                break
            except Exception as e:
                print(f"Exception occurred\n{e}")
                print("Retrying...")
    else:
        print('None')

print('Annotation completed')

    
total_time_taken = (datetime.now() - start_time)
print(f"Time taken: {datetime.now() - start_time }")
print(f"Total time taken: {total_time_taken}")
print(actual_total_tokens, completion_tokens, prompt_tokens)

Iteration 1 of 35: ASSISTANT:
 0</s>
6 280
0
Iteration 2 of 35: ASSISTANT:
 2</s>
12 591
2
Iteration 3 of 35: ASSISTANT:
 2</s>
18 901
2
Iteration 4 of 35: 0
Iteration 5 of 35: ASSISTANT:
 0</s>
24 1182
0
Iteration 6 of 35: ASSISTANT:
 2</s>
30 1501
2
Iteration 7 of 35: 0
Iteration 8 of 35: ASSISTANT:
 2</s>
36 1819
2
Iteration 9 of 35: 0
Iteration 10 of 35: None
Iteration 11 of 35: ASSISTANT:
 2</s>
42 2128
2
Iteration 12 of 35: ASSISTANT:
 2</s>
48 2438
2
Iteration 13 of 35: 0
Iteration 14 of 35: ASSISTANT:
 0</s>
54 2721
0
Iteration 15 of 35: ASSISTANT:
 2</s>
60 3034
2
Iteration 16 of 35: 0
Iteration 17 of 35: ASSISTANT:
 2</s>
66 3351
2
Iteration 18 of 35: 0
Iteration 19 of 35: 0
Iteration 20 of 35: ASSISTANT:
 0</s>
72 3636
0
Iteration 21 of 35: ASSISTANT:
 2</s>
78 3956
2
Iteration 22 of 35: 0
Iteration 23 of 35: ASSISTANT:
 2</s>
84 4272
2
Iteration 24 of 35: 0
Iteration 25 of 35: 0
Iteration 26 of 35: ASSISTANT:
 0</s>
90 4554
0
Iteration 27 of 35: ASSISTANT:
 2</s>
96 4870
2


In [33]:
parsed_annotation['_annotator'] = model_name_or_path
with open(f'{file_code}-{model_name}_anno.json', 'w') as fp:
    json.dump(parsed_annotation, fp)

In [34]:
items = 0
for key, value in parsed_annotation['mi_anno'].items():
    if value['concept_id'] is not None:
        #print(key, value)
        items += 1
print(items)

32
