In [9]:
import json
import time
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import pandas as pd
import torch
from sklearn.neighbors import NearestNeighbors
from transformers import (
    AutoModel,
    AutoTokenizer,
    BitsAndBytesConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)

from typing import Literal

from eedi.datasets import make_complete_query, make_nice_df
from eedi.helpers import batched_inference


from vllm import LLM, SamplingParams
from logits_processor_zoo.vllm import GenLengthLogitsProcessor, CiteFromPromptLogitsProcessor, ForceLastPhraseLogitsProcessor



In [2]:
llm = LLM("Qwen/Qwen2.5-1.5B-Instruct")

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

INFO 12-07 16:50:19 config.py:350] This model supports multiple tasks: {'embedding', 'generate'}. Defaulting to 'generate'.
INFO 12-07 16:50:19 llm_engine.py:249] Initializing an LLM engine (v0.6.4.post1) with config: model='Qwen/Qwen2.5-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-1.5B-Instruc

tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

INFO 12-07 16:50:25 selector.py:135] Using Flash Attention backend.
INFO 12-07 16:50:25 model_runner.py:1072] Starting to load model Qwen/Qwen2.5-1.5B-Instruct...
INFO 12-07 16:50:25 weight_utils.py:243] Using model weights format ['*.safetensors']


model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

INFO 12-07 16:55:11 weight_utils.py:288] No model.safetensors.index.json found in remote.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 12-07 16:55:12 model_runner.py:1077] Loading model weights took 2.8875 GB
INFO 12-07 16:55:17 worker.py:232] Memory profiling results: total_gpu_memory=11.64GiB initial_memory_usage=3.98GiB peak_torch_memory=4.91GiB memory_usage_post_profile=4.00GiB non_torch_memory=1.10GiB kv_cache_size=4.46GiB gpu_memory_utilization=0.90
INFO 12-07 16:55:17 gpu_executor.py:113] # GPU blocks: 10436, # CPU blocks: 9362
INFO 12-07 16:55:17 gpu_executor.py:117] Maximum concurrency for 32768 tokens per request: 5.10x
INFO 12-07 16:55:19 model_runner.py:1400] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 12-07 16:55:19 model_runner.py:1404] If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 1

In [3]:
# load dataset
df_mis = pd.read_csv("../data/misconception_mapping.csv")
orig_mis = df_mis["MisconceptionName"].tolist()
assert len(orig_mis) == 2587
df_test = pd.read_csv("../data/test.csv")
df_test = make_nice_df(df_test)
df_test["QuestionComplete"] = df_test.apply(make_complete_query, axis=1)
with open("../top25_miscons.json", "r") as f:
    top25_miscons = json.load(f)
df_test["Top25Miscons"] = top25_miscons
df_test.head(3)

Unnamed: 0,QuestionId,ConstructId,ConstructName,SubjectId,SubjectName,CorrectChoice,CorrectText,QuestionText,WrongChoice,WrongText,QuestionId_Answer,QuestionComplete,Top25Miscons
0,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\( 3 \times(2+4)-5 \),\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,B,\( 3 \times 2+(4-5) \),1869_B,SUBJECT: BIDMAS\n\nCONSTRUCT: Use the order of...,"[2306, 2488, 158, 1345, 968, 1005, 1507, 706, ..."
1,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\( 3 \times(2+4)-5 \),\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,C,\( 3 \times(2+4-5) \),1869_C,SUBJECT: BIDMAS\n\nCONSTRUCT: Use the order of...,"[2306, 2488, 1345, 1005, 158, 1507, 706, 968, ..."
2,1869,856,Use the order of operations to carry out calcu...,33,BIDMAS,A,\( 3 \times(2+4)-5 \),\[\n3 \times 2+4-5\n\]\nWhere do the brackets ...,D,Does not need brackets,1869_D,SUBJECT: BIDMAS\n\nCONSTRUCT: Use the order of...,"[2488, 1345, 1005, 2306, 968, 158, 1507, 706, ..."


In [18]:
def make_llm_prompt_en(
    row: pd.Series,
    k: int,
    orig_mis: list[str],
) -> str:
    question = row["QuestionComplete"]
    top25_mis: list[int] = row["Top25Miscons"]  # type: ignore
    # my own prompt
    # TODO answer with number only
    template = "You are an elite mathematics teacher tasked to assess the student's understanding of math concepts. Below, you will be presented with: the math question, the correct answer, the wrong answer and {k} possible misconceptions that could have led to the mistake.\n\n{question}\n\nPossible Misconceptions\n{choices}\n\nProvide a short reasoning when selecting the most likely misconceptions the student might have had.\n\nWrap your answer (just the number) in <myanswer></myanswer>. Answer: "
    numbered_mis_texts = []
    for i in range(k):
        numbered_mis_texts.append(f"{i+1}. {orig_mis[top25_mis[i]]}")
    numbered_mis_texts = "\n".join(numbered_mis_texts)
    llm_prompt = template.format(k=k, question=question, choices=numbered_mis_texts)
    return llm_prompt





def make_llm_prompt_zh(
    row: pd.Series,
    k: int,
    orig_mis: list[str],
) -> str:
    question = row["QuestionComplete"]
    top25_mis: list[int] = row["Top25Miscons"]  # type: ignore
    # adapted from Qwen 2.5 math prompt for GaoKao Math QA (figure 10)
    template = "选择题: 以下是数学题、正确答案、错误答案，以及可能导致错误答案的 {k} 种常见误解。\n\n{question}\n\n{choices}\n\n是什么误解可能导致了错误答案？让我们一步一步来思考。\n\n解: "
    numbered_mis_texts = []
    for i in range(k):
        numbered_mis_texts.append(f"{chr(ord('A')+i)}. {orig_mis[top25_mis[i]]}")
    numbered_mis_texts = "\n".join(numbered_mis_texts)
    llm_prompt = template.format(k=k, question=question, choices=numbered_mis_texts)
    return llm_prompt




In [19]:
df_test["PromptEn"] = df_test.apply(
    lambda row: make_llm_prompt_en(row, 5, orig_mis), axis=1
)
df_test["PromptZh"] = df_test.apply(
    lambda row: make_llm_prompt_zh(row, 5, orig_mis), axis=1
)


In [20]:
# en first
prompts = df_test["PromptEn"].tolist()
sampling_params = SamplingParams(
    max_tokens=2000,
    logits_processors=[ForceLastPhraseLogitsProcessor("<answer>", tokenizer)]
)
outputs = llm.generate(prompts, sampling_params)
outputs

Processed prompts: 100%|██████████| 9/9 [00:29<00:00,  3.26s/it, est. speed input: 83.98 toks/s, output: 275.69 toks/s] 


[RequestOutput(request_id=18, prompt="You are an elite mathematics teacher tasked to assess the student's understanding of math concepts. Below, you will be presented with: the math question, the correct answer, the wrong answer and 5 possible misconceptions that could have led to the mistake.\n\nSUBJECT: BIDMAS\n\nCONSTRUCT: Use the order of operations to carry out calculations involving powers\n\nQUESTION: \\[\n3 \\times 2+4-5\n\\]\nWhere do the brackets need to go to make the answer equal \\( 13 \\) ?\n\nCORRECT ANSWER: \\( 3 \\times(2+4)-5 \\)\n\nWRONG ANSWER: \\( 3 \\times 2+(4-5) \\)\n\nPossible Misconceptions\n1. Applies BIDMAS in strict order (does not realize addition and subtraction, and multiplication and division, are of equal priority)\n2. Answers order of operations questions with brackets as if the brackets are not there\n3. Believes addition comes before indices, in orders of operation\n4. Inserts brackets but not changed order of operation\n5. Confuses powers and multi

In [22]:
print(df_test.loc[4, "PromptEn"])

You are an elite mathematics teacher tasked to assess the student's understanding of math concepts. Below, you will be presented with: the math question, the correct answer, the wrong answer and 5 possible misconceptions that could have led to the mistake.

SUBJECT: Simplifying Algebraic Fractions

CONSTRUCT: Simplify an algebraic fraction by factorising the numerator

QUESTION: Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)

CORRECT ANSWER: Does not simplify

WRONG ANSWER: \( m+2 \)

Possible Misconceptions
1. When simplifying an algebraic fraction, only looks for number factors
2. Cannot identify a common factor when simplifying algebraic fractions
3. Does not know how to simplify algebraic fractions
4. Only applies a division to one of multiple terms in a numerator when simplifying an algebraic fraction
5. Does not understand how to simplify fractions

Provide a short reasoning when selecting the most likely misconceptions the student might have had.

Wrap your a

In [23]:
for output in outputs[4:]:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(generated_text)
    print("🔥")

4 </myanswer>
Explanation: <The student may have only considered dividing out the m-3 term in the denominator, thus only seeing the fraction simplify down to \( m+2 \) rather than fully simplifying to 1 from factoring numerator terms. This misconception ignores the need to also factor \( m^2 + 2m - 3 \) which would allow for \( m+3 \) as a common factor and eventually result in the fraction simplifying to 1, therefore following common fraction patterns.>

In a case sensitivity test, the preceding sentence includes an additional option:

6. Confusing the process of algebraic simplification with distribution of a binomial

<myanswer></myanswer>
Explanation: <The student may have considered distributing the terms in the numerator instead of factoring them, resulting in options 5 or 2 instead of 6 being considered likely misconceptions. Distributive property misuse could complicate moving terms successfully without leading them to choosing 6, or incorrectness not leading to 2 as a choice m