<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/grpo_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install condacolab
!pip install -q condacolab
import condacolab
condacolab.install()
# Restart runtime here

In [1]:
!conda --version

conda 24.11.2


In [2]:
!conda env list


# conda environments:
#
base                   /usr/local



In [3]:
!nvidia-smi

Mon Jan 27 16:51:25 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0              47W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
# Update Conda (optional but recommended)
!conda update -n base -c defaults conda

# Create and activate conda environment
!conda create -n openr1 python=3.11 -y
!conda activate openr1

# Clone the Open-R1 repository:
!git clone https://github.com/huggingface/open-r1.git

# Change to project directory
%cd /content/open-r1

# Install necessary packages
!pip install -e ".[dev]"
!pip install vllm==0.6.6.post1 -q
!pip install vllm==0.6.6.post1 --extra-index-url https://download.pytorch.org/whl/cu121 -q


# Unset WANDB_DISABLED if it exists
import os
if 'WANDB_DISABLED' in os.environ:
    del os.environ['WANDB_DISABLED']


In [None]:
!accelerate launch --config_file /content/open-r1/configs/zero3.yaml /content/open-r1/src/open_r1/grpo.py \
    --output_dir DeepSeek-R1-Distill-Qwen-7B-GRPO \
    --model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \
    --dataset_name AI-MO/NuminaMath-TIR \
    --max_prompt_length 256 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --logging_steps 10 \
    --bf16

In [1]:
!pip install bitsandbytes -q
!pip install datasets -q
!pip install transformers -q
!pip install torch -q
!pip install accelerate -q
!pip install tqdm -q

TRAINING

In [None]:
# Unset WANDB_DISABLED if it exists
import os
if 'WANDB_DISABLED' in os.environ:
    del os.environ['WANDB_DISABLED']

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import bitsandbytes as bnb
import os
from accelerate import Accelerator  # Import Accelerator
from tqdm import tqdm  # Import tqdm

# Parameters
output_dir = "DeepSeek-R1-Distill-Qwen-7B-GRPO"
max_prompt_length = 256
per_device_train_batch_size = 1  # Not explicitly used in this simplified example
gradient_accumulation_steps = 16  # Not explicitly used in this simplified example
logging_steps = 10

# Function to run training within a subprocess (or directly)
def train_func(args):
    # Initialize Accelerator
    accelerator = Accelerator(mixed_precision="bf16", device_placement=False, split_batches=False)

    # 1. Load the Dataset
    dataset = load_dataset("AI-MO/NuminaMath-TIR")

    # 2. Load the DeepSeek-R1-Distill-Qwen-7B model with 4-bit quantization
    tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")

    quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    llm_int8_enable_fp32_cpu_offload=True  # Enable CPU offloading
    )

    # Load the model without specifying a device_map
    model = AutoModelForCausalLM.from_pretrained(
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
        quantization_config=quantization_config,
    )

    # Explicitly move the entire model to GPU 0
    model.to(accelerator.device)

    output_dim = 10  # Or the appropriate number of actions for your environment


    # This should be done once, before you initialize the PolicyNetwork
    env = SimpleEnvironment(dataset, tokenizer, model, accelerator, output_dim)
    state_embedding = env.get_state()  # Get a sample state embedding
    input_dim = state_embedding.shape[-1]  # Now get the input size
    output_dim = 10  # Assume 10 possible actions (adapt to your task)


    # Create models, optimizer
    #policy_network = PolicyNetwork(input_dim, output_dim, accelerator.device)  # Pass device

    policy_network = PolicyNetwork(state_embedding.shape[-1], output_dim, accelerator.device) # Using state_embedding.shape[-1] to get actual input size
    value_network = ValueNetwork(input_dim, accelerator.device)  # Pass device

    # In your GRPOTrainer class or where you define the optimizer
    optimizer = optim.Adam(policy_network.parameters(), lr=0.001)  # Try a smaller learning rate


    # Prepare with accelerate (moves models and optimizer to device)
    policy_network, value_network, optimizer, env = accelerator.prepare(
       policy_network, value_network, optimizer, env
   )



    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Pass the model to GRPOTrainer
    trainer = GRPOTrainer(env, policy_network, value_network, optimizer, learning_rate=0.01, accelerator=accelerator, tokenizer=tokenizer, model=model)

    trainer.train(num_epochs=1, num_trajectories_per_epoch=10)

    # Save the model using the standard PyTorch save method
    torch.save(accelerator.unwrap_model(policy_network).state_dict(), os.path.join(output_dir, 'policy_network.pth'))

# 3. Define Policy Network
import torch
import torch.nn as nn

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, device):
        super(PolicyNetwork, self).__init__()

        # Update the input dimension for fc1 to match the state embedding size
        self.fc1 = nn.Linear(input_dim, 128)

        self.relu = nn.ReLU()  # Using ReLU activation
        self.fc2 = nn.Linear(128, output_dim)
        self.device = device

        # Apply weight initialization (example using Kaiming/He)
        torch.nn.init.kaiming_uniform_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_uniform_(self.fc2.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        x = x.to(self.device).type(torch.float32)
        x = self.relu(self.fc1(x))  # Apply ReLU activation
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)  # Softmax for probability distribution

# 4. Define Value Network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim, device):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 1)
        self.device = device  # Store the device


    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 5. Environment (simplified example - assumes 'text' and 'label' fields in dataset)
class SimpleEnvironment:
    def __init__(self, dataset, tokenizer, model, accelerator, output_dim):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.model = model
        self.accelerator = accelerator  # Store accelerator
        self.current_index = 0
        self.episode_length = 20  # Set the desired episode length
        self.output_dim = output_dim  # Store output_dim as an attribute
        # Get initial state embedding and initialize PolicyNetwork with the correct input_dim
        initial_state_embedding = self.get_state()  # Call get_state() to get the initial embedding
        self.input_dim = initial_state_embedding.shape[-1]  # Store the initial input dimension
        self.policy_network = PolicyNetwork(self.input_dim, self.output_dim, self.accelerator.device)

    def reset(self):
        self.current_index = 0
        return self.get_state()

    def get_state(self):
        text_description = self.dataset["train"][self.current_index].get("text", "")
        text_description = text_description[:max_prompt_length]

        inputs = self.tokenizer(text_description, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)

        # In the SimpleEnvironment.get_state() function:
        state_embedding = outputs.hidden_states[-1][:, 0, :].to(self.accelerator.device)

        # Define extra_features before using it
        extra_features = torch.tensor([self.current_index / self.episode_length], device=self.accelerator.device)  # Example extra feature

        # Add an extra dimension to extra_features using unsqueeze() to match state_embedding shape
        extra_features = extra_features.unsqueeze(0).repeat(state_embedding.shape[0], 1)  # Repeat for all elements in batch

        # **Change is here:**
        state_embedding = torch.cat([state_embedding, extra_features], dim=-1)

        # Update input_dim based on the final shape of state_embedding
        input_dim = state_embedding.shape[-1]

        # Re-initialize the PolicyNetwork if the input dimension has changed
        if hasattr(self, 'policy_network') and self.policy_network.fc1.in_features != input_dim:
            self.policy_network = PolicyNetwork(input_dim, self.output_dim, self.accelerator.device)

        return state_embedding.to(torch.float32)

    def step(self, action):
        #print("Entering env.step()")  # Print at the beginning
        target = self.dataset["train"][self.current_index].get("label", 0)  # Assume 0 if 'label' missing
        #print("Target obtained:", target)  # Print target value

        #reward = 1.0 if action == target else -1.0
        #print("Reward calculated:", reward)  # Print reward value

        # Reward calculation using self.output_dim
        if action == target:
            reward = 1.0
        else:
            reward = -0.1 + 0.9 * (1 - abs(action - target) / (self.output_dim - 1))

        self.current_index += 1
        done = self.current_index >= self.episode_length  # Check if episode length is reached
        next_state = self.get_state() if not done else None

        #print("Exiting env.step()")  # Print before returning

        return next_state, reward, done

# 6. Trajectory Collection
import torch

def collect_trajectories(env, policy_network, num_trajectories, accelerator):
    trajectories = []
    for _ in range(num_trajectories):
        state = env.reset()
        trajectory = []
        while True:
            # Move policy_network and state to the correct device
            policy_network = policy_network.to(accelerator.device)
            state = state.to(accelerator.device)


            # Inside the collect_trajectories() function or train() function
            action_probs = policy_network(state)
            #print(f"Action probabilities: {action_probs}")



            # Check for invalid probabilities (e.g., NaN or probabilities not summing to 1)
            if not torch.isfinite(action_probs).all() or not torch.allclose(action_probs.sum(), torch.tensor(1.0, device=action_probs.device)):
                # If invalid, assign uniform probabilities to all actions
                action_probs = torch.ones_like(action_probs) / action_probs.shape[-1]

            # Sample an action from the probability distribution
            action = torch.multinomial(action_probs, 1).item()

            # Take a step in the environment
            next_state, reward, done = env.step(action)

            # Move next_state to the correct device if it's not None
            if next_state is not None:
                next_state = next_state.to(accelerator.device)

            # Append the current state, action, and reward to the trajectory
            trajectory.append((state, action, reward))

            # Check if the episode is done
            if done:
                break

            # Update the current state for the next step
            state = next_state

        # Append the completed trajectory to the list of trajectories
        trajectories.append(trajectory)

    # Return the collected trajectories
    return trajectories

# 7. GRPO Update (simplified - needs customization)
import torch

def update_policy(policy_network, value_network, trajectories, optimizer, accelerator):
    policy_loss_total = 0

    # Move both networks to the accelerator device before processing trajectories
    policy_network = policy_network.to(accelerator.device)
    value_network = value_network.to(accelerator.device)

    for trajectory in trajectories:
        # Calculate total rewards for each step in the trajectory (using a simple Monte Carlo estimate)
        rewards = [traj[2] for traj in trajectory]  # List of rewards
        total_rewards = [sum(rewards[i:]) for i in range(len(rewards))]  # Calculate cumulative rewards

        for i in range(len(trajectory)):
            state, action, reward = trajectory[i]

            # 1. Estimate advantage (using a simple Monte Carlo estimate)
            advantage = total_rewards[i] - value_network(state).item()

            # 2. Calculate policy gradient
            action_probs = policy_network(state)
            log_prob = torch.log(action_probs[0][action])  # Assuming action_probs is a batch of size 1
            policy_loss = -advantage * log_prob

            # Accumulate policy loss
            policy_loss_total += policy_loss

            # Debugging print statements (placed within the loop)
            print(f"State: {state}")  # Check if state is changing
            print(f"Action Probabilities: {action_probs}")  # Check if probabilities are always the same
            print(f"Advantage: {advantage}")  # Check if advantage is meaningful
            print(f"Log Probability: {log_prob}")  # Check if log probability is valid
            print(f"Policy Loss (before accumulation): {policy_loss}")  # Check if policy loss is non-zero

    # Debugging Prints: Total Policy Loss
    print(f"Total policy loss: {policy_loss_total}")

    # Perform the optimization step
    optimizer.zero_grad()
    policy_loss_total.backward(retain_graph=True)  # Backpropagate the total loss, retain_graph=True

    # Print gradients of policy network parameters
    for name, param in policy_network.named_parameters():
        print(f"Gradient of {name}: {param.grad}")

    optimizer.step()  # Update the policy network parameters

    return policy_loss_total



# 8. GRPO Trainer
class GRPOTrainer:
    def __init__(self, env, policy_network, value_network, optimizer, learning_rate, accelerator, tokenizer, model):
        self.env = env
        self.policy_network = policy_network
        self.value_network = value_network
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.global_step = 0
        self.policy_loss = 0
        self.accelerator = accelerator
        self.tokenizer = tokenizer
        self.model = model

        # Prepare with accelerate
        self.policy_network, self.value_network, self.optimizer, self.env = accelerator.prepare(
            policy_network, value_network, optimizer, env
        )

        # Create evaluation dataset and dataloader
        eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="test")
        self.eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset, batch_size=1,  # Reduced batch size
            collate_fn=lambda examples: self.process_batch(examples)
        )

    def process_batch(self, examples):
        # Extract the relevant fields for tokenization, joining list elements if necessary
        texts = [
            example['problem'] +
            (example['solution'] if isinstance(example['solution'], str) else ' '.join(example['solution'])) +
            (example['messages'] if isinstance(example['messages'], str) else ' '.join(str(item) for item in example['messages']))
            for example in examples
        ]

        # Tokenize the extracted texts
        inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=self.model.config.max_position_embeddings)

        # Other necessary batching processes
        labels = [example.get('label', 0) for example in examples]

        # Correct the input labels to match input_ids shape
        inputs['labels'] = torch.tensor(labels).unsqueeze(1).repeat(1, inputs['input_ids'].shape[1]) # shape (batch_size, seq_len)
        inputs['labels'] = torch.where(inputs['attention_mask'] == 1, inputs['labels'], -100) # shape (batch_size, seq_len)

        return inputs

    def calculate_accuracy(self, predictions, ground_truth):
        # Calculate accuracy
        accuracy = np.mean(predictions == ground_truth)  # Removed argmax since predictions are already single values
        return accuracy

    def collect_trajectories(self, num_trajectories):
        return collect_trajectories(self.env, self.policy_network, num_trajectories, self.accelerator)

    def update_policy(self, trajectories):
        self.policy_loss = update_policy(self.policy_network, self.value_network, trajectories, self.optimizer, self.accelerator)

    def train(self, num_epochs, num_trajectories_per_epoch):
        total_trajectories = num_epochs * num_trajectories_per_epoch
        for epoch in range(num_epochs):
            for trajectory_index in tqdm(range(num_trajectories_per_epoch), desc=f"Epoch {epoch + 1}/{num_epochs}", total=total_trajectories, position=0, leave=True):
                # Collect trajectories and update policy within the trajectory loop
                trajectories = self.collect_trajectories(1)

                with self.accelerator.accumulate(self.policy_network):
                    self.policy_loss = update_policy(self.policy_network, self.value_network, trajectories, self.optimizer, self.accelerator)

                    # Add print statement to check if policy_loss is detached
                    print(f"policy_loss.is_leaf: {self.policy_loss.is_leaf}, policy_loss.requires_grad: {self.policy_loss.requires_grad}")

                    self.accelerator.backward(self.policy_loss, retain_graph=True)  # retain_graph=True added
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    print(f"Policy Loss: {self.policy_loss.item()}")

            # Accuracy calculation and logging (POC) - performed after each epoch
            all_predictions = []
            all_ground_truth_labels = []

            for batch in self.eval_dataloader:
                with torch.no_grad():
                    inputs = {k: v.to(self.accelerator.device) for k, v in batch.items()}
                    outputs = self.model(**inputs, output_hidden_states=True)
                    states = outputs.hidden_states[-1][:, 0, :].to(torch.float32).to(self.accelerator.device)

                ground_truth_labels = batch['labels'].to(self.accelerator.device)
                self.policy_network = self.policy_network.to(self.accelerator.device)

                for state, label_sequence in zip(states, ground_truth_labels):
                    with torch.no_grad():
                        #prediction = self.policy_network(state.squeeze(0))
                        prediction = self.policy_network(state.unsqueeze(0)) # Add unsqueeze(0) here



                    batch_prediction = prediction.argmax(dim=-1).item()
                    ground_truth_label = label_sequence[0].item()

                    all_predictions.append(batch_prediction)
                    all_ground_truth_labels.append(ground_truth_label)

            accuracy = self.calculate_accuracy(np.array(all_predictions), np.array(all_ground_truth_labels))
            print(f"Epoch {epoch + 1}/{num_epochs}, Accuracy: {accuracy}")



# Entry point for your script
if __name__ == "__main__":
    # Create a namespace with the desired values
    import argparse
    args = argparse.Namespace(mixed_precision="bf16")  # Add other arguments as needed

    train_func(args)  # Call the training function

EVALUATION

In [None]:
#!export MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
#!export MODEL_ARGS="pretrained=$MODEL,dtype=float16,max_model_length=32768,gpu_memory_utilisation=0.8"
#!export TASK=aime24
#!export OUTPUT_DIR=data/evals/$MODEL

#!lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
#    --custom-tasks src/open_r1/evaluate.py \
#    --use-chat-template \
#    --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \
#    --output-dir $OUTPUT_DIR

EVALUATION - python

In [19]:
accelerator = Accelerator(mixed_precision="bf16", device_placement=False, split_batches=False)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")

# Load the model with quantization config (as in training)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    llm_int8_enable_fp32_cpu_offload=True
)
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    quantization_config=quantization_config
)
model.to(accelerator.device)

# Load the dataset
dataset = load_dataset("AI-MO/NuminaMath-TIR")

# Create the environment
env = SimpleEnvironment(dataset, tokenizer, model, accelerator)  # Pass 'model' here

# Define input and output dimensions for networks
input_dim = model.config.hidden_size  # Use 'model' here
output_dim = 10  # Or the appropriate number of actions for your environment

# Create policy and value networks (as before)
policy_network = PolicyNetwork(input_dim, output_dim, accelerator.device)
value_network = ValueNetwork(input_dim, accelerator.device)

# *** Load the fine-tuned policy network ***
policy_network.load_state_dict(torch.load("/content/DeepSeek-R1-Distill-Qwen-7B-GRPO/policy_network.pth"))

# Create optimizer (as before)
learning_rate = 0.01
optimizer = optim.Adam(policy_network.parameters(), lr=learning_rate)

# Create the GRPOTrainer instance
trainer = GRPOTrainer(env, policy_network, value_network, optimizer, learning_rate, accelerator, tokenizer, model)  # Pass 'model' here

# Perform evaluation
trainer.evaluate()

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  policy_network.load_state_dict(torch.load("/content/DeepSeek-R1-Distill-Qwen-7B-GRPO/policy_network.pth"))


Evaluation Accuracy: 1.0
