# Environmental Setup

In [None]:
!pip install checklist
!jupyter nbextension install --py --user checklist.viewer
!jupyter nbextension enable --py --user checklist.viewer

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM

import checklist
from checklist.editor import Editor
from checklist.perturb import Perturb
from checklist.test_types import INV
import csv
import spacy
import numpy as np
import itertools

from tqdm import tqdm
from sklearn.metrics import accuracy_score

# Model Setup

In [None]:
# Need to login to Hugging Face to download the Gemma model
!pip install huggingface_hub
from huggingface_hub import notebook_login
notebook_login()

In [None]:
def load_model_and_tokenizer(name="qwen"):

    path_dict = {
        "qwen" : "Qwen/Qwen1.5-7B-Chat",
        "aya" : "CohereForAI/aya-101",
        "yi" : "01-ai/Yi-6B-Chat",
        "gemma" : "google/gemma-2b-it",
    }

    assert name in path_dict, "unknown model"

    tokenizer = AutoTokenizer.from_pretrained(path_dict[name])
    if name == 'aya':
        model = AutoModelForSeq2SeqLM.from_pretrained(path_dict[name], torch_dtype="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(path_dict[name], torch_dtype="auto")

    return model, tokenizer


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model, tokenizer = load_model_and_tokenizer(name="gemma")

model = model.to(device)

# Create Dataset

In [None]:
editor = checklist.editor.Editor()
editor.tg

In [None]:
nlp = spacy.load('en_core_web_sm')

## Load Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!ls ./drive/MyDrive

In [None]:
qs = []
labels = []
all_questions = set()
for x in open('./drive/MyDrive/quora_duplicate_questions.tsv').readlines()[1:]:
    try:
        q1, q2, label = x.strip().split('\t')[3:]
    except:
        print(x)
        continue
    all_questions.add(q1)
    all_questions.add(q2)
    qs.append((q1, q2))
    labels.append(label)
labels = np.array(labels).astype(int)

In [None]:
print(qs[:5])
print(labels[:5])

In [None]:
all_questions = list(all_questions)
parsed_questions = list(nlp.pipe(all_questions))
spacy_map = dict([(x, y) for x, y in zip(all_questions, parsed_questions)])

In [None]:
parsed_qs = [(spacy_map[q[0]], spacy_map[q[1]]) for q in qs]

## Robustness

In [None]:
def wrap_apply_to_each(fn, both=False, *args, **kwargs):
    def new_fn(qs, *args, **kwargs):
        q1, q2 = qs
        ret = []
        fnq1 = fn(q1, *args, **kwargs)
        fnq2 = fn(q2, *args, **kwargs)
        if type(fnq1) != list:
            fnq1 = [fnq1]
        if type(fnq2) != list:
            fnq2 = [fnq2]
        ret.extend([(x, str(q2)) for x in fnq1])
        ret.extend([(str(q1), x) for x in fnq2])
        if both:
            ret.extend([(x, x2) for x, x2 in itertools.product(fnq1, fnq2)])
        return [x for x in ret if x[0] and x[1]]
    return new_fn

def wrap_apply_to_both(fn, *args, **kwargs):
    def new_fn(qs, *args, **kwargs):
        q1, q2 = qs
        ret = []
        fnq1 = fn(q1, *args, **kwargs)
        fnq2 = fn(q2, *args, **kwargs)
        if type(fnq1) != list:
            fnq1 = [fnq1]
        if type(fnq2) != list:
            fnq2 = [fnq2]
        ret.extend([(x, x2) for x, x2 in itertools.product(fnq1, fnq2)])
        return [x for x in ret if x[0] and x[1]]
    return new_fn

typos & contractions

In [None]:
ROB_typo_data = Perturb.perturb(qs, wrap_apply_to_each(Perturb.add_typos), nsamples=1000).data

ROB_contra_data = Perturb.perturb(qs, wrap_apply_to_each(Perturb.contractions, both=True), nsamples=1000).data



paraphrase

In [None]:
import re

def me_to_you(text):
    t = re.sub(r'\bI\b', 'you', text)
    t = re.sub(r'\bmy\b', 'your', t)
    return re.sub(r'\bmine\b', 'yours', t)

def paraphrases(text):
    ts = ['How do I ', 'How can I ', 'What is a good way to ', 'How should I ']
    templates1 = ['How do I {x}?', 'How can I {x}?', 'What is a good way to {x}?', 'If I want to {x}, what should I do?',
                'In order to {x}, what should I do?']
    ts2 = ['Can you ', 'Can I ']#, 'Do I']
    ts3 = ['Do I ']
    templates2 = ['Can you {x}?', 'Can I {x}?', 'Do you think I can {x}?', 'Do you think you can {x}?',]
    templates3 = ['Do I {x}?', 'Do you think I {x}?']
    ret = []
    for i, (tsz, templates) in enumerate(zip([ts, ts2, ts3], [templates1, templates2, templates3])):
        for t in tsz:
            if text.startswith(t):
                x = text[len(t):].strip('?')
                ts = editor.template(templates, x=x).data[0]
                if i <= 1:
                    ts = ts + [me_to_you(x) for x in ts]
                ret += ts
    return ret

def paraphrases_product(text):
    pr = paraphrases(text)
    return list(itertools.product(pr, pr))

def paraphrase_each(pair):
    p1 = paraphrases(pair[0])
    p2 = paraphrases(pair[1])
    return list(itertools.product(p1, p2))

In [None]:
ROB_paraphrase_prod_data = Perturb.perturb(list(all_questions), paraphrases_product, nsamples=100, keep_original=False).data

ROB_paraphrase_each_data = Perturb.perturb(qs, paraphrase_each, nsamples=100, keep_original=True).data

## NER

### names

In [None]:
adjs_without_overlap = ['dead', 'gay', 'Jewish', 'Christian', 'American', 'mad', 'immortal', 'evil', 'famous', 'racist', 'Muslim', 'white', 'black', 'English', 'autistic', 'Australian', 'trustworthy', 'an atheist', 'an anarchist', 'an inventor', 'Indian', 'Armenian', 'an astronaut', 'an immigrant']

person1 and person2 are different by first and last name

In [None]:
t = editor.template((
    'Is {first_name1} {last_name1} {adj}?',
    'Is {first_name2} {last_name2} {adj}?',
    ),
    adj=adjs_without_overlap,
    remove_duplicates=True,
    nsamples=1000)

NER_first_last_data = t.data
# label 0

person1 and person2 are different by first name only

In [None]:
t = editor.template((
    'Is {first_name} {last_name} {adj}?',
    'Is {first_name2} {last_name} {adj}?',
    ),
    adj=adjs_without_overlap,
    remove_duplicates=True,
    nsamples=1000)

NER_first_data = t.data
# label = 0

person1 and person2 are different by last name only

In [None]:
t = editor.template((
    'Is {first_name} {last_name} {adj}?',
    'Is {first_name} {last_name2} {adj}?',
    ),
    adj=adjs_without_overlap,
    remove_duplicates=True,
    nsamples=1000)

NER_last_data = t.data
# label = 0

Locations, Names, Numbers

In [None]:
def change_both_wrapper(fn):
    def change_both(qs):
        q1, q2 = qs
        seed = np.random.randint(100)
        c1 = fn(q1, seed=seed, meta=True)
        c2 = fn(q2, seed=seed, meta=True)
        if not c1 or not c2:
            return
        c1, m1 = c1
        c2, m2 = c2
        return [(q1, q2) for q1, q2, m1, m2 in zip(c1, c2, m1, m2) if m1 == m2]
    return change_both

def change_each_wrapper(fn):
    def change_one(qs, **kwargs):
        q1, q2 = qs
        seed = np.random.randint(100)
        c1 = fn(q1, seed=seed, meta=True, **kwargs)
        c2 = fn(q2, seed=seed, meta=True, **kwargs)
        if not c1 or not c2:
            return
        c1, m1 = c1
        c2, m2 = c2
        ret = []
        ret.extend([(q1_, str(q2)) for q1_, m1_ in zip(c1, m1) if m1_[0] in str(q2)])
        ret.extend([(str(q1), q2_) for q2_, m2_ in zip(c2, m2) if m2_[0] in str(q1)])
        return ret
    return change_one

In [None]:
# Change location
NER_loc_data = Perturb.perturb(parsed_qs, change_both_wrapper(Perturb.change_location), nsamples=1000).data

# Change names
NER_names_data = Perturb.perturb(parsed_qs, change_both_wrapper(Perturb.change_names), nsamples=1000).data

# Change number
NER_num_data = Perturb.perturb(parsed_qs, change_both_wrapper(Perturb.change_number), nsamples=1000).data

## Negation

In [None]:
mid = ['normal', 'ok', 'safe', 'dangerous', 'acceptable', 'reasonable', 'proper', 'wrong', 'healthy', 'important']

mid2 = mid + ['legal', 'awkward', 'socially acceptable']

In [None]:
print(', '.join(editor.suggest('Is it {mid} to {mask} in {country}?', mid=mid2)[:100]))

In [None]:
things = ['work', 'vote', 'travel', 'marry', 'drive', 'study', 'protest', 'campaign', 'fight', 'gamble', 'hunt', 'pray', 'smoke', 'fish', 'murder', 'invest', 'pee', 'march', 'worship', 'volunteer', 'surf', 'shoot', 'dance', 'camp', 'preach', 'spy', 'be gay', 'lie', 'divorce', 'discriminate']

In [None]:
tmp = editor.suggest(('How can I become a person who is {mask}', 'How can I become a person who is not {mask}?'))
tmp.remove('differently')
t = editor.template((
    'How can I become {a:x} person?',
    'How can I become a person who is not {x}?',
    ),
    x=tmp,
    remove_duplicates=True,
    nsamples=1000)

NEG_person_data = t.data
# label 0

In [None]:
t = editor.template(('Is it {mid} to {activity} in {country}?','Is it {mid} not to {activity} in {country}?'),
                activity=things,
                mid=mid2,
                remove_duplicates=True,
                nsamples=1000)

NEG_activity_data = t.data
# label 0

In [None]:
# prepare vocab
professions = editor.suggest('{first_name} works as {a:mask}.')[:30]
print(', '.join(professions))
professions = editor.suggest('{first_name} works as {a:mask}.')[:30]
professions += editor.suggest('{first_name} {last_name} works as {a:mask}.')[:30]
professions = list(set(professions))

other_nouns = ['player', 'person', 'friend', 'kid', 'candidate']
nouns = list(set(professions + other_nouns))

In [None]:
t = editor.template((
    'What are things {a:noun} should worry about?',
    'What are things {a:noun} should not worry about?',
),
                noun=nouns,
                remove_duplicates=True,
                nsamples=1000)

NEG_worry_data = t.data
# label 0

In [None]:
antonyms = [('progressive', 'conservative'),('religious', 'secular'),('positive', 'negative'),('defensive', 'offensive'),('rude',  'polite'),('optimistic', 'pessimistic'),('stupid', 'smart'),('negative', 'positive'),('unhappy', 'happy'),('active', 'passive'),('impatient', 'patient'),('powerless', 'powerful'),('visible', 'invisible'),('fat', 'thin'),('bad', 'good'),('cautious', 'brave'), ('hopeful', 'hopeless'),('insecure', 'secure'),('humble', 'proud'),('passive', 'active'),('dependent', 'independent'),('pessimistic', 'optimistic'),('irresponsible', 'responsible'),('courageous', 'fearful')]

In [None]:
t = editor.template([(
    'How can I become {a:x[0]} person?',
    'How can I become a person who is not {x[1]}?',
    ),
    (
    'How can I become {a:x[1]} person?',
    'How can I become a person who is not {x[0]}?',
    ),
],
    unroll=True,
    x=antonyms,
    remove_duplicates=True,
    nsamples=1000)

NEG_antonym_data = t.data
# label 1

## SRL

### Who do X think - Who is the ... according to X

In [None]:
print(', '.join(editor.suggest('Who is the best {mask} in the world?')))

In [None]:
thing = ['chef', 'boxer', 'player', 'footballer', 'athlete', 'rapper', 'actor', 'singer', 'cook', 'magician', 'coach', 'cyclist', 'wrestler', 'drummer', 'musician', 'quarterback', 'hacker', 'baker', 'fighter', 'journalist', 'teacher', 'doctor', 'gamer', 'husband', 'DJ', 'person', 'man', 'woman', 'surgeon', 'comedian', 'trainer', 'programmer', 'guitarist', 'goalkeeper']

In [None]:
print(', '.join(editor.suggest('Who do {mask} think is the the best {thing} in the world?', thing=thing)))

In [None]:
subjects = ['you', 'people', 'readers', 'guys', 'fans', 'experts', 'scientists', 'Americans', 'students', 'men', 'voters', 'authors', 'conservatives', 'women', 'Canadians', 'analysts', 'critics', 'judges', 'artists', 'researchers', 'liberals', 'historians', 'Australians', 'journalists', 'Republicans', 'coaches', 'parents', 'kids', 'economists', 'reporters', 'consumers', 'veterans', 'doctors']

In [None]:
print(', '.join(editor.suggest('Who do {subjects} think is the the {mask} {thing} in the world?', thing=thing, subjects=subjects)[:50]))

In [None]:
best = ['best', 'greatest', 'worst', 'top', 'smartest', 'strongest', 'finests', 'happiest', 'coolest', 'richest', 'leading', 'brightest', 'premier', 'ultimate', 'dominant']

In [None]:
t = editor.template((
    'Who do {subjects} think is the {best} {thing} in the world?',
    'Who is the {best} {thing} in the world according to {subjects}?'
),
    subjects=subjects,
    best=best,
    thing=thing,
    remove_duplicates=True,
    nsamples=1000)

SRL_best_data = t.data
# label = 1

### Order doesn't matter for comparison

In [None]:
print(', '.join([str(x) for x in editor.suggest('Are {mask} smaller than {a}?', a=['bananas', 'dogs', 'cars', 'cats', 'elephants'])][:100]))
things = editor.suggest('Are {mask} smaller than {a}?',a=['bananas', 'dogs', 'cars', 'cats', 'elephants'] )[:100]
print(', '.join([str(x) for x in editor.suggest('Are {a} {mask} than {a2}?', a=things)][:100]))
comp = ['better', 'worse', 'cheaper', 'bigger', 'louder', 'longer', 'larger', 'smaller', 'warmer', 'colder', 'thicker', 'lighter', 'heavier']

In [None]:
t = editor.template([
    (
    'Are {t1} {comp} than {t2}?',
    'What is {comp}, {t2} or {t1}?'
    ),
    (
    'Are {t1} {comp} than {t2}?',
    'Are {t2} {comp} than {t1}?',
    ),
    (
    'Are {t1} {comp} than {t2}?',
    'What is {comp}, {t1} or {t2}?',
    )
]
    ,
    t = things,
    comp = comp,
    remove_duplicates=True,
    nsamples=1000)

SRL_comp_data = t.data
# label = 1

### Order doesn't matter for symmetric relations

In [None]:
print(', '.join(editor.suggest('Is {first_name1} {mask} to {first_name2}?', remove_duplicates=True)[:100]))
print()
print(', '.join(editor.suggest('Is {first_name1} {mask} {first_name2}?', remove_duplicates=True)[:100]))

In [None]:
symmetric = ['dating', 'married to', 'close to', 'engaged to', 'connected to', 'married to', 'friends with', 'related to', 'an acquaintance of']

In [None]:
t = editor.template((
    'Is {first_name1} {s} {first_name2}?',
    'Is {first_name2} {s} {first_name1}?',
),
    s = symmetric,
    remove_duplicates=True,
    nsamples=1000)

SRL_symrel_data = t.data
# label = 1

### Order matters for asymmetric relations

In [None]:
asymmetric = ['hurting', 'lying to', 'loyal to', 'faithful to', 'proposing to', 'indebted to', 'abusive to', 'using', 'expecting', 'beating', 'punching', 'raising', 'poisoning', 'protecting', 'kidnapping']

In [None]:
t = editor.template((
    'Is {first_name1} {s} {first_name2}?',
    'Is {first_name2} {s} {first_name1}?',
),
    s = asymmetric,
    remove_duplicates=True,
    nsamples=1000)

SRL_asymrel_data = t.data
# TODO label = 0? (but =1 in the checklist code)


### More traditional SRL

In [None]:
print(', '.join(editor.suggest('Did John buy the {mask}?', remove_duplicates=True)[:100]))
obj = ['farm', 'house', 'property', 'company', 'land', 'ticket', 'newspaper', 'book', 'island', 'estate', 'ranch', 'boat', 'horse', 'paper', 'business', 'gun', 'game', 'factory', 'castle', 'painting', 'rifle', 'car', 'school', 'building']

In [None]:
print(', '.join(editor.suggest('Did John {mask} the {obj}?', obj=obj, remove_duplicates=True)[:100]))

In [None]:
import pattern
import pattern.en
verbs = ['buy', 'purchase', 'sell', 'leave', 'own', 'take', 'keep', 'want', 'lose', 'destroy', 'inherit', 'find', 'use', 'need', 'receive', 'return', 'like', 'enjoy', 'abandon', 'manage', 'remember', 'miss', 'move', 'seize', 'steal']
a = pattern.en.tenses('stolen')[0]
verbs = [(v, pattern.en.conjugate(v, *a)) for v in verbs]
verbs[3] = ('leave', 'left')
verbs

traditional SRL: active / passive swap

In [None]:
t = editor.template((
    'Did {first_name} {verb[0]} the {obj}?',
    'Was the {obj} {verb[1]} by {first_name}?'
),
    verb=verbs,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_apswap_data = t.data
# label = 1

traditional SRL: wrong active / passive swap

In [None]:
t = editor.template((
    'Did {first_name} {verb[0]} the {obj}?',
    'Was {first_name} {verb[1]} by the {obj}?'
),
    verb=verbs,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_w_apswap_data = t.data
# label = 0

traditional SRL: active / passive swap with people

In [None]:
print(', '.join(editor.suggest('Does {first_name} {mask} {first_name2}?', remove_duplicates=True)[:100]))
pverb = ['love', 'hate', 'like', 'remember', 'recognize', 'trust', 'deserve', 'understand', 'blame', 'dislike', 'prefer', 'follow', 'notice', 'hurt', 'bother', 'support', 'believe', 'accept', 'attack']
a = pattern.en.tenses('stolen')[0]
pverb = [(v, pattern.en.conjugate(v, *a)) for v in pverb]
t = editor.template((
    'Does {first_name} {verb[0]} {first_name2}?',
    'Is {first_name2} {verb[1]} by {first_name}?',
),
    verb=pverb,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_apswap_ppl_data = t.data
# label = 1

traditional SRL: wrong active / passive swap with people

In [None]:
pverb = ['love', 'hate', 'like', 'remember', 'recognize', 'trust', 'deserve', 'understand', 'blame', 'dislike', 'prefer', 'follow', 'notice', 'hurt', 'bother', 'support', 'believe', 'accept', 'attack']
a = pattern.en.tenses('stolen')[0]
pverb = [(v, pattern.en.conjugate(v, *a)) for v in pverb]
t = editor.template((
    'Does {first_name} {verb[0]} {first_name2}?',
    'Is {first_name} {verb[1]} by {first_name2}?',
),
    verb=pverb,
    obj=obj,
    remove_duplicates=True,
    nsamples=1000)

SRL_w_apswap_ppl_data = t.data
# label = 1

# Inference


In [None]:
def response_from_generate(model, messages):
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=1)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    output_mapping = {'A' : 1, 'B' : 0}
    # output_mapping = {'Yes': 1, 'No': 0}

    return output_mapping.get(response, None)


def response_from_forward(model, messages):
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    output = model.forward(model_inputs.input_ids)

    # idx 32 = A (Yes) (label 1), idx 33 = B (No) (label 0)
    # response = 1 - torch.argmax(output.logits[0, -1, 32:34]).item()
    # idx 59603 = A (Yes) (label 1), idx 59616 = B (No) (label 0)
    response = 1 - torch.argmax(output.logits[0, -1, [59603, 59616]]).item()

    return response


def inference(model, data, inference_mode='generate'):
    """
    Perform inference on model using created data samples. The first sentence
    in each list of strings is the gold label. inference_mode='generate' means
    .generate() is used to create a written response; inference_mode='forward'
    means .forward() uses the output logits to determine the response.
    """

    # system_message = "Are the two questions are paraphrase of each other? Please only respond with A (Yes), B (No)."
    system_message = "Do the following two questions have the same meaning? Respond with A (Yes) or B (No)."
    # system_message = "Do the following two questions have the same meaning? Respond with 'Yes' or 'No'."

    gold_labels, pred_labels = [], []

    for pairs in tqdm(data):
        sentence_labels = []
        for i, pair in enumerate(pairs):

            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": pair[0] + '\n' + pair[1]}
            ]

            if inference_mode == 'generate':
                response = response_from_generate(model, messages)
            elif inference_mode == 'forward':
                response = response_from_forward(model, messages)
            else:
                assert False, 'unknown inference mode'

            if i == 0:
                if response is None:
                    break
                gold_labels.append(response)
            else:
                if response is None:
                    continue
                sentence_labels.append(response)

        if len(sentence_labels) == 0:
            continue

        pred_labels.append(sentence_labels)

    return gold_labels, pred_labels

In [None]:
def evaluate(gold_labels, pred_labels):

    y_true, y_pred = [], []

    for i, sentence_labels in enumerate(pred_labels):
        for prompt_label in sentence_labels:
            y_pred.append(prompt_label)
            y_true.append(gold_labels[i])

    return accuracy_score(y_true, y_pred)


In [None]:
def inference_MFT(model, data, inference_mode='generate', label=None, fewshot=False):
    """
    Perform inference on model using created data samples. The first sentence
    in each list of strings is the gold label. inference_mode='generate' means
    .generate() is used to create a written response; inference_mode='forward'
    means .forward() uses the output logits to determine the response.
    """

    # system_message = "Are the two questions are paraphrase of each other? Please only respond with A (Yes), B (No)."
    system_message = "Do the following two questions have the same meaning? Respond only with one letter A (Yes) or B (No)."
    # system_message = "Do the following two questions have the same meaning? Respond with 'Yes' or 'No'."
    # system_message = "Consider the following pair of questions. Do they convey the same meaning? Please respond with 'A' for Yes or 'B' for No."

    pred_labels = []

    for pair in tqdm(data):

        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": 'Question 1: ' + pair[0] + ' Question 2: ' + pair[1]}
            # {"role": "user", "content": pair[0] + ' \n ' + pair[1]}
        ]

        if fewshot:
            # messages = [
            #     {"role": "system", "content": system_message},
            #     {"role": "user", "content": "Question 1: What is the step by step guide to invest in share market in india?	Question 2: What is the step by step guide to invest in share market?"},
            #     {"role": "system", "content": "B"},
            #     {"role": "user", "content": "Question 1: How can I become a strong person?'	Question 2: How can I become a person who is not weak?"},
            #     {"role": "system", "content": "A"},
            #     {"role": "user", "content": 'Question 1: ' + pair[0] + ' Question 2: ' + pair[1]}
            # ]

            # For Gemma which doesn't support system prompt as the first input
            messages = [
                {"role": "user", "content": system_message + "Question 1: What is the step by step guide to invest in share market in india?	Question 2: What is the step by step guide to invest in share market?"},
                {"role": "system", "content": "B"},
                {"role": "user", "content": system_message + "Question 1: How can I become a strong person?'	Question 2: How can I become a person who is not weak?"},
                {"role": "system", "content": "A"},
                {"role": "user", "content": system_message + 'Question 1: ' + pair[0] + ' Question 2: ' + pair[1]}
            ]

        if inference_mode == 'generate':
            response = response_from_generate(model, messages)
        elif inference_mode == 'forward':
            response = response_from_forward(model, messages)
        else:
            assert False, 'unknown inference mode'

        if response is None:
            continue

        pred_labels.append(response)

    gold_labels = [label] * len(pred_labels)

    return gold_labels, pred_labels

def evaluate_MFT(gold_labels, pred_labels):
    return accuracy_score(gold_labels, pred_labels)

# Run the test


### Robustness

In [None]:
gold_labels, pred_labels = inference(model, ROB_typo_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, ROB_contra_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, ROB_paraphrase_prod_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, ROB_paraphrase_each_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

### NER

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_last_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_last_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, NER_loc_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, NER_names_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference(model, NER_num_data, inference_mode='generate')
print(f'Accuracy: {evaluate(gold_labels, pred_labels):.4f}')

### Negation

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_person_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_activity_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_worry_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_antonym_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

### SRL

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_best_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
for i in range(3):
    gold_labels, pred_labels = inference_MFT(model, list(pairs[i] for pairs in SRL_comp_data), inference_mode='generate', label=1)
    print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_symrel_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_asymrel_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_ppl_data, inference_mode='generate', label=1)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_ppl_data, inference_mode='generate', label=0)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

### Test Few-Shot

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_last_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_first_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NER_last_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_person_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_activity_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_worry_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, NEG_antonym_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_best_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
for i in range(3):
    gold_labels, pred_labels = inference_MFT(model, list(pairs[i] for pairs in SRL_comp_data), inference_mode='generate', label=1, fewshot=True)
    print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_symrel_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_asymrel_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_apswap_ppl_data, inference_mode='generate', label=1, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')

In [None]:
gold_labels, pred_labels = inference_MFT(model, SRL_w_apswap_ppl_data, inference_mode='generate', label=0, fewshot=True)
print(f'Accuracy: {evaluate_MFT(gold_labels, pred_labels):.4f}')