In [1]:
import json
import os
os.environ['HF_HOME'] = "/datastor1/wenxuand/"
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
import sys
from pathlib import Path
import argparse
from tqdm import tqdm
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

sys.path.append('../src')
from utils import get_L_prompt, load_lambada_data, make_prompt_triviaqa, get_response

In [2]:
def write_data(filename, data):
    with open(filename, 'a') as fout:
        for sample in data:
            fout.write(json.dumps(sample))
            fout.write('\n')

def read_data(filename):
    data = []
    with open(filename, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

In [14]:
L = load_dataset('lucadiliello/triviaqa')
L_train =  L['train'].shuffle(seed=42).select(range(3000)) 
L_test = L['validation'].shuffle(seed=42).select(range(1000))
print("L_train:", len(L_train))
print("L_test:", len(L_test))

L_train: 3000
L_test: 1000


In [15]:
print(L_train[0])
print(type(L_train[0]))
write_data('../data/tqa_train.jsonl', L_train)
write_data('../data/tqa_test.jsonl', L_test)

{'context': '[DOC] [TLE] Wilberforce (cat) brought to you by PhotosofCatsWilberforce (cat) brought to you by PhotosofCats [PAR] Named after [PAR] William Wilberforce [PAR] Wilberforce was a cat who lived at 10 Downing Street between 1973 and 1987 and served under four British Prime Ministers: Edward Heath, Harold Wilson, Jim Callaghan and Margaret Thatcher. His chief function was to catch mice, in which role he was the successor to Petra. In life he had been referred to as "the best mouser in Britain" as fit his role. [PAR] According to Bernard Ingham, the former press secretary to Margaret Thatcher, Wilberforce was a normal cat for whom Thatcher once bought "a tin of sardines in a Moscow supermarket". On the BBC coverage of the 1983 general election, presenter Esther Rantzen was allowed to hold Wilberforce and introduce him to viewers. [PAR] He retired on 3 April 1987, and was succeeded by Humphrey who was born in 1988, the year Wilberforce died. [PAR] This article uses material from 

In [3]:
L_train = read_data('../data/tqa_train.jsonl')
L_test = read_data('../data/tqa_test.jsonl')


In [4]:
def init_model(model_name, device):
    global model
    global tokenizer
    global terminators 
    torch_dtype = "auto"#torch.bfloat16
    # if 'gemma-3' in model_name:
    #     model = Gemma3ForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(device)
    # el
    if 'gemma' in model_name:
        model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch_dtype).to(device)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(device)
    print("model.config.torch_dtype:", model.config.torch_dtype)  
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    if "llama" in model_name:
        terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]


In [5]:
init_model("meta-llama/Llama-3.2-3B-Instruct",'cuda')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

model.config.torch_dtype: torch.bfloat16


In [6]:
import importlib
importlib.reload(sys.modules['utils'])
from utils import get_L_prompt, load_lambada_data, make_prompt_triviaqa, get_response


In [7]:
for k, v in L_train[0].items():
    print(k, v)

print(make_prompt_triviaqa(L_train[0]))

context [DOC] [TLE] Wilberforce (cat) brought to you by PhotosofCatsWilberforce (cat) brought to you by PhotosofCats [PAR] Named after [PAR] William Wilberforce [PAR] Wilberforce was a cat who lived at 10 Downing Street between 1973 and 1987 and served under four British Prime Ministers: Edward Heath, Harold Wilson, Jim Callaghan and Margaret Thatcher. His chief function was to catch mice, in which role he was the successor to Petra. In life he had been referred to as "the best mouser in Britain" as fit his role. [PAR] According to Bernard Ingham, the former press secretary to Margaret Thatcher, Wilberforce was a normal cat for whom Thatcher once bought "a tin of sardines in a Moscow supermarket". On the BBC coverage of the 1983 general election, presenter Esther Rantzen was allowed to hold Wilberforce and introduce him to viewers. [PAR] He retired on 3 April 1987, and was succeeded by Humphrey who was born in 1988, the year Wilberforce died. [PAR] This article uses material from the W

In [9]:
prompt = make_prompt_triviaqa(L_train[0])
response = get_response(prompt.prompt, model, tokenizer, 'cuda', True)
print(response)

for i in range(10):
    prompt = make_prompt_triviaqa(L_train[i])
    response = get_response(prompt.prompt, model, tokenizer, 'cuda', True)
    print(prompt)
    print(response)
    print('----------------------')

Margaret Thatcher
PromptCompletion(prompt='Question: Who or what was Wilberforce who retired from 10 Downing Street In 1987\n\nAnswer:', completion=' Cat', answers=['cat'])
Margaret Thatcher
----------------------
PromptCompletion(prompt='Question: What is the seven-branched candlestick, based on the candelabrum that was used in the Temple in Jerusalem in ancient times, that is the national symbol of the State of Israel?\n\nAnswer:', completion=' Menorah', answers=['menorah'])
The Menorah.
----------------------
PromptCompletion(prompt='Question: Which motor manufacturer produces the 7-seater MPV known as the Orlando?\n\nAnswer:', completion=' Chevrolet', answers=['chevrolet'])
Kia
----------------------
PromptCompletion(prompt='Question: What type of creature is a malimbe?\n\nAnswer:', completion=' Bird', answers=['bird', 'birds'])
Bird
----------------------
PromptCompletion(prompt='Question: "What, according to John Lennon, ""is just what happens to you, while you\'re busy making ot

In [10]:
import re
import string
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def handle_punc(text):
        exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"]))
        return ''.join(ch if ch not in exclude else ' ' for ch in text)

    def lower(text):
        return text.lower()

    def replace_underscore(text):
        return text.replace('_', ' ')

    return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip()

In [11]:
for i in range(10):
    prompt = make_prompt_triviaqa(L_train[i])
    response = get_response(prompt.prompt, model, tokenizer, 'cuda', True)
    print(prompt)
    print(response, normalize_answer(response))
    print('----------------------')




PromptCompletion(prompt='Question: Who or what was Wilberforce who retired from 10 Downing Street In 1987\n\nAnswer:', completion=' Cat', answers=['cat'])
Margaret Thatcher margaret thatcher
----------------------
PromptCompletion(prompt='Question: What is the seven-branched candlestick, based on the candelabrum that was used in the Temple in Jerusalem in ancient times, that is the national symbol of the State of Israel?\n\nAnswer:', completion=' Menorah', answers=['menorah'])
The Menorah. menorah
----------------------
PromptCompletion(prompt='Question: Which motor manufacturer produces the 7-seater MPV known as the Orlando?\n\nAnswer:', completion=' Chevrolet', answers=['chevrolet'])
Kia kia
----------------------
PromptCompletion(prompt='Question: What type of creature is a malimbe?\n\nAnswer:', completion=' Bird', answers=['bird', 'birds'])
Bird bird
----------------------
PromptCompletion(prompt='Question: "What, according to John Lennon, ""is just what happens to you, while you\'

In [45]:
import torch
def get_response(prompt, model, tokenizer, device = 'cuda', is_chat=False):
    with torch.no_grad():
        if is_chat:
            message = [
                {"role": "system", "content": "Generate an INCORRECT answer to the question in less than 3 words."},
                {"role": "user", "content": prompt},]
            input_ids = tokenizer.apply_chat_template(message, add_generation_prompt=True,return_tensors="pt", tokenize=True, return_dict=False)[0].tolist()
        else:
            input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0].tolist()

        response = model.generate(
            input_ids=torch.tensor([input_ids]).to(device),
            attention_mask=torch.ones(1,len(input_ids)).to(device),
            max_new_tokens=10,do_sample=True, temperature=1.2, top_k = 200,
            pad_token_id=tokenizer.eos_token_id)
        decoded_response = tokenizer.decode(response[0][len(input_ids):], skip_special_tokens=True)
    return decoded_response

In [39]:
from collections import Counter

def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)

In [50]:
from tqdm import tqdm
import random
seed_value = 42
random.seed(seed_value)
torch.manual_seed(seed_value)

def generate_negative_case_tqa(L, model, tokenizer):
    count_has_negative = 0
    result_L = []
    for i in tqdm(range(len(L))):
        positive_l = L[i]
        positive_l['correct'] = True
        result_L.append(positive_l)

        prompt = make_prompt_triviaqa(L[i])
        flag_has_negative = False
        for j in range(5):
            response = normalize_answer(get_response(prompt.prompt, model, tokenizer, 'cuda', True))
            cur_f1 = metric_max_over_ground_truths(f1_score, response, L[i]['answers'])
            if cur_f1 < 0.3:
                flag_has_negative = True
                negative_L = positive_l.copy()
                negative_L['correct'] = False
                negative_L['answers'] = [response]
                # print("response:", negative_L['answers'], positive_l['answers'])
                result_L.append(negative_L)
                break
        count_has_negative += flag_has_negative

    print("count_has_negative:", count_has_negative)
    return result_L

triviaqa_pn_train = generate_negative_case_tqa(L_train, model, tokenizer)
print(len(triviaqa_pn_train))
triviaqa_pn_test = generate_negative_case_tqa(L_test, model, tokenizer)
print(len(triviaqa_pn_test))
        



100%|██████████| 3000/3000 [06:29<00:00,  7.70it/s]


count_has_negative: 2887
5887


100%|██████████| 1000/1000 [02:11<00:00,  7.61it/s]

count_has_negative: 963
1963





In [52]:
triviaqa_pn_train_3000 = random.sample(triviaqa_pn_train, 3000)
triviaqa_pn_test_1000 = random.sample(triviaqa_pn_test, 1000)
print(len(triviaqa_pn_train_3000))
print(len(triviaqa_pn_test_1000))
write_data('../data/triviaqa_pn_train.jsonl', triviaqa_pn_train_3000)
write_data('../data/triviaqa_pn_test.jsonl', triviaqa_pn_test_1000)

3000
1000
