In [7]:
with open("hf-token.txt", "r") as f:
    token = f.read()

LF = "\n"

In [2]:
%%capture
!pip install flash-attn --no-build-isolation

from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from transformers import pipeline
import pandas as pd
import re

print(torch.mps.is_available())
torch.mps.empty_cache()

In [22]:
# deepseek-ai/deepseek-llm-7b-base: 7m 23s for loading (no feedback in the last minute)
# deepseek-ai/DeepSeek-R1-Distill-Qwen-14B: 13m 23s
# deepseek-ai/DeepSeek-R1-Distill-Llama-8B 1-5 min fairly accurate responses
# deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B poor accuracy
# princeton-nlp/gemma-2-9b-it-SimPO runs well, slightly worse performance than LGBM baseline
# ibm-granite/granite-3.2-8b-instruct

model_name = "Qwen/Qwen3-14B-Base"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Load in float16 to fit memory
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,   # Important for Mac memory fitting
    device_map="auto",
    trust_remote_code=True,
    offload_folder = "offload",
    token=token
)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

RuntimeError: Invalid buffer size: 23.60 GB

In [15]:
from_save = False
if from_save:
    model = AutoModelForCausalLM.from_pretrained("large-models/qwen-14b-quantized-model")
    tokenizer = AutoTokenizer.from_pretrained("large-models/qwen-14b-quantized-tokenizer")

In [16]:
prompt = "What is your favourite colour"
tokens = tokenizer(prompt, return_tensors = "pt").to("mps:0")
print("generating...")
res = model.generate(
                    **tokens, 
                     max_new_tokens = 200, 
                     temperature = 0.01,
                     num_beams = 2, 
                     length_penalty = 2.0,
                     no_repeat_ngram_size = 3
                     )
print("decoding...")
print(tokenizer.decode(res[0], skip_special_tokens = True))

generating...
decoding...
What is your favourite colour. It is a bright, sunny day. The sky is blue and the sun is shining.
The little girl was playing in the garden when she saw something shiny in the grass. She went to take a closer look and saw that it was a big, shiny rock. She was so excited and wanted to take it home.
But then she heard a voice. It was her mother. She said, "No, sweetheart. That rock is not yours. It belongs to someone else. You must give it back."
The girl was sad, but she knew her mother was right. She put the rock back where she found it and went back inside.
From then on, the little girl always remembered to ask before taking something that didn't belong to her. Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she went to the park with her mommy.


In [17]:
cls_ref = {
    0 : "Algebra",
    1 : "Geometry & Trigonometry",
    2 : "Calculus & Analysis",
    3 : "Probability & Statistics",
    4 : "Number Theory",
    5 : "Combinatorics & Discrete Math",
    6 : "Linear Algebra",
    7 : "Abstract Algebra & Topology"
}

def gen_prompt(problem, cls_ref):

    txt = """<context> You are a math expert. Your job is to read math problems and use your expertise to classify them 
into the appropriate subcategories. Please select the subcategory below that best describes 
the following problem: </context>"""

    txt += LF * 2 + '<problem> "' + problem + '" </problem>' + LF * 2

    for k,v in cls_ref.items():
        txt += f"({k}) {v} {LF}"

    txt += f"\n<query> Which of the above subcategories best fits this math problem? Please provide an answer in the form: '(9) subj'. </query>"

    return txt

In [18]:
df_tr = pd.read_csv("train-data/train.csv")

In [19]:
llm_ref = {
    "0" : 0,
    "geo" : 1,
    "trig" : 1,
    "1" : 1,
    "calc" : 2,
    "analy" : 2,
    "2" : 2,
    "probab" : 3,
    "stat" : 3,
    "3" : 3,
    "num" : 4,
    "theory" : 4,
    "4" : 4,
    "combin" : 5,
    "disc" : 5,
    "5" : 5,
    "line" : 6,
    "6" : 6,
    "abstr" : 7,
    "topol" : 7,
    "7" : 7,
    "algebra" : 0
}

def interpret_llm(output):
    ''' attempts to convert the llm's answer to the correct format '''

    ans = 8

    for k,v in llm_ref.items():
        if k in output: 
            ans = v
            break

    if ans == 8:
        return "(8) Err"
    
    else:
        return f"({ans}) {cls_ref[ans]}"

In [20]:
# deepseek 7b: 5s per query, correct answer around 40% of the time
# deepseek distill qwen 14b: 
import time

for i,(q,a) in enumerate(zip(df_tr.Question, df_tr.label)):

    start_time = time.time()

    prompt = gen_prompt(q, cls_ref)

    # print(prompt)

    tokens = tokenizer(prompt, return_tensors = "pt").to("mps:0")
    # print("generating...")
    res = model.generate(
                        **tokens, 
                        max_new_tokens = 15, 
                        temperature = 0.01,
                        num_beams = 2, 
                        length_penalty = 2.0,
                        no_repeat_ngram_size = 3
                        )    
    # print("decoding...")
    output = tokenizer.decode(res[0], skip_special_tokens = True)
    output = output.split(prompt)[1].replace("\n", "")

    ans = interpret_llm(output)
    print(f"question {i}, label = {a}, llm response = {ans}")
    n_tokens = len(re.split("\s+", output))
    t = time.time() - start_time
    print(f"time elapsed: {t}, {n_tokens / t} tokens / sec\n")




question 0, label = 3, llm response = (8) Err
time elapsed: 0.7041831016540527, 18.46110758617234 tokens / sec

question 1, label = 5, llm response = (8) Err
time elapsed: 0.24683117866516113, 52.667576561043575 tokens / sec

question 2, label = 0, llm response = (8) Err
time elapsed: 0.6738250255584717, 13.356583176085998 tokens / sec

question 3, label = 1, llm response = (4) Number Theory
time elapsed: 0.6049990653991699, 4.9586853461015545 tokens / sec

question 4, label = 5, llm response = (8) Err
time elapsed: 0.5482816696166992, 9.119400259168748 tokens / sec

question 5, label = 5, llm response = (8) Err
time elapsed: 0.43603515625, 11.466965285554311 tokens / sec

question 6, label = 1, llm response = (8) Err
time elapsed: 0.5559699535369873, 17.98658351297886 tokens / sec

question 7, label = 1, llm response = (4) Number Theory
time elapsed: 0.5865561962127686, 5.114599452482425 tokens / sec

question 8, label = 2, llm response = (8) Err
time elapsed: 0.41704726219177246, 21.

KeyboardInterrupt: 