# DSPY-based chessboard classification
path of data: ../data

In [1]:
import dspy
import pandas as pd
import os
from pathlib import Path
import base64
from transformers import AutoProcessor, AutoModelForVision2Seq
from typing import Any, Dict, List, Optional
from PIL import Image
import torch
import numpy as np
from io import BytesIO

In [2]:
def dspy_image_to_pil(img: Any) -> Image.Image:
    """
    Convert dspy.Image (or similar) to a PIL.Image.
    Handles:
      - dspy.Image(url=...) where url is a data:...;base64,... string
      - dspy.Image(url=...) where url is a local file path
      - plain file-path strings
      - direct PIL.Image
    """

    # 1) If it's already a PIL image
    if isinstance(img, Image.Image):
        return img.convert("RGB")

    # 2) If it has a .url attribute (dspy.Image)
    url = getattr(img, "url", None)
    if isinstance(url, str):
        s = url.strip()

        # data URL with base64 content
        if s.startswith("data:") and "base64," in s:
            b64 = s.split("base64,", 1)[1]
            data = base64.b64decode(b64)
            return Image.open(BytesIO(data)).convert("RGB")

        # plain local path case (DSPy can also store raw path)
        if os.path.exists(s):
            return Image.open(s).convert("RGB")

    # 3) If it has a .path attribute (your own wrappers)
    p = getattr(img, "path", None)
    if isinstance(p, str) and os.path.exists(p):
        return Image.open(p).convert("RGB")

    # 4) If the object itself is a path-like string
    if isinstance(img, str) and os.path.exists(img):
        return Image.open(img).convert("RGB")

    raise ValueError(
        "Could not convert image to PIL. "
        "This is likely a dspy.Image with an unexpected internal format."
    )



class VisionChatAdapter(dspy.ChatAdapter):
    def format(self, signature, demos, inputs):
        messages = super().format(signature=signature, demos=demos, inputs=inputs)

        img = inputs.get("board_image", None)
        if img is None:
            return messages

        for m in reversed(messages):
            if m.get("role") != "user":
                continue

            # Keep existing text (whatever DSPy produced)
            existing = m.get("content", "")
            if isinstance(existing, list):
                # If DSPy already produced structured parts, just ensure an image part exists
                has_img = any(isinstance(p, dict) and p.get("type") == "image" for p in existing)
                if not has_img:
                    m["content"] = [{"type": "image", "data": img}] + existing
            else:
                # Convert plain text -> structured [image, text]
                m["content"] = [
                    {"type": "image", "data": img},
                    {"type": "text", "text": str(existing)},
                ]

            # Keep legacy field too (harmless, helps debugging)
            m["images"] = [img]
            break

        return messages


class LocalVisionLM(dspy.BaseLM):
    def __init__(self, model_id: str, **kwargs):
        super().__init__(model=model_id, model_type="chat", **kwargs)

        # Prefer GPU if it has enough free memory; otherwise fall back to CPU
        self.device = "cpu"
        if torch.cuda.is_available():
            try:
                free_bytes, total_bytes = torch.cuda.mem_get_info()
                # Require ~6GB free before loading the 4B model on GPU
                if free_bytes > 6 * 1024**3:
                    self.device = "cuda"
            except Exception:
                self.device = "cpu"

        # bf16 is best on Ampere+; fallback to fp16 on older GPUs; CPU uses fp32
        if self.device == "cuda":
            self.dtype = (
                torch.bfloat16
                if torch.cuda.get_device_capability(0)[0] >= 8
                else torch.float16
            )
        else:
            self.dtype = torch.float32

        self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        # Store the actual model as vision_model to avoid overwriting self.model
        self.vision_model = AutoModelForVision2Seq.from_pretrained(
            model_id,
            torch_dtype=self.dtype,
            trust_remote_code=True,
            low_cpu_mem_usage=True,
            device_map=None,
        )
        self.vision_model = self.vision_model.to(self.device)
        self.vision_model.eval()

    def __call__(
        self,
        prompt: Optional[str] = None,
        messages: Optional[List[Dict[str, Any]]] = None,
        max_tokens: Optional[int] = None,
        max_new_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        **kwargs,
    ):
        # Normalize generation length
        gen_max_new = max_new_tokens if max_new_tokens is not None else (max_tokens if max_tokens is not None else 512)

        # Normalize messages
        if messages is None:
            if prompt is None:
                raise ValueError("LocalVisionLM needs either `prompt` or `messages`.")
            messages = [{"role": "user", "content": prompt}]

        # Collect images and rebuild structured messages for the chat template
        images: List[Image.Image] = []
        chat_msgs: List[Dict[str, Any]] = []

        for m in messages:
            role = m.get("role", "user")
            content = m.get("content", "")
            parts = []

            if isinstance(content, list):
                for part in content:
                    if not isinstance(part, dict):
                        continue
                    if part.get("type") == "image":
                        src = part.get("image") or part.get("data")
                        if src is not None:
                            images.append(dspy_image_to_pil(src))
                        parts.append({"type": "image"})
                    elif part.get("type") == "text":
                        parts.append({"type": "text", "text": str(part.get("text", ""))})
            else:
                parts.append({"type": "text", "text": str(content)})

            if not parts:
                parts.append({"type": "text", "text": ""})

            chat_msgs.append({"role": role, "content": parts})

        if not images:
            # Fallback: check legacy "images" field
            for m in messages:
                legacy_imgs = m.get("images")
                if legacy_imgs:
                    for img in legacy_imgs:
                        images.append(dspy_image_to_pil(img))

        if not images:
            raise ValueError("No image found. Did you configure VisionChatAdapter and pass board_image=?")

        # Build text prompt with image placeholders using the chat template
        if hasattr(self.processor, "apply_chat_template"):
            text = self.processor.apply_chat_template(
                chat_msgs,
                tokenize=False,
                add_generation_prompt=True,
            )
        else:
            text = "\n".join(
                part.get("text", "")
                for msg in chat_msgs
                for part in msg.get("content", [])
                if part.get("type") == "text"
            ).strip()

        inputs = self.processor(text=text, images=images, return_tensors="pt")
        for k, v in inputs.items():
            if hasattr(v, "to"):
                inputs[k] = v.to(self.vision_model.device)

        gen_kwargs = {"max_new_tokens": gen_max_new}
        if temperature is not None and temperature > 0:
            gen_kwargs.update({"do_sample": True, "temperature": float(temperature)})

        with torch.no_grad():
            out = self.vision_model.generate(**inputs, **gen_kwargs)

        completion = self.processor.batch_decode(out, skip_special_tokens=True)[0].strip()
        return [completion]


In [3]:
MODEL_ID = "Qwen/Qwen3-VL-2B-Instruct"

lm = LocalVisionLM(MODEL_ID, temperature=0.0, max_tokens=700, cache=False)
dspy.configure(lm=lm, adapter=VisionChatAdapter())

`torch_dtype` is deprecated! Use `dtype` instead!


In [4]:
#testing connection
# lm("Say this is a test!", temperature=0.7)  # => ['This is a test!']
# lm(messages=[{"role": "user", "content": "Say this is a test!"}])  # => ['This is a test!']

## load the data
This notebook is in ../project_home_dir/DSPy-Classifier
The data is in ../project_home_dir/data

In [5]:
# Define data directory
data_dir = Path("../data")

# Load all CSV files from different games
games_data = {}
all_samples = []

for game_folder in data_dir.glob("game*_per_frame"):
    game_name = game_folder.name.split("_")[0]  # Extract game2, game4, etc.
    csv_path = game_folder / f"{game_name}.csv"
    images_dir = game_folder / "tagged_images"
    
    if csv_path.exists():
        df = pd.read_csv(csv_path)
        
        # Check which columns exist
        has_fen = 'fen' in df.columns
        has_classification = 'classification' in df.columns or 'label' in df.columns
        
        print(f"\n{game_name}:")
        print(f"  - Rows: {len(df)}")
        print(f"  - Has FEN: {has_fen}")
        print(f"  - Has classification: {has_classification}")
        print(f"  - Columns: {df.columns.tolist()}")
        
        # Load images for each frame range
        if images_dir.exists():
            image_files = sorted(images_dir.glob("*.jpg")) + sorted(images_dir.glob("*.png"))
            print(f"  - Available images: {len(image_files)}")
            
            # Create samples with images
            for idx, row in df.iterrows():
                sample = {
                    'game': game_name,
                    'from_frame': row.get('from_frame', None),
                    'to_frame': row.get('to_frame', None),
                    'fen': row.get('fen', None),
                    'classification': row.get('classification', row.get('label', None)),
                    'has_label': has_classification and pd.notna(row.get('classification', row.get('label', None)))
                }
                
                # Try to find corresponding image
                # Assuming images might be named with frame numbers
                if 'from_frame' in row:
                    frame_num = row['from_frame']
                    matching_images = [img for img in image_files if f"{frame_num}" in img.name or f"frame_{frame_num}" in img.name]
                    if matching_images:
                        sample['image_path'] = str(matching_images[0])
                        try:
                            pil_image = Image.open(matching_images[0])
                            image_array = np.array(pil_image)
                            
                            # Use the path as a simple object with path attribute
                            class ImageWithPath:
                                def __init__(self, path):
                                    self.path = path
                                def __str__(self):
                                    return self.path
                            
                            sample['image'] = dspy.Image.from_file(str(matching_images[0]))
                            sample['pil_image'] = pil_image
                            sample['image_array'] = image_array
                        except Exception as e:
                            print(f"  - Warning: Could not load image {matching_images[0]}: {e}")
                            sample['image'] = None
                            sample['pil_image'] = None
                            sample['image_array'] = None
                
                all_samples.append(sample)
        
        games_data[game_name] = df

# Create comprehensive dataframe
all_data = pd.DataFrame(all_samples)

print(f"\n{'='*60}")
print(f"Total samples: {len(all_data)}")
print(f"Samples with labels: {all_data['has_label'].sum()}")
print(f"Samples without labels: {(~all_data['has_label']).sum()}")
print(f"Samples with images: {all_data['image'].notna().sum()}")
print(f"\nFirst few samples:")
all_data.head()


game7:
  - Rows: 57
  - Has FEN: True
  - Has classification: False
  - Columns: ['from_frame', 'to_frame', 'fen']
  - Available images: 55

game6:
  - Rows: 93
  - Has FEN: True
  - Has classification: False
  - Columns: ['from_frame', 'to_frame', 'fen']
  - Available images: 92


  sample['image'] = dspy.Image.from_file(str(matching_images[0]))
  sample['image'] = dspy.Image.from_file(str(matching_images[0]))



game5:
  - Rows: 110
  - Has FEN: True
  - Has classification: False
  - Columns: ['from_frame', 'to_frame', 'fen']
  - Available images: 109


  sample['image'] = dspy.Image.from_file(str(matching_images[0]))



game4:
  - Rows: 186
  - Has FEN: True
  - Has classification: False
  - Columns: ['from_frame', 'to_frame', 'fen']
  - Available images: 184


  sample['image'] = dspy.Image.from_file(str(matching_images[0]))



game2:
  - Rows: 77
  - Has FEN: True
  - Has classification: False
  - Columns: ['from_frame', 'to_frame', 'fen']
  - Available images: 77

Total samples: 523
Samples with labels: 0
Samples without labels: 523
Samples with images: 519

First few samples:


  sample['image'] = dspy.Image.from_file(str(matching_images[0]))


Unnamed: 0,game,from_frame,to_frame,fen,classification,has_label,image_path,image,pil_image,image_array
0,game7,172,172,rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR,,False,../data/game7_per_frame/tagged_images/frame_00...,"<<CUSTOM-TYPE-START-IDENTIFIER>>[{""type"": ""ima...",<PIL.JpegImagePlugin.JpegImageFile image mode=...,"[[[180, 176, 167], [189, 184, 178], [191, 186,..."
1,game7,428,428,rnbqkbnr/pppppppp/8/8/3P4/8/PPP1PPPP/RNBQKBNR,,False,../data/game7_per_frame/tagged_images/frame_00...,"<<CUSTOM-TYPE-START-IDENTIFIER>>[{""type"": ""ima...",<PIL.JpegImagePlugin.JpegImageFile image mode=...,"[[[185, 171, 162], [190, 176, 167], [190, 175,..."
2,game7,696,696,rnbqkb1r/pppppppp/5n2/8/3P4/8/PPP1PPPP/RNBQKBNR,,False,../data/game7_per_frame/tagged_images/frame_00...,"<<CUSTOM-TYPE-START-IDENTIFIER>>[{""type"": ""ima...",<PIL.JpegImagePlugin.JpegImageFile image mode=...,"[[[45, 44, 50], [29, 28, 36], [21, 19, 30], [2..."
3,game7,708,708,rnbqkb1r/pppppppp/5n2/8/3P1B2/8/PPP1PPPP/RN1QKBNR,,False,../data/game7_per_frame/tagged_images/frame_00...,"<<CUSTOM-TYPE-START-IDENTIFIER>>[{""type"": ""ima...",<PIL.JpegImagePlugin.JpegImageFile image mode=...,"[[[28, 27, 32], [18, 17, 22], [11, 11, 19], [1..."
4,game7,736,736,rnbqkb1r/ppp1pppp/5n2/3p4/3P1B2/8/PPP1PPPP/RN1...,,False,../data/game7_per_frame/tagged_images/frame_00...,"<<CUSTOM-TYPE-START-IDENTIFIER>>[{""type"": ""ima...",<PIL.JpegImagePlugin.JpegImageFile image mode=...,"[[[33, 33, 41], [20, 20, 30], [17, 16, 30], [2..."


# Define signatures

In [6]:
class PieceClassificationSignature(dspy.Signature):
    """Classify chess pieces in a board square from an image.
    
    Args:
        board_image: Image containing the chessboard or square
        square_position: Board position (e.g., "e4", "a1")
        
    Returns:
        piece: Chess piece notation (K/Q/R/B/N/P for white, k/q/r/b/n/p for black, . for empty, ? for unknown/occluded)
        confidence: Confidence level of the classification (high/medium/low)
        reasoning: Brief explanation of the classification decision
    """
    board_image: dspy.Image = dspy.InputField(desc="Chessboard image or square region")
    square_position: str = dspy.InputField(desc="Square position in algebraic notation (e.g., e4)")
    piece: str = dspy.OutputField(desc="Piece: K/Q/R/B/N/P (white) or k/q/r/b/n/p (black) or . (empty) or ? (unknown/occluded)")
    confidence: str = dspy.OutputField(desc="Confidence: high/medium/low")
    reasoning: str = dspy.OutputField(desc="Why you made this classification")

In [14]:
class BoardStateSignature(dspy.Signature):
    """Analyze entire chessboard image and classify all 64 squares in one call.
    
    This signature is preferred over 64 individual square calls for efficiency.
    Uses global board context for better accuracy.
    
    Returns:
        piece_json: JSON mapping square->piece for all squares. Use ? for unknown/occluded.
        fen_notation: Standard FEN notation string (8 ranks), using ? for unknown squares
        board_confidence: Overall confidence score (0.0-1.0)
        occlusion_notes: List of squares that are occluded or unclear (e.g., "e4, f5")
    """
    board_image: dspy.Image = dspy.InputField(desc="Full chessboard image")
    piece_json: str = dspy.OutputField(desc='JSON: {"a1": "R", "e1": "K", "e4": ".", "f5": "?", ...} where ? = unknown/occluded')
    fen_notation: str = dspy.OutputField(desc="FEN notation (position part only), use ? for unknown squares")
    board_confidence: float = dspy.OutputField(desc="Confidence score: 0.0-1.0")
    occlusion_notes: str = dspy.OutputField(desc="Comma-separated list of occluded squares (e.g., 'e4, f5')")

Let's test the performance on the first 10 images

In [15]:
# create 10 Examples from the data, nothing fancy, just for testing and not dspy.example
from tqdm.notebook import tqdm

test_examples = []
for i, row in all_data.head(5).iterrows():
    image = row.get('image', None)
    if image is not None:
        test_examples.append(image)
img = test_examples[0]
print(type(img), getattr(img, "url", None)[:60])
# create a classifier based on the signature
classifier = dspy.Predict(signature=BoardStateSignature)
# run predictions
predictions = []
for i in tqdm(range(0,5)):
    example = test_examples[i]
    predictions.append(classifier(board_image=example))

# print predictions
for i, prediction in enumerate(predictions):
    print(f"\nPrediction for Example {i+1}:")
    print(f"FEN Notation: {prediction.fen_notation}")
    print(f"Piece JSON: {prediction.piece_json}")
    print(f"Board Confidence: {prediction.board_confidence}")
    print(f"Occlusion Notes: {prediction.occlusion_notes}")

<class 'dspy.adapters.types.image.Image'> data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAEBA


  0%|          | 0/5 [00:00<?, ?it/s]

ValueError: No image found. Did you configure VisionChatAdapter and pass board_image=?

In [13]:
# print predictions
for i, prediction in enumerate(predictions):
    print(f"\nPrediction for Example {i+1}:")
    print(f"FEN Notation: {prediction.fen_notation.strip()}")
    print(f"Piece JSON: {prediction.piece_json}")
    print(f"Board Confidence: {prediction.board_confidence}")
    print(f"Occlusion Notes: {prediction.occlusion_notes}")


Prediction for Example 1:
FEN Notation: {fen_notation}
Piece JSON: {piece_json}
Board Confidence: {board_confidence}
Occlusion Notes: {occlusion_notes}

Prediction for Example 2:
FEN Notation: {fen_notation}
Piece JSON: {piece_json}
Board Confidence: {board_confidence}
Occlusion Notes: {occlusion_notes}

Prediction for Example 3:
FEN Notation: {fen_notation}
Piece JSON: {piece_json}
Board Confidence: {board_confidence}
Occlusion Notes: {occlusion_notes}

Prediction for Example 4:
FEN Notation: {fen_notation}
Piece JSON: {piece_json}
Board Confidence: {board_confidence}
Occlusion Notes: {occlusion_notes}

Prediction for Example 5:
FEN Notation: {fen_notation}
Piece JSON: {piece_json}
Board Confidence: {board_confidence}
Occlusion Notes: {occlusion_notes}
