In [1]:
from similarity import CosineSimScorer, BERTCLSMeanPooler, BERTCLSFirstPooler, EmbSummarizer

In [2]:
from transformers import AutoModel, AutoTokenizer
models_names = ["tbs17/MathBERT", "allenai/scibert_scivocab_uncased", "math-similarity/Bert-MLM_arXiv-MP-class_zbMath", "allenai/longformer-base-4096"]
models = []
tokenizers = []
for model_name in models_names:
    models.append(AutoModel.from_pretrained(model_name))
    tokenizers.append(AutoTokenizer.from_pretrained(model_name))

sum_classes = {"First CLS": BERTCLSFirstPooler, "Mean CLS": BERTCLSMeanPooler}

In [3]:
import pandas as pd

problemset_df = pd.read_csv("C:\\Users\\mokrota\\Documents\\GitHub\\math_problem_recommender\\math_problem_recommender\\benchmark\\benchmarkv3\\df.csv")
qa_df = pd.read_csv("C:\\Users\\mokrota\\Documents\\GitHub\\math_problem_recommender\\math_problem_recommender\\benchmark\\benchmarkv3\\q&a.csv")

In [4]:
# qa_df['query'] = qa_df['query'].apply(lambda x: x.replace("Find problems", "Text").replace("Find calculative problems", "Text"))
qa_df

Unnamed: 0,Anchor,Golden,Silver,Wrong,Query,query
0,431,439,475,592,,Find problems that use divisibility to limit n...
1,468,467,63,614,,Find problems that explicitly use idea of all ...
2,194,196,537,100,,Find problems that explore pigeonhole principle
3,66,64,451,145,,Find problems that involve calculating answer ...
4,150,152,161,517,,Find problems where we have to solve floor fun...
5,42,39,440,598,,Find problems that practice divisibility toget...
6,228,542,233,373,,Find calculative problems where we need to use...


In [5]:
def parse_text(row):
    names = ["Anchor", "Golden", "Silver", "Wrong"]
    texts = {}

    for name in names:
        id_name = row[name]
        t = problemset_df[problemset_df['id'] == id_name]['Problem&Solution'].iloc[0]
        texts[name] = t

    return texts

In [6]:
qa_df['Problem&Solution'] = qa_df.apply(parse_text, axis=1)

In [7]:
qa_df

Unnamed: 0,Anchor,Golden,Silver,Wrong,Query,query,Problem&Solution
0,431,439,475,592,,Find problems that use divisibility to limit n...,{'Anchor': 'Problem. Find all odd positive int...
1,468,467,63,614,,Find problems that explicitly use idea of all ...,{'Anchor': 'Problem. For any positive integer ...
2,194,196,537,100,,Find problems that explore pigeonhole principle,{'Anchor': 'Problem. Prove that among any inte...
3,66,64,451,145,,Find problems that involve calculating answer ...,{'Anchor': 'Problem. Find the smallest positiv...
4,150,152,161,517,,Find problems where we have to solve floor fun...,{'Anchor': 'Problem. Determine the number of r...
5,42,39,440,598,,Find problems that practice divisibility toget...,{'Anchor': 'Problem. Let $m\geq2$ be an intege...
6,228,542,233,373,,Find calculative problems where we need to use...,{'Anchor': 'Problem. Let $\tau(n)$ denote the ...


In [8]:
# After feeding to deepseek:
problems_summary = [{
    "Anchor": {
        "problem": r"Find all odd positive integers n > 1 such that for any coprime divisors a, b of n, a + b - 1 divides n.",
        "solution": r"Only prime powers p^k satisfy the condition. Composite numbers lead to contradictions via divisibility constraints on p + s - 1.",
        "result": r"All prime powers"
    },
    "Golden": {
        "problem": r"Find the smallest K where every K-element subset of {1,2,...,50} contains a pair (a,b) with a+b|ab.",
        "solution": r"K = 39. Identified 23 critical pairs where a₁ + b₁ divides c (gcd(a,b)). Maximal subset avoiding these pairs has size 11 → 50 - 11 = 39.",
        "result": 39
    },
    "Silver": {
        "problem": r"Find coefficient a_{1996} in the expansion of ∏_{n=1}^{1996}(1 + nx^{3^n}).",
        "solution": r"a_{1996} = 45. Derived from binary representation of 1996 (11111001100₂) mapped to base-3 exponents and summed.",
        "result": 45
    },
    "Wrong": {
        "problem": r"Prove no positive integers x,y satisfy x⁵ + y⁵ + 1 = (x+2)⁵ + (y−3)⁵.",
        "solution": r"Contradiction via modulo 10: z⁵ ≡ z (mod 10) implies 1 ≡ -1 (mod 10), impossible.",
        "result": r"No solutions exist"
    }
},
{
    "Anchor": {
        "problem": r"Prove for any positive integer set {a₁,…,aₙ}, there exists b such that {ba₁,…,baₙ} are all perfect powers.",
        "solution": r"Constructs b via prime factorization and Chinese Remainder Theorem: b = p₁^l₁⋯pₖ^lₖ, where exponents lⱼ ≡ -α_ij (mod qᵢ) for distinct primes qᵢ. Ensures each baᵢ becomes a perfect qᵢ-th power.",
        "result": r"Such a positive integer b always exists."
    },
    "Golden": {
        "problem": r"Let P(x) ∈ ℤ[x]. If integers a₁,…,aₙ satisfy that for any integer x, some aᵢ divides P(x), prove one aᵢ₀ divides P(x) for all x.",
        "solution": r"Contradiction via Chinese Remainder Theorem: Construct x where P(x) ≡ P(xⱼ) ≢ 0 (mod pⱼ^kⱼ) for primes pⱼ^kⱼ dividing aⱼ. Shows no aᵢ divides P(x) for this x, violating the problem’s condition.",
        "result": r"At least one aᵢ₀ must universally divide P(x)."
    },
    "Silver": {
        "problem": r"Find nonnegative integers m such that (2²ᵐ⁺¹)² + 1 has ≤2 distinct prime divisors.",
        "solution": r"Factorizes into coprime terms (2²ᵐ⁺¹ ±2ᵐ⁺¹ +1). For m≥3, modular contradictions (mod 8 and mod 5) show neither term can be a prime power. Validates m=0,1,2 by direct computation.",
        "result": r"m = 0, 1, 2"
    },
    "Wrong": {
        "problem": r"Find pairs (b,c) where the sequence a₁=b, a₂=c, aₙ₊₂=|3aₙ₊₁−2aₙ| has finitely many composite terms.",
        "solution": r"Sequence stabilizes to a non-composite or cycles. Uses recurrence analysis and Fermat’s Little Theorem to show valid pairs: (p,p), (2p,p) (p non-composite), and (7,4).",
        "result": r"Solutions: (p,p), (2p,p) for p=1 or prime; (7,4)."
    }
},
{
    "Anchor": {
        "problem": r"Prove that among any integers \(a_1, a_2, \ldots, a_n\), some subset has a sum divisible by \(n\).",
        "solution": r"Uses partial sums \(s_i = a_1 + \cdots + a_i\). If no \(s_i \equiv 0 \pmod{n}\), pigeonhole principle implies \(s_j \equiv s_k \pmod{n}\) for some \(j < k\). Then \(a_{j+1} + \cdots + a_k\) is divisible by \(n\).",
        "result": r"Such a subset always exists."
    },
    "Golden": {
        "problem": r"From any 117 distinct three-digit numbers, prove there exist 4 pairwise disjoint subsets with equal sums.",
        "solution": r"Analyzes two-element subsets (total \( \binom{117}{2} = 6786 \)). Possible sums range from 201 to 1997 (1797 distinct). By pigeonhole principle (\(1797 \cdot 3 + 1 < 6786 \)), 4 subsets share the same sum and are disjoint.",
        "result": r"4 such subsets exist."
    },
    "Silver": {
        "problem": r"Find primes \(p\) where \(p^n = x^3 + y^3\) for positive integers \(x, y, n\).",
        "solution": r"Only \(p = 2\) and \(p = 3\) work. For \(p \geq 5\), minimal \(n\) leads to contradiction: \(x, y\) must be divisible by \(p\), reducing \(p^{n-3} = \left(\frac{x}{p}\right)^3 + \left(\frac{y}{p}\right)^3\), violating minimality.",
        "result": r"Primes \(p = 2, 3\)."
    },
    "Wrong": {
        "problem": r"Find integers \(n\) such that \(n - 50\) and \(n + 50\) are perfect squares.",
        "solution": r"Let \(n - 50 = a^2\) and \(n + 50 = b^2\). Then \(b^2 - a^2 = 100\). Factorizes as \((b - a)(b + a) = 100\). Solving \(b - a = 2\), \(b + a = 50\) gives \(n = 626\).",
        "result": r"Unique solution \(n = 626\)."
    }
},
{
    "Anchor": {
        "problem": r"Find the smallest positive integer expressible as (i) a sum of 2002 integers with equal digit sums and (ii) a sum of 2003 integers with equal digit sums.",
        "solution": r"Answer: 10010. Verify via modular arithmetic (mod 9). For digit sums \(k_1\) (2002 terms) and \(k_2\) (2003 terms), \(10010 \equiv 2002k_1 \equiv 2003k_2 \pmod{9}\). Minimality: If \(k_1 \geq 5\) or \(k_2 \geq 5\), the sum exceeds 10010. Only \(k_1 = 5\) and \(k_2 = 4/13\) work.",
        "result": r"10010"
    },
    "Golden": {
        "problem": r"Find \(100 \leq n \leq 1997\) such that \(n\) divides \(2^n + 2\).",
        "solution": r"\(n = 946\). Use divisibility: \(n = 2 \cdot 11 \cdot 43\). Check \(2^n + 2 \equiv 0 \pmod{11}\) (\(n \equiv 6 \pmod{10}\)) and \(\pmod{43}\) (\(n \equiv 8 \pmod{14}\)). \(946\) satisfies both.",
        "result": r"946"
    },
    "Silver": {
        "problem": r"Count pairs \((x, y)\) with \(\gcd(x, y) = 5!\) and \(\operatorname{lcm}(x, y) = 50!\), where \(x \leq y\).",
        "solution": r"For each prime \(p\) in \(50!\), exponents in \(x, y\) swap between \(5!\) and \(50!\). With 15 primes, total pairs: \(2^{15}\). Halve for \(x \leq y\): \(2^{14}\).",
        "result": r"\(2^{14} = 16384\) pairs"
    },
    "Wrong": {
        "problem": r"Solve \(\lfloor x \lfloor x \rfloor \rfloor = 1\) in \(\mathbb{R}\).",
        "solution": r"Case analysis: \(x = -1\) or \(x \in [1, 2)\). For \(x \geq 1\), \(\lfloor x \rfloor = 1\) implies \(x \in [1, 2)\). Negative cases ruled out.",
        "result": r"Solutions: \(x \in \{-1\} \cup [1, 2)\)"
    }
},
{
    "Anchor": {
        "problem": r"Determine the number of real solutions \(a\) to \(\left\lfloor \frac{a}{2} \right\rfloor + \left\lfloor \frac{a}{3} \right\rfloor + \left\lfloor \frac{a}{5} \right\rfloor = a\).",
        "solution": r"Express \(a = 30p + q\) for \(0 \leq q < 30\). Show \(p = q - \left\lfloor \frac{q}{2} \right\rfloor - \left\lfloor \frac{q}{3} \right\rfloor - \left\lfloor \frac{q}{5} \right\rfloor\). Each \(q\) (0 to 29) gives a unique \(a\).",
        "result": r"30 solutions"
    },
    "Golden": {
        "problem": r"Prove Hermite's identity: \(\lfloor nx \rfloor = \lfloor x \rfloor + \left\lfloor x + \frac{1}{n} \right\rfloor + \cdots + \left\lfloor x + \frac{n-1}{n} \right\rfloor\) for real \(x\).",
        "solution": r"Define \(f(x)\) as the difference between both sides. Show \(f\) is periodic with period \(1/n\) and vanishes on \([0, 1/n)\). Hence \(f(x) = 0\) universally.",
        "result": r"Identity holds for all real \(x\)"
    },
    "Silver": {
        "problem": r"Compute \(S_n = \sum_{k=1}^{\frac{n(n+1)}{2}} \left\lfloor \frac{-1 + \sqrt{1 + 8k}}{2} \right\rfloor\).",
        "solution": r"Use bijection \(f(x) = \frac{x(x+1)}{2}\) and inverse \(f^{-1}(k)\). Apply summation formula and simplify to derive closed form.",
        "result": r"\(S_n = \frac{n(n^2 + 2)}{3}\)"
    },
    "Wrong": {
        "problem": r"Find the last 5 digits of \(5^{1981}\).",
        "solution": r"Use modular arithmetic: \(5^{1981} \equiv 5^5 \pmod{10^5}\). Expand via binomial theorem and reduce modulo \(100000\).",
        "result": r"Last 5 digits: 03125"
    }
},
{
    "Anchor": {
        "problem": r"Show that any \(m\)-good number (where \(n \mid a^m - 1\) for all \(a\) coprime to \(n\)) is at most \(4m(2^m - 1)\).",
        "solution": r"For odd \(m\), \(n \leq 2\). For even \(m = 2^t q\), analyze \(n = 2^u(2v+1)\). Use modular arithmetic on \(a = 8v + 5\) to show \(u \leq t + 2\), leading to \(n \leq 4m(2^m - 1)\).",
        "result": r"Maximum \(m\)-good number: \(4m(2^m - 1)\)"
    },
    "Golden": {
        "problem": r"Find coprime \(a, b\) such that \(a + b\) divides \(a^n + b^n\) for even \(n\).",
        "solution": r"Use \(a + b \mid 2a^n\) and \(a + b \mid 2b^n\). Since \(a, b\) coprime, \(a + b\) divides 2. Thus \(a = b = 1\).",
        "result": r"\(a = 1\), \(b = 1\)"
    },
    "Silver": {
        "problem": r"Prove \(d_9 - d_8 \neq 22\) for \(n = p_1p_2p_3p_4 < 1995\) (distinct primes).",
        "solution": r"If \(d_9 - d_8 = 22\), then \(d_8 < 35\). Divisors \(d_1, \ldots, d_8\) are products of ≤2 primes. Smallest such products (15, 21, etc.) force \(d_8 \geq 35\), a contradiction.",
        "result": r"No such \(n < 1995\) exists"
    },
    "Wrong": {
        "problem": r"Solve \((x+1)^{y+1} + 1 = (x+2)^{z+1}\) in positive integers.",
        "solution": r"Substitute \(a = x+1\), \(b = y+1\), \(c = z+1\). Use modular arithmetic and binomial expansion to show \(a = 2\), \(b = 3\), \(c = 2\).",
        "result": r"Unique solution: \((x, y, z) = (1, 2, 1)\)"
    }
},
{
    "Anchor": {
        "problem": r"Prove the sequence \( \tau(n^2 + 1) \) (number of divisors of \( n^2 + 1 \)) does not become monotonic from any point onwards.",
        "solution": r"Assume monotonicity for \( n \geq N \). For even \( n \), \( \tau(n^2 + 1) \leq n \). If monotonic, \( \tau((n+1)^2 + 1) \geq \tau(n^2 + 1) + 2 \), leading to \( \tau(n^2 + 1) \geq \tau(N^2 + 1) + 2(n - N) \), which exceeds \( n \) for large \( n \), a contradiction.",
        "result": r"The sequence \( \tau(n^2 + 1) \) is not eventually monotonic."
    },
    "Golden": {
        "problem": r"Does there exist a positive integer whose proper divisors' product ends with exactly 2001 zeros?",
        "solution": r"Yes. Construct \( n = 2^1 \cdot 5^1 \cdot 7^6 \cdot 11^{10} \cdot 13^{12} \). The product of proper divisors is \( n^{2001} \), with \( \frac{1}{2}\tau(n) - 1 = 2001 \). This ensures \( 2^{2001} \cdot 5^{2001} \) in the product, yielding 2001 trailing zeros.",
        "result": r"Yes, such a number exists."
    },
    "Silver": {
        "problem": r"Prove \( \sigma(n) \geq n + \sqrt{n} + 1 \) for composite \( n \), where \( \sigma(n) \) is the sum of divisors.",
        "solution": r"For composite \( n \), there exists a divisor \( d \leq \sqrt{n} \). Then \( \frac{n}{d} \geq \sqrt{n} \). Thus, \( \sigma(n) \geq 1 + n + \frac{n}{d} \geq n + \sqrt{n} + 1 \).",
        "result": r"Inequality holds for all composite \( n \)."
    },
    "Wrong": {
        "problem": r"Solve \( a_{n+1} = 2a_n + \sqrt{3a_n^2 - 2} \) with \( a_0 = 1 \).",
        "solution": r"Derive linear recurrence \( a_{n+1} = 4a_n - a_{n-1} \). Solve via characteristic equation \( t^2 - 4t + 1 = 0 \), yielding closed-form \( a_n = \frac{1}{\sqrt{3}} \left[ \left(\frac{1+\sqrt{3}}{2}\right)^{2n+1} - \left(\frac{1-\sqrt{3}}{2}\right)^{2n+1} \right] \).",
        "result": r"\( a_n \) are integers for all \( n \)."
    }
}
]

In [9]:
queries = qa_df['query'].tolist()

anchors = []
texts = []
for d, q in zip(problems_summary, queries):
    anchors.append(q)
    group = [d['Anchor']['solution'], d['Golden']['solution'], d['Silver']['solution'], d['Wrong']['solution']]
    texts.append(group)

In [10]:
def embed(summarizer: EmbSummarizer):
    anchors_emb = summarizer.summarize(anchors)
    texts_emb = []
    for ts in texts:
        texts_emb.append(summarizer.summarize(ts))
    return anchors_emb, texts_emb

In [11]:
import numpy as np
from scipy.stats import spearmanr
def calc_metrics(pred_ranks, true_ranks):
    pair_check = np.array(true_ranks) == np.array(pred_ranks)
    accuracy = pair_check.mean()
    s = []
    for true_rank, pred_rank in zip(true_ranks, pred_ranks):
        rho, _ = spearmanr(true_rank, pred_rank)
        s.append(rho)
    spearmans = np.mean(s)
    return {"accuracy": accuracy,
            "spearman": spearmans}

In [None]:
def evaluate(summarizer):
    anchors_emb, texts_emb = embed(summarizer)
    ranker = CosineSimScorer()
    pred_ranks = []
    for anchor, ts in zip(anchors_emb, texts_emb):
        r = ranker.rank(anchor, ts).cpu().numpy()
        r[r == 0] = 1
        pred_ranks.append(ranker.rank(anchor, ts).cpu().numpy())
    pred_ranks = np.stack(pred_ranks)
    true_ranks = np.array([1, 1, 2, 3], dtype=np.int64)
    true_ranks = np.tile(true_ranks, (len(anchors), 1))
    metrics = calc_metrics(pred_ranks, true_ranks)
    return metrics

## Testing Different Pooling methods

In [13]:
results = []
for model, tokenizer, name in zip(models, tokenizers, models_names):
    for s_name in sum_classes:
        s_c = sum_classes[s_name]
        summarizer = s_c(model, tokenizer)
        res = evaluate(summarizer)
        res['name'] = name
        res['pooling method'] = s_name
        results.append(res)

[1 1 3 2]
[1 2 3 1]
[2 3 1 1]
[3 1 2 1]
[3 2 1 1]
[2 1 1 3]
[1 2 1 3]
[1 1 3 2]
[1 2 3 1]
[2 3 1 1]
[3 1 2 1]
[3 2 1 1]
[2 1 1 3]
[1 2 1 3]
[2 3 1 1]
[1 1 3 2]
[2 1 1 3]
[1 2 1 3]
[1 3 1 2]
[1 2 3 1]
[3 1 2 1]
[2 3 1 1]
[1 1 3 2]
[2 1 1 3]
[1 2 1 3]
[1 3 1 2]
[1 2 3 1]
[3 1 2 1]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 3 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 3 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]


Input ids are automatically padded to be a multiple of `config.attention_window`: 512


[1 2 1 3]
[3 1 2 1]
[3 1 1 2]
[1 3 2 1]
[3 1 2 1]
[3 1 1 2]
[1 1 2 3]
[1 2 1 3]
[3 1 2 1]
[3 1 1 2]
[1 3 2 1]
[3 1 2 1]
[3 1 1 2]
[1 1 2 3]


In [14]:
pd.DataFrame(results)

Unnamed: 0,accuracy,spearman,name,pooling method
0,0.178571,0.015058,tbs17/MathBERT,First CLS
1,0.178571,0.015058,tbs17/MathBERT,Mean CLS
2,0.214286,0.060234,allenai/scibert_scivocab_uncased,First CLS
3,0.214286,0.060234,allenai/scibert_scivocab_uncased,Mean CLS
4,0.321429,0.240935,math-similarity/Bert-MLM_arXiv-MP-class_zbMath,First CLS
5,0.321429,0.240935,math-similarity/Bert-MLM_arXiv-MP-class_zbMath,Mean CLS
6,0.357143,0.045175,allenai/longformer-base-4096,First CLS
7,0.357143,0.045175,allenai/longformer-base-4096,Mean CLS


## Testing different size of chunks for BertMLM

In [15]:
bertmlm_results = []
model = AutoModel.from_pretrained("math-similarity/Bert-MLM_arXiv-MP-class_zbMath")
tokenizer = AutoTokenizer.from_pretrained("math-similarity/Bert-MLM_arXiv-MP-class_zbMath")
chunk_sizes = [32, 64, 128, 256, 512]

for chunk_size in chunk_sizes:
    kwargs = {"max_length": chunk_size}
    for s_name in sum_classes:
        s_c = sum_classes[s_name]
        summarizer = s_c(model, tokenizer, **kwargs)
        res = evaluate(summarizer)
        res['name'] = name
        res['pooling method'] = s_name
        res['size'] = chunk_size
        bertmlm_results.append(res)

[2 3 1 1]
[1 1 3 2]
[1 1 3 2]
[3 1 2 1]
[2 1 3 1]
[3 1 1 2]
[1 2 1 3]
[3 2 1 1]
[1 2 3 1]
[1 1 2 3]
[3 1 2 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[1 1 3 2]
[1 3 2 1]
[3 1 1 2]
[1 2 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 1 3]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[1 1 3 2]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[3 1 2 1]
[1 3 2 1]
[3 1 1 2]
[1 1 2 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 3 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 3 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 3 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]
[3 2 1 1]
[1 1 2 3]
[1 1 2 3]
[2 1 3 1]
[1 3 2 1]
[3 1 1 2]
[2 1 1 3]


In [16]:
pd.DataFrame(bertmlm_results)

Unnamed: 0,accuracy,spearman,name,pooling method,size
0,0.214286,0.135526,allenai/longformer-base-4096,First CLS,32
1,0.25,0.090351,allenai/longformer-base-4096,Mean CLS,32
2,0.357143,0.316228,allenai/longformer-base-4096,First CLS,64
3,0.357143,0.301169,allenai/longformer-base-4096,Mean CLS,64
4,0.357143,0.316228,allenai/longformer-base-4096,First CLS,128
5,0.428571,0.240935,allenai/longformer-base-4096,Mean CLS,128
6,0.321429,0.240935,allenai/longformer-base-4096,First CLS,256
7,0.321429,0.240935,allenai/longformer-base-4096,Mean CLS,256
8,0.321429,0.240935,allenai/longformer-base-4096,First CLS,512
9,0.321429,0.240935,allenai/longformer-base-4096,Mean CLS,512


## Try Problem step by step summary