In [None]:
%pip install matplotlib numpy datasets transformers

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the MATH dataset
dataset = load_dataset("DigitalLearningGmbH/MATH-lighteval", "default")
print(f"Dataset loaded with {len(dataset['train'])} training examples")

# Load the Qwen tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B")
print(f"Tokenizer loaded: {tokenizer.name_or_path}")

# Function to calculate token lengths
def get_token_lengths(examples, batch_size=100):
    token_lengths = []
    
    for i in range(0, len(examples), batch_size):
        batch = examples[i:i+batch_size]
        questions = [example['question'] for example in batch]
        
        # Tokenize the questions
        encodings = tokenizer(questions, truncation=False, padding=False)
        
        # Get the length of each tokenized question
        batch_lengths = [len(encoding) for encoding in encodings.input_ids]
        token_lengths.extend(batch_lengths)
        
        if i % 1000 == 0:
            print(f"Processed {i}/{len(examples)} examples")
    
    return token_lengths

# Calculate token lengths for the training set
print("Calculating token lengths for training set...")
token_lengths = get_token_lengths(dataset['train'])

# Calculate statistics
avg_length = np.mean(token_lengths)
median_length = np.median(token_lengths)
max_length = np.max(token_lengths)
min_length = np.min(token_lengths)
p90 = np.percentile(token_lengths, 90)
p95 = np.percentile(token_lengths, 95)
p99 = np.percentile(token_lengths, 99)

print(f"Token length statistics:")
print(f"  Average: {avg_length:.2f}")
print(f"  Median: {median_length}")
print(f"  Min: {min_length}")
print(f"  Max: {max_length}")
print(f"  90th percentile: {p90}")
print(f"  95th percentile: {p95}")
print(f"  99th percentile: {p99}")

# Plot histogram of token lengths
plt.figure(figsize=(10, 6))
plt.hist(token_lengths, bins=50, alpha=0.7)
plt.axvline(x=avg_length, color='r', linestyle='--', label=f'Mean: {avg_length:.2f}')
plt.axvline(x=median_length, color='g', linestyle='--', label=f'Median: {median_length}')
plt.axvline(x=p95, color='orange', linestyle='--', label=f'95th percentile: {p95}')
plt.axvline(x=512, color='purple', linestyle='--', label='512 tokens')
plt.xlabel('Number of tokens')
plt.ylabel('Frequency')
plt.title('Distribution of Qwen token lengths for MATH-lighteval questions')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Check how many examples exceed common context window sizes
exceed_512 = sum(length > 512 for length in token_lengths)
exceed_1024 = sum(length > 1024 for length in token_lengths)
exceed_2048 = sum(length > 2048 for length in token_lengths)

print(f"Examples exceeding token limits:")
print(f"  > 512 tokens: {exceed_512} ({exceed_512/len(token_lengths)*100:.2f}%)")
print(f"  > 1024 tokens: {exceed_1024} ({exceed_1024/len(token_lengths)*100:.2f}%)")
print(f"  > 2048 tokens: {exceed_2048} ({exceed_2048/len(token_lengths)*100:.2f}%)")
