In [6]:
import os
import sys

# For Jupyter notebook, we need to set the base directory manually
# Assuming you're running this from a notebook in your project
BASE_DIR = os.getcwd()
# If you're in a subdirectory, you might need to go up a few levels
# Uncomment and modify as needed:
# BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), '../..'))

# Add parent directory to path for importing modules
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# Import data loading functions
# Make sure these imports work from your notebook location
try:
    from utils.data import load_asdiv_data, load_paramawps_data, load_svamp_data, load_aqua_data, load_dmath_data
except ImportError:
    print("Error importing data utilities. Please check your path and make sure you're running this from the correct directory.")
    print(f"Current working directory: {os.getcwd()}")
    print(f"sys.path: {sys.path}")
    raise

# Try to find the correct root directory
# Start with BASE_DIR and check a few levels up
data_root = None
possible_roots = [
    BASE_DIR,
    os.path.abspath(os.path.join(BASE_DIR, '..')),
    os.path.abspath(os.path.join(BASE_DIR, '../..')),
    os.path.abspath(os.path.join(BASE_DIR, '../../..'))
]

for root in possible_roots:
    test_path = os.path.join(root, "data")
    if os.path.exists(test_path):
        data_root = root
        print(f"Found data directory at: {test_path}")
        break

if data_root is None:
    print("Could not find data directory. Please specify the path manually.")
    data_root = BASE_DIR  # Default to current directory

# Define the data paths relative to data_root
data_paths = [
    os.path.join(data_root, "data", "curriculum_learning", "1_ASDiv", "ASDiv.xml"),
    os.path.join(data_root, "data", "curriculum_learning", "2_ParaMAWPS", "ParaMAWPS_trainset.json"),
    os.path.join(data_root, "data", "curriculum_learning", "3_SVAMP", "SVAMP.json"),
    os.path.join(data_root, "data", "curriculum_learning", "4_Dmath", "dmath_train.json"),
    os.path.join(data_root, "data", "curriculum_learning", "5_AQuA", "AQuA_train.json")  # Added the specific JSON file
]

# Dataset names for logging
dataset_names = ["ASDiv", "ParaMAWPS", "SVAMP", "DMath", "AQuA"]

# Dictionary mapping dataset file paths to their respective loading functions
data_loaders = {
    data_paths[0]: load_asdiv_data,
    data_paths[1]: load_paramawps_data,
    data_paths[2]: load_svamp_data,
    data_paths[3]: load_dmath_data,
    data_paths[4]: load_aqua_data
}

def print_sample(data, dataset_name):
    if data and len(data) > 0:
        print(f"\n===== SAMPLE FROM {dataset_name} =====")
        print(data[0])  # Simply print the first data item
        print("====================================\n")
    else:
        print(f"\nNo samples available from {dataset_name}\n")

# Verify paths exist before attempting to load
for data_path in data_paths:
    if os.path.exists(data_path) or (data_path == data_paths[4] and os.path.isdir(data_path)):
        print(f"✓ Path exists: {data_path}")
    else:
        print(f"✗ Path does not exist: {data_path}")
        
print("\n")

# Load and print a sample from each dataset
for data_path, dataset_name in zip(data_paths, dataset_names):
    print(f"Loading {dataset_name}...")
    try:
        # Get the appropriate loader function for this dataset
        loader_func = data_loaders[data_path]
        
        # Load the dataset
        data = loader_func(data_path)
        
        # Print a sample
        print_sample(data, dataset_name)
    except Exception as e:
        print(f"Error loading {dataset_name}: {str(e)}")
        import traceback
        traceback.print_exc()

Found data directory at: /work/math-reasoning-in-language-models/data
✓ Path exists: /work/math-reasoning-in-language-models/data/curriculum_learning/1_ASDiv/ASDiv.xml
✓ Path exists: /work/math-reasoning-in-language-models/data/curriculum_learning/2_ParaMAWPS/ParaMAWPS_trainset.json
✓ Path exists: /work/math-reasoning-in-language-models/data/curriculum_learning/3_SVAMP/SVAMP.json
✓ Path exists: /work/math-reasoning-in-language-models/data/curriculum_learning/4_Dmath/dmath_train.json
✓ Path exists: /work/math-reasoning-in-language-models/data/curriculum_learning/5_AQuA/AQuA_train.json


Loading ASDiv...
Loaded 2305 problems from ASDiv

===== SAMPLE FROM ASDiv =====
{'text': 'Question: Seven red apples and two green apples are in the basket. How many apples are in the basket?\nSolution: 7+2=9\nAnswer: 9 (apples)'}

Loading ParaMAWPS...
Loaded 13023 problems from ParaMAWPS

===== SAMPLE FROM ParaMAWPS =====
{'text': 'Question: Bryan took a look at his books as well . If Bryan has 56 books

In [12]:
def format_asdiv(item):
    """Format ASDiv dataset items"""
    text = item['text']
    parts = text.split('Question: ')[1].split('\nSolution:')
    question = parts[0].strip()
    
    solution_answer_parts = parts[1].split('\nAnswer:')
    solution = solution_answer_parts[0].strip()
    answer = solution_answer_parts[1].strip()
    
    formatted_answer = f"Let me solve this step by step.\n{solution}\nTherefore, the answer is {answer}."
    
    return {
        'question': question,
        'answer': formatted_answer
    }

def format_paramawps(item):
    """Format ParaMAWPS dataset items"""
    text = item['text']
    parts = text.split('Question: ')[1].split('\nEquation:')
    question = parts[0].strip()
    
    equation_answer_parts = parts[1].split('\nAnswer:')
    equation = equation_answer_parts[0].strip()
    answer = equation_answer_parts[1].strip()
    
    formatted_answer = f"Let me solve this step by step.\n{equation}\nTherefore, the answer is {answer}."
    
    return {
        'question': question,
        'answer': formatted_answer
    }

def format_svamp(item):
    """Format SVAMP dataset items"""
    text = item['text']
    parts = text.split('Question: ')[1].split('\nEquation:')
    question = parts[0].strip()
    
    equation_answer_parts = parts[1].split('\nAnswer:')
    equation = equation_answer_parts[0].strip()
    answer = equation_answer_parts[1].strip()
    
    formatted_answer = f"Let me solve this step by step.\n{equation}\nTherefore, the answer is {answer}."
    
    return {
        'question': question,
        'answer': formatted_answer
    }

def format_dmath(item):
    """Format DMath dataset items"""
    text = item['text']
    parts = text.split('Question: ')[1].split('\nSolution:')
    question = parts[0].strip()
    
    solution_answer_parts = parts[1].split('\nAnswer:')
    solution = solution_answer_parts[0].strip()
    answer = solution_answer_parts[1].strip()
    
    formatted_answer = f"Let me solve this step by step.\n{solution}\nTherefore, the answer is {answer}."
    
    return {
        'question': question,
        'answer': formatted_answer
    }

def format_aqua(item):
    """Format AQuA dataset items"""
    text = item['text']
    
    # Extract question and options
    if 'Options:' in text:
        question_options = text.split('Question: ')[1].split('Rationale:')[0].strip()
        
        # Split into question and options
        question_parts = question_options.split('Options:')
        question = question_parts[0].strip()
        options = question_parts[1].strip() if len(question_parts) > 1 else ""
        
        # Format the question with options
        question = f"{question}\nOptions:\n{options}"
    else:
        # Fallback if format is different
        question = text.split('Question: ')[1].split('Rationale:')[0].strip()
    
    # Extract rationale and answer
    if 'Rationale:' in text and 'Answer:' in text:
        rationale = text.split('Rationale:')[1].split('Answer:')[0].strip()
        answer = text.split('Answer:')[1].strip()
    else:
        # Fallback
        rationale = ""
        answer = text.split('Answer:')[1].strip() if 'Answer:' in text else ""
    
    formatted_answer = f"Let me solve this step by step.\n{rationale}\nTherefore, the answer is {answer}."
    
    return {
        'question': question,
        'answer': formatted_answer
    }

# Function to standardize all datasets
def standardize_datasets(data_paths, dataset_names, data_loaders):
    all_standardized_data = []
    
    for data_path, dataset_name in zip(data_paths, dataset_names):
        try:
            print(f"\nStandardizing {dataset_name}...")
            
            # Load the dataset
            loader_func = data_loaders[data_path]
            data = loader_func(data_path)
            
            # Select the appropriate formatting function
            if dataset_name == "ASDiv":
                format_func = format_asdiv
            elif dataset_name == "ParaMAWPS":
                format_func = format_paramawps
            elif dataset_name == "SVAMP":
                format_func = format_svamp
            elif dataset_name == "DMath":
                format_func = format_dmath
            elif dataset_name == "AQuA":
                format_func = format_aqua
            else:
                print(f"No formatting function for {dataset_name}")
                continue
            
            # Format the data
            standardized_data = []
            for item in data:
                try:
                    formatted_item = format_func(item)
                    standardized_data.append(formatted_item)
                except Exception as e:
                    print(f"Error formatting item in {dataset_name}: {e}")
                    continue
            
            # Print an example
            if standardized_data:
                print(f"Formatted {len(standardized_data)} examples from {dataset_name}")
                print(f"\n===== STANDARDIZED EXAMPLE FROM {dataset_name} =====")
                example = standardized_data[0]
                formatted_text = f"### Question: {example['question']}\n### Answer: {example['answer']}"
                print(formatted_text)
                print("====================================\n")
            
            # Add to combined dataset
            all_standardized_data.extend(standardized_data)
            
        except Exception as e:
            print(f"Error processing {dataset_name}: {e}")
            import traceback
            traceback.print_exc()
    
    print(f"\nTotal combined examples: {len(all_standardized_data)}")
    return all_standardized_data

# Call the function to process all datasets
combined_data = standardize_datasets(data_paths, dataset_names, data_loaders)


Standardizing ASDiv...
Loaded 2305 problems from ASDiv
Formatted 2305 examples from ASDiv

===== STANDARDIZED EXAMPLE FROM ASDiv =====
### Question: Seven red apples and two green apples are in the basket. How many apples are in the basket?
### Answer: Let me solve this step by step.
7+2=9
Therefore, the answer is 9 (apples).


Standardizing ParaMAWPS...
Loaded 13023 problems from ParaMAWPS
Formatted 13023 examples from ParaMAWPS

===== STANDARDIZED EXAMPLE FROM ParaMAWPS =====
### Question: Bryan took a look at his books as well . If Bryan has 56 books in each of his 9 bookshelves , how many books does he have in total ?
### Answer: Let me solve this step by step.
x=56*9
Therefore, the answer is 504.0.


Standardizing SVAMP...
Loaded 1000 problems from SVAMP
Formatted 1000 examples from SVAMP

===== STANDARDIZED EXAMPLE FROM SVAMP =====
### Question: Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack How much do you have to pay to buy each pack?
###

In [54]:
# Function to standardize all datasets but keep them separate
def standardize_datasets(data_paths, dataset_names, data_loaders):
    # Create a dictionary to store each dataset separately
    standardized_datasets = {}
    
    for data_path, dataset_name in zip(data_paths, dataset_names):
        try:
            print(f"\nStandardizing {dataset_name}...")
            
            # Load the dataset
            loader_func = data_loaders[data_path]
            data = loader_func(data_path)
            
            # Select the appropriate formatting function
            if dataset_name == "ASDiv":
                format_func = format_asdiv
            elif dataset_name == "ParaMAWPS":
                format_func = format_paramawps
            elif dataset_name == "SVAMP":
                format_func = format_svamp
            elif dataset_name == "DMath":
                format_func = format_dmath
            elif dataset_name == "AQuA":
                format_func = format_aqua
            else:
                print(f"No formatting function for {dataset_name}")
                continue
            
            # Format the data
            standardized_data = []
            for item in data:
                try:
                    formatted_item = format_func(item)
                    standardized_data.append(formatted_item)
                except Exception as e:
                    print(f"Error formatting item in {dataset_name}: {e}")
                    continue
            
            # Print an example
            if standardized_data:
                print(f"Formatted {len(standardized_data)} examples from {dataset_name}")
                print(f"\n===== STANDARDIZED EXAMPLE FROM {dataset_name} =====")
                example = standardized_data[0]
                formatted_text = f"### Question: {example['question']}\n### Answer: {example['answer']}"
                print(formatted_text)
                print("====================================\n")
            
            # Store dataset separately in the dictionary
            standardized_datasets[dataset_name] = standardized_data
            
        except Exception as e:
            print(f"Error processing {dataset_name}: {e}")
            import traceback
            traceback.print_exc()
    
    # Print stats for each dataset
    print("\nDataset statistics:")
    total_examples = 0
    for name, dataset in standardized_datasets.items():
        print(f"{name}: {len(dataset)} examples")
        total_examples += len(dataset)
    
    print(f"\nTotal examples across all datasets: {total_examples}")
    
    return standardized_datasets

# Call the function to process all datasets
dataset_dict = standardize_datasets(data_paths, dataset_names, data_loaders)

# Now you have access to each dataset separately:
asdiv_data = dataset_dict["ASDiv"]
paramawps_data = dataset_dict["ParaMAWPS"]
svamp_data = dataset_dict["SVAMP"]
dmath_data = dataset_dict["DMath"]
aqua_data = dataset_dict["AQuA"]


Standardizing ASDiv...
Loaded 2305 problems from ASDiv
Formatted 2305 examples from ASDiv

===== STANDARDIZED EXAMPLE FROM ASDiv =====
### Question: Seven red apples and two green apples are in the basket. How many apples are in the basket?
### Answer: Let me solve this step by step.
7+2=9
Therefore, the answer is 9 (apples).


Standardizing ParaMAWPS...
Loaded 13023 problems from ParaMAWPS
Formatted 13023 examples from ParaMAWPS

===== STANDARDIZED EXAMPLE FROM ParaMAWPS =====
### Question: Bryan took a look at his books as well . If Bryan has 56 books in each of his 9 bookshelves , how many books does he have in total ?
### Answer: Let me solve this step by step.
x=56*9
Therefore, the answer is 504.0.


Standardizing SVAMP...
Loaded 1000 problems from SVAMP
Formatted 1000 examples from SVAMP

===== STANDARDIZED EXAMPLE FROM SVAMP =====
### Question: Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack How much do you have to pay to buy each pack?
###

In [35]:
import os
import torch
import wandb
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig
from datasets import Dataset

In [15]:
run = wandb.init(project="gpt2-math", name="curriculum-learning-sft")
artifact = run.use_artifact('master_thesis_math_lm/gpt2-math/gpt2-math-model:v0', type='model')
artifact_dir = artifact.download()

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
model = AutoModelForCausalLM.from_pretrained(artifact_dir)

0,1
model_total_params,▁
model_trainable_params,▁

0,1
model_total_params,124439808
model_trainable_params,124439808


[34m[1mwandb[0m: Downloading large artifact gpt2-math-model:v0, 479.31MB. 8 files... 
[34m[1mwandb[0m:   8 of 8 files downloaded.  
Done. 0:0:1.2


In [16]:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [64]:
def formatting_prompts_func_old(examples):
    output_texts = []
    for example in examples:  # Limit to 5 samples
        text = f"### Question: {example['question']}\n### Answer: {example['answer']}"
        output_texts.append(text)
    return output_texts

def formatting_prompts_func(examples):
    formatted_data = []
    for example in examples[:5]:  # Limit to 5 samples
        formatted_data.append({
            "instruction": example["question"],
            "output": example["answer"]
        })
    return formatted_data

asdiv_formatted = formatting_prompts_func(asdiv_data)
paramawps_formatted = formatting_prompts_func(paramawps_data)
svamp_formatted = formatting_prompts_func(svamp_data)
dmath_formatted = formatting_prompts_func(dmath_data)
aqua_formatted = formatting_prompts_func(aqua_data)

asdiv_small = asdiv_formatted[:5]
asdiv_small

[{'instruction': 'Seven red apples and two green apples are in the basket. How many apples are in the basket?',
  'output': 'Let me solve this step by step.\n7+2=9\nTherefore, the answer is 9 (apples).'},
 {'instruction': 'Ellen has six more balls than Marin. Marin has nine balls. How many balls does Ellen have?',
  'output': 'Let me solve this step by step.\n6+9=15\nTherefore, the answer is 15 (balls).'},
 {'instruction': 'Janet has nine oranges and Sharon has seven oranges. How many oranges do Janet and Sharon have together?',
  'output': 'Let me solve this step by step.\n9+7=16\nTherefore, the answer is 16 (oranges).'},
 {'instruction': 'Allan brought two balloons and Jake brought four balloons to the park. How many balloons did Allan and Jake have in the park?',
  'output': 'Let me solve this step by step.\n2+4=6\nTherefore, the answer is 6 (balloons).'},
 {'instruction': 'Adam has five more apples than Jackie. Jackie has nine apples. How many apples does Adam have?',
  'output': '

In [80]:
import pandas as pd
df = pd.DataFrame(asdiv_small)

In [81]:
dataset = Dataset.from_pandas(df)

In [82]:
def prepare_datasets(example):
    example['prompt'] = f"""<|system|>
    You are a intelligent chatbot and expertise in Mathematics.</s>
    <|user|>
    {example['instruction']}.
    <|assistant|>
    {example['output']}"""
    return example

def tokenize_datasets(dataset):
    tokenized_dataset = dataset.map(
      lambda example: tokenizer(
          example['prompt'],
          truncation=True,
          max_length=512,
          ),
      batched=True,
      remove_columns=['prompt'])
    return tokenized_dataset

In [83]:
dataset = dataset.map(
    prepare_datasets, remove_columns=['instruction', 'output']
)
#dataset = dataset.shuffle(42).select(range(395000)).train_test_split(test_size=0.1, seed=42)

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [88]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

batch_size = 2
max_steps = 100

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=SFTConfig(
        output_dir="./models/mathgpt2sft/",
        gradient_accumulation_steps=batch_size,
        #evaluation_strategy="steps",
        do_eval=True,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        log_level="debug",
        save_strategy="no",
        save_total_limit=2,
        save_safetensors=False,
        fp16=True,
        logging_steps=50,
        learning_rate=2e-5,
        eval_steps=50,
        max_steps=max_steps,
        warmup_steps=30,
        dataset_text_field="prompt",
        lr_scheduler_type="cosine"
    ),
    data_collator=data_collator
)


Converting train dataset to ChatML:   0%|          | 0/5 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend


In [89]:
trainer.train()


Currently training with a batch size of: 2
The following columns in the training set don't have a corresponding argument in `GPT2LMHeadModel.forward` and have been ignored: prompt. If prompt are not expected by `GPT2LMHeadModel.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 5
  Num Epochs = 100
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 2
  Total optimization steps = 100
  Number of trainable parameters = 124,439,808
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss


KeyboardInterrupt: 