In [1]:
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead

# Load the model and tokenizer
class CONFIG:
    model_name:str = 'meta-llama/Meta-Llama-3-8B-Instruct'
    dataset_path: str = "data"
    init_result_path: str = "results"
    
tokenizer = AutoTokenizer.from_pretrained(CONFIG.model_name)
model = AutoModelForCausalLMWithValueHead.from_pretrained(CONFIG.model_name)

def llama3_8b_corrective_prompt(problem, previous_solution, correction_hint):
    """
    Perform corrective prompting using Llama3 8B model.

    Parameters:
    problem (str): The problem statement.
    previous_solution (str): The previous incorrect solution.
    correction_hint (str): The hint or correction to guide the model.

    Returns:
    str: The generated solution.
    """
    prompt = f"Problem: {problem}\nPrevious Solution: {previous_solution}\nCorrection Hint: {correction_hint}\nNew Solution:"

    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=150, num_return_sequences=1)  # Adjust max_length as needed
    solution = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the generated solution after "New Solution:"
    solution = solution.split("New Solution:")[1].strip() if "New Solution:" in solution else solution.strip()
    return solution

def categorize_responses(incorrect_pairs, evaluate_function):
    """
    Categorize responses into accepted and rejected after corrective prompting.

    Parameters:
    incorrect_pairs (list): List of dictionaries containing incorrect problem-solution pairs.
    evaluate_function (callable): The function that evaluates the solution.

    Returns:
    tuple: A tuple containing JSON strings of accepted and rejected datasets.
    """
    accepted = []
    rejected = []

    for item in incorrect_pairs:
        problem = item['problem']
        expected_solution = item['expected_solution']
        previous_solution = item['actual_solution']
        correction_hint = f"The correct solution should be: {expected_solution}"

        new_solution = llama3_8b_corrective_prompt(problem, previous_solution, correction_hint)

        if new_solution.strip() == expected_solution.strip():
            accepted.append({
                'problem': problem,
                'expected_solution': expected_solution,
                'previous_solution': previous_solution,
                'new_solution': new_solution
            })
        else:
            rejected.append({
                'problem': problem,
                'expected_solution': expected_solution,
                'previous_solution': previous_solution,
                'new_solution': new_solution
            })

    accepted_json = json.dumps(accepted, indent=4)
    rejected_json = json.dumps(rejected, indent=4)

    return accepted_json, rejected_json

# Example usage:
# Note: Replace the following incorrect_pairs with actual incorrect pairs data
incorrect_pairs = [
    {'problem': 'If a train travels 60 miles in 1 hour, how far will it travel in 4 hours?', 'expected_solution': '240 miles', 'actual_solution': '250 miles'},
    {'problem': 'Sarah has 5 apples. She buys 7 more apples. How many apples does she have now?', 'expected_solution': '12 apples', 'actual_solution': '11 apples'},
]

accepted_json, rejected_json = categorize_responses(incorrect_pairs, llama3_8b_corrective_prompt)
print("Accepted Responses JSON:", accepted_json)
print("Rejected Responses JSON:", rejected_json)
print(f"Number of accepted responses: {len(json.loads(accepted_json))}")
print(f"Number of rejected responses: {len(json.loads(rejected_json))}")



tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]