In [None]:
import aimo

env = aimo.make_env()
iter_test = env.iter_test()

DEBUG = True
DEBUG_ITER = []
if DEBUG:
    for test, sample_submission in iter_test:
        DEBUG_ITER.append([str(test['problem']), 0])
        sample_submission['answer'] = "Random Answer"
        env.predict(sample_submission)
    

In [None]:
import time
import os
import pandas as pd
import torch
import re
import random
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoModelForSequenceClassification
import gc
import traceback

NOTEBOOK_START_TIME = time.time()
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    PRIVATE = True
else:
    PRIVATE = False
    


MODEL_NAME = '/kaggle/input/deepseek-math-7b-rl/transformers/7b-rl/1'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
N_REPETITIONS = 1
MAX_NEW_TOKENS = 1000
TIME_LIMIT = 31500 if PRIVATE else 31500  #convert to 1 when submitting

torch.manual_seed(SEED)

In [None]:
def assign_subcategory(question1, label_type1, subcategories1):
    category = label_type1
    problem_text = question1.lower()
    
    if category in subcategories1:
        subs = subcategories1[category]
        if category == 'prealgebra':
            if any(word in problem_text for word in [
                'angle', 'triangle', 'circle', 'rectangle', 'square', 'perimeter', 'area', 'volume', 'surface area', 
                'parallel', 'perpendicular', 'right angle', 'acute angle', 'obtuse angle', 'equilateral', 'isosceles', 'scalene',
                'radius', 'diameter', 'circumference', 'pi', '\\pi', 'polygon', 'vertex', 'vertices', 'side', 'edge', 'face', 
                'base', 'height', 'altitude', 'chord', 'arc', 'sector', 'segment', 'tangent', 'secant', 'central angle', 
                'inscribed angle', 'congruent', 'similar', 'reflection', 'rotation', 'translation', 'symmetry', 'coordinate plane',
                'quadrant', 'midpoint', 'distance formula', 'Pythagorean theorem', 'a^2 + b^2 = c^2', 'right triangle', 
                'geometry', 'geometric figure', '\\angle', '\\triangle', '\\circle', '\\rectangle', '\\square', '\\overline{AB}', '\\parallel', '\\perp', 
                '\\cong', '\\sim', '\\cap', '\\cup', '\\Delta', '\\angle', '\\bigtriangleup', '\\bigtriangledown', '\\circ', '\\square', '\\rect', 
                '\\diamond', '\\lozenge', '\\hexagon', '\\pentagon', '\\heptagon', '\\octagon', '\\nonagon', '\\decagon', '\\dodecagon',
                'angle', 'triangle', 'circle', 'rectangle', 'square', 'perimeter', 'area', 'volume', 'surface area', 'parallel', 'perpendicular', 
                'right angle', 'acute angle', 'obtuse angle', 'equilateral', 'isosceles', 'scalene', 'radius', 'diameter', 'circumference', 'pi', 
                'polygon', 'vertex', 'vertices', 'side', 'edge', 'face', 'base', 'height', 'altitude', 'chord', 'arc', 'sector', 'segment', 
                'tangent', 'secant', 'central angle', 'inscribed angle', 'congruent', 'similar', 'reflection', 'rotation', 'translation', 
                'symmetry', 'coordinate plane', 'quadrant', 'midpoint', 'distance formula', 'Pythagorean theorem', 'a^2 + b^2 = c^2', 
                'right triangle', 'geometry', 'geometric figure', '\angle', '\triangle', '\circle', '\rectangle', '\square', '\overline{AB}', '\parallel', '\perp', 
                '\cong', '\sim', '\cap', '\cup', '\Delta', '\angle', '\bigtriangleup', '\bigtriangledown', '\circ', '\square', '\rect', 
                '\diamond', '\lozenge', '\hexagon', '\pentagon', '\heptagon', '\octagon', '\nonagon', '\\decagon', '\dodecagon',
            ]):
                return 'Basic Geometry'
            if any(word in problem_text for word in [
                'mean', 'median', 'mode', 'range', 'average', 'data', 'set', 'list', 'frequency', 'table', 'chart', 'graph',
                'histogram', 'bar graph', 'line graph', 'pie chart', 'stem-and-leaf plot', 'box plot', 'box-and-whisker plot',
                'scatter plot', 'plot', 'tally', 'survey', 'sample', 'population', 'percentage', 'percentile', 'probability', 
                'outlier', 'distribution', 'standard deviation', 'variance', 'interquartile range', 'quartile', 'summary', 
                'five-number summary', 'central tendency', 'spread', '\\bar{x}', '\\sigma', '\\mu',
                '\\sum', '\\prod', '\\binom', '\\mean', '\\median', '\\mode', '\\range', '\sigma', '\mu',
                '\sum', '\prod', '\binom', '\mean', '\median', '\mode', '\range',
            ]):
                return 'Basic Statistics'
            
        if category == 'algebra':
            if any(word in problem_text for word in ['quadratic', 'x^2',"ax^2 + bx + c", "x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}", "y = ax^2 + bx + c", "f(x) = ax^2 + bx + c",
                                                    "y = a(x - h)^2 + k", "y = a(x - r_1)(x - r_2)", 'parabola', 
                                                    'vertex', 'roots', 'qaudratic formula', 'discriminant', 'axis of symmetry', 'vertex form']):
                return 'Quadratic Equations'
            if any(word in problem_text for word in [
                'polynomial', 'polynomial equation', 'polynomial function', 'degree', 'roots', 'zeros', 'coefficients',
                'factor theorem', 'remainder theorem', 'synthetic division', 'polynomial long division', 'x^3', 'x^4', 'x^5', 'x^n',
                'a_n x^n + a_{n-1} x^{n-1} + \dots + a_1 x + a_0', '\\sum_{i=0}^n a_i x^i','degree','cubic', 'quartic', 'quintic', 
                'higher degree polynomial', 'constant term', 'leading coefficient', 'monomial', 'binomial', 'trinomial'
            ]):
                return 'Polynomial Equations'
            if any(word in problem_text for word in [
                'inequality', 'inequalities', 'solve for', 'greater than', 'less than', 'at least', 'at most',
                'no more than', 'no less than', 'range', 'interval', 'solution set', 'system of inequalities', 
                'compound inequality', 'linear inequality', 'quadratic inequality', 'absolute value inequality',
                'strict inequality', 'non-strict inequality', 'graph of inequality', 'boundary line', 'shaded region',
                'feasible region', 'intersection of inequalities', 'union of inequalities', 'x>', 'x<', 'x>=', 'x<=',
                '\\geq', '\\leq', '\\gt', '\\lt', 'y>', 'y<', 'y>=', 'y<=', 'inequality symbol', 'double inequality',
                'inequality notation', 'greater than or equal to', 'less than or equal to', '\geq', '\leq', '\gt', '\lt',
            ]):
                return 'Inequalities'
            if any(word in problem_text for word in [
                'system of equations', 'systems of equations', 'simultaneous equations', 'solve for', 'solution set', 'consistent',
                'inconsistent', 'dependent', 'independent', 'substitution method', 'elimination method', 'matrix method', 
                'graphical method', 'intersection point', 'linear system', 'nonlinear system', 'homogeneous system', 'augmented matrix', 
                'row reduction', 'Gaussian elimination', 'Gauss-Jordan elimination', 'Cramer\'s rule', 'determinant', 
                'coefficient matrix', 'variable matrix', 'constant matrix', 'ordered pair', 'solution pair', 'solution vector', 
                '3x3 system', '2x2 system', 'unique solution', 'infinite solutions', 'no solution', '\begin{cases}', '\end{cases}'
            ]):
                return 'Systems of Equations'
            if any(word in problem_text for word in [
                'exponential function', 'exponential equation', 'exponential growth', 'exponential decay', 'base', 'exponent', 
                'rate of growth', 'rate of decay', 'compounded', 'continuously', 'half-life', 'doubling time', 'e^', 'a e^{bx}', 
                'a \\cdot b^x', 'a e^{kt}', 'exponential model', 'exponential form', 'logarithmic form', 'inverse of exponential', 
                'logarithm', 'natural logarithm', 'ln', 'log', 'y = a e^{bx}', 'y = a b^x', 'P = P_0 e^{rt}', 'A = P (1 + r/n)^{nt}', 
                'A = P e^{rt}', 'x = log_b y', 'y = e^x', 'y = e^{-x}', 'y = 2^x', 'y = 10^x', 'y = 2^{-x}', 'y = 10^{-x}', 
                '\\exp(x)', '\\log(x)', '\\ln(x)', '\\log_{b}(x)', '\exp(x)', '\log(x)', '\ln(x)', '\log_{b}(x)'
            ]):
                return 'Exponential and Logarithmic Functions'
            if any(word in problem_text for word in [
                'function', 'relation', 'domain', 'range', 'mapping', 'input', 'output', 'ordered pair', 'vertical line test',
                'inverse function', 'composition of functions', 'composite function', 'piecewise function', 'continuous', 
                'discontinuous', 'interval', 'function notation', 'f(x)', 'g(x)', 'h(x)', 'y = f(x)', 'y = g(x)', 'y = h(x)',
                'one-to-one', 'onto', 'bijective', 'injective', 'surjective', 'identity function', 'constant function', 
                'linear function', 'quadratic function', 'polynomial function', 'rational function', 'exponential function', 
                'logarithmic function', 'absolute value function', 'step function', 'greatest integer function', 'modulus function', 
                'even function', 'odd function', 'symmetric', 'asymptote', 'horizontal asymptote', 'vertical asymptote', 
                'transformations', 'translations', 'reflections', 'stretch', 'compression'
            ]):
                return 'Functions and Relations'
            
        if category == 'number theory':
            if any(word in problem_text for word in [
                'modular arithmetic', 'modulus', 'mod', 'congruence', 'congruent', 'residue', 'modulo', 'remainder', 'divisibility',
                'Chinese Remainder Theorem', 'CRT', 'mod n', 'mod k', 'mod m', 'a ≡ b (mod n)', 'a ≡ b \\pmod{n}', 'a mod n', 
                'a mod m', 'a mod k', 'b mod n', 'b mod m', 'b mod k', '\\equiv', '\\pmod', 'a \\bmod n', 'b \\bmod m', '\\bmod{k}',
                '\equiv', '\pmod', 'a \bmod n', 'b \bmod m', '\bmod{k}'
            ]):
                return 'Modular Arithmetic'
            if any(word in problem_text for word in [
                'greatest common divisor', 'gcd', 'greatest common factor', 'gcf', 'highest common factor', 'hcf', 'least common multiple',
                'lcm', 'least common divisor', 'euclidean algorithm', 'division algorithm', 'prime factorization', 'common factor',
                'common multiple', 'divisor', 'multiple', 'factor', 'a \\mid b', 'a \\nmid b','gcd(a, b)', 'gcd(a, b, c)', '\\gcd', 'lcm(a, b)', 
                'lcm(a, b, c)', '\\lcm'
            ]):
                return 'GCD and LCM'
            if any(word in problem_text for word in [
                'diophantine equation', 'diophantine equations', 'linear diophantine equation', 'integer solutions', 'integer solution', 
                'integer root', 'integer roots', 'whole number solutions', 'whole number solution', 'whole number root', 'whole number roots', 
                'solve in integers', 'solve in whole numbers', 'beazout\'s identity', 'beazout\'s theorem', 'extended euclidean algorithm', 
                'positive integer solutions', 'negative integer solutions', 'Pell\'s equation', 'ax + by = c', 'ax + by = d', 'x^2 - Dy^2 = 1',
                'x^2 + y^2 = z^2', 'x + y = z', 'x - y = z', 'congruence modulo', 'congruence relation', 'modular equation'
            ]):
                return 'Diophantine Equations'
            if any(word in problem_text for word in [
                'sequence', 'sequences', 'series', 'arithmetic sequence', 'geometric sequence', 'arithmetic series', 'geometric series',
                'finite sequence', 'infinite sequence', 'finite series', 'infinite series', 'fibonacci sequence', 'harmonic series',
                'convergence', 'divergence', 'common difference', 'common ratio', 'term', 'nth term', 'general term', 'sum', 'partial sum',
                'recurrence relation', 'recursive formula', 'explicit formula', 'sigma notation', 'summation', 'summation notation',
                'a_n', 'S_n', 'n-th term', 'first term', 'last term', '\\sum_{i=1}^n', '\\sum_{n=1}^\\infty', 'n-th partial sum', 'limit'
            ]):
                return 'Sequences and Series'
            if any(word in problem_text for word in [
                'perfect number', 'perfect numbers', 'sum of divisors', 'aliquot sum', 'proper divisors', 'divisors', 'deficient number',
                'abundant number', 'even perfect number', 'odd perfect number', 'Euclid\'s theorem', 'Mersenne prime', 'Mersenne primes',
                '2^{p-1}(2^p - 1)', '2^p - 1', '6', '28', '496', '8128', '33550336'
            ]):
                return 'Perfect Numbers'
            if any(word in problem_text for word in [
                'fermat number', 'fermat numbers', 'fermat prime', 'fermat primes', '2^{2^n} + 1', 'F_n', 'Fermat\'s little theorem',
                'prime factor', 'factorization', 'composite', 'Fermat composite', 'modular arithmetic', 'pseudoprime', 'probable prime'
            ]):
                return 'Fermat Numbers'
        
        if category == 'counting & probability':
            if any(word in problem_text for word in [
                'combination', 'combinations', 'permutation', 'permutations', 'binomial coefficient', 'n choose k', 'nCr', 'nPr', 
                'factorial', 'arrangement', 'ordered', 'unordered', 'distinct', 'different', 'selection', 'subset', 'order matters',
                'order does not matter', 'pigeonhole principle', 'counting principle', 'combinatorial', 'combinatorics', '\\binom{n}{k}','\\dbinom{n}{k}',
                '\\frac{n!}{k!(n-k)!}', '\\frac{n!}{(n-k)!}', 'C(n, k)', 'P(n, k)', 'comb(n, k)', 'perm(n, k)',
                '\binom{n}{k}','\dbinom{n}{k}', '\frac{n!}{k!(n-k)!}', '\frac{n!}{(n-k)!}'
            ]):
                return 'Combinations and Permutations'
            if any(word in problem_text for word in [
                'probability', 'probabilities', 'likely', 'unlikely', 'chance', 'odds', 'event', 'outcome', 'sample space', 'experiment', 
                'trial', 'random', 'independent events', 'dependent events', 'mutually exclusive', 'complementary events', 
                'conditional probability', 'P(A|B)', 'P(A and B)', 'P(A or B)', 'Bayes\' theorem', 'Bayes\' rule', 'with replacement', 
                'without replacement', 'expected value', 'mean value', 'variance', 'standard deviation', 'distribution', 'normal distribution', 
                'binomial distribution', 'Bernoulli distribution', 'geometric distribution', 'hypergeometric distribution', 'Poisson distribution', 
                'combinatorial probability', 'tree diagram', 'Venn diagram', 'union', 'intersection', 'complement', 'pdf', 'pmf', 'poisson'
            ]):
                return 'Probability'
            if any(word in problem_text for word in [
                'card', 'cards', 'deck', 'standard deck','\\hearts', '\\diamonds', '\\clubs', '\\spades', 'playing card', 'playing cards', 'ace', 'king', 'queen', 'jack', 'hearts', 'diamonds',
                'clubs', 'spades', 'suit', 'suits', 'rank', 'ranks', 'face card', 'face cards', 'black card', 'black cards', 'red card', 
                'red cards', 'drawing cards', 'drawing a card', 'draw a card', 'draw cards', 'probability of drawing', 'number of cards', 
                'combination of cards', 'permutation of cards', 'poker', 'bridge', 'solitaire', 'card game', '52 cards', '52-card deck'
            ]):
                return 'Card Problems'
            if any(word in problem_text for word in [
                'dice', 'die', 'roll', 'rolled', 'rolling', 'probability', 'face', 'faces', 'side', 'sides', 'number cube', 
                'six-sided die', 'fair die', 'unfair die', 'loaded die', 'sum of faces', 'sum of numbers', 'product of faces', 
                'product of numbers', 'number on face', 'numbers on faces'
            ]):
                return 'Dice Problems'
            
        if category == 'precalculus':
            if any(word in problem_text for word in [
                'trigonometric function', 'trigonometric identity', 'trigonometric equation', 'sin', 'cos', 'tan', 'cot', 'sec', 'csc',
                'sine', 'cosine', 'tangent', 'cotangent', 'secant', 'cosecant', 'right triangle', 'hypotenuse', 'opposite', 'adjacent',
                'Pythagorean identity', 'double angle formula', 'half angle formula', 'sum and difference formulas', 'sum of angles',
                'difference of angles', 'law of sines', 'law of cosines', 'unit circle', 'radian', 'degree', 'angle', 'period', 'amplitude',
                'frequency', 'phase shift', 'vertical shift', 'inverse trigonometric functions', 'arc', 'arcsin', 'arccos', 'arctan',
                'arccot', 'arcsec', 'arccsc', 'graph of sine', 'graph of cosine', 'graph of tangent', 'graph of cotangent', 'graph of secant',
                'graph of cosecant', '\\sin', '\\cos', '\\tan', '\\cot', '\\sec', '\\csc', '\sin', '\cos', '\tan', '\cot', '\sec', '\csc'
            ]):
                return 'Trigonometric Functions and Identities'
            if any(word in problem_text for word in [
                'conic section', 'conic sections', 'parabola', 'parabolas', 'ellipse', 'ellipses', 'hyperbola', 'hyperbolas', 
                'circle', 'circles', 'focus', 'foci', 'directrix', 'directrices', 'vertex', 'vertices', 'axis of symmetry', 
                'major axis', 'minor axis', 'transverse axis', 'conjugate axis', 'asymptote', 'asymptotes', 'eccentricity', 
                'latus rectum', 'standard form', 'general form', 'equation of a circle', 'equation of a parabola', 
                'equation of an ellipse', 'equation of a hyperbola', 'x^2 + y^2 = r^2', '(x-h)^2 + (y-k)^2 = r^2', 'y = ax^2 + bx + c', 
                'y = a(x-h)^2 + k', '\\frac{(x-h)^2}{a^2} + \\frac{(y-k)^2}{b^2} = 1', '\\frac{(x-h)^2}{a^2} - \\frac{(y-k)^2}{b^2} = 1', 
                'r = e / (1 + e cos \\theta)', 'r = e / (1 + e sin \\theta)', 'r = e / (1 - e cos \\theta)', 'r = e / (1 - e sin \\theta)'
            ]):
                return 'Conic Sections'
            if any(word in problem_text for word in [
                'sequence', 'sequences', 'series', 'arithmetic sequence', 'geometric sequence', 'arithmetic series', 'geometric series',
                'finite sequence', 'infinite sequence', 'finite series', 'infinite series', 'fibonacci sequence', 'harmonic series',
                'convergence', 'divergence', 'common difference', 'common ratio', 'term', 'nth term', 'general term', 'sum', 'partial sum',
                'recurrence relation', 'recursive formula', 'explicit formula', 'sigma notation', 'summation', 'summation notation',
                'a_n', 'S_n', 'n-th term', 'first term', 'last term', '\\sum_{i=1}^n', '\\sum_{n=1}^\\infty', 'n-th partial sum'
            ]):
                return 'Sequences and Series'
            if any(word in problem_text for word in [
                'vector', 'vectors', 'magnitude', 'direction', 'dot product', 'cross product', 'scalar', 'vector addition', 
                'vector subtraction', 'unit vector', 'vector projection', 'vector components', 'component form', 
                'magnitude of a vector', 'direction of a vector', 'position vector', 'free vector', 'zero vector', 'orthogonal vectors', 
                'parallel vectors', 'resultant vector', 'linear combination', 'basis vectors', 'standard unit vectors', 'vector notation', 
                'geometric interpretation of vectors', 'algebraic interpretation of vectors', 'vector equation', 'parametric equation', 
                'rectangular form', 'polar form', 'vector field', 'gradient', 'divergence', 'curl', '\\vec{v}', '\\vec{a}', '\\vec{b}', 
                '\\vec{i}', '\\vec{j}', '\\vec{k}', '\\overrightarrow{AB}', 'vector', 'vectors', 'magnitude', 'direction', 'dot product', 'cross product', 'scalar', 
                'vector addition', 'vector subtraction', 'unit vector', 'vector projection', 'vector components', 'i', 'j', 'k', 'component form', 
                'magnitude of a vector', 'direction of a vector', 'position vector', 'free vector', 'zero vector', 'orthogonal vectors', 
                'parallel vectors', 'resultant vector', 'linear combination', 'basis vectors', 'standard unit vectors', 'vector notation', 
                'geometric interpretation of vectors', 'algebraic interpretation of vectors', 'vector equation', 'parametric equation', 
                'rectangular form', 'polar form', 'vector field', 'gradient', 'divergence', 'curl'
            ]):
                return 'Vectors'
            if any(word in problem_text for word in [
                'complex number', 'complex numbers', 'imaginary unit', 'imaginary part', 'real part', 'real and imaginary parts',
                '$i$', '$j$', 'a + bi', 'a + bj', 'complex plane', 'Argand diagram', 'magnitude', 'absolute value', 'modulus', 
                'argument', 'conjugate', 'complex conjugate', 'polar form', 'rectangular form', 'Cartesian form', 'exponential form', 
                'Euler\'s formula', 'De Moivre\'s theorem', 'roots of complex numbers', 'powers of complex numbers', 
                'addition of complex numbers', 'subtraction of complex numbers', 'multiplication of complex numbers', 
                'division of complex numbers', '$z$', '\\overline{z}', '|z|'
            ]):
                return 'Complex Numbers'
            if any(word in problem_text for word in [
                'exponential function', 'logarithmic function', 'exponential equation', 'logarithmic equation', 'exponential growth',
                'exponential decay', 'base', 'exponent', 'logarithm', 'logarithms', 'natural logarithm', 'common logarithm', 
                'ln', 'log', 'e', 'log base 10', 'log base e', 'log base 2', 'inverse functions', 'inverse function', 'inverse of exponential',
                'inverse of logarithm', 'change of base formula', 'logarithmic scale', 'logarithmic properties', 'logarithmic identity',
                'power rule', 'product rule', 'quotient rule', 'continuous growth', 'continuous decay', 'half-life', 'doubling time',
                'compound interest', 'continuous compounding', 'decibel scale', 'Richter scale', 'y = e^x', 'y = e^{-x}', 'y = a e^{bx}', 
                'y = a e^{-bx}', 'y = b^x', 'y = b^{-x}', 'y = a b^x', 'y = a b^{-x}', 'y = \log_b{x}', 'y = \ln{x}', 'y = \log_{10}{x}',
                '\exp(x)', '\log(x)', '\ln(x)', '\log_{b}(x)', 'y = \\log_b{x}', 'y = \\ln{x}', 'y = \\log_{10}{x}',
                '\\exp(x)', '\\log(x)', '\\ln(x)', '\\log_{b}(x)'
            ]):
                return 'Exponential and Logarithmic Functions'
            if any(word in problem_text for word in [
                'limit', 'limits', 'lim', 'approaches', 'approaching', 'as x approaches', 'as x tends to', 'as x goes to', 
                'one-sided limit', 'left-hand limit', 'right-hand limit', 'two-sided limit', 'finite limit', 'infinite limit', 
                'limit at infinity', 'limit at a point', 'existence of limit', 'continuity', 'discontinuity', 'epsilon-delta definition', 
                'L\'Hopital\'s rule', 'asymptote', 'asymptotic behavior', 'convergence', 'divergence', 'indeterminate form', 
                'bounded', 'unbounded', 'squeeze theorem', 'sandwich theorem', 'end behavior', 'vertical asymptote', 'horizontal asymptote', 
                'graphical interpretation of limits', 'algebraic evaluation of limits', '\\lim_{x \\to a}', '\\lim_{x \\to \\infty}', 
                '\\lim_{x \\to -\\infty}'
            ]):
                return 'Limits (Introductory)'
            if any(word in problem_text for word in [
                'parametric equation', 'parametric equations', 'parametric form', 'parametric representation', 'parameter', 
                'parameters', 'parametric curve', 'parametric curves', 'x(t)', 'y(t)', 'x = f(t)', 'y = g(t)', 'parameter t', 
                't-parameter', 'eliminating the parameter', 'eliminate the parameter', 'convert to Cartesian form', 
                'convert to rectangular form', 'parameterization', 'parameterize', 'parameterizing', 'graph of parametric equations', 
                'path of a particle', 'motion along a curve', 'trajectory', 'velocity', 'acceleration', 'tangent to a curve', 
                'arc length', 'arc length of a parametric curve', 'polar coordinates',
            ]):
                return 'Parametric Equations'
        
        if category == 'calculus':
            if any(word in problem_text for word in [
                'limit', 'limits', 'lim', 'approaches', 'approaching', 'as x approaches', 'as x tends to', 'as x goes to', 
                'one-sided limit', 'left-hand limit', 'right-hand limit', 'two-sided limit', 'finite limit', 'infinite limit', 
                'limit at infinity', 'limit at a point', 'existence of limit', 'continuity', 'discontinuity', 'epsilon-delta definition', 
                'L\'Hopital\'s rule', 'asymptote', 'asymptotic behavior', 'convergence', 'divergence', 'indeterminate form', 
                'bounded', 'unbounded', 'squeeze theorem', 'sandwich theorem', 'end behavior', 'vertical asymptote', 'horizontal asymptote', 
                'graphical interpretation of limits', 'algebraic evaluation of limits', '\\lim_{', '\lim_{x \\to \\infty}', 
                '\\lim_{x \\to -\\infty}'
            ]):
                return 'Limits'
            if any(word in problem_text for word in [
                'derivative', 'differentiation', 'derivatives', 'differentiate', 'differentiating', 'differential', 'rate of change', 
                'instantaneous rate of change', 'slope of the tangent', 'tangent line', 'normal line', 'dy/dx', 'd/dx', 'd^2y/dx^2', 
                'd^n y/dx^n', 'first derivative', 'second derivative', 'nth derivative', 'higher-order derivatives', 'partial derivative', 
                'partial differentiation', '\frac{dy}{dx}', '\frac{d^2y}{dx^2}', '\frac{d^n y}{dx^n}', '\frac{\\partial y}{\\partial x}', 
                'partial derivatives', 'gradient', 'gradient vector', 'directional derivative', 'chain rule', 'product rule', 'quotient rule', 
                'implicit differentiation', 'logarithmic differentiation', 'mean value theorem', 'extreme value theorem', 
                'critical point', 'inflection point', 'concavity', 'convexity', 'increasing function', 'decreasing function', 
                'local maximum', 'local minimum', 'global maximum', 'global minimum', 'optimization', 'related rates', 'differentials', 
                'linear approximation', 'tangent approximation', '\\frac{dy}{dx}', '\\frac{d^2y}{dx^2}', '\\frac{d^n y}{dx^n}', '\\frac{\\partial y}{\\partial x}'
            ]):
                return 'Derivatives'
            if any(word in problem_text for word in [
                'integral', 'integrals', 'integration', 'antiderivative', 'antiderivatives', 'definite integral', 
                'indefinite integral', 'Riemann sum', 'fundamental theorem of calculus', 'FTC', 'integration by parts', 
                'u-substitution', 'trigonometric substitution', 'partial fraction decomposition', 'improper integral', 
                'area under the curve', 'volume of revolution', 'arc length', 'surface area of revolution', 'integration techniques', 
                'numerical integration', 'Simpson\'s rule', 'trapezoidal rule', '\int', '\int_{a}^{b}', '\int_{0}^{\\infty}', '\int_{0}^{\infty}',
                '\int_{-\\infty}^{\\infty}', '\int_{a}^{\\infty}', '\int_{-\\infty}^{b}', '\int u dv', '\int_{a}^{b} f(x) dx', 
                '\int f(x) dx', '\iint', '\iiint', '\oint', '\int_{-\infty}^{\infty}', '\int_{a}^{\infty}', '\int_{-\infty}^{b}'
            ]):
                return 'Integrals'
            if any(word in problem_text for word in [
                'differential equation', 'differential equations', 'ordinary differential equation', 'ODE', 'partial differential equation', 
                'PDE', 'first-order differential equation', 'second-order differential equation', 'higher-order differential equation', 
                'linear differential equation', 'nonlinear differential equation', 'homogeneous differential equation', 
                'nonhomogeneous differential equation', 'solution to the differential equation', 'general solution', 'particular solution', 
                'initial value problem', 'boundary value problem', 'separable differential equation', 'exact differential equation', 
                'integrating factor', 'characteristic equation', 'Laplace transform', 'inverse Laplace transform', 'method of undetermined coefficients', 
                'method of variation of parameters', 'eigenvalue', 'eigenvector', 'phase plane', 'stability', 'bifurcation', 'direction field', 
                'isocline', 'Green\'s function', '\frac{dy}{dx}', '\frac{d^2y}{dx^2}', '\frac{d^n y}{dx^n}', '\\frac{\\partial y}{\\partial x}', 
                '\frac{\\partial^2 y}{\\partial x^2}', '\\frac{\\partial^n y}{\\partial x^n}', '\frac{\partial y}{\partial x}', 
                '\frac{\partial^2 y}{\partial x^2}', '\frac{\partial^n y}{\partial x^n}'
            ]):
                return 'Differential Equations'
            if any(word in problem_text for word in [
                'series', 'sequences', 'sequence', 'arithmetic sequence', 'geometric sequence', 'arithmetic series', 'geometric series',
                'finite series', 'infinite series', 'finite sequence', 'infinite sequence', 'convergence', 'divergence', 'limit of a sequence',
                'sum of a series', 'nth term', 'general term', 'partial sum', 'infinite sum', 'power series', 'Taylor series', 'Maclaurin series',
                'Fourier series', 'radius of convergence', 'interval of convergence', 'absolute convergence', 'conditional convergence',
                'ratio test', 'root test', 'integral test', 'comparison test', 'alternating series test', 'divergence test', 'p-series', 
                'harmonic series', 'telescoping series', 'recursive sequence', 'recurrence relation', 'closed form', 'summation notation',
                '\\sum', '\\sum_{n=1}^{\\infty}', '\\sum_{n=0}^{\\infty}', 'a_n', 'S_n'
            ]):
                return 'Sequences and Series'
            if any(word in problem_text for word in [
                'multivariable calculus', 'partial derivatives', 'multiple integrals', 'partial derivative', 'gradient', 'divergence', 'curl',
                'vector calculus', 'vector field', 'scalar field', 'level curves', 'level surfaces', 'tangent plane', 'normal line',
                'Lagrange multipliers', 'directional derivative', 'Jacobian', 'Hessian', 'double integral', 'triple integral',
                'line integral', 'surface integral', 'volume integral', 'flux', 'Green\'s theorem', 'Stokes\' theorem', 'Gauss\'s theorem',
                'divergence theorem', 'coordinate transformation', 'polar coordinates', 'cylindrical coordinates', 'spherical coordinates',
                'chain rule for multiple variables', 'change of variables', 'Jacobian determinant', 'iterated integral', 'path independence',
                'conservative field', 'parametric surface', 'parametric curve', '\\nabla'
            ]):
                return 'Multivariable Calculus'
            if any(word in problem_text for word in [
                'optimization problem', 'optimization', 'maximize', 'minimize', 'maximum', 'minimum', 'extrema', 'extremum',
                'local maximum', 'local minimum', 'global maximum', 'global minimum', 'critical point', 'stationary point',
                'Lagrange multipliers', 'constrained optimization', 'unconstrained optimization', 'objective function', 
                'cost function', 'profit function', 'constraint', 'constraints', 'feasible region', 'boundary point', 'end point',
                'first derivative test', 'second derivative test', 'concavity', 'convexity', 'inflection point', 'monotonicity',
                'interval of increase', 'interval of decrease', 'optimization technique', 'gradient ascent', 'gradient descent',
                'optimization algorithm', 'Newton\'s method', 'Hessian matrix', 'Jacobain matrix', 'optimal solution', 'optimal value'
            ]):
                return 'Optimization Problems'
            if any(word in problem_text for word in [
                'area under the curve', 'area between curves', 'integrate', 'integration', 'definite integral', 'indefinite integral', 
                'Riemann sum', 'fundamental theorem of calculus', 'FTC', 'bounded region', 'upper limit', 'lower limit', 'limits of integration', 
                'dx', 'dy', 'dA', 'double integral', 'triple integral', 'surface area', 'volume', 'disk method', 'washer method', 
                'shell method', 'numerical integration', 'Simpson\'s rule', 'trapezoidal rule', 'integration technique', 'integration by parts', 
                'u-substitution', 'trigonometric substitution', 'partial fraction decomposition', 'improper integral', 'polar coordinates', 
                'arc length', 'arc length of a curve', 'parametric equations', '\\int', '\\int_{a}^{b}', '\\int_{0}^{\\infty}', 
                '\\int_{-\\infty}^{\\infty}', '\\int_{a}^{\\infty}', '\\int_{-\\infty}^{b}', '\\iint', '\\iiint'
            ]):
                return 'Area Under Curves'
            if any(word in problem_text for word in [
                'vector calculus', 'vector field', 'scalar field', 'gradient', 'divergence', 'curl', 'Laplacian', 'del operator',
                'nabla', 'line integral', 'surface integral', 'volume integral', 'flux', 'Green\'s theorem', 'Stokes\' theorem', 
                'Gauss\'s theorem', 'divergence theorem', 'parametric equations', 'parametric surface', 'parametric curve', 
                'path integral', 'circulation', 'conservative field', 'irrotational field', 'solenoidal field', 'vector potential',
                'scalar potential', 'potential function', 'work done by a force field', 'field line', 'vector notation', 'vector operations', 
                '\nabla \\cdot \\vec{', '\vec{', '\nabla \\cdot \vec{', '\vec{', '\nabla \cdot \vec{', '\nabla \cdot \\vec{', 
                '\\nabla \\times \\vec{F}', '\\int \\vec{F} \\cdot d\\vec{r}', '\\iint_S \\vec{F} \\cdot d\\vec{S}', '\\iiint_V \\vec{F} \\cdot dV', 
                '\\oint_C \\vec{F} \\cdot d\\vec{r}'
            ]):
                return 'Vector Calculus'
        
        if category == 'intermediate algebra':
            if any(word in problem_text for word in [
                'complex number', 'complex numbers', 'imaginary unit', 'imaginary part', 'real part', 'real and imaginary parts',
                '$i$', '$j$', 'a + bi', 'a + bj', 'complex plane', 'Argand diagram', 'magnitude', 'absolute value', 'modulus', 
                'argument', 'conjugate', 'complex conjugate', 'polar form', 'rectangular form', 'Cartesian form', 'exponential form', 
                'Euler\'s formula', 'De Moivre\'s theorem', 'roots of complex numbers', 'powers of complex numbers', 
                'addition of complex numbers', 'subtraction of complex numbers', 'multiplication of complex numbers', 
                'division of complex numbers', '$z$', '\overline{z}', '|z|', 'arg', '\Re', '\Im'
            ]):
                return 'Complex Numbers'
            if any(word in problem_text for word in ['quadratic', 'x^2',"ax^2 + bx + c", "x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}", "y = ax^2 + bx + c", "f(x) = ax^2 + bx + c",
                                                    "y = a(x - h)^2 + k", "y = a(x - r_1)(x - r_2)", 'parabola', 
                                                    'vertex', 'roots', 'qaudratic formula', 'discriminant', 'axis of symmetry', 'vertex form']):
                return 'Quadratic Functions'
            if any(word in problem_text for word in [
                'exponential function', 'exponential equation', 'exponential growth', 'exponential decay', 'base', 'exponent', 
                'rate of growth', 'rate of decay', 'compounded', 'continuously', 'half-life', 'doubling time', 'e^', 'a e^{bx}', 
                'a \\cdot b^x', 'a e^{kt}', 'exponential model', 'exponential form', 'logarithmic form', 'inverse of exponential', 
                'logarithm', 'natural logarithm', 'ln', 'log', 'y = a e^{bx}', 'y = a b^x', 'P = P_0 e^{rt}', 'A = P (1 + r/n)^{nt}', 
                'A = P e^{rt}', 'x = log_b y', 'y = e^x', 'y = e^{-x}', 'y = 2^x', 'y = 10^x', 'y = 2^{-x}', 'y = 10^{-x}', 
                '\exp(x)', '\log(x)', '\ln(x)', '\log_{b}(x)', '\exp', '\log', '\ln', '\\exp(x)', '\\log(x)', '\\ln(x)', '\\log_{b}(x)', '\\exp', '\\log', '\\ln'
            ]):
                return 'Exponential and Logarithmic Functions'
            if any(word in problem_text for word in [
                'inequality', 'inequalities', 'absolute value', 'absolute values', 'greater than', 'less than', 'greater than or equal to', 
                'less than or equal to', 'strict inequality', 'non-strict inequality', 'compound inequality', 'linear inequality', 
                'quadratic inequality', 'rational inequality', 'absolute value inequality', 'solution set', 'interval notation', 
                'number line', 'graphing inequalities', 'solving inequalities', 'systems of inequalities', 'bounded region', 'unbounded region', 
                'critical point', 'test point', 'boundary line', 'shaded region', 'intersection of solutions', 'union of solutions', 
                'piecewise function', 'distance from zero', 'vertex form of absolute value function', '|x|', '|ax + b|'
            ]):
                return 'Inequalities and Absolute Values'
            if any(word in problem_text for word in [
                'system of non-linear equations', 'non-linear system', 'simultaneous non-linear equations', 'solve non-linear system', 
                'solution of non-linear equations', 'intersection of curves', 'non-linear substitution', 'graphical solution', 
                'Newton\'s method', 'fixed point iteration', 'Jacobian matrix', 'non-linear solver', 'algebraic system', 
                'non-linear algebraic system', 'iterative method', 'numerical method', 'quadratic system', 'cubic system', 
                'polynomial system', 'rational system', 'exponential system', 'logarithmic system', 'trigonometric system', 
                'transcendental system', 'convergence criteria', 'divergence criteria', 'initial guess', 'non-linear equations'
            ]):
                return 'Systems of Non-linear Equations'
            if any(word in problem_text for word in [
                'matrix', 'matrices', 'determinant', 'determinants', 'inverse matrix', 'identity matrix', 'transpose', 'matrix addition', 
                'matrix subtraction', 'matrix multiplication', 'scalar multiplication', 'elementary row operations', 'row echelon form', 
                'reduced row echelon form', 'Gaussian elimination', 'Gauss-Jordan elimination', 'system of linear equations', 'Cramer\'s rule', 
                'adjugate', 'adjoint', 'cofactor', 'minor', 'characteristic polynomial', 'eigenvalue', 'eigenvector', 'diagonalization', 
                'trace', 'rank', 'singular', 'nonsingular', 'orthogonal matrix', 'symmetric matrix', 'skew-symmetric matrix', 
                'Hermitian matrix', 'unitary matrix', 'positive definite matrix', 'negative definite matrix', 'matrix equation', 
                'det', '\\det'
            ]):
                return 'Matrices and Determinants'
            if any(word in problem_text for word in [
                'series', 'sequences', 'sequence', 'arithmetic sequence', 'geometric sequence', 'arithmetic series', 'geometric series',
                'finite series', 'infinite series', 'finite sequence', 'infinite sequence', 'convergence', 'divergence', 'limit of a sequence',
                'sum of a series', 'nth term', 'general term', 'partial sum', 'infinite sum', 'power series', 'Taylor series', 'Maclaurin series',
                'Fourier series', 'radius of convergence', 'interval of convergence', 'absolute convergence', 'conditional convergence',
                'ratio test', 'root test', 'integral test', 'comparison test', 'alternating series test', 'divergence test', 'p-series', 
                'harmonic series', 'telescoping series', 'recursive sequence', 'recurrence relation', 'closed form', 'summation notation',
                '\sum_{','\sum{}', '^{\\infty}', '^{\infty}', '\sum_{n=1}^{\\infty}', '\sum_{n=0}^{\\infty}', 'a_n', 'S_n'
            ]):
                return 'Sequences and Series'
            if any(word in problem_text for word in [
                'polynomial theorem', 'polynomial theorems', 'factor theorem', 'remainder theorem', 'Rolle\'s theorem', 
                'Descartes\' rule of signs', 'fundamental theorem of algebra', 'roots of a polynomial', 'zeros of a polynomial', 
                'multiplicity of a root', 'Vieta\'s formulas', 'synthetic division', 'polynomial division', 'long division', 
                'factoring polynomials', 'irreducible polynomial', 'complex roots', 'real roots', 'quadratic polynomial', 
                'cubic polynomial', 'quartic polynomial', 'polynomial function', 'rational root theorem', 'root finding', 
                'polynomial equation', 'degree of a polynomial', 'leading coefficient', 'constant term', 'coefficient', 
                'binomial theorem', 'Pascal\'s triangle'
            ]):
                return 'Polynomial Theorems'
            if any(word in problem_text for word in [
                'conic section', 'conic sections', 'parabola', 'parabolas', 'ellipse', 'ellipses', 'hyperbola', 'hyperbolas', 
                'circle', 'circles', 'focus', 'foci', 'directrix', 'directrices', 'vertex', 'vertices', 'axis of symmetry', 
                'major axis', 'minor axis', 'transverse axis', 'conjugate axis', 'asymptote', 'asymptotes', 'eccentricity', 
                'latus rectum', 'standard form', 'general form', 'equation of a circle', 'equation of a parabola', 
                'equation of an ellipse', 'equation of a hyperbola', 'x^2 + y^2 = r^2', '(x-h)^2 + (y-k)^2 = r^2', 'y = ax^2 + bx + c', 
                'y = a(x-h)^2 + k', '\frac{(x-h)^2}{a^2} + \frac{(y-k)^2}{b^2} = 1','\frac{(x-h)^2}{a^2} + \\frac{(y-k)^2}{b^2} = 1', '\frac{(x-h)^2}{a^2} - \frac{(y-k)^2}{b^2} = 1', 
                '\frac{(x-h)^2}{a^2} - \\frac{(y-k)^2}{b^2} = 1','r = e / (1 + e cos \\theta)', 'r = e / (1 + e sin \\theta)', 'r = e / (1 - e cos \\theta)', 'r = e / (1 - e sin \\theta)', 
            ]):
                return 'Conic Sections'
            if any(word in problem_text for word in [
                'binomial theorem', 'binomial expansion', 'binomial coefficients', 'Pascal\'s triangle', 'combinations', 
                'binomial expression', 'binomial series', 'expansion of (a + b)^n', 'expansion of (x + y)^n', 'general term', 
                'binomial term', 'binomial probability', 'multinomial theorem', '\binom{n}{k}', '\binom{n}{r}', '\binom{n}{m}', '\binom', '\dbinom', 
            ]):
                return 'Binomial Theorem'
        
        if category == 'geometry':
            if any(word in problem_text for word in [
                'coordinate geometry', 'analytic geometry', 'Cartesian coordinates', 'coordinate plane', 'x-axis', 'y-axis', 
                'origin', 'quadrant', 'slope', 'intercept', 'x-intercept', 'y-intercept', 'distance formula', 'midpoint formula', 
                'equation of a line', 'slope-intercept form', 'point-slope form', 'standard form', 'general form', 'distance between points', 'midpoint of a segment', 
                'collinear', 'concurrent', 'parallel lines', 'perpendicular lines', 'perpendicular distance', 'inclination of a line', 
                'angle between lines', 'intersection of lines', 'linear equation', 'system of linear equations', 'conic sections', 
                'locus', 'transformation of coordinates', 'rotation of axes', 'reflection over a line', 'translation of axes', 
                'scaling of coordinates'
            ]):
                return 'Coordinate Geometry'
            if any(word in problem_text for word in [
                'transformation', 'transformations', 'geometric transformation', 'translation', 'rotation', 'reflection', 'dilation', 
                'enlargement', 'reduction', 'scaling', 'shear', 'stretch', 'isometry', 'rigid motion', 'non-rigid transformation', 
                'symmetry', 'line of symmetry', 'center of rotation', 'angle of rotation', 'center of dilation', 'scale factor', 
                'vector translation', 'image', 'preimage', 'mapping', 'composition of transformations', 'composite transformation', 
                'coordinate transformation', 'matrix transformation', 'affine transformation', 'homothety', 'congruence transformation', 
                'similarity transformation', 'identity transformation', 'inverse transformation', 'reflection symmetry', 
                'rotational symmetry', 'translational symmetry'
            ]):
                return 'Transformations'
            if any(word in problem_text for word in [
                'area', 'perimeter', 'circumference', 'polygon', 'triangle', 'quadrilateral', 'rectangle', 'square', 'parallelogram', 
                'rhombus', 'trapezoid', 'trapezium', 'kite', 'circle', 'semicircle', 'sector', 'segment', 'ellipse', 'regular polygon', 
                'irregular polygon', 'side length', 'base', 'height', 'altitude', 'radius', 'diameter', 'apothem', 'sides', 'vertices', 
                'pi', 'π', '\\pi', 'base area', 'lateral area', 'total surface area', 'cross-sectional area', 'slant height', 'radius', 
                'diameter', 'height', 'altitude', 'side length', 'edge length', 'perpendicular height', 'circumference', 'area of base', 
                'area of lateral face', 'area of cross section', 'perimeter of a polygon', 'perimeter of a rectangle', 
                'perimeter of a square', 'perimeter of a triangle', 'area of a rectangle', 'area of a square', 'area of a triangle', 
                'area of a parallelogram', 'area of a rhombus', 'area of a trapezoid', 'area of a trapezium', 'area of a kite', 
                'area of a circle', 'area of a sector', 'area of a segment', 'area of an ellipse', 'heron\'s formula', 'area of regular polygon', 
            ]):
                return 'Area and Perimeter'
            if any(word in problem_text for word in [
                'volume', 'surface area', 'volumes', 'surface areas', '3D shapes', 'three-dimensional shapes', 'solid geometry', 
                'cylinder', 'cylinders', 'sphere', 'spheres', 'cone', 'cones', 'pyramid', 'pyramids', 'prism', 'prisms', 'cuboid', 
                'cuboids', 'cube', 'cubes', 'rectangular prism', 'rectangular prisms', 'triangular prism', 'triangular prisms', 
                'hemisphere', 'hemispheres', 'ellipsoid', 'ellipsoids', 'tetrahedron', 'tetrahedrons', 'polyhedron', 'polyhedrons', 
                'base area', 'lateral area', 'total surface area', 'cross-sectional area', 'slant height', 'radius', 'diameter', 
                'height', 'altitude', 'side length', 'edge length', 'perpendicular height', 'circumference', 'area of base', 
                'area of lateral face', 'area of total surface', 'area of cross section', 'volume of cylinder', 'volume of sphere', 
                'volume of cone', 'volume of pyramid', 'volume of prism', 'volume of cuboid', 'volume of cube', 'surface area of cylinder', 
                'surface area of sphere', 'surface area of cone', 'surface area of pyramid', 'surface area of prism', 'surface area of cuboid', 
                'surface area of cube'
            ]):
                return 'Volume and Surface Area'
            if any(word in problem_text for word in [
                'similar', 'similarity', 'congruent', 'congruence', 'AA criterion', 'SSS criterion', 'SAS criterion', 
                'ASA criterion', 'AAS criterion', 'corresponding angles', 'corresponding sides', 'proportional', 'ratio of sides', 
                'scale factor', 'geometric transformation', 'dilation', 'reflection', 'rotation', 'translation', 'enlargement', 
                'reduction', 'congruent triangles', 'similar triangles', 'similar polygons', 'congruent polygons', 
                'similar figures', 'congruent figures', 'similar shapes', 'congruent shapes', 'angle-angle similarity', 
                'side-side-side similarity', 'side-angle-side similarity', 'angle-side-angle congruence', 
                'side-angle-side congruence', 'angle-side-side congruence', 'AA', 'SSS', 'SAS', 'ASA', 'AAS'
            ]):
                return 'Similarity and Congruence'
            if any(word in problem_text for word in [
                'Pythagorean theorem', 'Pythagoras\' theorem', 'Pythagorean triple', 'Pythagorean triples', 'right triangle', 
                'right-angled triangle', 'hypotenuse', 'legs of a triangle', 'a^2 + b^2 = c^2', 'distance formula'
            ]):
                return 'Pythagorean Theorem'
            if any(word in problem_text for word in [
                'geometric proof', 'geometric proofs', 'proof by contradiction', 'proof by induction', 'proof by construction', 
                'proof by exhaustion', 'direct proof', 'indirect proof', 'two-column proof', 'paragraph proof', 'flowchart proof', 
                'coordinate proof', 'axiom', 'postulate', 'theorem', 'corollary', 'lemma', 'conjecture'
            ]):
                return 'Geometric Proofs'
    
    return category

In [None]:
CoT_examples = {
    'prealgebra': [
        {
            'problem': 'What is $7x-3 = 5x+9$.',
            'instructions': 'Arrangle all the terms in x on one side and all the constant terms on the other side. Divide both sides by the coefficients of x',
            'answer': '\\boxed{6}'
            
        },
        {
            'problem': 'Solve for x in $3x+2 \gt 8$.',
            'instructions': 'Arrangle all the terms in x on one side and all the constant terms on the other side. Divide both sides by the coefficients of x',
            'answer': '\\boxed{2}'
        }
    ],
    'algebra': [
        {
            'problem': 'Solve the cubic equation: x^3 - 6x^2 + 11x - 6 = 0. Provide the smallest root.',
            'instructions': 'Use the factor theorem and synthetic division to solve the cubic equation.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Find the smallest root of the polynomial equation: 2x^4 - 3x^3 + x^2 - 6x + 4 = 0.',
            'instructions': 'Use the rational root theorem and factorization to solve the polynomial equation.',
            'answer': '\\boxed{1}'
        }
    ],
    'precalculus': [
        {
            'problem': 'Find the derivative of f(x) = x^2.',
            'instructions': 'Use the power rule to find the derivative.',
            'answer': '\\boxed{2x}'
        },
        {
            'problem': 'Calculate the derivative of f(x) = e^x.',
            'instructions': 'Use the derivative rule for exponential functions.',
            'answer': '\\boxed{e^x}'
        }
    ],
    'geometry': [
        {
            'problem': 'Find the slope of the line passing through the points (1, 2) and (3, 4).',
            'instructions': 'Use the slope formula to find the slope of the line.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Calculate the distance between the points (1, 2) and (4, 6).',
            'instructions': 'Use the distance formula to find the distance between the points.',
            'answer': '\\boxed{5}'
        }
    ],
    'number theory': [
        {
            'problem': 'Find the remainder when 25 is divided by 4.',
            'instructions': 'Use the definition of modular arithmetic to find the remainder.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Solve for x: 2x \\equiv 3 \\pmod{5}.',
            'instructions': 'Use the properties of congruences to solve for x.',
            'answer': '\\boxed{4}'
        }
    ],
    'counting & probability': [
        {
            'problem': 'How many ways can you arrange 5 books on a shelf?',
            'instructions': 'Use the formula for permutations to find the number of arrangements.',
            'answer': '\\boxed{120}'
        },
        {
            'problem': 'How many ways can you choose 3 out of 5 students?',
            'instructions': 'Use the formula for combinations to find the number of ways.',
            'answer': '\\boxed{10}'
        }
    ],
    'intermediate algebra': [
        {
            'problem': 'Solve the quadratic equation: x^2 - 5x + 6 = 0.',
            'instructions': 'Factor the quadratic equation and find the roots. Provide the smaller root.',
            'answer': '\\boxed{2}'
        },
        {
            'problem': 'Find the vertex of the parabola given by the equation y = 2x^2 - 4x + 1. Provide the x-coordinate.',
            'instructions': 'Use the vertex formula to find the x-coordinate of the vertex.',
            'answer': '\\boxed{1}'
        }
    ],
    'calculus': [
        {
            'problem': 'Solve the differential equation: \\frac{dy}{dx} = x. Provide the value of y when x = 1 and y(0) = 0.',
            'instructions': 'Use separation of variables to solve the differential equation and find the value of y at x = 1.',
            'answer': '\\boxed{0.5}'
        },
        {
            'problem': 'Solve the differential equation: \\frac{dy}{dx} = 2. Provide the value of y when x = 2 and y(0) = 0.',
            'instructions': 'Use separation of variables to solve the differential equation and find the value of y at x = 2.',
            'answer': '\\boxed{4}'
        }
    ],
    'Basic Geometry': [
        {
            'problem': 'What is the area of a rectangle with a length of 8 units and a width of 5 units?',
            'instructions': 'Find the area of the rectangle using the formula for the area of a rectangle.',
            'answer': '\\boxed{40}'
        },
        {
            'problem': 'A circle has a radius of 7 units. Calculate its circumference.',
            'instructions': 'Use the formula for the circumference of a circle to find the answer.',
            'answer': '\\boxed{44}'
        }
    ],
    'Basic Statistics': [
        {
            'problem': 'Find the mean of the following set of numbers: 4, 8, 15, 16, 23, 42.',
            'instructions': 'Calculate the mean by adding all the numbers and dividing by the count of the numbers.',
            'answer': '\\boxed{18}'
        },
        {
            'problem': 'The following data set represents the scores of students in a test: 12, 15, 20, 20, 23, 25. Find the median.',
            'instructions': 'Arrange the data in ascending order and find the middle value(s) to calculate the median.',
            'answer': '\\boxed{20}'
        }
    ],
    'Quadratic Equations': [
        {
            'problem': 'Solve the quadratic equation: x^2 - 5x + 6 = 0.',
            'instructions': 'Factor the quadratic equation and find the roots. Provide the smaller root.',
            'answer': '\\boxed{2}'
        },
        {
            'problem': 'Find the vertex of the parabola given by the equation y = 2x^2 - 4x + 1. Provide the x-coordinate.',
            'instructions': 'Use the vertex formula to find the x-coordinate of the vertex.',
            'answer': '\\boxed{1}'
        }
    ],
    'Polynomial Equations': [
        {
            'problem': 'Solve the cubic equation: x^3 - 6x^2 + 11x - 6 = 0. Provide the smallest root.',
            'instructions': 'Use the factor theorem and synthetic division to solve the cubic equation.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Find the smallest root of the polynomial equation: 2x^4 - 3x^3 + x^2 - 6x + 4 = 0.',
            'instructions': 'Use the rational root theorem and factorization to solve the polynomial equation.',
            'answer': '\\boxed{1}'
        }
    ],
    'Inequalities': [
        {
            'problem': 'Solve the inequality: 3x - 7 > 5.',
            'instructions': 'Isolate the variable x and solve the inequality.',
            'answer': '\\boxed{4}'
        },
        {
            'problem': 'Find the solution set for the inequality: -2(x + 1) \leq 4. Provide the smallest integer in the solution set.',
            'instructions': 'Distribute and solve the inequality for x.',
            'answer': '\\boxed{-3}'
        }
    ],
    'Systems of Equations': [
        {
            'problem': 'Solve the system of equations: 2x + y = 5 and x - y = 1. Provide the x-coordinate.',
            'instructions': 'Use the substitution or elimination method to solve the system.',
            'answer': '\\boxed{2}'
        },
        {
            'problem': 'Find the solution to the system of equations: x + 2y = 3 and 3x - y = 2. Provide the x-coordinate.',
            'instructions': 'Use the elimination method to solve the system of equations.',
            'answer': '\\boxed{1}'
        }
    ],
    'Exponential and Logarithmic Functions': [
        {
            'problem': 'Solve for x: 3^x = 81.',
            'instructions': 'Express 81 as a power of 3 and solve for x.',
            'answer': '\\boxed{4}'
        },
        {
            'problem': 'Solve the equation: \\log_2(x) = 5.',
            'instructions': 'Rewrite the logarithmic equation in exponential form to solve for x.',
            'answer': '\\boxed{32}'
        }
    ],
    'Functions and Relations': [
        {
            'problem': 'Determine if the following relation is a function: \\{(1,2), (2,3), (3,4), (4,5)\\}. Provide 1 for true, 0 for false.',
            'instructions': 'Check if every input has exactly one output.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Find the domain of the function f(x) = \\frac{1}{x - 3}. Provide the value that x cannot be.',
            'instructions': 'Determine the values of x that make the function undefined.',
            'answer': '\\boxed{3}'
        }
    ],
    'Modular Arithmetic': [
        {
            'problem': 'Find the remainder when 25 is divided by 4.',
            'instructions': 'Use the definition of modular arithmetic to find the remainder.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Solve for x: 2x \\equiv 3 \\pmod{5}.',
            'instructions': 'Use the properties of congruences to solve for x.',
            'answer': '\\boxed{4}'
        }
    ],
    'GCD or LCM': [
        {
            'problem': 'Find the GCD of 56 and 98.',
            'instructions': 'Use the Euclidean algorithm to find the greatest common divisor.',
            'answer': '\\boxed{14}'
        },
        {
            'problem': 'Find the LCM of 12 and 18.',
            'instructions': 'Use the relationship between GCD and LCM to find the least common multiple.',
            'answer': '\\boxed{36}'
        }
    ],
    'Diophantine Equations': [
        {
            'problem': 'Solve the Diophantine equation: 3x + 4y = 5. Provide the x value.',
            'instructions': 'Find integer solutions to the linear Diophantine equation.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Find an integer solution for the equation: x^2 - 4y^2 = 1. Provide the x value.',
            'instructions': 'Use the method of completing the square or other algebraic techniques to find solutions.',
            'answer': '\\boxed{3}'
        }
    ],
    'Sequences and Series': [
        {
            'problem': 'Find the 10th term of the arithmetic sequence: 2, 5, 8, 11, ...',
            'instructions': 'Use the formula for the nth term of an arithmetic sequence.',
            'answer': '\\boxed{29}'
        },
        {
            'problem': 'Calculate the sum of the first 5 terms of the geometric series: 3, 9, 27, ...',
            'instructions': 'Use the formula for the sum of the first n terms of a geometric series.',
            'answer': '\\boxed{363}'
        }
    ],
    'Perfect Numbers': [
        {
            'problem': 'Is 28 a perfect number? Provide 1 for true, 0 for false.',
            'instructions': 'Verify if the sum of the proper divisors of 28 equals 28.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Find the smallest perfect number.',
            'instructions': 'Use the definition of a perfect number to find the smallest one.',
            'answer': '\\boxed{6}'
        }
    ],
    'Fermat Numbers': [
        {
            'problem': 'Is 17 a Fermat number? Provide 1 for true, 0 for false.',
            'instructions': 'Check if 17 can be expressed in the form 2^{2^n} + 1.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Find the next Fermat number after F_3 = 257.',
            'instructions': 'Use the formula for Fermat numbers to find the next one.',
            'answer': '\\boxed{65537}'
        }
    ],
    'Combinations and Permutations': [
        {
            'problem': 'How many ways can you arrange 5 books on a shelf?',
            'instructions': 'Use the formula for permutations to find the number of arrangements.',
            'answer': '\\boxed{120}'
        },
        {
            'problem': 'How many ways can you choose 3 out of 5 students?',
            'instructions': 'Use the formula for combinations to find the number of ways.',
            'answer': '\\boxed{10}'
        }
    ],
    'Probability': [
        {
            'problem': 'What is the probability of rolling a sum of 7 with two dice?',
            'instructions': 'Calculate the number of favorable outcomes divided by the total number of outcomes.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'What is the probability of drawing an ace from a standard deck of cards? Provide the probability as an integer percentage.',
            'instructions': 'Find the ratio of the number of aces to the total number of cards.',
            'answer': '\\boxed{8}'
        }
    ],
    'Card Problems': [
        {
            'problem': 'What is the probability of drawing a heart from a standard deck of cards? Provide the probability as an integer percentage.',
            'instructions': 'Determine the ratio of hearts to the total number of cards in the deck.',
            'answer': '\\boxed{25}'
        },
        {
            'problem': 'How many ways can you draw 2 cards from a standard deck?',
            'instructions': 'Use the combinations formula to find the number of ways.',
            'answer': '\\boxed{1326}'
        }
    ],
    'Trigonometric Functions and Identities': [
        {
            'problem': 'Find the value of \\sin(30°).',
            'instructions': 'Use the trigonometric identity for \\sin(30°).',
            'answer': '\\boxed{0.5}'
        },
        {
            'problem': 'Calculate the value of \\cos(60°).',
            'instructions': 'Use the trigonometric identity for \\cos(60°).',
            'answer': '\\boxed{0.5}'
        }
    ],
    'Conic Sections': [
        {
            'problem': 'Find the x-coordinate of the vertex of the parabola given by the equation y = (x - 2)^2 + 3.',
            'instructions': 'Identify the vertex form of the parabola to find the x-coordinate of the vertex.',
            'answer': '\\boxed{2}'
        },
        {
            'problem': 'Find the y-coordinate of the focus of the parabola given by the equation y = x^2.',
            'instructions': 'Use the formula for the focus of a parabola to find the y-coordinate.',
            'answer': '\\boxed{1}'
        }
    ],
    'Vectors': [
        {
            'problem': 'Find the magnitude of the vector \\vec{v} = (3, 4).',
            'instructions': 'Use the formula for the magnitude of a vector.',
            'answer': '\\boxed{5}'
        },
        {
            'problem': 'Calculate the dot product of vectors \\vec{a} = (1, 2) and \\vec{b} = (3, 4).',
            'instructions': 'Use the formula for the dot product of two vectors.',
            'answer': '\\boxed{11}'
        }
    ],
    'Complex Numbers': [
        {
            'problem': 'Find the real part of the complex number z = 3 + 4i.',
            'instructions': 'Identify the real part of the complex number.',
            'answer': '\\boxed{3}'
        },
        {
            'problem': 'Calculate the imaginary part of the complex number z = 1 + 2i.',
            'instructions': 'Identify the imaginary part of the complex number.',
            'answer': '\\boxed{2}'
        }
    ],
    'Exponential and Logarithmic Equations': [
        {
            'problem': 'Solve for x: 2^x = 16.',
            'instructions': 'Rewrite the equation with the same base and solve for x.',
            'answer': '\\boxed{4}'
        },
        {
            'problem': 'Solve for x: \\log_{10}(x) = 2.',
            'instructions': 'Rewrite the logarithmic equation in exponential form to solve for x.',
            'answer': '\\boxed{100}'
        }
    ],
    'Systems of Non-linear Equations': [
        {
            'problem': 'Solve the system of non-linear equations: y = x^2 and y = 2x + 3. Provide the x-coordinate of one intersection point.',
            'instructions': 'Find the intersection points of the given non-linear equations.',
            'answer': '\\boxed{3}'
        },
        {
            'problem': 'Find the x-coordinate of the intersection points for the system: y = x^2 and y = -x + 2.',
            'instructions': 'Solve the system of equations for the x-coordinates of the intersection points.',
            'answer': '\\boxed{1}'
        }
    ],
    'Limits': [
        {
            'problem': 'Find the limit as x approaches 2 of \\frac{x^2 - 4}{x - 2}.',
            'instructions': 'Use limit laws to evaluate the limit.',
            'answer': '\\boxed{4}'
        },
        {
            'problem': 'Evaluate the limit as x approaches infinity of \\frac{1}{x}.',
            'instructions': 'Use limit laws to evaluate the limit.',
            'answer': '\\boxed{0}'
        }
    ],
    'Derivatives': [
        {
            'problem': 'Find the derivative of f(x) = x^2.',
            'instructions': 'Use the power rule to find the derivative.',
            'answer': '\\boxed{2x}'
        },
        {
            'problem': 'Calculate the derivative of f(x) = e^x.',
            'instructions': 'Use the derivative rule for exponential functions.',
            'answer': '\\boxed{e^x}'
        }
    ],
    'Integrals': [
        {
            'problem': 'Find the integral of f(x) = x^2 from x = 0 to x = 1.',
            'instructions': 'Use the power rule to find the definite integral.',
            'answer': '\\boxed{\\frac{1}{3}}'
        },
        {
            'problem': 'Calculate the definite integral of f(x) = 2x from 0 to 2.',
            'instructions': 'Use the power rule to find the definite integral.',
            'answer': '\\boxed{4}'
        }
    ],
    'Differential Equations': [
        {
            'problem': 'Solve the differential equation: \\frac{dy}{dx} = x. Provide the value of y when x = 1 and y(0) = 0.',
            'instructions': 'Use separation of variables to solve the differential equation and find the value of y at x = 1.',
            'answer': '\\boxed{0.5}'
        },
        {
            'problem': 'Solve the differential equation: \\frac{dy}{dx} = 2. Provide the value of y when x = 2 and y(0) = 0.',
            'instructions': 'Use separation of variables to solve the differential equation and find the value of y at x = 2.',
            'answer': '\\boxed{4}'
        }
    ],
    'Sequences and Series': [
        {
            'problem': 'Find the sum of the first 10 terms of the arithmetic series: 2, 4, 6, 8, ...',
            'instructions': 'Use the formula for the sum of an arithmetic series.',
            'answer': '\\boxed{110}'
        },
        {
            'problem': 'Calculate the sum of the first 3 terms of the geometric series: 2, 6, 18, ...',
            'instructions': 'Use the formula for the sum of a geometric series.',
            'answer': '\\boxed{26}'
        }
    ],
    'Optimization Problems': [
        {
            'problem': 'Find the maximum value of the function f(x) = -x^2 + 4x.',
            'instructions': 'Use the vertex formula to find the maximum value of the quadratic function.',
            'answer': '\\boxed{4}'
        },
        {
            'problem': 'Calculate the minimum value of the function f(x) = x^2 - 4x + 7.',
            'instructions': 'Use the vertex formula to find the minimum value of the quadratic function.',
            'answer': '\\boxed{3}'
        }
    ],
    'Area Under Curves': [
        {
            'problem': 'Find the area under the curve of f(x) = x^2 from x = 0 to x = 1.',
            'instructions': 'Use definite integration to find the area under the curve.',
            'answer': '\\boxed{\\frac{1}{3}}'
        },
        {
            'problem': 'Calculate the area under the curve of f(x) = 2x from x = 0 to x = 2.',
            'instructions': 'Use definite integration to find the area under the curve.',
            'answer': '\\boxed{4}'
        }
    ],
    'Vector Calculus': [
        {
            'problem': 'Find the gradient of the function f(x) = x^2 at x = 3.',
            'instructions': 'Use the definition of the gradient to find the gradient at x = 3.',
            'answer': '\\boxed{6}'
        },
        {
            'problem': 'Calculate the divergence of the vector field \\mathbf{F} = (x, x, x) at x = 2.',
            'instructions': 'Use the definition of divergence to find the divergence at x = 2.',
            'answer': '\\boxed{3}'
        }
    ],
    'Coordinate Geometry': [
        {
            'problem': 'Find the slope of the line passing through the points (1, 2) and (3, 4).',
            'instructions': 'Use the slope formula to find the slope of the line.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Calculate the distance between the points (1, 2) and (4, 6).',
            'instructions': 'Use the distance formula to find the distance between the points.',
            'answer': '\\boxed{5}'
        }
    ],
    'Transformations': [
        {
            'problem': 'Find the new x-coordinate of the point (1, 2) after a translation of (3, 4).',
            'instructions': 'Use the translation formula to find the new x-coordinate of the point.',
            'answer': '\\boxed{4}'
        },
        {
            'problem': 'Calculate the new x-coordinate of the point (1, 2) after a rotation of 90° about the origin.',
            'instructions': 'Use the rotation formula to find the new x-coordinate of the point.',
            'answer': '\\boxed{-2}'
        }
    ],
    'Area and Perimeter': [
        {
            'problem': 'Find the perimeter of a rectangle with a length of 8 units and a width of 5 units.',
            'instructions': 'Use the perimeter formula for a rectangle to find the perimeter.',
            'answer': '\\boxed{26}'
        }
    ],
    'Volume and Surface Area': [
        {
            'problem': 'Find the volume of a cylinder with a radius of 3 units and a height of 5 units.',
            'instructions': 'Use the volume formula for a cylinder to find the volume.',
            'answer': '\\boxed{141}'
        },
        {
            'problem': 'Calculate the surface area of a sphere with a radius of 4 units.',
            'instructions': 'Use the surface area formula for a sphere to find the surface area.',
            'answer': '\\boxed{201}'
        }
    ],
    'Similarity and Congruence': [
        {
            'problem': 'Determine if two triangles with sides 3, 4, 5 and 6, 8, 10 are similar. Provide 1 for true, 0 for false.',
            'instructions': 'Use the criteria for triangle similarity to determine if the triangles are similar.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Are two triangles with angles 30°, 60°, 90° and 30°, 60°, 90° congruent? Provide 1 for true, 0 for false.',
            'instructions': 'Use the criteria for triangle congruence to determine if the triangles are congruent.',
            'answer': '\\boxed{1}'
        }
    ],
    'Pythagorean Theorem': [
        {
            'problem': 'Find the length of the hypotenuse of a right triangle with legs of 6 units and 8 units.',
            'instructions': 'Use the Pythagorean theorem to find the length of the hypotenuse.',
            'answer': '\\boxed{10}'
        },
        {
            'problem': 'Calculate the length of a leg of a right triangle if the hypotenuse is 13 units and the other leg is 5 units.',
            'instructions': 'Use the Pythagorean theorem to find the length of the leg.',
            'answer': '\\boxed{12}'
        }
    ],
    'Geometric Proofs': [
        {
            'problem': 'Prove that the sum of the interior angles of a triangle is 180 degrees. Provide 1 for true, 0 for false.',
            'instructions': 'Use the properties of triangles to prove the statement.',
            'answer': '\\boxed{1}'
        },
        {
            'problem': 'Prove that the base angles of an isosceles triangle are equal. Provide 1 for true, 0 for false.',
            'instructions': 'Use the properties of isosceles triangles to prove the statement.',
            'answer': '\\boxed{1}'
        }
    ]
}

In [None]:
class MathProblemSolver:
    def __init__(self, model_name, device):
        self.device = device
        self.model_name = model_name
        self.tokenizer, self.model = self.load_model_and_tokenizer()
        self.generation_config = self.load_generation_config()
        self.tokenizerBERT, self.modelBERT = self.load_bert_model()
        self.label_mapping = {
            0: "Algebra",
            1: "Counting & Probability",
            2: "Geometry",
            3: "Intermediate Algebra",
            4: "Number Theory",
            5: "Prealgebra",
            6: "Precalculus"
        }
        
    def load_model_and_tokenizer(self):
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32,
                device_map='auto',
                offload_folder="offload"
            )
            return tokenizer, model
        except Exception as e:
            print(f"Error loading model/tokenizer: {e}")
            raise

    def load_generation_config(self):
        try:
            generation_config = GenerationConfig.from_pretrained(self.model_name)
            generation_config.pad_token_id = self.tokenizer.eos_token_id
            return generation_config
        except Exception as e:
            print(f"Error loading generation config: {e}")
            raise
            
    def load_bert_model(self):
        modelname= '/kaggle/input/bert-classifier-math/transformers/prefinetuned/1/bert-finetuned-math-prob-classification'
        tokenizerBERT = AutoTokenizer.from_pretrained(modelname)
        modelBERT = AutoModelForSequenceClassification.from_pretrained(modelname)
        modelBERT.to(self.device)
        return tokenizerBERT, modelBERT

    def split_into_steps(self, text):
        try:
            sentences = text.split('. ')
            sentences = [s.strip() for s in sentences if s.strip()]
            steps = [f"Step {i+1}: {sentence}" for i, sentence in enumerate(sentences)]
            return steps
        except Exception as e:
            print(f"Error splitting instructions: {e}")
            return text

    def select_random_instructions(self, question, label_type):
        try:
            #df = pd.read_csv(csv_file)
            #df = df[pd.to_numeric(df['answer'], errors='coerce').notnull()]
            #df = df.drop_duplicates(subset=['problem'])
            #filtered_df = df[df['Type'] == class_label.lower()]
            #df['subcategory'] = df.apply(assign_subcategory, axis=1, subcategories=subcategories)
            #selected_rows = CoT_examples[subcategory]
            #results = [{'question': row['problem'], 'instruction': row['steps'], 'answer': row['answer']} for _, row in selected_rows.iterrows()]
            #return results
            subcategories = {
                'prealgebra': ['Basic Geometry','Basic Statistics'],
                'algebra': ['Quadratic Equations', 'Polynomial Equations', 'Inequalities', 'Systems of Equations', 'Exponential and Logarithmic Functions', 'Functions and Relations'],
                'number theory': ['Modular Arithmetic', 'GCD or LCM', 'Integer Properties', 'Diophantine Equations', 
                                  'Sequences and Series', 'Number Bases', 'Perfect Numbers', 'Fermat Numbers'],
                'counting & probability': ['Combinations and Permutations', 'Probability', 'Expected Value', 
                                             'Binomial Probability', 'Card Problems', 'Dice Problems', 'Bayes\' Theorem', 
                                             'Probability Distributions'],
                'precalculus': ['Trigonometric Functions and Identities', 'Conic Sections', 'Sequences and Series', 'Vectors', 'Complex Numbers', 
                                'Exponential and Logarithmic Functions', 'Limits (Introductory)', 'Parametric Equations'],
                'calculus': ['Limits', 'Derivatives', 'Integrals', 'Differential Equations', 'Series and Sequences', 'Multivariable Calculus', 
                             'Optimization Problems', 'Rate of Change', 'Area Under Curves', 'Vector Calculus'],
                'intermediate algebra': ['Complex Numbers', 'Quadratic Functions', 
                                         'Exponential and Logarithmic Equations', 'Inequalities and Absolute Values', 'Systems of Non-linear Equations', 
                                         'Matrices and Determinants', 'Sequences and Series', 'Polynomial Theorems', 'Conic Sections', 
                                         'Binomial Theorem'],
                'geometry': ['Coordinate Geometry', 'Transformations', 
                             'Area and Perimeter', 'Volume and Surface Area', 'Similarity and Congruence', 'Pythagorean Theorem', 
                             'Geometric Proofs']
                }
            subcategory = assign_subcategory(question1=question, label_type1=label_type, subcategories1=subcategories)
            examples = CoT_examples[f'{subcategory}']
            return examples
        except Exception as e:
            print(f"Error selecting random instructions: {e}")
            traceback.print_exc()
            return [{'problem': '...', 'instructions': 'Step 1: ..., Step 2: ...', 'answer': '\\boxed{0-99}'}]

    def generate_answer_multiple_times(self, question, label_type, num_iterations):
        exemplars = self.select_random_instructions(question, label_type)
        exemplar_texts = "\n\n".join(
            f"Q: {exemplar['problem']}\n"
            f"Steps: {self.split_into_steps(exemplar['instructions'])}\n"
            f"A: {exemplar['answer']}"
            for exemplar in exemplars
        )

        prompt = (
            'You are going to solve math problems that have a positive integer solution. Keep logic concise.'
            f"Here is a math problem you are to solve (positive numerical answer): {question}\n"
            f"This particular question is a {label_type} question.\n"
            "To solve it, first determine a series of logical steps for solving the problem and then follow these steps. It is imperative that the final answer is in the brackets of \\boxed{}. \n\n"
            "The final answer should be a positive integer and not an algebraic expression.\n"
            f"Here are some examples of how to solve similar {label_type} problems step-by-step:\n\n"
            f"{exemplar_texts}\n\n"
            "Now answer the question following the above formatting:"
        )

        answers = []
        for _ in range(num_iterations):
            try:
                inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
                with torch.no_grad():
                    outputs = self.model.generate(
                        inputs['input_ids'],
                        max_new_tokens=MAX_NEW_TOKENS,
                        num_beams=3,
                        early_stopping=True,
                        no_repeat_ngram_size=2,
                        attention_mask=inputs['attention_mask'],
                        temperature=0.2,
                        top_p=0.9,
                        top_k=20,
                        do_sample=True
                    )

                result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                match = re.search(r'Now answer the question following the above formatting:\s*(.*)', result, re.DOTALL)
                match = self.process_text_output(match.group(1))
                if match:
                    answers.append(match)
#                     answers.append(self.extract_boxed_answer(match.group(1).strip()))
            except Exception as e:
                print(f"Error during generation: {e}")
                traceback.print_exc()

        return answers   #list of int

#     def extract_boxed_answer(self, generated_text):
#         try:
#             match = re.search(r'\\boxed{([^}]*)}', generated_text)
#             if match:
#                 return match.group(1)
#         except Exception as e:
#             print(f"Error extracting boxed answer: {e}")
#         return 0  
    
    def naive_parse2(self, text):
        numbers = re.findall(r'\d+', text)
        numbers = [int(num) for num in numbers]
        print(f"Answer from Naive_Parser: {numbers}")

        if not numbers:
            return None

        frequency = {}
        for num in numbers:
            if num in frequency:
                frequency[num] += 1
            else:
                frequency[num] = 1

        max_frequency = max(frequency.values())
        most_common = [num for num, freq in frequency.items() if freq == max_frequency]
        return random.choice(most_common) if type(random.choice(most_common)) != None else 1

    def process_text_output(self, output):
        try:
           # print('Extracting text output:', output)
            result_output = re.findall(r'\\boxed\s*\{([^}]*)\}', output) #list
#             '\\boxed\{\$?[\(\[{]?\(?(\d+)\)?[\]\)}]?\$?\}'

##### BUG SOMEWHERE HERE- NOT CREATING A LIST


#             print(result_output)
#             if not result_output:  #no boxed answer given
                
#                 result_output = self.naive_parse(output)
#             else:
            result_output = [self.naive_parse2(string) for string in result_output if any(word in string for word in ['1', '2', '3', '4', '5', '6', '7', '8', '9'])]
            result_output = result_output[-1] #last instance of boxed answer, becomes a string
            result_output = result_output % 1000
        except Exception as e:
            print(f"Error during text output processing: {e}")
            traceback.print_exc()
            result_output = 1
        return result_output    # int

    def aggregate_answers(self, answers):
        answer_counts = {}
        answers = list(map(str, answers)) #convert list of ints to list of strings
#         print(answers)
        for answer in answers:
            if answer in answer_counts:
                answer_counts[answer] += 1
            else:
                answer_counts[answer] = 1
#         print(answer_counts)
        sorted_answers = sorted(answer_counts.items(), key=lambda item: item[1], reverse=True)
        return sorted_answers[0] if sorted_answers else ('1', 0)

    def manage_context_and_generate_answers(self, question, context_prompt, total_tokens, conversation, label_type):
        try:
            answers = self.generate_answer_multiple_times(question, label_type, num_iterations=N_REPETITIONS) #list of integers
            best_answer, count = self.aggregate_answers(answers)   #best_answer is a string
#             best_generated_text = best_answer
#             boxed_answer = self.process_text_output(best_answer)
            # managing context
            qna_text = f"Question: {question}\nAnswer: {best_answer}\n\n"
            conversation += qna_text
            total_tokens += len(self.tokenizer.encode(qna_text))

            if total_tokens >= 4000:
                conversation += context_prompt + "\n\n"
                total_tokens = len(self.tokenizer.encode(conversation))

#             result = {
# #                 'generated_text': answers,
#                 'int_answer': best_answer,
# #                 'total_tokens': total_tokens
#             }
            result = int(best_answer)  #convert string
        except Exception as e:
            print(e)
            traceback.print_exc()
            return 999
            
        return result #int
        
    def predict(self, text):
        try:
            inputs = self.tokenizerBERT(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
            inputs = {key: val.to(self.device) for key, val in inputs.items()}

            with torch.no_grad():
                outputs = self.modelBERT(**inputs)

            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(predictions, dim=-1).item()
            return self.label_mapping[predicted_class]
        except Exception as e:
            predicted_class = random.randint(0,6)
            return self.label_mapping[predicted_class]
    

    def flush(self):
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
#context_prompt = """
#This is a compendium of things which we would normally explain in maths competitions for teenaged schoolchildren as context for you to better understand the task at hand:

#\\begin{enumerate}
 #   \\item The final answer should be a positive integer and should be placed within $\\boxed{}$.
  #  \\item Established mathematical notation will be used.
 #   \\item We might use a colon or a vertical line as a separator in set notation so $\\{x \\mid x \\in \\mathbb{Z}, x > 0\\} = \\{y : y \\in \\mathbb{Z}, y > 0\\}$.
  #  \\item Floor and ceiling notation, so for $x$ a real number we let $\\lfloor x \\rfloor = \\max\\{z \\mid x \\in \\mathbb{Z}, z \\leq x\\}$. Similarly $\\lceil x \\rceil = \\min\\{z \\mid x \\in \\mathbb{Z}, z \\geq x\\}$.
   # \\item Fractional part notation. If $x$ is a real number, we define $\\{x\\}$ to mean $x - \\lfloor x \\rfloor$.
    #\\item We write a line over a non-negative integer written in base 10 notation to indicate that it is being viewed as a string of digits rather than a number. Thus the second digit of $\\overline{1729}$ is 7 but 1729 does not have a second digit because it is an integer.
   # \\item We allow a phrase such as "$x$ is a 3-digit positive integer" to mean that if written in Arabic notation as $x = a_m \\cdot a_{m-1} \\cdot \\ldots \\cdot a_1$ with $a_i$ all digits and $a_m \\neq 0$, then $n = m$. We allow "the sum of the digits of $n$" to mean: write $n$ in Arabic base 10 notation and then sum the digits.
  #  \\item We allow informal probability language such as: "a point is chosen uniformly at random in the interval $[0, 1]$".
  #  \\item We use $\\binom{n}{r}$ to denote the number of ways of choosing $r$ things from $n$ things.
  #  \\item The sum over the empty set is 0 and the product over the empty set is 1.
  #  \\item We use an ellipsis to denote an obvious pattern, either on the line of print of midline (as appropriate) so the set of the first $n$ positive integers can be written $\\{1, 2, \\ldots, n\\}$ and their sum is $1 + 2 + \\ldots + n$.
  #  \\item For integers $l$, $m$, $n$ then $l$ raised by $m$ which is raised by $n$ denotes $l^{(m^n)}$.
  #  \\item $m^0 = 1$ for all integers $m$ (including 0) if doing combinatorial enumeration. If $x$ is real then $x^0$ needs to be clarified if $x = 0$.
  #  \\item British or American versions of English can be used. Thus "highest common factor" means the same as "greatest common divisor".
   # \\item A prefix or subscript may be used to indicate features of a triangle associated with vertices. Thus triangle $ABC$ has three altitudes, and the one dropped from $A$ could be denoted the altitude through $A$, the $A$-altitude or the altitude $h_a$. Similarly for median lines.
  #  \\item If the term natural number is used, then it will be made clear if 0 is a natural number.
   # \\item $:=$ means 'is defined to be equal to'.
#\\end{enumerate>"""

context_prompt = " "

total_tokens = 0
full_conversation = context_prompt + "\n\n"

# Set up the evaluation API
#import aimo

#env = aimo.make_env()
#iter_test = env.iter_test()

#inputfile= '/kaggle/input/prmclean/prmclean.csv'
solver = MathProblemSolver(MODEL_NAME, DEVICE)

In [None]:
#input_directory = '/kaggle/input'
#for dirname, _, filenames in os.walk(input_directory):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

In [None]:
import pandas as pd
df = pd.read_csv('/kaggle/input/ai-mathematical-olympiad-prize/train.csv')
DEBUG_ITER = []
for index, row in df.iterrows():
    DEBUG_ITER.append([str(row["problem"]), 0])

In [None]:
# TIME_LIMIT = 31500 #for testing on 3 questions
# Iterate through the test set and use the model to make predictions

if DEBUG:
    iter_test = DEBUG_ITER
#     print(DEBUG_ITER)

for test, sample_submission in iter_test:
    TIME_SPENT = time.time() - NOTEBOOK_START_TIME
    
    if TIME_SPENT>TIME_LIMIT:
        sample_submission['answer'] = 0
        env.predict(sample_submission)
        break
    if not DEBUG:
        problem = str(test['problem'].values[0])
    else:
        problem = test
    print(problem)
    predtype = solver.predict(problem).lower()

    result = solver.manage_context_and_generate_answers(problem, context_prompt, total_tokens, full_conversation, predtype)
    print(result)
    
    if not DEBUG:
        sample_submission['answer'] = result #['int_answer']
        env.predict(sample_submission)
    