In [None]:
# Installing necessary libraries
!pip install unsloth vllm                # Install Unsloth and vLLM (for fast inference/training)
!pip install --upgrade pillow            # Upgrade PIL (used for image processing)
!pip install datasets                    # Hugging Face datasets library
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b  # Install specific TRL commit (GRPO support)
!pip install tensorboard                 # TensorBoard for visualizing training logs

# Clean up existing PIL or Google modules from memory (for fresh imports later)
import sys
modules = list(sys.modules.keys())
for x in modules:
    if "PIL" in x or "google" in x:
        sys.modules.pop(x)               # Unload any conflicting or partially loaded modules


In [None]:
# Import from Unsloth: FastLanguageModel is a wrapper to simplify loading/training LLMs
# PatchFastRL patches the model to support GRPO (a reinforcement learning algorithm)
from unsloth import FastLanguageModel, PatchFastRL
import torch

PatchFastRL("GRPO", FastLanguageModel)  # Patch the FastLanguageModel with GRPO functionality

# Configuration parameters
max_seq_length = 1024       # Maximum length of a sequence during training
lora_rank = 64              # Rank used in LoRA (Low-Rank Adaptation)

# Load the model using Unsloth's FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="google/gemma-2-2b-it",  # Name of the base model (Gemma 2B Instruction-tuned)
    max_seq_length=max_seq_length,      # Pass max sequence length
    load_in_4bit=True,                  # Load model in 4-bit precision (saves memory)
    fast_inference=False,              # Disable inference optimizations (can set to True for inference-only mode) for (reduce latency, minimize computational costs, and improve scalability)
    max_lora_rank=lora_rank,           # Set LoRA rank for Unsloth internal validation
    gpu_memory_utilization=0.5,        # Limit GPU memory use to 50%
)


In [11]:
# Apply LoRA (parameter-efficient fine-tuning) to the model
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,                        # LoRA rank
    target_modules=[                   # List of model modules where LoRA will be applied
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank,              # Scaling factor for LoRA
    use_gradient_checkpointing="unsloth",  # Enable gradient checkpointing to save memory
    random_state=3407,                 # Set a fixed seed for reproducibility
)


In [12]:
# Load and preprocess the PokerBench dataset.

from datasets import load_dataset, Dataset
import re

# System instruction to be included in the user prompt
SYSTEM_INSTRUCTION = """You are an expert poker player. Your response must be in the format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>"""

def inspect_dataset():
    """Inspect the PokerBench dataset to understand its structure."""
    data = load_dataset("RZ412/PokerBench")["train"]
    sample = data[0]
    print("Sample PokerBench record:", sample)
    print("Available keys:", list(sample.keys()))
    return data

def format_poker_prompt(game_state: dict) -> list:
    """Construct zero-shot prompt from game state, compatible with Gemma-2-2B."""
    # Fallback values for missing keys
    position = game_state.get('position', 'Unknown')
    stack = game_state.get('stack', 'Unknown')
    hand = game_state.get('hand', 'Unknown')
    community_cards = game_state.get('community_cards', 'None')
    pot = game_state.get('pot', 'Unknown')
    to_call = game_state.get('to_call', '0')
    opponent_actions = game_state.get('opponent_actions', 'None')

    # Check for alternative field names (adjust based on inspection)
    if 'player_position' in game_state:
        position = game_state['player_position']
    if 'stack_size' in game_state:
        stack = game_state['stack_size']
    if 'cards' in game_state:
        hand = game_state['cards']
    if 'board' in game_state:
        community_cards = game_state['board']
    if 'pot_size' in game_state:
        pot = game_state['pot_size']
    if 'amount_to_call' in game_state:
        to_call = game_state['amount_to_call']
    if 'actions' in game_state:
        opponent_actions = game_state['actions']

    prompt = f"""{SYSTEM_INSTRUCTION}

Game State:
- Position: {position}
- Stack: {stack}
- Hand: {hand}
- Community Cards: {community_cards}
- Pot: {pot}
- To Call: {to_call}
- Opponent Actions: {opponent_actions}

What action should you take (fold, call, raise, check)? Provide reasoning and final action."""
    return [
        {"role": "user", "content": prompt}
    ]

def extract_xml_answer(text: str) -> str:
    """Extract answer from XML-formatted response."""
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except:
        return ""

def get_pokerbench_dataset(split="train") -> Dataset:
    """Load and preprocess PokerBench dataset."""
    data = load_dataset("RZ412/PokerBench")[split]

    # Inspect the dataset to confirm structure
    print("Inspecting PokerBench dataset...")
    sample = data[0]
    print("Sample record:", sample)
    print("Available keys:", list(sample.keys()))

    # Test the chat template with a sample prompt
    print("Testing chat template...")
    sample_prompt = format_poker_prompt(sample)
    try:
        test_output = tokenizer.apply_chat_template(sample_prompt, tokenize=False, add_generation_prompt=True)
        print("Sample formatted prompt:", test_output)
    except Exception as e:
        print(f"Chat template error: {e}")
        raise

    # Map the dataset to create prompts
    try:
        data = data.map(
            lambda x: {
                "prompt": format_poker_prompt(x),
                "answer": x.get("action", "")  # Use .get to avoid KeyError
            },
            batched=False  # Process one example at a time to isolate errors
        )
    except Exception as e:
        print(f"Error during dataset mapping: {e}")
        raise
    return data

# Load and preprocess dataset
dataset = get_pokerbench_dataset()

Inspecting PokerBench dataset...
Sample record: {'instruction': '\n\nYou are a specialist in playing 6-handed No Limit Texas Holdem. The following will be a game scenario and you need to make the optimal decision.\n\nHere is a game summary:\n\nThe small blind is 0.5 chips and the big blind is 1 chips. Everyone started with 100 chips.\nThe player positions involved in this game are UTG, HJ, CO, BTN, SB, BB.\nIn this hand, your position is HJ, and your holding is [King of Diamond and Jack of Spade].\nBefore the flop, HJ raise 2.0 chips, and BB call. Assume that all other players that is not mentioned folded.\nThe flop comes King Of Spade, Seven Of Heart, and Two Of Diamond, then BB check, and HJ check.\nThe turn comes Jack Of Club, then BB check, HJ bet 3 chips, BB raise 10 chips, and HJ call.\nThe river comes Seven Of Club, then BB check.\n\n\nNow it is your turn to make a move.\nTo remind you, the current pot size is 24.0 chips, and your holding is [King of Diamond and Jack of Spade].\

In [13]:
# Reward Functions
# Define reward functions for GRPO training.

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    '''Checks if the model gave the correct final answer, compared to the ground-truth answer'''
    responses = [completion[0]["content"] for completion in completions]
    extracted = [extract_xml_answer(r) for r in responses]
    return [2.0 if r.lower() == a.lower() else 0.0 for r, a in zip(extracted, answer)]

def valid_action_reward_func(completions, **kwargs) -> list[float]:
  '''Checks whether the model generated a valid poker action, even if it's not the correct one'''
    valid_actions = {"fold", "call", "raise", "check"}
    responses = [completion[0]["content"] for completion in completions]
    extracted = [extract_xml_answer(r).lower() for r in responses]
    return [0.5 if r in valid_actions else 0.0 for r in extracted]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
  '''Enforces that the model strictly follows the expected response format'''
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    return [0.5 if re.match(pattern, r) else 0.0 for r in responses]

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
  '''-More lenient formatting check than the strict regex version
     - Measures partial compliance with the expected tags '''
    def count_xml(text: str) -> float:
        count = 0.0
        if text.count("<reasoning>\n") == 1:
            count += 0.125
        if text.count("\n</reasoning>\n") == 1:
            count += 0.125
        if text.count("\n<answer>\n") == 1:
            count += 0.125
        if text.count("\n</answer>") == 1:
            count += 0.125
        return count
    responses = [completion[0]["content"] for completion in completions]
    return [count_xml(r) for r in responses]

In [14]:
# Training with GRPO
# Configure and run GRPO training.

from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported   # Utility from Unsloth to detect hardware support for BF16

training_args = GRPOConfig(
    use_vllm=False,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    num_generations=4,
    max_prompt_length=256,
    max_completion_length=200,
    max_steps=50,
    save_steps=50,
    max_grad_norm=0.1,   # Clip gradients to avoid exploding gradients
    report_to="tensorboard",
    output_dir="outputs",
    run_name="gemma_pokerbench",
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        strict_format_reward_func,
        valid_action_reward_func,
        correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 563,200 | Num Epochs = 1 | Total steps = 50
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 83,066,880 of 2,697,408,768 (3.08% trained)
  out = torch_matmul(X, W, out = out)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / strict_format_reward_func,rewards / valid_action_reward_func,rewards / correctness_reward_func
1,0.0,0.5,0.285264,173.125,0.000161,0.375,0.0,0.125,0.0
2,0.0,0.4375,0.375411,175.75,0.000158,0.3125,0.0,0.0,0.125
3,0.0,0.375,0.213766,159.375,0.000351,0.3125,0.0,0.0625,0.0
4,0.0,0.320312,0.206897,179.375,0.003973,0.289062,0.0,0.03125,0.0
5,0.0,0.398438,0.23054,181.0625,0.035626,0.335938,0.0,0.0625,0.0
6,0.0001,0.375,0.184912,165.875,0.137117,0.34375,0.0,0.03125,0.0
7,0.0003,0.304688,0.222113,184.4375,0.271274,0.273438,0.0,0.03125,0.0
8,0.0005,0.4375,0.396465,186.9375,0.509281,0.28125,0.0,0.03125,0.125
9,0.0008,0.242188,0.110462,195.1875,0.761819,0.242188,0.0,0.0,0.0
10,0.0017,0.21875,0.094837,189.125,1.745564,0.21875,0.0,0.0,0.0


  out = torch_matmul(X, W, out = out)


TrainOutput(global_step=50, training_loss=0.016330718234474375, metrics={'train_runtime': 1397.1832, 'train_samples_per_second': 0.143, 'train_steps_per_second': 0.036, 'total_flos': 0.0, 'train_loss': 0.016330718234474375})

In [15]:
from google.colab import drive
drive.mount('/content/drive')
model.save_pretrained("/content/drive/MyDrive/lora_adapter")

Mounted at /content/drive


In [16]:

from peft import PeftModel
from unsloth import FastLanguageModel

# Sample test prompt (adjust keys based on dataset schema from inspection)
test_prompt = format_poker_prompt({
    "player_position": "Button",
    "stack_size": 1000,
    "cards": ["As", "Kd"],
    "board": [],
    "pot_size": 100,
    "amount_to_call": 50,
    "actions": ["Small Blind posts 25", "Big Blind posts 50"]
})

# Tokenize the prompt
text = tokenizer.apply_chat_template(test_prompt, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=1024).to("cuda")

# Debug input shape
print("Input shape:", inputs["input_ids"].shape)

# Inference without GRPO
print("Inference without GRPO:")
outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=1024,
    temperature=0.8,
    top_p=0.95,
    do_sample=True,
)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(output)


Input shape: torch.Size([1, 140])
Inference without GRPO:
user
You are an expert poker player. Your response must be in the format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

Game State:
- Position: Button
- Stack: 1000
- Hand: ['As', 'Kd']
- Community Cards: []
- Pot: 100
- To Call: 50
- Opponent Actions: ['Small Blind posts 25', 'Big Blind posts 50']

What action should you take (fold, call, raise, check)? Provide reasoning and final action.
model
<reasoning> You have a decent pair of Aces and Kings, making it a strong hand.  However, the pot is still relatively small at 100. There's a good chance you'll face a larger bet on the flop, so calling might limit your range and lead to a potentially unfavorable showdown. </reasoning>
<answer> **Raise** </answer> 
 
**Explanation:**

* **Positional advantage:** Being in the button gives you the opportunity to see the flop with your pair before the other players have to act.
* **Limiting the action:**  Raising forces your opponent 

In [17]:
# Save LoRA
model.save_pretrained("grpo_poker_lora")
# Inference with GRPO LoRA
print("\nInference with GRPO LoRA:")
# Load the LoRA weights onto the existing model
model = PeftModel.from_pretrained(model, "grpo_poker_lora").to("cuda")
outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=1024,
    temperature=0.8,
    top_p=0.95,
    do_sample=True,
)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(output)


Inference with GRPO LoRA:




user
You are an expert poker player. Your response must be in the format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

Game State:
- Position: Button
- Stack: 1000
- Hand: ['As', 'Kd']
- Community Cards: []
- Pot: 100
- To Call: 50
- Opponent Actions: ['Small Blind posts 25', 'Big Blind posts 50']

What action should you take (fold, call, raise, check)? Provide reasoning and final action.
model
<reasoning>
- **Position:** Button is a very strong position. You get to act last and have a good chance of controlling the pace of the hand.
- **Hand:**  While Ace-King is a strong starting hand, it's not a dominant hand on its own. You have a strong top pair and an Ace, but you have no flush draws.
- **Pot Odds:** The pot is 100. While not a huge pot, it's enough to make things interesting.
- **Opponent Actions:**  The blinds are in, and the Big Blind is sizing up the pot.
- **Opponent's Behavior:** The Big Blind's actions can be read as passive or aggressive depending on their playing