In [1]:
import re
import torch
import torch.nn as nn
#import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from trl import PPOConfig#, PPOTrainer
#import luxai_s3
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
#from luxai_s3.params import EnvParams
import numpy as np
from datasets import load_dataset, Dataset
#from peft import LoraConfig, get_peft_model
import os
#from accelerate import infer_auto_device_map
import gc
#import copy
gc.enable()

#from stable_baselines3 import PPO
#import gymnasium as gym
#import gym

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
os.environ["FLASH_ATTENTION"] = "1"
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.cache_size_limit = 128
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
np.set_printoptions(linewidth=200)
# Configure CUDA memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,garbage_collection_threshold:0.8"
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False

# Enable gradient checkpointing
os.environ["PYTORCH_ATTENTION_USE_MEMORY_EFFICIENT_ATTENTION"] = "1"
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.cuda.set_per_process_memory_fraction(0.9)

In [3]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# ✅ Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# ✅ Ensure pad token is set correctly
tokenizer.pad_token = tokenizer.eos_token

# ✅ Optimized quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,  # ✅ Add nested quantization for better memory usage
    bnb_4bit_quant_storage="bfloat16"  # Enable quantized storage
)

# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True
# )

In [4]:
response_length = 512

In [5]:
policy_autoconfig = AutoConfig.from_pretrained(model_name)
policy_autoconfig.max_position_embeddings = 4500 + response_length
policy_autoconfig.use_cache = False

In [6]:
def create_model():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        device_map="auto",
        # quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        config=policy_autoconfig,
        attn_implementation="flash_attention_2",
        # low_cpu_mem_usage=True
    )

    model.gradient_checkpointing_enable()

    return model

In [7]:
num_games_to_train = 1000

In [8]:
output_dir="outputs/DeepSeek-R1-Distill-Qwen-1.5B-PPO_20250224_01"
run_name="DeepSeek-R1-Distill-Qwen-1.5B-PPO-20250224_01"

training_args = PPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    batch_size=1,
    learning_rate=5e-4,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=1,
    bf16=True,
    gradient_accumulation_steps=8,
    num_sample_generations=0,
    max_grad_norm=1.0,
    num_train_epochs=1,
    save_steps=20,
    log_on_each_node=False,
    report_to="none",
    num_ppo_epochs=1,
    cliprange=0.2,
    vf_coef=1.0,
    kl_coef=0.05,
    prediction_loss_only=True,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    max_steps=1,
    per_device_train_batch_size=1,
    torch_empty_cache_steps=1,
    total_episodes=num_games_to_train,
    micro_batch_size=1,
    mini_batch_size=1,
    local_batch_size=1,
    response_length=response_length,
    temperature=0.6,
)

training_args.entropy_coef = 0.01
training_args.target_kl = 0.01
training_args.target_value_loss = 0.01

In [9]:
temp_model_1 = create_model()
temp_model_2 = create_model()

In [10]:
tokenizer

LlamaTokenizerFast(name_or_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', vocab_size=151643, model_max_length=16384, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<｜begin▁of▁sentence｜>', 'eos_token': '<｜end▁of▁sentence｜>', 'pad_token': '<｜end▁of▁sentence｜>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	151643: AddedToken("<｜end▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151644: AddedToken("<｜User｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151645: AddedToken("<｜Assistant｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151646: AddedToken("<｜begin▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151647: AddedToken("<|EOT|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151648: AddedToken("<think>", rstrip=False

In [11]:
temp_model_1.config.hidden_size

1536

In [12]:
temp_model_1

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2FlashAttention2(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): 

In [13]:
class SharedPolicyAndValueModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.base = model
        # Value head: a lightweight linear layer that maps hidden_size to a scalar.
        self.value_head = nn.Linear(model.config.hidden_size, 1).to(model.device, dtype=torch.bfloat16)

    def forward(self, input_ids, attention_mask=None, **kwargs):
        # Forward pass through the shared transformer backbone.
        outputs = self.base.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,  # we need hidden states for value prediction
            **kwargs
        )
        hidden_states = outputs.hidden_states[-1]  # final layer: (batch, seq_len, hidden_size)
        # Policy logits: project hidden states to vocabulary size.
        logits = self.base.lm_head(hidden_states)
        # Value: Use a simple pooling strategy.
        # For example, use the hidden state corresponding to the first token (or [CLS]) as a summary.
        value = self.value_head(hidden_states)
        value = value.squeeze(-1)
        # Alternatively, you can average the hidden states across the sequence:
        # pooled = hidden_states.mean(dim=1)
        # value = self.value_head(pooled)
        return logits, value

In [14]:
model_1 = SharedPolicyAndValueModel(temp_model_1)
model_2 = SharedPolicyAndValueModel(temp_model_2)

In [15]:
model_1

SharedPolicyAndValueModel(
  (base): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(151936, 1536)
      (layers): ModuleList(
        (0-27): 28 x Qwen2DecoderLayer(
          (self_attn): Qwen2FlashAttention2(
            (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
            (k_proj): Linear(in_features=1536, out_features=256, bias=True)
            (v_proj): Linear(in_features=1536, out_features=256, bias=True)
            (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
          (post_attent

In [16]:
env = RecordEpisode(
    LuxAIS3GymEnv(numpy_output=True)
)



In [17]:
from Modified_PPO_Trainer.ppo_trainer_20250224_01 import ModifiedPPOTrainer

In [18]:
trainer = ModifiedPPOTrainer(
    model_1=model_1,
    model_2=model_2,
    processing_class=tokenizer,
    args=training_args,
    game_env=env,
    num_games_to_train=num_games_to_train
)

---------- Optimizer
AcceleratedOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: True
    initial_lr: 0.0005
    lr: 0.0005
    maximize: False
    weight_decay: 0.01
)
AcceleratedOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.99)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: True
    initial_lr: 0.0005
    lr: 0.0005
    maximize: False
    weight_decay: 0.01
)
---------- LR Scheduler
<accelerate.scheduler.AcceleratedScheduler object at 0x743b4968b7a0>
<accelerate.scheduler.AcceleratedScheduler object at 0x743b494201a0>


In [19]:
trainer.train()

===training policy===
Game number: 1
---------- Chat text ----------

<｜begin▁of▁sentence｜>
System: You are an AI playing a strategy game. Output actions only.


Background:
- Your goal is to win the 1 vs 1 game by collecting relic points.
- You can win a match by collecting more relic points than your opponent.
- You can win the game by winning 3 matches first.
- You can play the game by choosing an action for each unit you have.

Below is the game rules:
{"Game Objective": ["Two teams compete in a best-of-5 match sequence (called a game).", "Each match lasts 100 steps.", "Teams control units to gain relic points on the map while preventing the opposing team from doing the same.", "Gain more relic points than your opponent.", "Choose an action for each unit you have.", "Strategy: Explore more in early matches to discover relics fast, learn the map and opponent behavior, then exploit this knowledge to win later matches."], "Map Features": {"Description": "The map is a 24x24 2D grid, ra

W0226 06:14:53.018000 160412 torch/fx/experimental/symbolic_shapes.py:5124] [0/0] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0226 06:14:53.019000 160412 torch/fx/experimental/recording.py:298] [0/0] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0226 06:14:53.051000 160412 torch/fx/experimental/symbolic_shapes.py:5124] [1/0] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0226 06:14:53.053000 160412 torch/fx/experimental/recording.py:298] [1/0] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0226 06:14:53.090000 160412 torch/fx/experimental/symbolic_shapes.py:5124] [2/0] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0226 06:14:53.091000 160412 torch/fx/experimental/recording.py:298] [2/0] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0226 06:14:54.376000 160

---------- Backward start
---------- Backward end
---------- Backward start
---------- Backward end
---------- Chat text ----------

<｜begin▁of▁sentence｜>
System: You are an AI playing a strategy game. Output actions only.


Background:
- Your goal is to win the 1 vs 1 game by collecting relic points.
- You can win a match by collecting more relic points than your opponent.
- You can win the game by winning 3 matches first.
- You can play the game by choosing an action for each unit you have.

Below is the game rules:
{"Game Objective": ["Two teams compete in a best-of-5 match sequence (called a game).", "Each match lasts 100 steps.", "Teams control units to gain relic points on the map while preventing the opposing team from doing the same.", "Gain more relic points than your opponent.", "Choose an action for each unit you have.", "Strategy: Explore more in early matches to discover relics fast, learn the map and opponent behavior, then exploit this knowledge to win later matches."], 

W0226 06:19:08.735000 160412 torch/fx/experimental/symbolic_shapes.py:5124] [0/1] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0226 06:19:08.736000 160412 torch/fx/experimental/recording.py:298] [0/1] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0226 06:19:08.783000 160412 torch/fx/experimental/symbolic_shapes.py:5124] [1/1] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0226 06:19:08.785000 160412 torch/fx/experimental/recording.py:298] [1/1] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0226 06:19:08.834000 160412 torch/fx/experimental/symbolic_shapes.py:5124] [2/1] failed during evaluate_expr(Eq(u0, 1), hint=None, size_oblivious=False, forcing_spec=False
E0226 06:19:08.835000 160412 torch/fx/experimental/recording.py:298] [2/1] failed while running evaluate_expr(*(Eq(u0, 1), None), **{'fx_node': False})
W0226 06:19:09.067000 160

---------- Backward start
---------- Backward end
---------- Backward start
---------- Backward end
---------- Chat text ----------

<｜begin▁of▁sentence｜>
System: You are an AI playing a strategy game. Output actions only.


Background:
- Your goal is to win the 1 vs 1 game by collecting relic points.
- You can win a match by collecting more relic points than your opponent.
- You can win the game by winning 3 matches first.
- You can play the game by choosing an action for each unit you have.

Below is the game rules:
{"Game Objective": ["Two teams compete in a best-of-5 match sequence (called a game).", "Each match lasts 100 steps.", "Teams control units to gain relic points on the map while preventing the opposing team from doing the same.", "Gain more relic points than your opponent.", "Choose an action for each unit you have.", "Strategy: Explore more in early matches to discover relics fast, learn the map and opponent behavior, then exploit this knowledge to win later matches."], 

OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 0 has a total capacity of 23.55 GiB of which 1.13 GiB is free. Including non-PyTorch memory, this process has 21.48 GiB memory in use. 21.19 GiB allowed; Of the allocated memory 19.52 GiB is allocated by PyTorch, and 1.65 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)