Skip to content

mr-spaghetti-code/connections-rl

Repository files navigation

Connections-RL

Train a language model (Gemma-3 4B-it) to play the NYT Connections puzzle game using Reinforcement Learning with LoRA on Apple Silicon (MLX framework).

Overview

Connections is a daily word puzzle where players must identify four groups of four related words from a grid of 16 words. This project uses GRPO (Group Relative Policy Optimization) to fine-tune an LLM to play the game effectively.

Project Structure

connections-rl/
├── config/
│   └── connections_grpo_config.json   # Training configuration
├── data/
│   ├── clean_data.csv                 # Raw puzzle data
│   ├── connections_games.json         # Processed game records
│   └── connections_rl_data.jsonl      # Synthetic training data
├── src/
│   ├── connections/
│   │   ├── game_logic.py              # State, evaluation, transitions
│   │   ├── game.py                    # Full game loop
│   │   ├── prompt.py                  # System prompt, formatting
│   │   └── rewards.py                 # Reward calculation
│   ├── synth/
│   │   └── connections_data_generation.py
│   └── utils/
│       └── config.py
├── scripts/
│   ├── validate_data.py               # Data validation
│   ├── prepare_data.py                # Data transformation
│   ├── train_connections.py           # Main training script
│   └── evaluate_connections.py        # Side-by-side evaluation
├── tests/
│   ├── test_connections_game_logic.py
│   ├── test_connections_rewards.py
│   └── test_connections_prompt.py
└── requirements.txt

Quick Start

1. Install Dependencies

pip install -r requirements.txt

2. Prepare Data

# Validate the raw data
python scripts/validate_data.py

# Process and split into train/val/test
python scripts/prepare_data.py

# Generate synthetic training data
PYTHONPATH=. python src/synth/connections_data_generation.py --num-samples 2000

3. Train the Model

python scripts/train_connections.py --config config/connections_grpo_config.json

4. Evaluate

# Evaluate base model only
python scripts/evaluate_connections.py --num-games 50

# Compare base vs LoRA
python scripts/evaluate_connections.py --adapter experiments/YOUR_RUN/adapters/adapter_final.npz

Configuration

Key settings in config/connections_grpo_config.json:

{
    "model": {"name": "mlx-community/gemma-3-4b-it-bf16"},
    "training": {
        "iterations": 500,
        "learning_rate": 1e-6
    },
    "lora": {
        "rank": 64,
        "alpha": 128.0
    },
    "rl": {
        "num_generations": 2,
        "max_trials": 4
    },
    "reward": {
        "win_game": 30.0,
        "correct_guess_hard": 16.0,
        "format_fail_penalty": 12.0
    }
}

Reward Design

The reward function incentivizes:

Event Reward
Win game +30
Perfect game (0 mistakes) +15 bonus
Correct guess (by difficulty) +8 to +16
"3 of 4" feedback +1
Invalid format -12
Word not in pool -8 per word

Testing

# Run all tests
PYTHONPATH=. pytest tests/ -v

# Run specific test file
PYTHONPATH=. pytest tests/test_connections_game_logic.py -v

Data

The dataset contains 332 valid Connections puzzles (2023-07-01 to 2024-06-16):

  • Train: 232 games
  • Validation: 50 games
  • Test: 50 games

Split is chronological to prevent data leakage.

Success Criteria

  • LoRA model win rate > base model win rate
  • Valid formatted guesses >90% of the time
  • Training completes 500 steps without collapse

License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages