In [2]:
import re
import sys
sys.path.append('./readme2kg-exp/src/')
import os
import random
from collections import defaultdict
from termcolor import colored
from functools import partial, reduce
import operator as op
import hashlib
import multiprocessing as mp
import logging

from predictor import BasePredictor, LABELS
from webanno_tsv import webanno_tsv_read_file, Document, Annotation, Token
import utils
import cleaner

In [12]:
phase = 'test_unlabeled'
base_path = f'./readme2kg-exp/data/{phase}'
file_names = [fp for fp in os.listdir(base_path) if os.path.isfile(os.path.join(base_path, fp)) and fp.endswith('.tsv')]
model_name = 'Mistral-7B-Instruct-v0.3'
output_folder = f'./readme2kg-exp/results/{model_name}/{phase}'
os.makedirs(output_folder, exist_ok=True)

In [13]:
prompt_id = 0
prompt_template_path = f'./readme2kg-exp/config/deepseek-chat-prompt-0.txt'
if os.path.isfile(prompt_template_path):
    with open(prompt_template_path, 'r') as fd:
        prompt_template = fd.read()
else:
    prompt_template = ''

# Load Mistral model

In [5]:
from huggingface_hub import snapshot_download
from pathlib import Path

mistral_models_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistralai/Mistral-7B-Instruct-v0.3", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

'/home/ann/mistral_models/7B-Instruct-v0.3'

In [6]:
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest


tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3")
model = Transformer.from_folder(mistral_models_path)

# unit test code
sentence_text = """# DejaVu ## Table of Contents =================    * [Code](#code)     * [Install Requirements](#install-requirements)     * [Usage](#usage)     * [Example](#example)   * [Datasets](#datasets)   * [Deployment and Failure Injection Scripts of Train-Ticket](#deployment-and-failure-injection-scripts-of-train-ticket)   * [Citation](#citation)   * [Supplementary details](#supplementary-details)    ## Paper A preprint version: https://arxiv.org/abs/2207.09021 ## Code ### Install 1."""
prompt = prompt_template.replace('{input_text}', sentence_text)

# original code
#prompt = prompt_template.replace('{input_text}', sentence.text)

'''
messages=[
            {"role": "system", "content": "You are a helpful NER annotator"},
            {"role": "user", "content": prompt},
        ]
'''
completion_request = ChatCompletionRequest(messages=[
            {"role": "system", "content": "You are a helpful NER annotator"},
            {"role": "user", "content": prompt},
        ]
)

tokens = tokenizer.encode_chat_completion(completion_request).tokens

out_tokens, _ = generate([tokens], model, max_tokens=1000, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)

print(result)



In [7]:
def do_prediction(sentence, tokens, sid_path):
    try:
        #print(f"Process-{os.getpid()} processing {colored(sentence.text, 'red')} ...")
        prompt = prompt_template.replace('{input_text}', sentence.text)

        messages=[
                    {"role": "system", "content": "You are a helpful NER annotator"},
                    {"role": "user", "content": prompt},
                ]
        
        completion_request = ChatCompletionRequest(messages=[
            {"role": "system", "content": "You are a helpful NER annotator"},
            {"role": "user", "content": prompt},
        ])
        tokens = tokenizer.encode_chat_completion(completion_request).tokens

        out_tokens, _ = generate([tokens], model, max_tokens=255, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
        result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
        
        #print(f"Process-{os.getpid()} predict {colored(sentence.text, 'cyan')} successfully")
        with open(sid_path, 'w') as file:
            file.write(result)
    except Exception as ex:
        logging.error(f'[do_prediction] got exception: {ex}')

In [8]:
def extract_annotation_labels_if_possible(predicted_text):
    label_to_text_list = defaultdict(list)
    acc_adjusted_pos = 0
    for label in LABELS:
        regex = f'<{label}>(.*?)</{label}>'
        matches = re.finditer(regex, predicted_text, flags=re.IGNORECASE | re.DOTALL)
        for m in matches:
            adjusted_pos = len(label) + 2
            label_to_text_list[label].append({
                'text': m.group(1),
                'start': m.start(1) - adjusted_pos - acc_adjusted_pos,
                'end': m.end(1) - adjusted_pos - acc_adjusted_pos,
            })
            acc_adjusted_pos += adjusted_pos * 2 + 1
    return label_to_text_list



def post_process(predicted_text, tokens):
    cleaned_text = cleaner.Cleaner(predicted_text).clean()
    label_to_text_list = extract_annotation_labels_if_possible(cleaned_text)
    return label_to_text_list

In [9]:
def predict(sentence, tokens):
    path = f'./readme2kg-exp/results/{model_name}/prompt-{prompt_id}/zzz_{file_name}' # NOTE: prefix zzz for directory sorting, non-sense
    os.makedirs(path, exist_ok=True)
    sid = hashlib.sha256(sentence.text.encode()).hexdigest()[:8]
    #if not os.path.isfile(f'{path}/{sid}.txt'):   # original code
    if os.path.isdir(f'{path}'):
        do_prediction(sentence, tokens, f'{path}/{sid}.txt')

    with open(f'{path}/{sid}.txt', 'r') as fd:
        predicted_text = fd.read()

    label_to_text_list = post_process(predicted_text, tokens)
    # NOTE: sanity checking
    for label, text_list in label_to_text_list.items():
        for text in text_list:
            if text['text'] != sentence.text[text['start']:text['end']]:
                prompt = prompt_template.replace('{input_text}', sentence.text)
                #logging.warning(f"BUG? The predicted text is not exact the same as the original text. \n\nPrompt: {prompt}\nOriginal: {colored(sentence.text, 'green')}\nGenerated: {colored(text['text'], 'red')}\n--------------------------------------------------------------------------------")

    span_tokens_to_label_list = []
    for label, text_list in label_to_text_list.items():
        for text in text_list:
            span_tokens_to_label_list.append({
                'span_tokens': utils.make_span_tokens(tokens, text['start'], text['end']),
                'label': label
            })
    return span_tokens_to_label_list


In [10]:
def call_serial(doc: Document):
    annotations = []
    for sent in doc.sentences:
        tokens = doc.sentence_tokens(sent)
        span_tokens_to_label_list = predict(sentence=sent, tokens=tokens)
        
        # create the annotation instances
        for span_tokens_to_label in span_tokens_to_label_list:
            span_tokens = span_tokens_to_label['span_tokens']
            label = span_tokens_to_label['label']
            if span_tokens is None:
                continue

            annotation = utils.make_annotation(tokens=span_tokens, label=label)
            annotations.append(annotation)

    result = utils.replace_webanno_annotations(doc, annotations=annotations)
    return result

In [14]:
for file_name in file_names:
    file_path = os.path.join(base_path, file_name)
    ref_doc = webanno_tsv_read_file(file_path)
    predicted_doc = call_serial(ref_doc)
    # Verify
    if ref_doc.text != predicted_doc.text:
        #logging.warning('content changed')
        pass
    if len(ref_doc.sentences) == len(predicted_doc.sentences):
        #logging.warning('sentences changed')
        pass
    if len(ref_doc.tokens) == len(predicted_doc.tokens):
        #logging.warning('tokens changed')
        pass
    for s1, s2 in zip(ref_doc.sentences, predicted_doc.sentences):
        if s1 == s2:
            #logging.warning(f'sentence changed, \n{s1}\n{s2}')
            pass

    for t1, t2 in zip(ref_doc.tokens, predicted_doc.tokens):
        if t1 == t2:
            #logging.warning(f'token changed: \n{t1}\n{t2}')
            pass

    logging.warning(f"Predicted {len(predicted_doc.annotations)} annotations")
    prediction_path = os.path.join(output_folder, file_name)
    with open(prediction_path, 'w') as fd:
        fd.write(predicted_doc.tsv())

