In [None]:
!pip install mistralai

Collecting mistralai
  Downloading mistralai-0.1.3-py3-none-any.whl (15 kB)
Collecting httpx<0.26.0,>=0.25.2 (from mistralai)
  Downloading httpx-0.25.2-py3-none-any.whl (74 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.0/75.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting orjson<4.0.0,>=3.9.10 (from mistralai)
  Downloading orjson-3.9.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (138 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.5/138.5 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pandas<3.0.0,>=2.2.0 (from mistralai)
  Downloading pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m46.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyarrow<16.0.0,>=15.0.0 (from mistralai)
  Downloading pyarrow-15.0.1-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)
[2K     [90m━━━━━━━

In [None]:
!git clone https://github.com/chriswu99aaa/MeTNet.git

Cloning into 'MeTNet'...
remote: Enumerating objects: 241, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 241 (delta 6), reused 0 (delta 0), pack-reused 228[K
Receiving objects: 100% (241/241), 32.24 MiB | 16.25 MiB/s, done.
Resolving deltas: 100% (128/128), done.


In [None]:
import os
os.environ['MISTRAL_API_KEY'] = 'MY-API-KEY'

In [None]:
import os
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-medium"

client = MistralClient(api_key=api_key)


## Loading Data and Preprocessing

In [None]:
from pathlib import Path
import re

def read_file(file_path):
    file_path = Path(file_path)

    raw_text = file_path.read_text().strip()
    raw_docs = re.split(r'\n\t?\n', raw_text)
    token_docs = []
    tag_docs = []
    for doc in raw_docs:
        tokens = []
        tags = []
        for line in doc.split('\n'):
            if '\t' in line:
                token, tag = line.split('\t')
                tokens.append(token)
                tags.append(tag)
        token_docs.append(tokens)
        tag_docs.append(tags)

    return token_docs, tag_docs

In [None]:
train_texts, train_tags = read_file('/content/MeTNet/data/Few-COMM/train.txt')

texts, tags = read_file('/content/MeTNet/data/Few-COMM/dev.txt')

test_text, test_tags = read_file('/content/MeTNet/data/Few-COMM/test.txt')

In [None]:
# get tags from tow datasets and union the two sets
unique_tags = set(tag for doc in train_tags for tag in doc)
unique_tags_val = set(tag for doc in tags for tag in doc)
unique_tags = unique_tags | unique_tags_val

tag2id = {tag: id for id, tag in enumerate(unique_tags)}
id2tag = {id: tag for tag, id in tag2id.items()}

In [None]:
# Getting the first query data
query_texts = texts[:200][:]
query_tags = tags[:200][:]

In [None]:
query = query_texts[0]
query_tag = query_tags[0]

In [None]:
query = ''.join(query)

In [None]:
query

'味可滋冷萃奶茶/250ml'

## Constructing Prompt

In [None]:
import re
example_query = """维	O
他	O
和	O
梨	O
饮	O
料	O
2	O
5	O
0	O
m	O
l	O
*	O
6	O
盒	O
/	O
组	O
（	O
香	B-原产地
港	I-原产地
进	I-原产地
口	I-原产地
）	O
"""

1. extract an example for each label
2. create a dictionary dict{word: tag}
3. put example according to its id. i.e Label A has id 3, then intext_mapping[3] will have an example text

In [None]:
tag2id['I-适用性别']

75

In [None]:
num_query_tags = len(unique_tags)

In [None]:
text_mapping = {}
assigned = [0]*num_query_tags



In [None]:
# sampling each class with one entity
test_tags = query_tags[:8000]
for i in range(len(test_tags)):
    tag = test_tags[i]
    for j in range(len(tag)):
        if tag[j] != 'O':
            # the index corresponds to the id of the tag
            id = tag2id[tag[j]]
            word = query_texts[i][j]
            if assigned[id] == 0:
                # give one example for each class
                text_mapping[word] = tag[j]
                assigned[id] = 1
                print(text_mapping[word])
                print(id)


In [None]:
text_mapping = str(text_mapping).replace(':','')

In [None]:
system_prompt = f"""
    You are a Chinese Named Entity Recognizer. The input data is in Chinese and the
    data is annotated in BIO scheme.
    The full list of BIO Tags are provided in the following part delimeted by triple #
    ###
    'B-国产/进口','B-适用人群','I-适用人群','B-适用空间','I-版型','B-面料材质','B-产地','B-酸碱度',
    'B-果肉颜色','B-锅底类型','B-是否去骨','I-适用空间','I-是否有机','B-外观','B-袖长','I-长短','B-组合形式', 'I-送礼对象',
    'I-袖型','I-型号','B-型号','B-形状形态','I-接口','I-是否净洗','I-供电方式','B-送礼对象','B-是否去皮','B-是否净洗','I-适用人数',
    'I-定制服务','B-定制服务','I-适用对象','B-供电方式','I-产地','I-配件类型','I-脂肪含量','I-大小','B-其他属性','O','B-厚薄',
    'I-糖含量','I-面料材质','I-袖长','I-造型','I-颜色','I-是否去骨','B-赠品','I-其他属性', 'I-品质等级','I-适用运营商','I-组合形式',
    'I-果肉颜色','B-接口','I-功能功效','B-适用人数','B-版型','B-材质','I-厚度','B-功能功效','B-长短' 'B-造型','I-成分','I-填充材质',
    'I-是否去皮','B-适用衣物','I-国产/进口','I-鞋垫材质','B-填充材质', 'B-运输服务',
    'I-控制方式', 'I-适用衣物','B-品质等级','B-适用运营商','B-适用对象','I-存储容量','I-香型','I-运输服务','I-商品特色',
    'B-脂肪含量','I-厚薄','B-控制方式','B-筒高','I-长度','I-外观','B-领型','B-分类','I-适用季节','B-适用季节','B-糖含量','B-存储容量',
    'I-分类','B-商品特色','I-赠品','B-香型','B-是否有机','I-材质','I-酸碱度','B-厚度','B-礼盒类型','I-筒高','B-袖型','I-形状形态',
    'B-成分','B-大小','I-领型','B-鞋垫材质','B-长度', 'I-锅底类型', 'B-配件类型', 'B-颜色',
    'I-礼盒类型'
    ###
'''
    An example  of character and its BIO tag is provided in the delimeter triple #
    ###
    {text_mapping}
    ###
    When giving a input setence in chinese classify each character
    using the following format
    character BIOtag
    If input has number and symbols, keep number and symbols in the output


    If all characters are tagged as O class, provide the output without include an explanation or notification
    provide only the word and tag in the ouptut, and nothing else.
    remove your Note from your response
"""

In [None]:
messages = [
    ChatMessage(role='system', content=system_prompt),
    ChatMessage(role="user", content='清真肉串')
]

# No streaming
chat_response = client.chat(
    model=model,
    messages=messages,
)

result = chat_response.choices[0].message.content

In [None]:
result

"清 'B-是否清真'\n真 'I-是否清真'\n肉 'B-适用运营商'\n串 'I-商品特色'"

In [None]:
import re
result = re.split('\n',result)


In [None]:
test_result = [re.split(' ',line) for line in result]

In [None]:
test_result

[['清', "'B-是否清真'"], ['真', "'I-是否清真'"], ['肉', "'B-适用运营商'"], ['串', "'I-商品特色'"]]

In [None]:
def check_format(result):
    for i in range(len(result)):
        if '' in result[i]:
            result.pop()
            result.pop()
            break
    return result

In [None]:
test_result = check_format(test_result)

In [None]:
test_result


[['清', "'B-是否清真'"], ['真', "'I-是否清真'"], ['肉', "'B-适用运营商'"], ['串', "'I-商品特色'"]]

## Evaluation

In [None]:
from sklearn.metrics import f1_score
import statistics
def compute_metrics(preds, labels):
    """
    This function computes the confusion matrix for the given prediction and label
    """

    return f1_score(labels,preds,average='micro')

def average_f1(f1_list):
    return statistics.mean(f1_list)

In [None]:
def check_key(dict, predictions):
    """
    this function convert prediction to its id
    by checking if the tag exists; otherwise assign to
    O-class

    dict: the tag2id dictionary
    predictions: the list of prediciton
    """
    for i in range(len(predictions)):
        pred = predictions[i]
        # print(pred)
        if pred in dict.keys():
            # print(pred, ' in map')
            predictions[i] = dict[pred]
        else:
            predictions[i] = dict['O']
    return predictions



In [None]:
def validate_dimension(preds, labels):
    '''
    This function validate size of two input to check if they
    have the same size; otherwise, it drim the list from the back.

    '''
    if len(preds) == len(labels):
        return preds, labels
    else:
        len_p = len(preds)
        len_l = len(labels)
        if len_p > len_l:
# if length of prediction is longer than the label, remove those extra elements
            preds = preds[:len_l]
            return preds, labels
        else:
            labels = labels[:len_p]
            return preds, labels

## One Shot

In [None]:
query_texts = test_texts[:8000][:]
query_tags = test_tags[:8000][:]

In [None]:
f1_scores = []



In [None]:
for i in range(200,len(query_texts)):
    print(i)

    query = query_texts[i]
    query = ' '.join(query)
    true_label = query_tags[i]


    messages = [
    ChatMessage(role='system', content=system_prompt),
    ChatMessage(role="user", content=query)
    ]

    # No streaming
    chat_response = client.chat(
        model=model,
        messages=messages
        # temperature=0.2
    )


    result = chat_response.choices[0].message.content

    result = re.split('\n',result)
    result = [re.split('[ \t]',line) for line in result]
    print(result)
    result = check_format(result)


    pred = [line[1] for line in result]
    pred = check_key(tag2id,pred)

    # convert tag to id for f1 score calculation
    true_label = [int(tag2id[tag]) for tag in true_label]

    # validate input
    pred, true_label = validate_dimension(pred, true_label)

    f1 = compute_metrics(true_label, pred)
    if f1 < 0.4:
        # record those low f1_score prediction for error analysis
        id_list.append(i)
        pred_list.append(word_pred)
        low_f1_list.append(f1)
    f1_scores.append(f1)



In [None]:
print("One shot :",average_f1(f1_scores))

0.670989257213999

## Five Shots

In [None]:
five_shot_f1_scores = []


In [None]:
num_query_tags = len(unique_tags)
text_mapping = {}
assigned = [0]*num_query_tags


In [None]:
# sampling five instance for each class
test_tags = query_tags[:8000]
for i in range(len(test_tags)):
    tag = test_tags[i]
    for j in range(len(tag)):
        if tag[j] != 'O':
            # the index corresponds to the id of the tag
            id = tag2id[tag[j]]
            word = query_texts[i][j]
            if assigned[id] < 5:
                if assigned[id] == 0:
                    word_list = [word]
                    text_mapping[tag[j]] = word_list
                    assigned[id] += 1
                else:
                    word_list = text_mapping[tag[j]]
                    if word not in word_list:
                        word_list.append(word)
                        assigned[id] += 1
                    text_mapping[tag[j]] = word_list



In [None]:
text_mapping

{'B-冲泡方式': ['冷'],
 'I-冲泡方式': ['萃'],
 'B-适用时间': ['端', '夏', '女', '初', '毕'],
 'I-适用时间': ['午', '节', '天', '神', '秋'],
 'B-适用性别': ['女', '男'],
 'I-适用性别': ['孩', '司', '机', '女', '通'],
 'B-粗细': ['细', '圆'],
 'I-粗细': ['粉', '条'],
 'B-色系': ['香'],
 'I-色系': ['槟', '色', '系', '可', '选'],
 'B-保质期': ['六'],
 'I-保质期': ['个', '月', '以', '上'],
 'B-风味': ['港', '戚', '老', '风', '苏'],
 'I-风味': ['式', '风', '味'],
 'B-适用生肖': ['兔', '猴'],
 'I-适用生肖': ['子'],
 'B-连接方式': ['光', '有', 'a', '无'],
 'I-连接方式': ['纤', '线', 'u', 'x', '限'],
 'B-加热方式': ['快', '煤', '燃', '双'],
 'I-加热方式': ['速', '电', '热', '气', '面'],
 'B-剂型': ['颗', '微', '含', '圆'],
 'I-剂型': ['粒', '颗', '片'],
 'B-甜度': ['无', '甜', '好', '干', '纯'],
 'I-甜度': ['糖', '度', '高', '吃', '型'],
 'B-适用车型': ['轿', 's', '丰', '比', '货'],
 'I-适用车型': ['车', 'u', 'v', '田', '亚'],
 'B-是否清真': ['清'],
 'I-是否清真': ['真'],
 'B-系列': ['甄', '满', '特', '小', '超'],
 'I-系列': ['选', '钻', '享', '猪', '佩'],
 'B-是否带盖': ['带', '可'],
 'I-是否带盖': ['盖', '子'],
 'B-裙型': ['小', '拼', '公'],
 'I-裙型': ['黑', '裙', '接', '款', '主'],
 'B-内容': ['女'],
 '

In [None]:
assigned

In [None]:
text_mapping = str(text_mapping)

In [None]:
system_prompt = f"""
    You are a Chinese Named Entity Recognizer. The input data is in Chinese and the
    data is annotated in BIO scheme.
    The full list of BIO Tags are provided in the following part delimeted by triple #
    ###
    'B-国产/进口','B-适用人群','I-适用人群','B-适用空间','I-版型','B-面料材质','B-产地','B-酸碱度',
    'B-果肉颜色','B-锅底类型','B-是否去骨','I-适用空间','I-是否有机','B-外观','B-袖长','I-长短','B-组合形式', 'I-送礼对象',
    'I-袖型','I-型号','B-型号','B-形状形态','I-接口','I-是否净洗','I-供电方式','B-送礼对象','B-是否去皮','B-是否净洗','I-适用人数',
    'I-定制服务','B-定制服务','I-适用对象','B-供电方式','I-产地','I-配件类型','I-脂肪含量','I-大小','B-其他属性','O','B-厚薄',
    'I-糖含量','I-面料材质','I-袖长','I-造型','I-颜色','I-是否去骨','B-赠品','I-其他属性', 'I-品质等级','I-适用运营商','I-组合形式',
    'I-果肉颜色','B-接口','I-功能功效','B-适用人数','B-版型','B-材质','I-厚度','B-功能功效','B-长短' 'B-造型','I-成分','I-填充材质',
    'I-是否去皮','B-适用衣物','I-国产/进口','I-鞋垫材质','B-填充材质', 'B-运输服务',
    'I-控制方式', 'I-适用衣物','B-品质等级','B-适用运营商','B-适用对象','I-存储容量','I-香型','I-运输服务','I-商品特色',
    'B-脂肪含量','I-厚薄','B-控制方式','B-筒高','I-长度','I-外观','B-领型','B-分类','I-适用季节','B-适用季节','B-糖含量','B-存储容量',
    'I-分类','B-商品特色','I-赠品','B-香型','B-是否有机','I-材质','I-酸碱度','B-厚度','B-礼盒类型','I-筒高','B-袖型','I-形状形态',
    'B-成分','B-大小','I-领型','B-鞋垫材质','B-长度', 'I-锅底类型', 'B-配件类型', 'B-颜色',
    'I-礼盒类型'
    ###
'''
    An example  of character and its BIO tag is provided in the delimeter triple #
    ###
    {text_mapping}
    ###
    When giving a input setence in chinese classify each character
    using the following format
    character BIOtag
    If input has number and symbols, keep number and symbols in the output


    If all characters are tagged as O class, provide the output without include an explanation or notification
    provide only the word and tag in the ouptut, and nothing else.
    remove your Note from your response
"""

In [None]:
def check_format(result):
    for i in range(len(result)):
        if '' in result[i]:
            result.pop()
            result.pop()
            break
    for i in range(len(result)):
        for j in range(len(result[i])):
            if (result[i][j]) == 1:
                result.pop(i)
    return result

In [None]:
five_shot_query_texts = test_texts[:8000][:]
five_shot_query_tags = test_tags[:8000][:]

In [None]:
# id list which record id of those prediction f1 scores lower than 0.4
# pred list will record the prediction from LLM
id_list = []
pred_list = []
low_f1_list = []

In [None]:
for i in range(len(five_shot_query_texts)):
    print(i)

    query = five_shot_query_texts[i]
    query = ' '.join(query)
    true_label = five_shot_query_tags[i]


    messages = [
    ChatMessage(role='system', content=system_prompt),
    ChatMessage(role="user", content=query)
    ]

    # No streaming
    chat_response = client.chat(
        model=model,
        messages=messages,
        temperature=0.2
    )


    result = chat_response.choices[0].message.content

    result = re.split('\n',result)
    result = [re.split('[ \t]',line) for line in result]
    print(result)
    result = check_format(result)


    word_pred = [line[1] for line in result]
    pred = check_key(tag2id,word_pred)

    # convert tag to id for f1 score calculation
    true_label = [int(tag2id[tag]) for tag in true_label]

    # validate input
    prediction, true_label = validate_dimension(pred, true_label)
    f1 = compute_metrics(true_label, prediction)

    if f1 < 0.4:
        # record those low f1_score prediction for error analysis
        id_list.append(i)
        pred_list.append(word_pred)
        low_f1_list.append(f1)
    five_shot_f1_scores.append(f1)

197
[['b', '-', '百', 'B-其他属性'], ['事', '-', '事', 'O'], ['可', '-', '可', 'O'], ['乐', '-', '乐', 'O'], ['无', '-', '无', 'B-糖含量'], ['糖', '-', '糖', 'I-糖含量'], ['树', '-', '树', 'B-果肉颜色'], ['莓', '-', '莓', 'I-果肉颜色'], ['味', '-', '味', 'O'], ['碳', '-', '碳', 'B-酸碱度'], ['酸', '-', '酸', 'I-酸碱度'], ['汽', '-', '汽', 'B-供电方式'], ['水', '-', '水', 'I-供电方式'], ['饮', '-', '饮', 'B-其他属性'], ['料', '-', '料', 'I-其他属性'], ['5', '-', '5', 'O'], ['0', '-', '0', 'O'], ['0', '-', '0', 'O'], ['m', '-', 'm', 'O'], ['l', '-', 'l', 'O'], ['/', '-', '/', 'O'], ['1', '-', '1', 'O'], ['瓶', '-', '瓶', 'B-其他属性'], ['/', '-', '/', 'O'], ['份', '-', '份', 'I-其他属性']]
198
[['单', 'O'], ['片', 'O'], ['套', 'O'], ['装', 'O'], ['光', 'O'], ['盘', 'O'], ['加', 'O'], ['袋', 'O'], ['3', 'O'], ['7', 'O'], ['2', 'O'], ['5', 'O'], ['得', 'O'], ['力', 'O'], ['1', 'O'], ['套', 'O'], [''], ['Note:', 'All', 'characters', 'in', 'the', 'input', 'are', 'tagged', 'as', 'O', 'class,', 'which', 'means', 'they', 'do', 'not', 'belong', 'to', 'any', 'of', 'the', 'named', 'entit

In [None]:
print("Five shot :",average_f1(f1_scores))

0.789081294871634

pred_list

In [None]:
low_f1_list

[0.3125, 0.36363636363636365, 0.22727272727272727, 0.39285714285714285, 0.1875]