In [1]:
import os, math, numpy as np
import pandas as pd
import re, gc
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
from tqdm import tqdm

pd.set_option('display.max_rows', 300)

# Configuration
IS_SUBMISSION = True
base_model_path = "Qwen/Qwen2.5-7B-Instruct"  
query_max_len, doc_max_len = 320, 48
task = "Given a math multiple-choice problem with a student's wrong answer, retrieve the math misconceptions"

print('IS_SUBMISSION:', IS_SUBMISSION)

df_train = pd.read_csv("./data/train.csv").fillna(-1).sample(10, random_state=42).reset_index(drop=True)
df_test = pd.read_csv("./data/test.csv")
df_misconception_mapping = pd.read_csv("./data/misconception_mapping.csv")

df_ret = df_test if IS_SUBMISSION else df_train

TEMPLATE_INPUT_V3 = '{QUESTION}\nCorrect answer: {CORRECT_ANSWER}\nStudent wrong answer: {STUDENT_WRONG_ANSWER}'

def format_input_v3(row, wrong_choice):
    assert wrong_choice in "ABCD"
    question_text = row.get("QuestionText", "No question text provided")
    subject_name = row.get("SubjectName", "Unknown subject")
    construct_name = row.get("ConstructName", "Unknown construct")
    correct_answer = row.get("CorrectAnswer", "Unknown")
    assert wrong_choice != correct_answer
    correct_answer_text = row.get(f"Answer{correct_answer}Text", "No correct answer text available")
    wrong_answer_text = row.get(f"Answer{wrong_choice}Text", "No wrong answer text available")

    formatted_question = f"""Question: {question_text}
    
SubjectName: {subject_name}
ConstructName: {construct_name}"""

    ret = {
        "QUESTION": formatted_question,
        "CORRECT_ANSWER": correct_answer_text,
        "STUDENT_WRONG_ANSWER": wrong_answer_text,
        "MISCONCEPTION_ID": row.get(f'Misconception{wrong_choice}Id'),
    }
    ret["PROMPT"] = TEMPLATE_INPUT_V3.format(**ret)
    return ret

items = []
target_ids = []
for _, row in df_ret.iterrows():
    for choice in ['A', 'B', 'C', 'D']:
        if choice == row["CorrectAnswer"]:
            continue
        if not IS_SUBMISSION and row[f'Misconception{choice}Id'] == -1:
            continue
        item = {'QuestionId_Answer': '{}_{}'.format(row['QuestionId'], choice)}
        item['Prompt'] = format_input_v3(row, choice)['PROMPT']
        items.append(item)
        target_ids.append(int(row.get(f'Misconception{choice}Id', -1)))
df_input = pd.DataFrame(items)

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'<instruct>{task_description}\n<query>{query}'

def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
    inputs = tokenizer(
        queries,
        max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) -
        len(tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
        return_token_type_ids=False,
        truncation=True,
        return_tensors=None,
        add_special_tokens=False
    )
    prefix_ids = tokenizer(examples_prefix, add_special_tokens=False)['input_ids']
    suffix_ids = tokenizer('\n<response>', add_special_tokens=False)['input_ids']
    new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len + 8) // 8 * 8 + 8
    new_queries = tokenizer.batch_decode(inputs['input_ids'])
    for i in range(len(new_queries)):
        new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
    return new_max_length, new_queries

queries = [get_detailed_instruct(task, q) for q in df_input['Prompt']]
documents = df_misconception_mapping['MisconceptionName'].tolist()

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
examples_prefix = ''
new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)

with open('data.json', 'w') as f:
    data = {'texts': new_queries + documents}
    f.write(json.dumps(data))

MAX_LENGTH = query_max_len

def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    sequence_lengths = attention_mask.sum(dim=1) - 1
    batch_size = last_hidden_states.shape[0]
    return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_embeddings_in_batches(model, tokenizer, texts, max_length, batch_size=4):
    embeddings = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
        batch_texts = texts[i : i + batch_size]
        batch_dict = tokenizer(
            batch_texts,
            max_length=max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        batch_dict = {k: v.to(device) for k, v in batch_dict.items()}  # ensure all on same device
        with torch.no_grad():
            # Just forward pass without autocast to avoid complexity
            outputs = model(**batch_dict, output_hidden_states=True)
            # Get the last hidden state from the tuple of hidden_states
            hidden_states = outputs.hidden_states[-1]  
            batch_embeddings = last_token_pool(hidden_states, batch_dict["attention_mask"])
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1).cpu()
        embeddings.append(batch_embeddings)
    return torch.cat(embeddings, dim=0)

# Load the base Qwen model fully on a single GPU or CPU (no device_map)
model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map=None
)

data = json.load(open("data.json"))
all_texts = data["texts"]
num_queries = len(new_queries)
num_docs = len(documents)

embeds = get_embeddings_in_batches(model, tokenizer, all_texts, max_length=MAX_LENGTH, batch_size=2)
text_to_embed = {text: emb for text, emb in zip(all_texts, embeds)}

query_embeddings = torch.stack([text_to_embed[t] for t in new_queries])
doc_embeddings = torch.stack([text_to_embed[t] for t in documents])

scores = query_embeddings @ doc_embeddings.T
sorted_indices = torch.argsort(scores, dim=1, descending=True)[:,:25].tolist()

if not IS_SUBMISSION:
    def compute_metrics(q_embeds: torch.Tensor, d_embeds: torch.Tensor, target_ids):
        scores = q_embeds @ d_embeds.T
        avg_precisions = []
        recall_counts = []
        for i, target_id in enumerate(target_ids):
            sorted_idx = torch.argsort(scores[i], descending=True)
            relevant_docs_top100 = (sorted_idx[:100] == target_id).nonzero(as_tuple=True)[0]
            recall_counts.append(1 if len(relevant_docs_top100) > 0 else 0)
            precision_at_k = 0.0
            num_relevant = 0
            for rank, idx in enumerate(sorted_idx[:25]):
                if idx == target_id:
                    num_relevant += 1
                    precision_at_k += num_relevant / (rank + 1)
            avg_precisions.append(precision_at_k / 1 if num_relevant > 0 else 0)

        map25 = sum(avg_precisions) / len(avg_precisions)
        recall100 = sum(recall_counts) / len(recall_counts)
        print(f"MAP@25: {map25:.4f}")
        print(f"Recall@100: {recall100:.4f}")

    compute_metrics(query_embeddings, doc_embeddings, target_ids)

df_input["MisconceptionId"] = [" ".join([str(x) for x in row]) for row in sorted_indices]
df_input[["QuestionId_Answer", "MisconceptionId"]].to_csv("submission.csv", index=False)

print("submission.csv created:")
display(pd.read_csv('submission.csv'))


  from .autonotebook import tqdm as notebook_tqdm


IS_SUBMISSION: True


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.08it/s]
Embedding: 100%|██████████| 1298/1298 [00:58<00:00, 22.32it/s]


submission.csv created:


Unnamed: 0,QuestionId_Answer,MisconceptionId
0,1869_B,2142 2068 2286 2065 2354 1251 1529 1095 1431 5...
1,1869_C,2142 2068 2286 2065 2354 1251 1248 1485 1529 1...
2,1869_D,2142 2286 2068 2065 1485 2008 1380 1248 739 53...
3,1870_A,2142 2068 2286 2065 1485 2008 1248 2354 1529 2...
4,1870_B,2142 2068 2286 2065 2008 1248 1485 2354 1529 2...
5,1870_C,2142 2068 2286 2065 1485 2008 1248 1380 2354 1...
6,1871_A,2142 2286 1251 2354 1095 1869 2068 2065 1431 3...
7,1871_C,2142 2286 2068 1251 2354 2065 1095 1869 1431 1...
8,1871_D,2142 2286 2068 2354 2065 1251 1095 1869 1431 7...


KeyError: "['TrueMisconceptionId'] not in index"