# Unsloth GRPO Training for Bitcoin Enhanced Prediction

This notebook implements Group Relative Policy Optimization (GRPO) using Unsloth for comprehensive Bitcoin prediction.

**Dataset**: `bitcoin-enhanced-prediction-dataset-with-local-comprehensive-news`

**Training Method**: Unsloth GRPO
- Built-in preference learning optimization
- Efficient memory usage with Unsloth
- Streamlined training pipeline

## Install Libraries

In [None]:
# !pip install -U "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install -U xformers trl peft accelerate bitsandbytes

In [None]:
# Ensure protobuf uses pure-Python implementation early to avoid descriptor errors
# import os as _osThe code snippet you provided is setting environment variables using the `os.environ.setdefault()` method.

# _os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
# _os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# !pip install -U "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install -U xformers trl peft accelerate bitsandbytes

## Imports

In [None]:
# from unsloth import FastLanguageModel, is_bfloat16_supported
# from unsloth.chat_templates import get_chat_template
from datasets import load_dataset, Dataset
from transformers import TrainingArguments, AutoTokenizer, AutoModel
from peft import PeftModel
import torch, random, os
import json
import numpy as np
from datetime import datetime

# Try to import GRPO classes, with fallback to available alternatives
try:
    from trl import GRPOTrainer, GRPOConfig
    GRPO_AVAILABLE = True
    print("✅ GRPO classes imported successfully")
except ImportError:
    try:
        # Try alternative imports (GRPO might be under different names)
        from trl import PPOTrainer, PPOConfig
        from trl import SFTTrainer, SFTConfig
        GRPO_AVAILABLE = False
        print("⚠️ GRPOTrainer not found, will use SFTTrainer as fallback")
    except ImportError:
        # Last resort - use basic trainer
        from transformers import Trainer
        GRPO_AVAILABLE = False
        print("⚠️ Advanced TRL classes not found, using basic Trainer")

SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

## Configuration

In [None]:
# Model configuration
BASE_MODEL_NAME = "./Qwen3-8B"  # Preferred local base model (if available)
FALLBACK_MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"  # Fallback HF model if local path missing
ADAPTER_PATH = "./my-awesome-model_final_bitcoin-enhanced-prediction-dataset-with-local-comprehensive-news-v2"  # Pre-trained adapter (folder)
CHECKPOINT = "checkpoint-400"  # Specific checkpoint within adapter folder
MAX_SEQ_LENGTH = 2048
DTYPE = torch.bfloat16  # Auto-detection
LOAD_IN_4BIT = True

# LoRA configuration
LORA_R = 32
LORA_ALPHA = 32
LORA_DROPOUT = 0.0
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# GRPO configuration
OUTPUT_DIR = "./qwen_bitcoin_enhanced_grpo_unsloth_pretrained_from_sft"
LEARNING_RATE = 3e-7  # Lower for pre-trained model
NUM_TRAIN_EPOCHS = 1
PER_DEVICE_TRAIN_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
MAX_LENGTH = 1024
MAX_PROMPT_LENGTH = 512
BETA = 0.1

# Quick sanity run controls (useful for CI or first launch)
SANITY_RUN = False            # Set True to run a very short sanity training
SANITY_MAX_STEPS = 30         # Number of steps for sanity run
SANITY_DATASET_SIZE = 256     # Subset size for sanity run

# Dataset
DATASET_NAME = "tahamajs/bitcoin-enhanced-prediction-dataset-with-local-comprehensive-news"

# Reward model for comprehensive analysis
REWARD_MODEL_NAME = "microsoft/DialoGPT-medium"  # Good for conversational quality assessment

## Load Model and Tokenizer

In [None]:
import torch
from pathlib import Path
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel
from unsloth import FastLanguageModel, get_chat_template

# # --- User-defined variables ---
# BASE_MODEL_NAME = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
# FALLBACK_MODEL_NAME = "mistralai/Mistral-7b-Instruct-v0.2"
# ADAPTER_PATH = "path/to/your/first/adapter"  # <<-- IMPORTANT: Set this path
# CHECKPOINT = "checkpoint-final"
# MAX_SEQ_LENGTH = 2048
# DTYPE = torch.bfloat16  # Set the desired data type here
# LOAD_IN_4BIT = True
# LORA_R = 16
# TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
# LORA_ALPHA = 16
# LORA_DROPOUT = 0
# SEED = 3407
# REWARD_MODEL_NAME = "starling-lm/reward-model"
# # --- End of user-defined variables ---


preferred_path = Path(BASE_MODEL_NAME)
chosen_model_name = BASE_MODEL_NAME if preferred_path.exists() else FALLBACK_MODEL_NAME
if chosen_model_name != BASE_MODEL_NAME:
    print(f"ℹ️ Local model path not found at {BASE_MODEL_NAME}. Falling back to {FALLBACK_MODEL_NAME}")

print(f"🔄 Loading base model: {chosen_model_name}")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=chosen_model_name,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=DTYPE,  # Use the configured DTYPE variable
    load_in_4bit=LOAD_IN_4BIT,
)

# Ensure pad token exists and align embeddings to tokenizer size
if getattr(tokenizer, "pad_token", None) is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load pre-trained adapter with vocab-size auto-alignment on mismatch
first_adapter_loaded = False
adapter_path = f"{ADAPTER_PATH}/{CHECKPOINT}"
if not Path(adapter_path).exists():
    print(f"ℹ️ First adapter checkpoint not found at {adapter_path}. Skipping adapter load.")
else:
    print(f"🔄 Loading pre-trained adapter: {adapter_path}")
    try:
        # Load the first adapter without merging
        model = PeftModel.from_pretrained(model, adapter_path)
        first_adapter_loaded = True
        print(f"✅ Successfully loaded adapter from {adapter_path}")
    except Exception as e:
        err = str(e)
        print(f"⚠️ Could not load adapter on first try: {e}")

        # Try to detect and fix vocab size mismatch
        import re
        try:
            expected_match = re.search(r"copying a param with shape torch\.Size\(\[(\d+),", err)
            current_match = re.search(r"current model is torch\.Size\(\[(\d+),", err)
            if expected_match and current_match:
                expected = int(expected_match.group(1))
                current = int(current_match.group(1))
                missing = expected - current
                if missing > 0:
                    print(f"🔧 Detected vocab mismatch. Adding {missing} special token(s) to align.")
                    extra_tokens = [f"<|extra_{i}|>" for i in range(missing)]
                    tokenizer.add_special_tokens({"additional_special_tokens": extra_tokens})
                    model.resize_token_embeddings(len(tokenizer))
                    model = PeftModel.from_pretrained(model, adapter_path)
                    first_adapter_loaded = True
                    print(f"✅ Adapter loaded after aligning vocab size to {len(tokenizer)}")
                else:
                     print("ℹ️ Vocab sizes match but other mismatch detected. Proceeding without adapter.")
            else:
                print("ℹ️ No clear vocab mismatch pattern found. Proceeding without adapter.")
        except Exception as e2:
            print(f"⚠️ Auto-alignment failed: {e2}. Proceeding without adapter.")

# ⚠️ IMPORTANT: Merging into a 4-bit model is not recommended.
# Instead of merging, we keep the first adapter loaded and active.
# The new LoRA adapter for training will be applied on top.
if first_adapter_loaded:
    print("✅ First adapter is loaded and active. Skipping merge step for 4-bit model.")

# Apply chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="chatml",
)

# Prepare model for new LoRA adapter training
print("🔧 Preparing model for new LoRA adapter training...")
try:
    # This will add a *new* adapter for training, while keeping the first one active.
    model = FastLanguageModel.get_peft_model(
        model,
        r=LORA_R,
        target_modules=TARGET_MODULES,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=SEED,
        use_rslora=False,
        loftq_config=None,
    )
    print("✅ New LoRA adapter initialized successfully for training")
except Exception as e:
    print(f"⚠️ Error initializing new LoRA adapter: {e}")
    print("ℹ️ Continuing with current model state...")

# Load reward model with consistent data type
print(f"\n🔄 Loading reward model: {REWARD_MODEL_NAME}")
try:
    reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL_NAME)
    reward_model = AutoModel.from_pretrained(
        REWARD_MODEL_NAME,
        torch_dtype=DTYPE, # Load in the same precision to save memory
        device_map="auto",
    )
    reward_model.eval()
    print(f"✅ Reward model loaded successfully")
except Exception as e:
    print(f"⚠️ Could not load reward model, using rule-based rewards: {e}")
    reward_model = None
    reward_tokenizer = None

print(f"\n📊 Model Configuration:")
print(f"  Base model: {chosen_model_name}")
print(f"  First adapter (loaded, not merged): {adapter_path if first_adapter_loaded else 'Not loaded'}")
print(f"  New LoRA adapter initialized for training with rank: {LORA_R}")
print(f"  Max sequence length: {MAX_SEQ_LENGTH}")
print(f"  Load in 4bit: {LOAD_IN_4BIT}")
print(f"  Data type: {DTYPE}")
print(f"  Reward model: {REWARD_MODEL_NAME if reward_model else 'Rule-based only'}")

## Load and Prepare Dataset

In [None]:
# Load dataset
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Dataset loaded: {DATASET_NAME}")
print(f"Total samples: {len(dataset):,}")

# Show sample
print("\n=== Sample Data ===")
sample = dataset[0]
for key, value in sample.items():
    print(f"{key}: {str(value)[:150]}{'...' if len(str(value)) > 150 else ''}")

## Format Dataset for GRPO

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Assume 'dataset' is your loaded dataset
# dataset = load_dataset(...) 

# Assume you have your model's max length
# For example, for Llama 3 it's 8192
MAX_LENGTH = 2048 

# ==============================================================================
# STEP 1: Format the text prompts (Your original code - this is perfect!)
# ==============================================================================
def formatting_prompts_func(examples):
    """
    Format examples for GRPO training.
    GRPOTrainer expects a single 'prompt' per sample and generates completions internally.
    The reference answer is saved in a 'target' column for offline evaluation.
    """
    instructions = examples.get("instruction", [""] * len(examples.get("input", [])))
    inputs = examples.get("input", [])
    outputs = examples.get("output", [])

    prompts = []
    targets = []
    for instruction, user_input, output in zip(instructions, inputs, outputs):
        system_msg = instruction or "You are a helpful Bitcoin market analyst."
        user_msg = user_input or ""
        # Build prompt ending right before the assistant's turn
        prompt = (
            f"<|im_start|>system\n{system_msg}<|im_end|>\n"
            f"<|im_start|>user\n{user_msg}<|im_end|>\n"
            f"<|im_start|>assistant\n"
        )
        prompts.append(prompt)
        # Store the ground truth completion separately
        targets.append((output or "") + "<|im_end|>")
        
    return {"prompt": prompts, "target": targets}

print("📝 Formatting dataset for Unsloth GRPO (prompt-only mode)...")
formatted_dataset = dataset.map(
    formatting_prompts_func,
    batched=True,
    remove_columns=dataset.column_names,
    desc="Formatting prompts"
)

# Keep a separate copy of targets for evaluation, then remove from the training set
reference_targets = formatted_dataset["target"]
formatted_dataset = formatted_dataset.remove_columns(["target"])

print(f"Formatted dataset samples: {len(formatted_dataset):,}")
print("Columns after formatting:", formatted_dataset.column_names)

print("\n=== Formatted Sample Prompt (Text) ===")
print(formatted_dataset[0]['prompt'][:500])
print("✅ Dataset text formatting complete.")


# ==============================================================================
# STEP 2: Tokenize the formatted prompts (The required fix)
# ==============================================================================
# Make sure to use the correct model name for your tokenizer
# Since you use Unsloth and ChatML format, a model like Mistral-Instruct is a good guess
tokenizer = AutoTokenizer.from_pretrained("unsloth/mistral-7b-instruct-v0.2-bnb-4bit")

# Define the tokenization function
def tokenize_function(examples):
    # This will process the text in the 'prompt' column
    tokenized_output = tokenizer(
        examples["prompt"],
        truncation=True,
        max_length=2048,
    )
    # Create the 'labels' column by cloning 'input_ids'
    tokenized_output["labels"] = tokenized_output["input_ids"][:]
    return tokenized_output

print("\n⚡ Tokenizing the dataset...")
tokenized_dataset = formatted_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["prompt"] # It's good practice to remove the old column
)

print(f"Tokenized dataset samples: {len(tokenized_dataset):,}")
print("✅ Final columns passed to trainer:", tokenized_dataset.column_names)

# Now, 'tokenized_dataset' is ready to be passed to your trainer or DataLoader.
# It contains 'input_ids' and 'attention_mask', which is exactly what the
# data collator expects, and this will solve the ValueError.

In [None]:
# Define a simple identity collator so Trainer doesn't try to pad raw text
class IdentityCollator:
    """Pass-through collator for raw prompt samples.
    Returns list[dict] unchanged so GRPOTrainer can tokenize internally."""
    def __call__(self, features):
        return features

raw_text_collator = IdentityCollator()

# Quick sanity check: show keys of first sample
first = formatted_dataset[0]
print("🔍 Sample keys:", list(first.keys()))
print("Prompt length:", len(first['prompt']))

# Optional: truncate very long samples (safeguard)
MAX_PROMPT_CHARS = 4000
if any(len(r['prompt']) > MAX_PROMPT_CHARS for r in formatted_dataset.select(range(min(50, len(formatted_dataset))))):
    def truncate_func(examples):
        prompts = []
        for p in examples['prompt']:
            if len(p) > MAX_PROMPT_CHARS:
                p = p[:MAX_PROMPT_CHARS] + "..."
            prompts.append(p)
        return {"prompt": prompts}
    print("✂️ Truncating overlong prompts for safety...")
    formatted_dataset = formatted_dataset.map(truncate_func, batched=True)
    print("✅ Truncation pass complete")

In [None]:
# Helper Functions for Structured Output Parsing
def parse_trading_output(text):
    """
    Parse trading output JSON from text response.
    Expected format: {"action":"SELL","confidence":99,"stop_loss":10668.23,"take_profit":9377.95,"forecast_10d":[...]}
    """
    if not text:
        return None
    
    import json
    import re
    
    try:
        # Try direct JSON parsing first
        return json.loads(text.strip())
    except:
        pass
    
    try:
        # Look for JSON-like structure in the text
        json_pattern = r'\{[^{}]*"action"[^{}]*\}'
        matches = re.findall(json_pattern, text, re.IGNORECASE | re.DOTALL)
        
        if matches:
            # Try to parse the most complete match
            for match in matches:
                try:
                    # Clean up the match and try parsing
                    cleaned = match.strip()
                    return json.loads(cleaned)
                except:
                    continue
        
        # Alternative: Extract components manually
        result = {}
        
        # Extract action
        action_match = re.search(r'"action"\s*:\s*"([^"]+)"', text, re.IGNORECASE)
        if action_match:
            result['action'] = action_match.group(1).upper()
        
        # Extract confidence
        conf_match = re.search(r'"confidence"\s*:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
        if conf_match:
            result['confidence'] = float(conf_match.group(1))
        
        # Extract stop_loss
        sl_match = re.search(r'"stop_loss"\s*:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
        if sl_match:
            result['stop_loss'] = float(sl_match.group(1))
        
        # Extract take_profit
        tp_match = re.search(r'"take_profit"\s*:\s*(\d+(?:\.\d+)?)', text, re.IGNORECASE)
        if tp_match:
            result['take_profit'] = float(tp_match.group(1))
        
        # Extract forecast_10d array
        forecast_match = re.search(r'"forecast_10d"\s*:\s*\[([^\]]+)\]', text, re.IGNORECASE)
        if forecast_match:
            try:
                forecast_str = forecast_match.group(1)
                forecast_values = [float(x.strip()) for x in forecast_str.split(',')]
                result['forecast_10d'] = forecast_values
            except:
                pass
        
        return result if result else None
        
    except Exception as e:
        return None

def calculate_forecast_similarity(resp_forecast, gt_forecast):
    """
    Calculate similarity between two forecast arrays.
    Uses multiple metrics: correlation, directional accuracy, and magnitude similarity.
    """
    if not resp_forecast or not gt_forecast:
        return 0.0
    
    import numpy as np
    
    try:
        # Ensure both are numeric arrays
        resp_arr = np.array([float(x) for x in resp_forecast if isinstance(x, (int, float))])
        gt_arr = np.array([float(x) for x in gt_forecast if isinstance(x, (int, float))])
        
        if len(resp_arr) == 0 or len(gt_arr) == 0:
            return 0.0
        
        # Align lengths (take minimum)
        min_len = min(len(resp_arr), len(gt_arr))
        resp_arr = resp_arr[:min_len]
        gt_arr = gt_arr[:min_len]
        
        if min_len < 2:
            return 0.0
        
        similarity_score = 0.0
        
        # 1. Correlation similarity (40%)
        try:
            corr = np.corrcoef(resp_arr, gt_arr)[0, 1]
            if not np.isnan(corr):
                similarity_score += abs(corr) * 0.4
        except:
            pass
        
        # 2. Directional accuracy (30%)
        resp_directions = np.diff(resp_arr) > 0  # True for up, False for down
        gt_directions = np.diff(gt_arr) > 0
        if len(resp_directions) > 0:
            directional_accuracy = np.mean(resp_directions == gt_directions)
            similarity_score += directional_accuracy * 0.3
        
        # 3. Magnitude similarity (30%)
        try:
            # Normalize both arrays to compare relative changes
            resp_norm = (resp_arr - np.mean(resp_arr)) / (np.std(resp_arr) + 1e-8)
            gt_norm = (gt_arr - np.mean(gt_arr)) / (np.std(gt_arr) + 1e-8)
            
            # Calculate mean squared error and convert to similarity
            mse = np.mean((resp_norm - gt_norm) ** 2)
            magnitude_similarity = max(0, 1 - (mse / 4))  # Normalize MSE
            similarity_score += magnitude_similarity * 0.3
        except:
            pass
        
        return min(1.0, max(0.0, similarity_score))
        
    except Exception as e:
        return 0.0

print("✅ Helper functions for structured output parsing defined")

In [None]:
import json
import re

def parse_trading_output(response_text):
    """
    Parses a JSON object from a string, looking for the content between ```json and ```.
    """
    try:
        # Find the JSON block within the response
        match = re.search(r"```json\s*([\s\S]+?)\s*```", response_text)
        if match:
            json_str = match.group(1)
            return json.loads(json_str)
    except (json.JSONDecodeError, TypeError):
        # Handle cases where parsing fails or input is not a string
        pass
    return None

def calculate_forecast_similarity(predicted, actual):
    """
    Calculates the similarity between two forecast arrays using Mean Absolute Percentage Error (MAPE).
    A lower MAPE results in a higher similarity score (reward).
    """
    if not predicted or not actual:
        return 0.0
    
    # Ensure lists have the same length for comparison
    min_len = min(len(predicted), len(actual))
    if min_len == 0:
        return 0.0
        
    predicted = predicted[:min_len]
    actual = actual[:min_len]
    
    errors = []
    for p, a in zip(predicted, actual):
        if a > 0: # Avoid division by zero
            errors.append(abs((p - a) / a))
    
    if not errors:
        return 0.0
        
    mean_absolute_percentage_error = sum(errors) / len(errors)
    
    # Convert error to a similarity score (reward). 1.0 is perfect, 0.0 is high error.
    # An error of 10% (0.1) would result in a 0.9 reward.
    similarity = max(0, 1 - mean_absolute_percentage_error)
    return similarity

def calculate_price_prediction_reward(response, ground_truth):
    """
    Calculates a reward based ONLY on the accuracy of numerical price predictions
    (stop_loss, take_profit, and forecast_10d) against the ground truth.
    
    The final reward is a value between 0.0 and 1.0.
    """
    response_json = parse_trading_output(response)
    ground_truth_json = parse_trading_output(ground_truth)
    
    # If we don't have valid JSON in both, no reward can be calculated.
    if not response_json or not ground_truth_json:
        return 0.0
        
    price_rewards = []
    
    # 1. Price Level Accuracy (Stop Loss and Take Profit)
    for price_field in ['stop_loss', 'take_profit']:
        resp_price = response_json.get(price_field)
        gt_price = ground_truth_json.get(price_field)
        
        # Ensure both are valid, positive numbers before calculating reward
        if isinstance(resp_price, (int, float)) and isinstance(gt_price, (int, float)) and gt_price > 0:
            # Calculate the percentage difference
            price_diff_pct = abs((resp_price - gt_price) / gt_price)
            
            # Convert the percentage difference into a reward score (0 to 1)
            # A perfect match gets 1.0. A 20% difference gets a 0.8 reward.
            price_similarity = max(0, 1 - price_diff_pct)
            price_rewards.append(price_similarity)
            
    # 2. Forecast Accuracy (10-day prediction array)
    resp_forecast = response_json.get('forecast_10d')
    gt_forecast = ground_truth_json.get('forecast_10d')
    
    # Ensure both are lists
    if isinstance(resp_forecast, list) and isinstance(gt_forecast, list):
        forecast_similarity = calculate_forecast_similarity(resp_forecast, gt_forecast)
        price_rewards.append(forecast_similarity)
        
    # If no price fields were rewarded, the total reward is 0.
    if not price_rewards:
        return 0.0
        
    # The final reward is the average of all calculated price rewards.
    total_reward = sum(price_rewards) / len(price_rewards)
    
    return total_reward

### Example Usage ###

# --- Test Case 1: Close match ---


## Enhanced Reward Function with Structured Output Parsing

The reward function has been enhanced to handle structured JSON trading outputs with the following format:
```json
{
  "action": "SELL",
  "confidence": 99,
  "stop_loss": 10668.23,
  "take_profit": 9377.95,
  "forecast_10d": [8830.75, 9174.91, 8277.01, 6955.27, 7754.00, 7621.30, 8265.59, 8736.98, 8621.90, 8129.97]
}
```

### Reward Distribution:
- **Structured Output Parsing (25%)**: Parses and validates JSON format, compares action/confidence/prices/forecast with ground truth
- **Prediction Quality (20%)**: Keywords, length, comprehensive analysis indicators
- **Technical Analysis (15%)**: Technical indicators, chart patterns, trading signals
- **News Integration (15%)**: News impact assessment, multi-factor analysis
- **Specificity (10%)**: Price targets, timeframes, confidence levels
- **AI Assessment (10%)**: Conversational quality using reward model
- **Bonuses (5%)**: Structure, disclaimers, professional formatting

### Key Features:
- **JSON Parsing**: Extracts structured trading data from model responses
- **Action Matching**: Compares trading actions (BUY/SELL/HOLD) with ground truth
- **Confidence Scoring**: Rewards confidence levels close to expected values
- **Price Accuracy**: Evaluates stop_loss and take_profit price levels
- **Forecast Similarity**: Multi-metric comparison of 10-day price predictions using correlation, directional accuracy, and magnitude similarity

In [None]:
# Custom GRPO Trainer with Enhanced Reward Integration
class CustomGRPOTrainer(GRPOTrainer):
    """
    Custom GRPO Trainer that integrates the enhanced reward function
    with structured output parsing for Bitcoin trading predictions.
    This version correctly processes batches and is structured robustly.
    """
    
    def __init__(self, reward_model=None, reward_tokenizer=None, **kwargs):
        # Pass the reward function method to the parent class
        kwargs['reward_funcs'] = [self._compute_reward_batch]
        
        super().__init__(**kwargs)
        
        # Now self.tokenizer and other attributes are safely initialized
        self.reward_model = reward_model
        self.reward_tokenizer = reward_tokenizer

    def _compute_reward_batch(self, prompts=None, completions=None, **kwargs_inner):
        """
        This is our main reward logic, now defined as a class method.
        It correctly loops through a batch of completions and returns a list of rewards.
        """
        rewards = []
        try:
            # 'completions' is a list of generated text strings from the model
            for completion_text in completions:
                # Calculate reward for EACH completion individually
                reward_score = calculate_comprehensive_prediction_reward(
                    response=completion_text,
                    ground_truth=None,  # Correct for GRPO, no ground truth needed here
                    reward_model=self.reward_model,
                    reward_tokenizer=self.reward_tokenizer
                )
                rewards.append(reward_score)
            
            return rewards # Return the list of calculated rewards for the batch

        except Exception as e:
            print(f"Warning: Error during batch reward computation: {e}")
            # If an error occurs, return a list of fallback rewards
            # that matches the batch size to avoid crashing the trainer.
            batch_size = len(completions) if completions is not None else 0
            return [0.5] * batch_size
    
    def log_reward_details(self, response, reward_score):
        """
        Log detailed reward breakdown for debugging and analysis.
        (This method requires no changes)
        """
        print(f"\n=== Reward Analysis ===")
        print(f"Response length: {len(response)} chars")
        print(f"Reward score: {reward_score:.4f}")
        
        response_json = parse_trading_output(response)
        
        if response_json:
            print(f"Structured output found:")
            print(f"  Action: {response_json.get('action', 'N/A')}")
            print(f"  Confidence: {response_json.get('confidence', 'N/A')}")
            print(f"  Stop Loss: {response_json.get('stop_loss', 'N/A')}")
            print(f"  Take Profit: {response_json.get('take_profit', 'N/A')}")
            forecast = response_json.get('forecast_10d', [])
            print(f"  Forecast: {forecast[:3]}... ({len(forecast)} values)")
        else:
            print("No structured output found - using text-based scoring")
        
        print("=" * 25)

print("✅ Corrected CustomGRPOTrainer created.")

In [None]:
# # Test Enhanced Reward Function with Structured Output
# print("🧪 Testing Enhanced Reward Function with Structured Output")
# print("=" * 60)

# # Example model response with structured output
# example_response = '''Based on my analysis of Bitcoin's current market conditions, technical indicators, and recent news sentiment, here is my prediction:

# {"action":"SELL","confidence":85,"stop_loss":11200.50,"take_profit":9500.75,"forecast_10d":[10450.20, 10100.85, 9750.40, 9500.75, 9200.30, 8950.10, 9150.60, 9400.25, 9300.80, 9150.45]}

# This prediction is based on bearish divergence in RSI, declining institutional interest, and regulatory concerns affecting market sentiment.'''

# # Example ground truth for comparison
# example_ground_truth = '''{"action":"SELL","confidence":90,"stop_loss":11000.00,"take_profit":9400.00,"forecast_10d":[10400.00, 10050.00, 9700.00, 9450.00, 9200.00, 8900.00, 9100.00, 9350.00, 9250.00, 9100.00]}'''

# # Test parsing
# print("📊 Testing JSON Parsing:")
# parsed_response = parse_trading_output(example_response)
# parsed_gt = parse_trading_output(example_ground_truth)

# print("Response JSON:", parsed_response)
# print("Ground Truth JSON:", parsed_gt)

# # Test reward calculation
# print(f"\n🏆 Testing Reward Calculation:")
# reward_score = calculate_comprehensive_prediction_reward(
#     response=example_response,
#     ground_truth=example_ground_truth,
#     reward_model=reward_model,
#     reward_tokenizer=reward_tokenizer
# )

# print(f"Reward Score: {reward_score:.4f}")

# # Test forecast similarity
# if parsed_response and parsed_gt:
#     forecast_sim = calculate_forecast_similarity(
#         parsed_response.get('forecast_10d', []),
#         parsed_gt.get('forecast_10d', [])
#     )
#     print(f"Forecast Similarity: {forecast_sim:.4f}")

# print("\n✅ Enhanced reward function testing completed!")

## Setup GRPO Training

In [None]:
# Training arguments
max_steps = SANITY_MAX_STEPS if SANITY_RUN else -1

grpo_args = GRPOConfig(
    # FIX 1: Corrected parameter name
    output_dir=OUTPUT_DIR,

    # FIX 2: Set max_steps and remove conditional logic for epochs
    max_steps=max_steps,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    bf16=True,  # You are enabling bfloat16 training

    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    logging_steps=10,
    save_steps=0 if SANITY_RUN else 100,
    save_strategy="no" if SANITY_RUN else "steps",
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    weight_decay=0.01,
    max_grad_norm=1.0,
        # max_new_tokens=1024,  # Limit to 50 new tokens generated

    # fp16=not is_bfloat16_supported(),
    # bf16=is_bfloat16_supported(),
    remove_unused_columns=False,
    dataloader_num_workers=2,
    seed=SEED,
    report_to="none",
)

# Your print statements for verification are good and will still work.
print(f"🎯 Training Configuration:")
# ... (rest of your print statements)

## Initialize GRPO Trainer

In [None]:
# Initialize GRPO trainer with tokenized dataset
print("🔧 Initializing GRPO Trainer with pre-tokenized dataset...")
generation_kwargs = {
    "max_new_tokens": 1024,  # <-- Set your generation length here
    "do_sample": True,
    "top_k": 50,
    "temperature": 0.7,
}

if GRPO_AVAILABLE:
    try:
        # Note: Use the DataCollatorForLanguageModeling to pad tokenized batches correctly
        from transformers import DataCollatorForLanguageModeling
        
        # This collator properly handles tokenized data (input_ids, attention_mask) for language modeling
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False  # We're doing causal LM, not masked LM
        )
        
        # Use the CustomGRPOTrainer with pre-tokenized data
        grpo_trainer = CustomGRPOTrainer(
            model=model,
            tokenizer=tokenizer,
            args=grpo_args,
            train_dataset=tokenized_dataset,  # Now contains input_ids, attention_mask
            reward_model=reward_model,
            reward_tokenizer=reward_tokenizer,
            max_length=MAX_LENGTH,
            max_prompt_length=MAX_PROMPT_LENGTH,
            beta=BETA,
                # generation_kwargs=generation_kwargs, # <-- The correct argument

            # Pass the appropriate collator for tokenized data
            data_collator=data_collator
        )
        print("✅ CustomGRPOTrainer initialized successfully")
        print(f"📊 Dataset columns: {formatted_dataset.column_names}")
        
    except Exception as e:
        print(f"⚠️ Failed to initialize GRPO trainer: {e}")
        print("🔄 Falling back to standard Trainer")
        
        # For fallback, we can use the same tokenized dataset
        from transformers import Trainer, DataCollatorForLanguageModeling
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=False  # We're doing causal LM, not masked LM
        )
        
        grpo_trainer = Trainer(
            model=model,
            tokenizer=tokenizer,
            args=grpo_args,
            train_dataset=formatted_dataset,
            data_collator=data_collator,
        )
        print("✅ Fallback Trainer initialized with tokenized data")
else:
    print("⚠️ GRPO not available, using standard Trainer")
    
    from transformers import Trainer, DataCollatorForLanguageModeling
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False  # We're doing causal LM, not masked LM
    )
    
    grpo_trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=grpo_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )
    print("✅ Standard Trainer initialized with tokenized data")

print(f"🎯 Training ready with {type(grpo_trainer).__name__}")
print(f"📊 Training dataset: {len(formatted_dataset):,} samples")
print(f"🔧 Trainer configuration:")
print(f"  • Effective batch size: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  • Total training steps: {len(formatted_dataset) // (PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS) * (0 if SANITY_RUN else NUM_TRAIN_EPOCHS)}")
print(f"  • Learning rate: {LEARNING_RATE}")
print(f"  • Dataloader workers: 0 (avoiding multiprocessing issues)")
print(f"  • Sanity run: {SANITY_RUN}")

if SANITY_RUN:
    print(f"  • Sanity max steps: {SANITY_MAX_STEPS}")
    print("⚠️ SANITY_RUN is enabled - this will be a short test run")

In [None]:
# Initialize GRPO trainer with enhanced reward function
print("🔧 Initializing GRPO Trainer...")
print("✅ Final check: Columns being passed to trainer:", tokenized_dataset.column_names)

if GRPO_AVAILABLE:
    try:
        # Initialize CustomGRPOTrainer with enhanced reward integration
        grpo_trainer = CustomGRPOTrainer(
            model=model,
            tokenizer=tokenizer,
            args=grpo_args,  # Use GRPOConfig
    train_dataset=tokenized_dataset,  # <--- CORRECTED
            reward_model=reward_model,
            reward_tokenizer=reward_tokenizer,
            # GRPO-specific parameters
            max_length=MAX_LENGTH,
            max_prompt_length=MAX_PROMPT_LENGTH,
            beta=BETA,
            # Let GRPO handle its own data collation for tokenized data
        )
        
        print("✅ CustomGRPOTrainer initialized successfully")
        print(f"📊 Training dataset format: {formatted_dataset.column_names}")
        
    except Exception as e:
        print(f"⚠️ Failed to initialize GRPO trainer: {e}")
        print("🔄 Falling back to standard trainer")
        
        # Fallback initialization with correct data collator for tokenized data
        from transformers import Trainer, DataCollatorWithPadding
        
        # Use DataCollatorWithPadding for pre-tokenized data
        data_collator = DataCollatorWithPadding(
            tokenizer=tokenizer,
            padding=True,
            return_tensors="pt"
        )
        
        grpo_trainer = Trainer(
            model=model,
            tokenizer=tokenizer,
            args=grpo_args,
            train_dataset=tokenized_dataset,
            data_collator=data_collator,
        )
        print("✅ Fallback trainer initialized with DataCollatorWithPadding")

else:
    print("⚠️ GRPO not available, using standard training")
    from transformers import Trainer, DataCollatorWithPadding
    
    # Use DataCollatorWithPadding for pre-tokenized data
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        padding=True,
        return_tensors="pt"
    )
    
    grpo_trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=grpo_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )
    print("✅ Standard trainer initialized with DataCollatorWithPadding")

print(f"🎯 Training ready with {type(grpo_trainer).__name__}")
print(f"📊 Training dataset: {len(formatted_dataset):,} samples")
print(f"🔧 Trainer configuration:")
print(f"  • Effective batch size: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  • Total training steps: {len(formatted_dataset) // (PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS) * (0 if SANITY_RUN else NUM_TRAIN_EPOCHS)}")
print(f"  • Learning rate: {LEARNING_RATE}")
print(f"  • Sanity run: {SANITY_RUN}")

if SANITY_RUN:
    print(f"  • Sanity max steps: {SANITY_MAX_STEPS}")
    print("⚠️ SANITY_RUN is enabled - this will be a short test run")

## Start GRPO Training

In [None]:
# Start training
print("🚀 Starting Unsloth GRPO Training...")
print(f"Training {0 if SANITY_RUN else NUM_TRAIN_EPOCHS} epoch(s) on {len(formatted_dataset):,} samples")
print("="*60)

# Record start time
start_time = datetime.now()
print(f"Training started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

# Train the model
trainer_stats = grpo_trainer.train()

# Record end time
end_time = datetime.now()
training_duration = end_time - start_time

# Extract safe stats
final_loss = None
steps_done = None
try:
    final_loss = getattr(trainer_stats, "training_loss", None)
    steps_done = getattr(trainer_stats, "global_step", None)
except Exception:
    pass

print("\n" + "="*60)
print("🎉 GRPO Training Completed!")
print(f"Training finished at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total training time: {training_duration}")
print(f"Final training loss: {final_loss if final_loss is not None else 'N/A'}")
print(f"Training steps: {steps_done if steps_done is not None else 'N/A'}")

## Save Model

In [None]:
# Save the final model
print("💾 Saving trained model...")

# Save model and tokenizer
model.save_pretrained(f"{OUTPUT_DIR}/final_model")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final_model")

print(f"✅ Model saved to: {OUTPUT_DIR}/final_model")

# Save training summary
training_method = "Unsloth GRPO" if GRPO_AVAILABLE else "Unsloth SFT/Basic"
training_summary = {
    "base_model_name": chosen_model_name,
    "adapter_path": f"{ADAPTER_PATH}/{CHECKPOINT}",
    "adapter_loaded": adapter_loaded,
    "dataset": DATASET_NAME,
    "training_method": training_method,
    "grpo_available": GRPO_AVAILABLE,
    "total_samples": len(formatted_dataset),
    "training_config": {
        "epochs": training_config.num_train_epochs,
        "max_steps": training_config.max_steps if training_config.max_steps > 0 else None,
        "learning_rate": training_config.learning_rate,
        "batch_size": training_config.per_device_train_batch_size,
        "gradient_accumulation_steps": training_config.gradient_accumulation_steps,
        "sanity_run": SANITY_RUN,
        "sanity_max_steps": SANITY_MAX_STEPS if SANITY_RUN else None,
        "sanity_dataset_size": SANITY_DATASET_SIZE if SANITY_RUN else None,
    },
    "training_results": {
        "final_loss": final_loss,
        "total_steps": steps_done,
        "training_duration": str(training_duration),
    },
    "timestamps": {
        "start_time": start_time.isoformat(),
        "end_time": end_time.isoformat(),
    },
    "model_path": f"{OUTPUT_DIR}/final_model",
}

# Add GRPO-specific config if available
if GRPO_AVAILABLE and hasattr(training_config, 'beta'):
    training_summary["training_config"]["grpo_beta"] = training_config.beta
    training_summary["training_config"]["max_length"] = training_config.max_length
    training_summary["training_config"]["max_prompt_length"] = training_config.max_prompt_length

# Save summary
with open(f"{OUTPUT_DIR}/training_summary.json", "w") as f:
    json.dump(training_summary, f, indent=2)

print(f"Training summary saved to: {OUTPUT_DIR}/training_summary.json")

## Test Trained Model

In [None]:
def test_trained_model(model, tokenizer):
    """
    Generates a response from the trained model to a test prompt.
    """
    print("🧪 Testing the trained GRPO model...")

    # Prepare model for inference
    FastLanguageModel.for_inference(model)

    # Test sample
    test_messages = [
        {"role": "system", "content": "You are an expert Bitcoin market analyst. Provide accurate and insightful analysis."},
        {"role": "user", "content": "Based on recent market trends and news, what is your Bitcoin price prediction for the next week? Please provide detailed analysis."}
    ]

    # Format with chat template
    test_prompt = tokenizer.apply_chat_template(
        test_messages,
        tokenize=False,
        add_generation_prompt=True
    )

    print("Test prompt:")
    print(test_prompt)
    print("\n" + "="*50)

    # Generate response
    inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Decode response
    response = tokenizer.decode(
        outputs[0][len(inputs.input_ids[0]):],
        skip_special_tokens=True
    )

    print("Model Response:")
    print(response)
    print("\n✅ Model testing completed!")

# Run the test
test_trained_model(model, tokenizer)

## Training Summary

In [None]:
print("📊 Unsloth GRPO Training Summary")
print("=" * 50)
print(f"🤖 Model: {MODEL_NAME}")
print(f"📚 Dataset: {DATASET_NAME}")
print(f"📈 Training method: Unsloth GRPO (Group Relative Policy Optimization)")
print(f"📝 Total samples: {len(formatted_dataset):,}")
print()
print("🎯 Training Configuration:")
print(f"  • Epochs: {NUM_TRAIN_EPOCHS}")
print(f"  • Learning rate: {LEARNING_RATE}")
print(f"  • Batch size: {PER_DEVICE_TRAIN_BATCH_SIZE}")
print(f"  • Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  • Effective batch size: {PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  • Max length: {MAX_LENGTH}")
print(f"  • Max prompt length: {MAX_PROMPT_LENGTH}")
print(f"  • Beta (KL penalty): {BETA}")
print(f"  • LoRA rank: {LORA_R}")
print()
print("📊 Training Results:")
print(f"  • Final loss: {trainer_stats.training_loss:.4f}")
print(f"  • Training steps: {trainer_stats.global_step:,}")
print(f"  • Training duration: {training_duration}")
print()
print("💾 Outputs:")
print(f"  • Model saved to: {OUTPUT_DIR}/final_model")
print(f"  • Summary saved to: {OUTPUT_DIR}/training_summary.json")
print()
print("🔬 Key Features:")
print("  ✅ Unsloth-optimized GRPO training")
print("  ✅ Memory-efficient 4-bit quantization")
print("  ✅ LoRA parameter-efficient fine-tuning")
print("  ✅ Preference learning for Bitcoin analysis")
print("  ✅ Chat template formatting")
print("  ✅ Gradient checkpointing for memory optimization")
print()
print("🎉 Unsloth GRPO training completed successfully!")
print("📈 Model ready for Bitcoin prediction tasks!")