# Environmental Setup

In [None]:
!pip install checklist
!jupyter nbextension install --py --user checklist.viewer
!jupyter nbextension enable --py --user checklist.viewer

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM

import checklist
from checklist.editor import Editor
from checklist.perturb import Perturb
from checklist.test_types import INV
import csv
import spacy
import numpy as np
import itertools

from tqdm import tqdm
from sklearn.metrics import accuracy_score

In [None]:
# Need to login to the Hugging Face hub to download the Gemma model
!pip install huggingface_hub
from huggingface_hub import notebook_login
notebook_login()

# Model Setup

In [None]:
def load_model_and_tokenizer(name="qwen"):

    path_dict = {
        "qwen" : "Qwen/Qwen1.5-7B-Chat",
        "aya" : "CohereForAI/aya-101",
        "yi" : "01-ai/Yi-6B-Chat",
        "gemma" : "google/gemma-2b-it",
    }

    assert name in path_dict, "unknown model"

    tokenizer = AutoTokenizer.from_pretrained(path_dict[name])
    if name == 'aya':
        model = AutoModelForSeq2SeqLM.from_pretrained(path_dict[name], torch_dtype="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(path_dict[name], torch_dtype="auto")

    return model, tokenizer


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model, tokenizer = load_model_and_tokenizer(name="qwen")

model = model.to(device)

# Create Dataset

In [None]:
editor = checklist.editor.Editor()
editor.tg

In [None]:
nlp = spacy.load('en_core_web_sm')

## Load Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!ls ./drive/MyDrive

In [None]:
qs = []
labels = []
all_questions = set()
for x in open('./drive/MyDrive/quora_duplicate_questions.tsv').readlines()[1:]:
    try:
        q1, q2, label = x.strip().split('\t')[3:]
    except:
        print(x)
        continue
    all_questions.add(q1)
    all_questions.add(q2)
    qs.append((q1, q2))
    labels.append(label)
labels = np.array(labels).astype(int)

In [None]:
print(qs[:5])
print(labels[:5])

In [None]:
all_questions = list(all_questions)
parsed_questions = list(nlp.pipe(all_questions))
spacy_map = dict([(x, y) for x, y in zip(all_questions, parsed_questions)])

In [None]:
parsed_qs = [(spacy_map[q[0]], spacy_map[q[1]]) for q in qs]

## Load Chinese translation QQP

In [None]:
qs_zh = []
labels_zh = []
all_questions_zh = set()
for x in open('./drive/MyDrive/quora_duplicate_questions_zh_cn.tsv').readlines()[1:]:
    try:
        q1, q2, label = x.strip().split('\t')[3:]
    except:
        print(x)
        continue
    all_questions_zh.add(q1)
    all_questions_zh.add(q2)
    qs_zh.append((q1, q2))
    labels_zh.append(label)
labels_zh = np.array(labels_zh).astype(int)

In [None]:
print(len(qs_zh))
for q, l in zip(qs_zh, labels_zh):
    print(q)
    print(l)


# Create Tests

## Robustness

In [None]:
def wrap_apply_to_each(fn, both=False, *args, **kwargs):
    def new_fn(qs, *args, **kwargs):
        q1, q2 = qs
        ret = []
        fnq1 = fn(q1, *args, **kwargs)
        fnq2 = fn(q2, *args, **kwargs)
        if type(fnq1) != list:
            fnq1 = [fnq1]
        if type(fnq2) != list:
            fnq2 = [fnq2]
        ret.extend([(x, str(q2)) for x in fnq1])
        ret.extend([(str(q1), x) for x in fnq2])
        if both:
            ret.extend([(x, x2) for x, x2 in itertools.product(fnq1, fnq2)])
        return [x for x in ret if x[0] and x[1]]
    return new_fn

def wrap_apply_to_both(fn, *args, **kwargs):
    def new_fn(qs, *args, **kwargs):
        q1, q2 = qs
        ret = []
        fnq1 = fn(q1, *args, **kwargs)
        fnq2 = fn(q2, *args, **kwargs)
        if type(fnq1) != list:
            fnq1 = [fnq1]
        if type(fnq2) != list:
            fnq2 = [fnq2]
        ret.extend([(x, x2) for x, x2 in itertools.product(fnq1, fnq2)])
        return [x for x in ret if x[0] and x[1]]
    return new_fn

typos & contractions

In [None]:
ROB_typo_data = Perturb.perturb(qs, wrap_apply_to_each(Perturb.add_typos), nsamples=1000).data

ROB_contra_data = Perturb.perturb(qs, wrap_apply_to_each(Perturb.contractions, both=True), nsamples=1000).data



paraphrase

In [None]:
import re

def me_to_you(text):
    t = re.sub(r'\bI\b', 'you', text)
    t = re.sub(r'\bmy\b', 'your', t)
    return re.sub(r'\bmine\b', 'yours', t)

def paraphrases(text):
    ts = ['How do I ', 'How can I ', 'What is a good way to ', 'How should I ']
    templates1 = ['How do I {x}?', 'How can I {x}?', 'What is a good way to {x}?', 'If I want to {x}, what should I do?',
                'In order to {x}, what should I do?']
    ts2 = ['Can you ', 'Can I ']#, 'Do I']
    ts3 = ['Do I ']
    templates2 = ['Can you {x}?', 'Can I {x}?', 'Do you think I can {x}?', 'Do you think you can {x}?',]
    templates3 = ['Do I {x}?', 'Do you think I {x}?']
    ret = []
    for i, (tsz, templates) in enumerate(zip([ts, ts2, ts3], [templates1, templates2, templates3])):
        for t in tsz:
            if text.startswith(t):
                x = text[len(t):].strip('?')
                ts = editor.template(templates, x=x).data[0]
                if i <= 1:
                    ts = ts + [me_to_you(x) for x in ts]
                ret += ts
    return ret

def paraphrases_product(text):
    pr = paraphrases(text)
    return list(itertools.product(pr, pr))

def paraphrase_each(pair):
    p1 = paraphrases(pair[0])
    p2 = paraphrases(pair[1])
    return list(itertools.product(p1, p2))

In [None]:
ROB_paraphrase_prod_data = Perturb.perturb(list(all_questions), paraphrases_product, nsamples=100, keep_original=False).data

ROB_paraphrase_each_data = Perturb.perturb(qs, paraphrase_each, nsamples=100, keep_original=True).data

### Typo for Chinese Characters

In [None]:
!pip install pypinyin

from pypinyin import pinyin, lazy_pinyin, Style
import os
import pickle

In [None]:
def same_tone(character, dic):
    try:
        ch_pinyin = pinyin(character, style=Style.TONE3, heteronym=False)
        res = []
        for p_li in ch_pinyin:
            for p in p_li:
                if character in dic[p]:
                    dic[p].remove(character)
                res.extend(dic[p])
        return res
    except Exception as e:
        return []

In [None]:
typo_dic = None

with open('./drive/MyDrive/chinese_3500.pickle', 'rb') as f:
    typo_dic = pickle.load(f)

print(typo_dic)

In [None]:
import random

# def robust_typo(data, typo_dic, n_typos=2, n=10):
#     result = []
#     for text in data:
#         temp_data = [[] for _ in range(n)]
#         temp_data[0] = text
#         for i in range(1, n):
#             modified_sentence = text
#             typo_position = random.sample(range(0, len(text)), n_typos)
#             for j in typo_position:
#                 typo_list = same_tone(text[j], typo_dic)
#                 if typo_list == []:
#                     continue
#                 modified_sentence = modified_sentence[0:j] + random.choice(typo_list) + modified_sentence[j+1:]
#             temp_data[i] = modified_sentence
#         result.append(temp_data)
#     return result


def robust_typo(text, typo_dic, n_typos=2):
    modified_sentence = text
    typo_position = random.sample(range(0, len(text)), n_typos)
    for j in typo_position:
        typo_list = same_tone(text[j], typo_dic)
        if typo_list == []:
            continue
        modified_sentence = modified_sentence[0:j] + random.choice(typo_list) + modified_sentence[j+1:]
    return modified_sentence

In [None]:
# toy_examples = ['在马来西亚，冲浪是错误的吗？',
#                 '投资者不应该担心什么事情？'
# ]

# toy_examples_typo = robust_typo(toy_examples, typo_dic, n_typos=2, n=5)
# print(toy_examples_typo)

In [None]:
ROB_typo_data_zh = Perturb.perturb(qs_zh, wrap_apply_to_each(lambda x: robust_typo(x, typo_dic, n_typos=2)), nsamples=1000).data

In [None]:
print(ROB_typo_data_zh[:5])

## NER

### names

In [None]:
# adjs_without_overlap = ['dead', 'gay', 'Jewish', 'Christian', 'American', 'mad', 'immortal', 'evil', 'famous', 'racist', 'Muslim', 'white', 'black', 'English', 'autistic', 'Australian', 'trustworthy', 'an atheist', 'an anarchist', 'an inventor', 'Indian', 'Armenian', 'an astronaut', 'an immigrant']
adjs_without_overlap = ['死掉的', '同性恋', '犹太人', '基督教的', '美国的', '疯的', '不朽的', '邪恶的', '有名的', '有种族主义的', '穆斯林教的', '白的', '黑的', '英国的', '自闭的', '澳大利亚的', '值得信赖的', '信无神论的', '信无政府主义的', '一个发明家', '印度的', '亚美尼亚的', '一个宇航员', '移民']

In [None]:
first_names = [
    '伟', '芳', '娜', '秀英', '敏', '静', '丽', '强', '磊', '军', '洋', '勇', '艳', '杰', '娟', '涛', '明', '超', '秀兰', '霞',
    '平', '刚', '桂英', '帅', '晨', '波', '琳', '秀珍', '健', '俊', '帆', '宁', '琴', '宇', '芬', '云', '洁', '林', '哲', '岩',
    '辉', '菲', '华', '梅', '琪', '虹', '明哲', '天翔', '浩', '文'
]
last_names = [
    '王', '李', '张', '刘', '陈', '杨', '黄', '吴', '赵', '周', '徐', '孙', '马', '朱', '胡', '林', '郭', '何', '高', '罗',
    '郑', '梁', '谢', '宋', '唐', '许', '韩', '冯', '邓', '曹', '彭', '曾', '肖', '田', '董', '袁', '潘', '于', '蒋', '蔡',
    '余', '杜', '叶', '程', '苏', '魏', '吕', '丁', '任', '沈'
]


person1 and person2 are different by first and last name

In [None]:
t = editor.template((
    # 'Is {first_name1} {last_name1} {adj}?',
    # 'Is {first_name2} {last_name2} {adj}?',
    '{last_name1}{first_name1}是{adj}吗?',
    '{last_name2}{first_name2}是{adj}吗?',

    ),
    first_name=first_names,
    last_name=last_names,
    adj=adjs_without_overlap,
    remove_duplicates=True,
    nsamples=1000)

NER_first_last_data = t.data
print(NER_first_last_data[:5])
# label 0

person1 and person2 are different by first name only

In [None]:
t = editor.template((
    # 'Is {first_name} {last_name} {adj}?',
    # 'Is {first_name2} {last_name} {adj}?',
    '{last_name}{first_name1}是{adj}吗?',
    '{last_name}{first_name2}是{adj}吗?',
    ),
    first_name=first_names,
    last_name=last_names,
    adj=adjs_without_overlap,
    remove_duplicates=True,
    nsamples=1000)

NER_first_data = t.data
print(NER_first_data[:5])
# label = 0

person1 and person2 are different by last name only

In [None]:
t = editor.template((
    # 'Is {first_name} {last_name} {adj}?',
    # 'Is {first_name} {last_name2} {adj}?',
    '{last_name1}{first_name}是{adj}吗?',
    '{last_name2}{first_name}是{adj}吗?',
    ),
    first_name=first_names,
    last_name=last_names,
    adj=adjs_without_overlap,
    remove_duplicates=True,
    nsamples=1000)

NER_last_data = t.data
print(NER_last_data[:5])
# label = 0

Locations, Names, Numbers

In [None]:
def change_both_wrapper(fn):
    def change_both(qs):
        q1, q2 = qs
        seed = np.random.randint(100)
        c1 = fn(q1, seed=seed, meta=True)
        c2 = fn(q2, seed=seed, meta=True)
        if not c1 or not c2:
            return
        c1, m1 = c1
        c2, m2 = c2
        return [(q1, q2) for q1, q2, m1, m2 in zip(c1, c2, m1, m2) if m1 == m2]
    return change_both

def change_each_wrapper(fn):
    def change_one(qs, **kwargs):
        q1, q2 = qs
        seed = np.random.randint(100)
        c1 = fn(q1, seed=seed, meta=True, **kwargs)
        c2 = fn(q2, seed=seed, meta=True, **kwargs)
        if not c1 or not c2:
            return
        c1, m1 = c1
        c2, m2 = c2
        ret = []
        ret.extend([(q1_, str(q2)) for q1_, m1_ in zip(c1, m1) if m1_[0] in str(q2)])
        ret.extend([(str(q1), q2_) for q2_, m2_ in zip(c2, m2) if m2_[0] in str(q1)])
        return ret
    return change_one

In [None]:
# Change location
NER_loc_data = Perturb.perturb(parsed_qs, change_both_wrapper(Perturb.change_location), nsamples=1000).data

# Change names
NER_names_data = Perturb.perturb(parsed_qs, change_both_wrapper(Perturb.change_names), nsamples=1000).data

# Change number
NER_num_data = Perturb.perturb(parsed_qs, change_both_wrapper(Perturb.change_number), nsamples=1000).data

### INV for Chinese

In [None]:
!pip install hanlp -U

import hanlp

hanlp.pretrained.mtl.ALL
HanLP = hanlp.load(hanlp.pretrained.mtl.CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_BASE_ZH)

In [None]:
def perturb_location(data, locations, n=10):
    result = []
    skip = 0
    for text in data:
        temp_data = [[] for _ in range(n)]
        segment_result = HanLP(text, tasks='ner*')
        temp_loc = []
        for loc in segment_result['ner/msra']:
            if loc[1] == 'LOCATION':
                temp_loc.append(loc[0])
        if temp_loc == []:
            skip += 1
            continue
        for i in segment_result['tok/fine']:
            temp_data = [lst + [random.choice(locations)] if i in temp_loc else lst + [i] for lst in temp_data]
            temp_data[0][-1] = i
        temp_data = [''.join(lst) for lst in temp_data]
        result.append(temp_data)
    print("skipped: ", skip)
    return result

## Negation

In [None]:
# mid = ['normal', 'ok', 'safe', 'dangerous', 'acceptable', 'reasonable', 'proper', 'wrong', 'healthy', 'important']

# mid2 = mid + ['legal', 'awkward', 'socially acceptable']

mid = ['正常的', '可以的', '安全的', '危险的', '可接受的', '合理的', '适当的', '错误的', '健康的', '重要的']

mid2 = mid + ['合法的', '尴尬的', '社会上可接受的']

In [None]:
print(', '.join(editor.suggest('Is it {mid} to {mask} in {country}?', mid=mid2)[:100]))

In [None]:
# things = ['work', 'vote', 'travel', 'marry', 'drive', 'study', 'protest', 'campaign', 'fight', 'gamble', 'hunt', 'pray', 'smoke', 'fish', 'murder', 'invest', 'pee', 'march', 'worship', 'volunteer', 'surf', 'shoot', 'dance', 'camp', 'preach', 'spy', 'be gay', 'lie', 'divorce', 'discriminate']
things = ['工作', '投票', '旅行', '结婚', '驾驶', '学习', '抗议', '竞选', '战斗', '赌博', '狩猎', '祈祷', '吸烟', '钓鱼', '谋杀', '投资', '小便', '游行', '崇拜', '志愿', '冲浪', '射击', '跳舞', '露营', '布道', '间谍', '当同性恋', '撒谎', '离婚', '歧视']

In [None]:
tmp = editor.suggest(('How can I become a person who is {mask}', 'How can I become a person who is not {mask}?'))
tmp.remove('differently')
t = editor.template((
    'How can I become {a:x} person?',
    'How can I become a person who is not {x}?',
    ),
    x=tmp,
    remove_duplicates=True,
    nsamples=1000)

NEG_person_data = t.data
# print(NEG_person_data[:5])
# label 0

In [None]:
countries = [
    '中国', '美国', '印度', '日本', '德国', '英国', '法国', '意大利', '加拿大', '澳大利亚',
    '俄罗斯', '巴西', '韩国', '墨西哥', '印度尼西亚', '沙特阿拉伯', '伊朗', '土耳其', '阿根廷', '泰国',
    '南非', '埃及', '西班牙', '巴基斯坦', '荷兰', '瑞典', '波兰', '比利时', '马来西亚', '瑞士',
    '尼日利亚', '新加坡', '以色列', '挪威', '南韩', '奥地利', '乌克兰', '智利', '希腊', '葡萄牙',
    '哥伦比亚', '芬兰', '丹麦', '捷克共和国', '罗马尼亚', '新西兰', '爱尔兰', '匈牙利', '菲律宾', '越南'
]


In [None]:
t = editor.template(
    # ('Is it {mid} to {activity} in {country}?','Is it {mid} not to {activity} in {country}?'),
    ('在{country}，{activity}是{mid}吗？', '在{country}，不{activity}是{mid}吗？'),
                country=countries,
                activity=things,
                mid=mid2,
                remove_duplicates=True,
                nsamples=1000)

NEG_activity_data = t.data
print(NEG_activity_data[:5])
# label 0

In [None]:
# prepare vocab
professions = editor.suggest('{first_name} works as {a:mask}.')[:30]
print(', '.join(professions))
professions = editor.suggest('{first_name} works as {a:mask}.')[:30]
professions += editor.suggest('{first_name} {last_name} works as {a:mask}.')[:30]
professions = list(set(professions))

other_nouns = ['player', 'person', 'friend', 'kid', 'candidate']
nouns = list(set(professions + other_nouns))

In [None]:
nouns = [
    '摄影师', '实习生', '历史学家', '教育者', '助理', '活动家', '会计师', '女服务员', '编辑', '人',
    '行政人员', '高管', '音乐家', '翻译', '制片人', '演员', '模特', '建筑师', '作者', '艺术家',
    '企业家', '护士', '记者', '女演员', '运动员', '候选人', '分析师', '律师', '记者', '投资者',
    '秘书', '陪同', '调查员', '经济学家', '代理人', '作家', '组织者', '工程师', 'DJ', '朋友',
    '孩子'
]

In [None]:
t = editor.template((
    # 'What are things {a:noun} should worry about?',
    # 'What are things {a:noun} should not worry about?',
    '{noun}应该担心什么事情？',
    '{noun}不应该担心什么事情？'
),
                noun=nouns,
                remove_duplicates=True,
                nsamples=1000)

NEG_worry_data = t.data
print(NEG_worry_data[:5])
# label 0

In [None]:
# antonyms = [('progressive', 'conservative'),('religious', 'secular'),('positive', 'negative'),('defensive', 'offensive'),('rude',  'polite'),('optimistic', 'pessimistic'),('stupid', 'smart'),('negative', 'positive'),('unhappy', 'happy'),('active', 'passive'),('impatient', 'patient'),('powerless', 'powerful'),('visible', 'invisible'),('fat', 'thin'),('bad', 'good'),('cautious', 'brave'), ('hopeful', 'hopeless'),('insecure', 'secure'),('humble', 'proud'),('passive', 'active'),('dependent', 'independent'),('pessimistic', 'optimistic'),('irresponsible', 'responsible'),('courageous', 'fearful')]
antonyms = [
('进步的', '保守的'),
('宗教的', '世俗的'),
('积极的', '消极的'),
('防御性的', '攻击性的'),
('粗鲁的', '礼貌的'),
('乐观的', '悲观的'),
('愚蠢的', '聪明的'),
('消极的', '积极的'),
('不快乐的', '快乐的'),
('活跃的', '被动的'),
('不耐烦的', '耐心的'),
('无力的', '有力的'),
('可见的', '不可见的'),
('胖的', '瘦的'),
('坏的', '好的'),
('谨慎的', '勇敢的'),
('有希望的', '无希望的'),
('不安全的', '安全的'),
('谦虚的', '骄傲的'),
('被动的', '活跃的'),
('依赖的', '独立的'),
('悲观的', '乐观的'),
('不负责任的', '负责任的'),
('勇敢的', '恐惧的')
]

In [None]:
t = editor.template(
# [(
#     'How can I become {a:x[0]} person?',
#     'How can I become a person who is not {x[1]}?',
#     ),
#     (
#     'How can I become {a:x[1]} person?',
#     'How can I become a person who is not {x[0]}?',
#     ),
# ],
  [(
    '我怎么样才能成为一个{x[0]}人？',
    '我怎么样才能成为一个不{x[1]}人？',
    ),
    (
    '我该如何塑造自己成为一个{x[1]}人？',
    '我该怎么做才能避免成为一个{x[0]}人？',
    ),
  ],

    unroll=True,
    x=antonyms,
    remove_duplicates=True,
    nsamples=1000)

NEG_antonym_data = t.data
# label 1

In [None]:
print(NEG_antonym_data[1])

## SRL

### Who do X think - Who is the ... according to X

In [None]:
# print(', '.join(editor.suggest('Who is the best {mask} in the world?')))

In [None]:
# thing = ['chef', 'boxer', 'player', 'footballer', 'athlete', 'rapper', 'actor', 'singer', 'cook', 'magician', 'coach', 'cyclist', 'wrestler', 'drummer', 'musician', 'quarterback', 'hacker', 'baker', 'fighter', 'journalist', 'teacher', 'doctor', 'gamer', 'husband', 'DJ', 'person', 'man', 'woman', 'surgeon', 'comedian', 'trainer', 'programmer', 'guitarist', 'goalkeeper']

In [None]:
# print(', '.join(editor.suggest('Who do {mask} think is the the best {thing} in the world?', thing=thing)))

In [None]:
# subjects = ['you', 'people', 'readers', 'guys', 'fans', 'experts', 'scientists', 'Americans', 'students', 'men', 'voters', 'authors', 'conservatives', 'women', 'Canadians', 'analysts', 'critics', 'judges', 'artists', 'researchers', 'liberals', 'historians', 'Australians', 'journalists', 'Republicans', 'coaches', 'parents', 'kids', 'economists', 'reporters', 'consumers', 'veterans', 'doctors']

In [None]:
# print(', '.join(editor.suggest('Who do {subjects} think is the the {mask} {thing} in the world?', thing=thing, subjects=subjects)[:50]))

In [None]:
# best = ['best', 'greatest', 'worst', 'top', 'smartest', 'strongest', 'finests', 'happiest', 'coolest', 'richest', 'leading', 'brightest', 'premier', 'ultimate', 'dominant']

In [None]:

thing = [
    '厨师', '拳击手', '运动员', '足球运动员', '运动员', '说唱歌手', '演员', '歌手', '厨师', '魔术师', '教练', '自行车手',
    '摔跤手', '鼓手', '音乐家', '四分卫', '黑客', '面包师', '斗士', '记者', '教师', '医生', '玩家', '丈夫', 'DJ',
    '人', '男人', '女人', '外科医生', '喜剧演员', '训练师', '程序员', '吉他手', '守门员'
]

subjects = [
    '你', '人们', '读者', '大家', '粉丝', '专家', '科学家', '美国人', '学生', '男性', '选民', '作者', '保守派',
    '女性', '加拿大人', '分析师', '评论家', '法官', '艺术家', '研究人员', '自由派', '历史学家', '澳大利亚人', '记者',
    '共和党人', '教练', '父母', '孩子', '经济学家', '记者', '消费者', '退伍军人', '医生'
]

best = [
    '最好的', '最伟大的', '最差的', '顶尖的', '最聪明的', '最强的', '最好的', '最幸福的', '最酷的', '最富有的',
    '领先的', '最聪明的', '首屈一指的', '最终的', '主导的'
]

In [None]:
t = editor.template((
    # 'Who do {subjects} think is the {best} {thing} in the world?',
    # 'Who is the {best} {thing} in the world according to {subjects}?'
    '{subjects}认为世界上谁是{best}{thing}？',
    '根据{subjects}的看法，世界上谁是{best}{thing}？'
),
    subjects=subjects,
    best=best,
    thing=thing,
    remove_duplicates=True,
    nsamples=1000)

SRL_best_data = t.data
print(SRL_best_data[:5])
# label = 1

### Order doesn't matter for comparison

In [None]:
print(', '.join([str(x) for x in editor.suggest('Are {mask} smaller than {a}?', a=['bananas', 'dogs', 'cars', 'cats', 'elephants'])][:100]))
things = editor.suggest('Are {mask} smaller than {a}?',a=['bananas', 'dogs', 'cars', 'cats', 'elephants'] )[:100]
print(', '.join([str(x) for x in editor.suggest('Are {a} {mask} than {a2}?', a=things)][:100]))
comp = ['better', 'worse', 'cheaper', 'bigger', 'louder', 'longer', 'larger', 'smaller', 'warmer', 'colder', 'thicker', 'lighter', 'heavier']

In [None]:
things = editor.suggest('Are {mask} smaller than {a}?',a=['bananas', 'dogs', 'cars', 'cats', 'elephants'] )[:100]
print(things)

In [None]:
things = [
    '人类', '猫', '你', '狗', '人们', '老鼠', '猪', '鸟', '羊', '牛', '鼠', '鸡', '鱼', '熊', '我们', '大象',
    '兔子', '狮子', '猴子', '他们', '蛇', '蜜蜂', '蜘蛛', '蝙蝠', '小狗', '海豚', '婴儿', '小猫', '孩子',
    '青蛙', '蚂蚁', '蝴蝶', '昆虫', '乌龟', '树', '鸭', '鲸鱼', '机器人', '动物', '虫子', '小孩', '螃蟹',
    '胡萝卜', '龙', '蚊子', '汽车', '鲨鱼', '恐龙', '马', '老虎', '狼', '灵长类动物', '牲畜', '男人',
    '山羊', '黑猩猩', '鹿', '类人猿', '球', '爬行动物', '啮齿动物', '蠕虫', '哺乳动物', '苍蝇', '苹果',
    '小马', '蘑菇', '植物', '海豹', '土豆', '蜱', '女人', '狮子', '吸血鬼', '电脑', '石头', '宝可梦', '双胞胎',
    '事物', '蛋', '宠物', '矮人', '公牛', '印第安人', '男孩', '精灵', '小牛', '玩具', '甲虫', '香蕉',
    '花', '巨魔', '房屋', '女孩', '汉堡', '豆类', '钻石', '男士', '硬币', '恶魔'
]


comp = [
    '更好的', '更差的', '更便宜的', '更大的', '更响的', '更长的', '更大的', '更小的',
    '更暖和的', '更冷的', '更厚的', '更轻的', '更重的'
]

In [None]:
t = editor.template([
    # (
    # 'Are {t1} {comp} than {t2}?',
    # 'What is {comp}, {t2} or {t1}?'
    # ),
    # (
    # 'Are {t1} {comp} than {t2}?',
    # 'Are {t2} {comp} than {t1}?',
    # ),
    # (
    # 'Are {t1} {comp} than {t2}?',
    # 'What is {comp}, {t1} or {t2}?',
    # )
    (
        '{t1}比{t2}{comp}吗？',
        '哪个{comp}，{t2}还是{t1}？'
    ),
    (
        '{t1}比{t2}{comp}吗？',
        '{t2}比{t1}{comp}吗？',
    ),
    (
        '{t1}比{t2}{comp}吗？',
        '哪个{comp}，{t1}还是{t2}？',
    )
]
    ,
    t = things,
    comp = comp,
    remove_duplicates=True,
    nsamples=1000)

SRL_comp_data = t.data
print(SRL_comp_data[:5])
# label = 1

### Order doesn't matter for symmetric relations

In [None]:
print(', '.join(editor.suggest('Is {first_name1} {mask} to {first_name2}?', remove_duplicates=True)[:100]))
print()
print(', '.join(editor.suggest('Is {first_name1} {mask} {first_name2}?', remove_duplicates=True)[:100]))

In [None]:
# symmetric = ['dating', 'married to', 'close to', 'engaged to', 'connected to', 'married to', 'friends with', 'related to', 'an acquaintance of']
symmetric = [
    '在约会', '亲近', '订婚了', '有联络', '已婚了', '是朋友', '有关系', '是熟人'
]

In [None]:
names = [
    '李伟', '王伟', '王芳', '李娜', '张伟', '刘伟', '张敏', '李静', '王静', '王丽',
    '李强', '张静', '李敏', '王敏', '刘洋', '王勇', '王艳', '李军', '刘杰', '李娟',
    '张艳', '李明', '王丹', '李丽', '张磊', '王军', '王磊', '刘军', '李勇', '张勇',
    '王娜', '李杰', '张杰', '李艳', '张丽', '王强', '张华', '李华', '刘娟', '张涛',
    '李涛', '王涛', '刘丽', '王华', '李明', '杨军', '张军', '李霞', '张鹏', '李军',
    '刘敏', '李刚', '张健', '王健', '李峰', '张峰', '李哲', '王斌', '李斌', '李鹏',
    '杨杰', '张斌', '李俊', '张俊', '王峰', '刘斌', '张涛', '王超', '李青', '刘芳',
    '李志', '王志', '刘志', '刘强', '王霞', '张超', '李超', '王鹏', '王华', '张玲',
    '李玲', '刘玲', '杨兵', '张兵', '李兵', '张敏', '王敏', '李敏', '刘敏', '李青',
    '王青', '张青', '刘青', '李磊', '王磊', '刘磊', '张磊', '张勇', '李勇', '杨勇'
]


In [None]:
t = editor.template((
    # 'Is {first_name1} {s} {first_name2}?',
    # 'Is {first_name2} {s} {first_name1}?',
    '{name1}和{name2}{s}吗?',
    '{name2}和{name1}{s}吗?'
),
    name = names,
    s = symmetric,
    remove_duplicates=True,
    nsamples=1000)

SRL_symrel_data = t.data
print(SRL_symrel_data[:5])
# label = 1

### Order matters for asymmetric relations

In [None]:
# asymmetric = ['hurting', 'lying to', 'loyal to', 'faithful to', 'proposing to', 'indebted to', 'abusive to', 'using', 'expecting', 'beating', 'punching', 'raising', 'poisoning', 'protecting', 'kidnapping']

asymmetric = ['伤害',
              # '对...说谎',
              '忠诚于',
              '忠实于',
              # '求婚',
              # '欠债',
              '滥用',
              '利用',
              # '期望',
              '打',
              '拳打',
              '抚养',
              '投毒',
              '保护',
              '绑架',
              # additional
              '质疑',
              '取笑',
              '赞同',
              '鼓励',
              '依赖',
              '欺骗',
              ]

In [None]:
t = editor.template((
    '{first_name1}{s}{first_name2}了吗?',
    '{first_name2}{s}{first_name1}了吗?',
),
    first_name=names,
    s = asymmetric,
    remove_duplicates=True,
    nsamples=1000)

SRL_asymrel_data = t.data
# TODO label = 0? (but =1 in the checklist code)


In [None]:
print(SRL_asymrel_data[:5])

### More traditional SRL

In [None]:
print(', '.join(editor.suggest('Did John buy the {mask}?', remove_duplicates=True)[:100]))
obj = ['farm', 'house', 'property', 'company', 'land', 'ticket', 'newspaper', 'book', 'island', 'estate', 'ranch', 'boat', 'horse', 'paper', 'business', 'gun', 'game', 'factory', 'castle', 'painting', 'rifle', 'car', 'school', 'building']

In [None]:
print(', '.join(editor.suggest('Did John {mask} the {obj}?', obj=obj, remove_duplicates=True)[:100]))

In [None]:
# import pattern
# import pattern.en
# verbs = ['buy', 'purchase', 'sell', 'leave', 'own', 'take', 'keep', 'want', 'lose', 'destroy', 'inherit', 'find', 'use', 'need', 'receive', 'return', 'like', 'enjoy', 'abandon', 'manage', 'remember', 'miss', 'move', 'seize', 'steal']
# a = pattern.en.tenses('stolen')[0]
# verbs = [(v, pattern.en.conjugate(v, *a)) for v in verbs]
# verbs[3] = ('leave', 'left')
# verbs

In [None]:
obj = [
    '农场', '房子', '财产', '公司', '土地', '票', '报纸', '书', '岛屿', '庄园', '牧场', '船', '马', '纸', '生意',
    '枪', '游戏', '工厂', '城堡', '画', '步枪', '车', '学校', '建筑'
]
verbs = [
    '买', '购买', '卖', '离开', '拥有', '拿', '保持', '想要', '丢失', '摧毁', '继承', '找到', '使用', '需要', '接收',
    '归还', '喜欢', '享受', '放弃', '管理', '记得', '想念', '移动', '抓住', '偷'
]

traditional SRL: active / passive swap

In [None]:
t = editor.template((
    # 'Did {first_name} {verb[0]} the {obj}?',
    # 'Was the {obj} {verb[1]} by {first_name}?'
    '{last_name}{first_name}{verb}{obj}了吗？',
    '{obj}被{last_name}{first_name}{verb}了吗？'
),
    first_name=first_names,
    last_name=last_names,
    verb=verbs,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_apswap_data = t.data
print(SRL_apswap_data[:5])
# label = 1

traditional SRL: wrong active / passive swap

In [None]:
t = editor.template((
    # 'Did {first_name} {verb[0]} the {obj}?',
    # 'Was {first_name} {verb[1]} by the {obj}?'
    '{last_name}{first_name}{verb}{obj}了吗？',
    '{last_name}{first_name}被{obj}{verb}了吗？'
),
    first_name=first_names,
    last_name=last_names,
    verb=verbs,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_w_apswap_data = t.data
# label = 0

traditional SRL: active / passive swap with people

In [None]:
pverb = [
    '爱', '恨', '喜欢', '记得', '认出', '信任', '应得', '理解', '责备', '不喜欢', '更喜欢', '跟随', '注意', '伤害', '打扰',
    '支持', '相信', '接受', '攻击'
]

In [None]:
# print(', '.join(editor.suggest('Does {first_name} {mask} {first_name2}?', remove_duplicates=True)[:100]))
# pverb = ['love', 'hate', 'like', 'remember', 'recognize', 'trust', 'deserve', 'understand', 'blame', 'dislike', 'prefer', 'follow', 'notice', 'hurt', 'bother', 'support', 'believe', 'accept', 'attack']
# a = pattern.en.tenses('stolen')[0]
# pverb = [(v, pattern.en.conjugate(v, *a)) for v in pverb]

t = editor.template((
    # 'Does {first_name} {verb[0]} {first_name2}?',
    # 'Is {first_name2} {verb[1]} by {first_name}?',
    '{last_name1}{first_name1}{verb}{last_name2}{first_name2}了吗？',
    '{last_name2}{first_name2}被{last_name1}{first_name1}{verb}了吗？'
),
    first_name=first_names,
    last_name=last_names,
    verb=pverb,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_apswap_ppl_data = t.data
# label = 1

traditional SRL: wrong active / passive swap with people

In [None]:
# pverb = ['love', 'hate', 'like', 'remember', 'recognize', 'trust', 'deserve', 'understand', 'blame', 'dislike', 'prefer', 'follow', 'notice', 'hurt', 'bother', 'support', 'believe', 'accept', 'attack']
# a = pattern.en.tenses('stolen')[0]
# pverb = [(v, pattern.en.conjugate(v, *a)) for v in pverb]
t = editor.template((
    # 'Does {first_name} {verb[0]} {first_name2}?',
    # 'Is {first_name} {verb[1]} by {first_name2}?',
    '{last_name1}{first_name1}{verb}{last_name2}{first_name2}了吗？',
    '{last_name1}{first_name1}被{last_name2}{first_name2}{verb}了吗？'
),
    first_name=first_names,
    last_name=last_names,
    verb=pverb,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_w_apswap_ppl_data = t.data
# label = 1

# Inference


In [None]:
def response_from_generate(model, messages):
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=1)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    output_mapping = {'是' : 1, '否' : 0}
    # output_mapping = {'Yes': 1, 'No': 0}

    return output_mapping.get(response, None)


def response_from_forward(model, messages):
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    output = model.forward(model_inputs.input_ids)

    # idx 32 = A (Yes) (label 1), idx 33 = B (No) (label 0)
    response = 1 - torch.argmax(output.logits[0, -1, 32:34]).item()
    # idx 59603 = A (Yes) (label 1), idx 59616 = B (No) (label 0)
    # response = 1 - torch.argmax(output.logits[0, -1, [59603, 59616]]).item()

    return response


def inference(model, data, inference_mode='generate'):
    """
    Perform inference on model using created data samples. The first sentence
    in each list of strings is the gold label. inference_mode='generate' means
    .generate() is used to create a written response; inference_mode='forward'
    means .forward() uses the output logits to determine the response.
    """

    # system_message = "Are the two questions are paraphrase of each other? Please only respond with A (Yes), B (No)."
    system_message = "以下两个问题的意思相同吗？请只用一个字 '是' 或 '否' 回答。"
    # system_message = "Do the following two questions have the same meaning? Respond with 'Yes' or 'No'."

    gold_labels, pred_labels = [], []

    for pairs in tqdm(data):
        sentence_labels = []
        for i, pair in enumerate(pairs):

            messages = [
                # {"role": "system", "content": system_message},
                # {"role": "user", "content": pair[0] + '\n' + pair[1]}
                {"role": "user", "content": system_message + ' ' + '问题一：' + pair[0] + ' 问题二： ' + pair[1]}
            ]

            if inference_mode == 'generate':
                response = response_from_generate(model, messages)
            elif inference_mode == 'forward':
                response = response_from_forward(model, messages)
            else:
                assert False, 'unknown inference mode'

            # if i == 0:
            #     gold_labels.append(response)
            # else:
            #     sentence_labels.append(response)

            if i == 0:
                if response is None:
                    break
                gold_labels.append(response)
            else:
                if response is None:
                    continue
                sentence_labels.append(response)
        if len(sentence_labels) == 0:
            continue
        pred_labels.append(sentence_labels)

    return gold_labels, pred_labels

In [None]:
def evaluate(gold_labels, pred_labels):

    y_true, y_pred = [], []

    for i, sentence_labels in enumerate(pred_labels):
        for prompt_label in sentence_labels:
            y_pred.append(prompt_label)
            y_true.append(gold_labels[i])

    return accuracy_score(y_true, y_pred)


In [None]:
def inference_MFT(model, data, inference_mode='generate', label=None, fewshot=False, cot=False):
    """
    Perform inference on model using created data samples. The first sentence
    in each list of strings is the gold label. inference_mode='generate' means
    .generate() is used to create a written response; inference_mode='forward'
    means .forward() uses the output logits to determine the response.
    """

    # system_message = "Are the two questions are paraphrase of each other? Please only respond with A (Yes), B (No)."
    system_message = "以下两个问题的意思相同吗？请只用一个字 '是' 或 '否' 回答。"
    # system_message = "Do the following two questions have the same meaning? Respond with 'Yes' or 'No'."
    # system_message = "Consider the following pair of questions. Do they convey the same meaning? Please respond with 'A' for Yes or 'B' for No."

    # gold_labels = [label] * len(data)
    pred_labels = []

    # Think through the main topic and specific details each question addresses, and compare them to decide if they are asking about the same thing.
    if cot:
        # system_message += '提示：思考各个句子的主题以及具体的细节，例如人物、动作、地点等。'
        system_message = "以下两个问题的意思相同吗？（提示：思考各个句子的主题以及具体的细节，例如人物、动作、地点等）请只用一个字 '是' 或 '否' 回答。"

    for pair in tqdm(data):

        messages = [
            # {"role": "system", "content": system_message},
            # {"role": "user", "content": 'Question 1: ' + pair[0] + ' Question 2: ' + pair[1]}
            {"role": "user", "content": system_message + ' ' + '问题一：' + pair[0] + ' 问题二： ' + pair[1]}
        ]


        if fewshot:
            examples = [
                {"role": "user", "content": system_message + ' ' + '问题一：' + '宋宁购买土地了吗？' + ' 问题二： ' + '土地被宋宁购买了吗？'},
                {"role": "system", "content": "是"},
                {"role": "user", "content": system_message + ' ' + '问题一：' + '作为学生，我应该如何安排时间？' + ' 问题二： ' + '学生怎样安排时间比较好？'},
                {"role": "system", "content": "是"},
                {"role": "user", "content": system_message + ' ' + '问题一：' + '我该如何塑造自己成为一个骄傲的人？' + ' 问题二： ' + '我该怎么做才能成为一个谦虚的人？'},
                {"role": "system", "content": "否"},
                {"role": "user", "content": system_message + ' ' + '问题一：' + '投资者应该担心什么事情？' + ' 问题二： ' + '投资者不应该担心什么事情？'},
                {"role": "system", "content": "否"},
            ]
            messages = examples + messages

        if inference_mode == 'generate':
            response = response_from_generate(model, messages)
        elif inference_mode == 'forward':
            response = response_from_forward(model, messages)
        else:
            assert False, 'unknown inference mode'

        if response is None:
            continue

        pred_labels.append(response)

    gold_labels = [label] * len(pred_labels)

    return gold_labels, pred_labels

def evaluate_MFT(gold_labels, pred_labels):
    return accuracy_score(gold_labels, pred_labels)

# Run the test


### Robustness

In [None]:
# gold_labels, pred_labels = inference(model, ROB_typo_data, inference_mode='forward')
# print(f'Accuracy: {evaluate(gold_labels, pred_labels):.2f}')

In [None]:
# gold_labels, pred_labels = inference(model, ROB_contra_data, inference_mode='forward')
# print(f'Accuracy: {evaluate(gold_labels, pred_labels):.2f}')

In [None]:
# gold_labels, pred_labels = inference(model, ROB_paraphrase_prod_data, inference_mode='forward')
# print(f'Accuracy: {evaluate(gold_labels, pred_labels):.2f}')

In [None]:
# gold_labels, pred_labels = inference(model, ROB_paraphrase_each_data, inference_mode='forward')
# print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, ROB_typo_data_zh, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

### NER

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_last_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_last_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, NER_loc_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, NER_names_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, NER_num_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

### Negation

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_person_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_activity_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_worry_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_antonym_data[:1000], inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

### SRL

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_best_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
for i in range(3):
    gold_labels, pred_labels = inference_MFT(model, list(pairs[i] for pairs in SRL_comp_data), inference_mode='generate', label=1)
    print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_symrel_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_asymrel_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_ppl_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_ppl_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

### Test Few-Shot

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_worry_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_antonym_data[:1000], inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_ppl_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_asymrel_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_ppl_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_best_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_last_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_last_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_person_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_activity_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
for i in range(3):
    gold_labels, pred_labels = inference_MFT(model, list(pairs[i] for pairs in SRL_comp_data), inference_mode='generate', label=1, fewshot=True)
    print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_symrel_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

### Test CoT

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_ppl_data, inference_mode='generate', label=0, cot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_ppl_data, inference_mode='generate', label=0, cot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')