In [15]:
import pandas as pd

data = [
    {"Incorrect Math Meme": "8 ÷ 2(2+2) = 1", "Correct Answer": "16"},
    {"Incorrect Math Meme": "6 ÷ 2(1+2) = ?", "Correct Answer": "9"},
    {"Incorrect Math Meme": "9 - 3 ÷ 1/3 + 1 = ?", "Correct Answer": "3"},
    {"Incorrect Math Meme": "7 + 7 ÷ 7 + 7 × 7 - 7 = ?", "Correct Answer": "50"},
    {"Incorrect Math Meme": "50 - 5 × 10 = ?", "Correct Answer": "0"},
    {"Incorrect Math Meme": "4 ÷ 2(2) = ?", "Correct Answer": "4"},
    {"Incorrect Math Meme": "1 + 1 × 0 + 1 = ?", "Correct Answer": "2"},
    {"Incorrect Math Meme": "2 + 2 × 2 + 2 ÷ 2 = ?", "Correct Answer": "6"},
    {"Incorrect Math Meme": "10 ÷ 5 × 2 = ?", "Correct Answer": "4"},
    {"Incorrect Math Meme": "15 ÷ 3 × 5 = ?", "Correct Answer": "25"},
    {"Incorrect Math Meme": "(6 + 3) ÷ 3 + 2 × 4 = ?", "Correct Answer": "10"},
    {"Incorrect Math Meme": "18 ÷ 2(3+3) = ?", "Correct Answer": "3"},
    {"Incorrect Math Meme": "5 + 5 × 5 - 5 = ?", "Correct Answer": "25"},
    {"Incorrect Math Meme": "12 ÷ 4 × 2 = ?", "Correct Answer": "6"},
    {"Incorrect Math Meme": "100 ÷ 10 × 10 = ?", "Correct Answer": "100"},
    {"Incorrect Math Meme": "20 - 3 × 4 = ?", "Correct Answer": "8"},
    {"Incorrect Math Meme": "9 ÷ 3 + 3 × 3 = ?", "Correct Answer": "12"},
    {"Incorrect Math Meme": "7 × (3 + 2) - 10 ÷ 5 = ?", "Correct Answer": "31"},
    {"Incorrect Math Meme": "2(3 + 4) = ?", "Correct Answer": "14"},
    {"Incorrect Math Meme": "6 ÷ 2 × 3 = ?", "Correct Answer": "9"},
]
df = pd.DataFrame(data)
df.to_csv("math_meme_corrections.csv", index=False)

In [8]:
import torch
from torch.utils.data import Dataset
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from typing import List, Tuple
import re
import sympy
import pandas as pd

# Custom Dataset Class
class MathMemeDataset(Dataset):
    def __init__(self, memes: List[str], labels: List[int], explanations: List[str], tokenizer, max_length: int = 128):
        self.memes = memes
        self.labels = labels  # 0 for incorrect, 1 for correct
        self.explanations = explanations
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.memes)

    def __getitem__(self, idx):
        meme = str(self.memes[idx])
        encoding = self.tokenizer(
            meme,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'explanation': self.explanations[idx]
        }

# Math Meme Repair Model
class MathMemeRepair:
    def __init__(self):
        # Determine device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Load pre-trained tokenizer and model
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
        self.model.to(self.device)  # Move model to appropriate device

        # Load training data from CSV
        df = pd.read_csv('math_meme_corrections.csv')  # Adjust path if needed
        self.training_data = []
        for _, row in df.iterrows():
            meme = f"{row['Incorrect Math Meme']}?"
            correct_answer = row['Correct Answer']
            # Extract expression and claimed result
            expr = row['Incorrect Math Meme'].split('=')[0].strip()
            try:
                claimed_result = float(row['Incorrect Math Meme'].split('=')[1].strip().replace('?', ''))
            except:
                claimed_result = None  # Handle cases where no result is provided
            evaluated_result = self.evaluate_expression(expr)
            label = 1 if claimed_result is not None and abs(evaluated_result - claimed_result) < 1e-6 else 0  # 1 if correct, 0 if incorrect
            explanation = f"Correct: {expr} = {correct_answer}. Follow PEMDAS: Parentheses, Exponents, Multiplication/Division, Addition/Subtraction"
            self.training_data.append((meme, label, explanation))

        self.memes, self.labels, self.explanations = zip(*self.training_data)

    def fine_tune_model(self):
        """Fine-tune the pre-trained model"""
        # Prepare dataset
        train_dataset = MathMemeDataset(self.memes, self.labels, self.explanations, self.tokenizer)

        # Training arguments
        training_args = TrainingArguments(
            output_dir='./results',
            num_train_epochs=3,
            per_device_train_batch_size=8,
            warmup_steps=10,
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=10,
            save_steps=50,
            save_total_limit=1,
            fp16=torch.cuda.is_available()  # Use mixed precision if GPU available
        )

        # Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset
        )

        # Fine-tune the model
        trainer.train()
        self.model.save_pretrained('./fine_tuned_math_meme_model')
        self.tokenizer.save_pretrained('./fine_tuned_math_meme_model')

    def evaluate_expression(self, expression: str) -> float:
        """Evaluates the correct result using SymPy with proper symbol replacement"""
        try:
            # Replace special characters with standard operators
            expr = expression.replace('^', '**').replace('=', '').strip()
            expr = expr.replace('÷', '/').replace('×', '*')  # Convert ÷ to / and × to *
            expr = re.sub(r'(\d+)\(', r'\1*(', expr)  # Handle implicit multiplication
            return float(sympy.sympify(expr))
        except Exception as e:
            print(f"Error evaluating expression: {e}")
            return None

    def analyze_meme(self, meme: str) -> Tuple[str, str]:
        """Analyze and correct a math meme"""
        # Tokenize input and move to the same device as model
        inputs = self.tokenizer(meme, return_tensors='pt', padding=True, truncation=True, max_length=128)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}  # Move inputs to GPU/CPU

        # Predict if it's correct or incorrect
        with torch.no_grad():
            outputs = self.model(**inputs)
            prediction = torch.argmax(outputs.logits, dim=1).item()

        # Extract expression and claimed result
        try:
            expression, claimed_result = meme.split('=')
            claimed_result = float(claimed_result.strip().replace('?', ''))
            correct_result = self.evaluate_expression(expression)

            if correct_result is None:
                return "Invalid expression", "Cannot evaluate"

            if prediction == 0:  # Incorrect
                # Find similar example for explanation
                for train_meme, _, exp in self.training_data:
                    if similar_pattern(expression, train_meme.split('=')[0]):
                        return (f"Wrong: {meme} → Correct: {expression} = {correct_result}", exp)
                return (f"Wrong: {meme} → Correct: {expression} = {correct_result}",
                       "Follow PEMDAS: Parentheses, Exponents, Multiplication/Division, Addition/Subtraction")
            return (f"{meme}", "This is already correct!")
        except:
            return "Error parsing meme", "Cannot process"

def similar_pattern(expr1: str, expr2: str) -> bool:
    """Check if two expressions have similar operator patterns"""
    ops1 = set(re.findall(r'[+\-×÷^()]', expr1))
    ops2 = set(re.findall(r'[+\-×÷^()]', expr2))
    return ops1 == ops2

In [9]:
# Training Block
repairer = MathMemeRepair()
repairer.fine_tune_model()
print("Model training completed and saved to './fine_tuned_math_meme_model'")



Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/9 [00:00<?, ?it/s]

{'train_runtime': 18.4209, 'train_samples_per_second': 3.257, 'train_steps_per_second': 0.489, 'train_loss': 0.6104948255750868, 'epoch': 3.0}
Model training completed and saved to './fine_tuned_math_meme_model'


In [10]:
# Testing Block
# Initialize the model (can load pre-trained if available)
repairer = MathMemeRepair()

# Optional: Load the fine-tuned model if it exists (uncomment if you want to use a pre-trained model)
try:
    repairer.model = DistilBertForSequenceClassification.from_pretrained('./fine_tuned_math_meme_model')
    repairer.tokenizer = DistilBertTokenizer.from_pretrained('./fine_tuned_math_meme_model')
    repairer.model.to(repairer.device)
    print("Loaded pre-trained model from './fine_tuned_math_meme_model'")
except:
    print("No pre-trained model found, using initialized model (results may be inaccurate without training)")

# Test cases
test_memes = [
    "8 + 4 ÷ 2 × 3 = 18?",
    "2^3 + 5 × 2 = 26?",
    "10 - 2 × 3 + 4 = 8?",
    "5 × (2 + 3) ÷ 5 = 1?",
    "9 ÷ 3^2 + 1 = 2?"
]

print("Testing Math Meme Repair Model:")
print("=================================")
for meme in test_memes:
    result, explanation = repairer.analyze_meme(meme)
    print(f"Input: {meme}")
    print(f"Result: {result}")
    print(f"Explanation: {explanation}")
    print("---")
print("Funny Error Rating: 90% Math Sass, 10% Patience")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded pre-trained model from './fine_tuned_math_meme_model'
Testing Math Meme Repair Model:
Input: 8 + 4 ÷ 2 × 3 = 18?
Result: Wrong: 8 + 4 ÷ 2 × 3 = 18? → Correct: 8 + 4 ÷ 2 × 3  = 14.0
Explanation: Correct: 2 + 2 × 2 + 2 ÷ 2 = 6. Follow PEMDAS: Parentheses, Exponents, Multiplication/Division, Addition/Subtraction
---
Input: 2^3 + 5 × 2 = 26?
Result: Wrong: 2^3 + 5 × 2 = 26? → Correct: 2^3 + 5 × 2  = 18.0
Explanation: Follow PEMDAS: Parentheses, Exponents, Multiplication/Division, Addition/Subtraction
---
Input: 10 - 2 × 3 + 4 = 8?
Result: Wrong: 10 - 2 × 3 + 4 = 8? → Correct: 10 - 2 × 3 + 4  = 8.0
Explanation: Correct: 5 + 5 × 5 - 5 = 25. Follow PEMDAS: Parentheses, Exponents, Multiplication/Division, Addition/Subtraction
---
Input: 5 × (2 + 3) ÷ 5 = 1?
Result: Wrong: 5 × (2 + 3) ÷ 5 = 1? → Correct: 5 × (2 + 3) ÷ 5  = 5.0
Explanation: Correct: (6 + 3) ÷ 3 + 2 × 4 = 10. Follow PEMDAS: Parentheses, Exponents, Multiplication/Division, Addition/Subtraction
---
Input: 9 ÷ 3^2 + 1 = 2?
Re