In [11]:
import os
import torch
import wandb
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import Dataset

parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from utils.data import load_asdiv_data, load_paramawps_data, load_svamp_data, load_aqua_data, load_dmath_data

# Initialize Weights & Biases
run = wandb.init(project="gpt2-math", name="curriculum-learning-sft")
artifact = run.use_artifact('master_thesis_math_lm/gpt2-math/gpt2-math-model:v0', type='model')
artifact_dir = artifact.download()

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(artifact_dir)
model = AutoModelForCausalLM.from_pretrained(artifact_dir)

tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id


# Define dataset paths
BASE_DIR = os.getcwd()
# Your existing code for finding the data directory
print("Locating data directory...")
data_root = None
possible_roots = [
    BASE_DIR,
    os.path.abspath(os.path.join(BASE_DIR, '..')),
    os.path.abspath(os.path.join(BASE_DIR, '../..')),
    os.path.abspath(os.path.join(BASE_DIR, '../../..'))
]
for root in possible_roots:
    test_path = os.path.join(root, "data")
    if os.path.exists(test_path):
        data_root = root
        print(f"Found data directory at: {test_path}")
        break
if data_root is None:
    print("Could not find data directory. Please specify the path manually.")
    data_root = BASE_DIR  # Default to current directory

# Define the data paths relative to data_root
data_paths = [
    os.path.join(data_root, "data", "curriculum_learning", "1_ASDiv", "ASDiv.xml"),
    os.path.join(data_root, "data", "curriculum_learning", "2_ParaMAWPS", "ParaMAWPS_trainset.json"),
    os.path.join(data_root, "data", "curriculum_learning", "3_SVAMP", "SVAMP.json"),
    os.path.join(data_root, "data", "curriculum_learning", "4_Dmath", "dmath_train.json"),
    os.path.join(data_root, "data", "curriculum_learning", "5_AQuA", "AQuA_train.json")
]

# Dataset names for logging
dataset_names = ["ASDiv", "ParaMAWPS", "SVAMP", "DMath", "AQuA"]

# Verify data files exist
for path, name in zip(data_paths, dataset_names):
    if os.path.exists(path):
        print(f"✓ Found {name} dataset at: {path}")
    else:
        print(f"✗ Could not find {name} dataset at: {path}")


def preprocess_data(data, tokenizer):
    return tokenizer(
        [sample['text'] for sample in data],
        truncation=True,
        max_length=1024,
        padding="max_length",
        return_tensors="pt"
    )

    # Dictionary mapping dataset file paths to their respective loading functions
data_loaders = {
    data_paths[0]: load_asdiv_data,
    data_paths[1]: load_paramawps_data,
    data_paths[2]: load_svamp_data,
    data_paths[3]: load_dmath_data,
    data_paths[4]: load_aqua_data
}


# Training loop over datasets
print("\nStarting curriculum learning...")
for path, name in zip(data_paths, dataset_names):
    print(f"\n{'='*50}")
    print(f"Training on {name} dataset")
    print(f"{'='*50}")
    
    if not os.path.exists(path):
        print(f"Skipping {name} - file not found")
        continue
    
    try:
        print(f"Loading data from {path}")
        # Pass the file path to the loader function
        raw_data = data_loaders[path](path)[:5]  # Select only 5 samples
        print(f"Loaded {len(raw_data)} samples")
        
        dataset = Dataset.from_list(raw_data)
        print("Created dataset object")
        
        tokenized_data = preprocess_data(dataset, tokenizer)
        print("Tokenized the data")
        
        # Set up training arguments
        training_args = TrainingArguments(
            output_dir=f"./models/gpt2-math-curriculum/{name}",
            per_device_train_batch_size=1,
            gradient_accumulation_steps=1,
            save_steps=10,
            logging_steps=1,
            num_train_epochs=1,
            report_to=["wandb"],
        )
        
        # Initialize trainer
        trainer = SFTTrainer(
            model=model,
            train_dataset=tokenized_data,
            args=training_args,
            tokenizer=tokenizer
        )
        
        # Train the model
        print(f"Training on {name} dataset...")
        trainer.train()
        print(f"Completed training on {name} dataset")
        
        # Save the model for this dataset
        model_save_path = f"./models/gpt2-math-curriculum/{name}"
        model.save_pretrained(model_save_path)
        tokenizer.save_pretrained(model_save_path)
        print(f"Saved model to {model_save_path}")
        
        # Log to wandb
        wandb.log({f"trained_on_{name}": True})
        
    except Exception as e:
        print(f"Error training on {name} dataset: {e}")
        continue

print("\nCurriculum learning complete!")
wandb.finish()


[34m[1mwandb[0m: Downloading large artifact gpt2-math-model:v0, 479.31MB. 8 files... 
[34m[1mwandb[0m:   8 of 8 files downloaded.  
Done. 0:0:1.4


Locating data directory...
Found data directory at: /work/math-reasoning-in-language-models/data
✓ Found ASDiv dataset at: /work/math-reasoning-in-language-models/data/curriculum_learning/1_ASDiv/ASDiv.xml
✓ Found ParaMAWPS dataset at: /work/math-reasoning-in-language-models/data/curriculum_learning/2_ParaMAWPS/ParaMAWPS_trainset.json
✓ Found SVAMP dataset at: /work/math-reasoning-in-language-models/data/curriculum_learning/3_SVAMP/SVAMP.json
✓ Found DMath dataset at: /work/math-reasoning-in-language-models/data/curriculum_learning/4_Dmath/dmath_train.json
✓ Found AQuA dataset at: /work/math-reasoning-in-language-models/data/curriculum_learning/5_AQuA/AQuA_train.json

Starting curriculum learning...

Training on ASDiv dataset
Loading data from /work/math-reasoning-in-language-models/data/curriculum_learning/1_ASDiv/ASDiv.xml
Loaded 2305 problems from ASDiv
Loaded 5 samples
Created dataset object
Tokenized the data
Error training on ASDiv dataset: 'str' object has no attribute 'keys'

T

  trainer = SFTTrainer(




Training on SVAMP dataset
Loading data from /work/math-reasoning-in-language-models/data/curriculum_learning/3_SVAMP/SVAMP.json
Loaded 1000 problems from SVAMP
Loaded 5 samples
Created dataset object
Tokenized the data
Error training on SVAMP dataset: 'str' object has no attribute 'keys'

Training on DMath dataset
Loading data from /work/math-reasoning-in-language-models/data/curriculum_learning/4_Dmath/dmath_train.json
Loaded 7943 problems from DMath
Loaded 5 samples
Created dataset object
Tokenized the data
Error training on DMath dataset: 'str' object has no attribute 'keys'

Training on AQuA dataset
Loading data from /work/math-reasoning-in-language-models/data/curriculum_learning/5_AQuA/AQuA_train.json
Loaded 97467 problems from AQuA
Loaded 5 samples
Created dataset object
Tokenized the data
Error training on AQuA dataset: 'str' object has no attribute 'keys'

Curriculum learning complete!
