In [1]:
import re
import torch
import torch.nn as nn
#import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from trl import PPOConfig#, PPOTrainer
#import luxai_s3
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
#from luxai_s3.params import EnvParams
import numpy as np
from datasets import load_dataset, Dataset
#from peft import LoraConfig, get_peft_model
import os
#from accelerate import infer_auto_device_map
import gc
#import copy
gc.enable()

#from stable_baselines3 import PPO
#import gymnasium as gym
#import gym

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
os.environ["FLASH_ATTENTION"] = "1"
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.cache_size_limit = 64
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.set_printoptions(linewidth=200)
# Configure CUDA memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,garbage_collection_threshold:0.8"
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False

# Enable gradient checkpointing
os.environ["PYTORCH_ATTENTION_USE_MEMORY_EFFICIENT_ATTENTION"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

In [4]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

#dataset = get_gsm8k_questions()

In [5]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# ✅ Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# ✅ Ensure pad token is set correctly
tokenizer.pad_token = tokenizer.eos_token

# ✅ Optimized quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,  # ✅ Add nested quantization for better memory usage
    bnb_4bit_quant_storage="bfloat16"  # Enable quantized storage
)

# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True
# )

In [6]:
response_length = 512

In [7]:
policy_autoconfig = AutoConfig.from_pretrained(model_name)
policy_autoconfig.max_position_embeddings = 4000 + response_length
policy_autoconfig.use_cache = False

In [8]:
def create_model():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        device_map="auto",
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        config=policy_autoconfig,
        attn_implementation="flash_attention_2",
        # low_cpu_mem_usage=True
    )

    model.gradient_checkpointing_enable()

    return model

In [9]:
reward_multiplier = 5
# Reward functions
def strict_format_reward_func(completion) -> float:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<answer>\n.*?\n</answer>\n$"
    match = re.match(pattern, completion)

    return 0.5 * reward_multiplier if match else 0.0

def soft_format_reward_func(completion) -> float:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<answer>.*?</answer>"
    match = re.match(pattern, completion)

    return 0.5 * reward_multiplier if match else 0.0

def count_xml(text) -> float:
    count = 0.0
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001

    return count

def xmlcount_reward_func(completion) -> float:

    return count_xml(completion) * reward_multiplier

def answer_format_reward_func(completion, sap_range) -> float:
    # extract_xml_answer should extract the text between the <answer> tags.
    answer = extract_xml_answer(completion)
    
    # Updated regex pattern: for non-sap actions, force "0, 0" after the digit.
    answer_pattern = re.compile(
        r"^Unit\s+([0-9]+):\s+((?:[0-4],\s*0,\s*0)|(?:5,\s*-?\d+,\s*-?\d+))$"
    )

    answer_score = 0.0
    # Split the answer into lines and remove any extra whitespace.
    lines = [line.strip() for line in answer.strip().split("\n") if line.strip()]
    
    # Penalize if we do not have exactly 16 lines (one per unit)
    if len(lines) != 16:
        answer_score -= 0.2  # adjust penalty as desired

    for line in lines:
        match = answer_pattern.match(line)
        if match:
            # Reward for a valid line
            answer_score += 0.5 / 16
            unit_number = int(match.group(1))
            # Check that unit number is in the valid range
            if unit_number < 0 or unit_number > 15:
                answer_score -= 0.1 / 16
            else:
                answer_score += 0.2 / 16

            unit_action_str = match.group(2)
            # Since our pattern now always expects three parts separated by commas:
            parts = [part.strip() for part in unit_action_str.split(',')]
            if len(parts) != 3:
                answer_score -= 0.1 / 16
                continue
            try:
                action_num = int(parts[0])
                dx = int(parts[1])
                dy = int(parts[2])
            except:
                answer_score -= 0.1 / 16
                continue

            # For sap action (5), check that the provided (dx, dy) are within the allowed range.
            if action_num == 5:
                answer_score += 0.2 / 16  # reward for correct action code
                sap_action_range = max(abs(dx), abs(dy))  # or use Euclidean distance if desired
                if sap_action_range > sap_range:
                    answer_score -= 0.1 / 16
                else:
                    answer_score += 0.2 / 16
            else:
                # For non-sap actions (0-4), dx and dy must be exactly 0.
                if dx != 0 or dy != 0:
                    answer_score -= 0.1 / 16
                else:
                    answer_score += 0.2 / 16
                # Also, ensure action_num is within [0,4].
                if action_num < 0 or action_num > 4:
                    answer_score -= 0.1 / 16
                else:
                    answer_score += 0.2 / 16
        else:
            # Penalize for any line that doesn't match the required format.
            answer_score -= 0.1

    return answer_score * reward_multiplier


def point_gain_reward_func(reward_score) -> float:

    return reward_score if reward_score > 0.0 else -1

def match_won_reward_func(match_won) -> float:

    return 300.0 if match_won else 0.0

def match_lost_reward_func(match_lost) -> float:

    return -300.0 if match_lost else 0.0

def game_won_reward_func(game_won) -> float:

    return 1000.0 if game_won else 0.0

def game_lost_reward_func(game_lost) -> float:

    return -500.0 if game_lost else 0.0

In [10]:
num_games_to_train = 1000

In [11]:
output_dir="outputs/DeepSeek-R1-Distill-Qwen-1.5B-PPO"
run_name="DeepSeek-R1-Distill-Qwen-1.5B-PPO-20250221_01"

training_args = PPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    batch_size=1,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=1,
    bf16=True,
    gradient_accumulation_steps=8,
    num_sample_generations=0,
    max_grad_norm=0.1,
    num_train_epochs=1,
    save_steps=50,
    log_on_each_node=False,
    report_to="none",
    num_ppo_epochs=1,
    cliprange=0.2,
    vf_coef=1.0,
    kl_coef=0.01,
    prediction_loss_only=True,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    max_steps=1,
    per_device_train_batch_size=1,
    torch_empty_cache_steps=1,
    total_episodes=num_games_to_train,
    micro_batch_size=1,
    mini_batch_size=1,
    local_batch_size=1,
    response_length=response_length,
    temperature=0.6,
)

In [12]:
training_args.num_mini_batches

1

In [13]:
training_args.mini_batch_size

1

In [14]:
temp_model_1 = create_model()
temp_model_2 = create_model()

In [15]:
temp_model_1.config.hidden_size

1536

In [16]:
temp_model_1

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2FlashAttention2(
          (q_proj): Linear4bit(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear4bit(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear4bit(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06

In [17]:
class SharedPolicyAndValueModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.base = model
        # Value head: a lightweight linear layer that maps hidden_size to a scalar.
        self.value_head = nn.Linear(model.config.hidden_size, 1).to(model.device, dtype=torch.bfloat16)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # Forward pass through the shared transformer backbone.
        outputs = self.base.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,  # we need hidden states for value prediction
            **kwargs
        )
        hidden_states = outputs.hidden_states[-1]  # final layer: (batch, seq_len, hidden_size)
        # Policy logits: project hidden states to vocabulary size.
        logits = self.base.lm_head(hidden_states)
        # Value: Use a simple pooling strategy.
        # For example, use the hidden state corresponding to the first token (or [CLS]) as a summary.
        value = self.value_head(hidden_states)
        value = value.squeeze(-1)
        # Alternatively, you can average the hidden states across the sequence:
        # pooled = hidden_states.mean(dim=1)
        # value = self.value_head(pooled)
        return logits, value

In [18]:
model_1 = SharedPolicyAndValueModel(temp_model_1)
model_2 = SharedPolicyAndValueModel(temp_model_2)

In [19]:
model_1

SharedPolicyAndValueModel(
  (base): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(151936, 1536)
      (layers): ModuleList(
        (0-27): 28 x Qwen2DecoderLayer(
          (self_attn): Qwen2FlashAttention2(
            (q_proj): Linear4bit(in_features=1536, out_features=1536, bias=True)
            (k_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
            (v_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
            (o_proj): Linear4bit(in_features=1536, out_features=1536, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
            (up_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
            (down_proj): Linear4bit(in_features=8960, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm((1536,), eps=1

In [20]:
env = RecordEpisode(
    LuxAIS3GymEnv(numpy_output=True)
)



In [21]:
from Modified_PPO_Trainer.ppo_trainer_20250223_01 import ModifiedPPOTrainer

In [22]:
trainer = ModifiedPPOTrainer(
    model_1=model_1,
    model_2=model_2,
    processing_class=tokenizer,
    args=training_args,
    reward_functions=[
        strict_format_reward_func,
        soft_format_reward_func,
        xmlcount_reward_func,
        answer_format_reward_func,
        point_gain_reward_func,
        match_won_reward_func,
        match_lost_reward_func,
        game_won_reward_func,
        game_lost_reward_func
    ],
    game_env=env,
    num_games_to_train=num_games_to_train
)

---------- Optimizer
AcceleratedOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-06
    lr: 0.0
    maximize: False
    weight_decay: 0.1
)
AcceleratedOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 5e-06
    lr: 0.0
    maximize: False
    weight_decay: 0.1
)
---------- LR Scheduler
<accelerate.scheduler.AcceleratedScheduler object at 0x75030a50eab0>
<accelerate.scheduler.AcceleratedScheduler object at 0x75030a50f350>


In [None]:
trainer.train()

===training policy===


  0%|          | 0/505000 [00:00<?, ?it/s]

Game number: 1
---------- Chat text ----------
<｜begin▁of▁sentence｜>
Background:
Your goal is to win the 1 vs 1 game by collecting relic points.
You can win a match by collecting more relic points than your opponent.
You can win the game by winning 3 matches first.
Your job is to play the game by choosing an action for each unit you have.

Below is the game rules:
{"Game Objective": ["Two teams compete in a best-of-5 match sequence (called a game).", "Each match lasts 100 steps.", "Teams control units to gain relic points on the map while preventing the opposing team from doing the same.", "Strategy: Explore more in early matches to learn the map and opponent behavior, then exploit this knowledge to win later matches."], "Map Features": {"Description": "The map is a 24x24 2D grid, randomly generated but consistent across matches in a game (no full regeneration between matches).", "Key Map Features": {"Unknown Tiles": ["Not visible until a unit is within sensor range (randomized 2-4 til

W0224 03:53:15.764000 1785161 torch/fx/experimental/symbolic_shapes.py:5124] [0/0] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0224 03:53:15.765000 1785161 torch/fx/experimental/recording.py:298] [0/0] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0224 03:53:15.798000 1785161 torch/fx/experimental/symbolic_shapes.py:5124] [1/0] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0224 03:53:15.801000 1785161 torch/fx/experimental/recording.py:298] [1/0] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0224 03:53:15.838000 1785161 torch/fx/experimental/symbolic_shapes.py:5124] [2/0] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0224 03:53:15.839000 1785161 torch/fx/experimental/recording.py:298] [2/0] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0224 03:53:17.2040

---------- Backward start
---------- Backward end


W0224 03:53:28.234000 1785161 torch/fx/experimental/symbolic_shapes.py:5124] [0/1] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0224 03:53:28.235000 1785161 torch/fx/experimental/recording.py:298] [0/1] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0224 03:53:28.283000 1785161 torch/fx/experimental/symbolic_shapes.py:5124] [1/1] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0224 03:53:28.285000 1785161 torch/fx/experimental/recording.py:298] [1/1] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0224 03:53:28.335000 1785161 torch/fx/experimental/symbolic_shapes.py:5124] [2/1] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0224 03:53:28.336000 1785161 torch/fx/experimental/recording.py:298] [2/1] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0224 03:53:28.5640

---------- Backward start
---------- Backward end
---------- Chat text ----------
<｜begin▁of▁sentence｜>
Background:
Your goal is to win the 1 vs 1 game by collecting relic points.
You can win a match by collecting more relic points than your opponent.
You can win the game by winning 3 matches first.
Your job is to play the game by choosing an action for each unit you have.

Below is the game rules:
{"Game Objective": ["Two teams compete in a best-of-5 match sequence (called a game).", "Each match lasts 100 steps.", "Teams control units to gain relic points on the map while preventing the opposing team from doing the same.", "Strategy: Explore more in early matches to learn the map and opponent behavior, then exploit this knowledge to win later matches."], "Map Features": {"Description": "The map is a 24x24 2D grid, randomly generated but consistent across matches in a game (no full regeneration between matches).", "Key Map Features": {"Unknown Tiles": ["Not visible until a unit is with