# Lab 0: Generate Teacher Logits for Chess Move Evaluation

## Introduction

In this lab, you will generate teacher model logits for chess move evaluation using the MATE_DATASET. The task is to evaluate two candidate chess moves and determine which one is better.

**Dataset**: [OutFlankShu/MATE_DATASET](https://huggingface.co/datasets/OutFlankShu/MATE_DATASET)

**Task Format**:
- **Instruction**: System prompt explaining the chess evaluation task
- **Input**: FEN board position + two candidate moves with strategies and tactics
- **Output**: The better move (e.g., "MoveA:d2d8" or "MoveB:d4e6")

**Teacher Model**: Qwen3-30B-A3B (30B parameter MoE model)

**Prerequisites**:
- AWS EC2 instance with Trainium/Inferentia (e.g., trn1.32xlarge or inf2.48xlarge)
- AWS Neuron SDK installed
- Virtual environment activated: `source /opt/aws_neuronx_venv_pytorch_2_8_nxd_inference/bin/activate`

## Download Model Weights

Download the Qwen3-30B-A3B teacher model from HuggingFace.

In [None]:
!hf download Qwen/Qwen3-30B-A3B

## Import Dependencies

In [None]:
import torch
import json
from datasets import load_dataset

from transformers import AutoTokenizer, GenerationConfig
from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig
from neuronx_distributed_inference.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeInferenceConfig, NeuronQwen3MoeForCausalLM
from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config

torch.manual_seed(0)

## Configuration

In [None]:
model_path = "Qwen/Qwen3-30B-A3B"
traced_model_path = "/home/ubuntu/traced_model/Qwen3-30B-A3B/"
output_file = "data/chess_output.json"
num_samples = 100  # Number of samples to process from the dataset

## Load MATE_DATASET

Load the chess move evaluation dataset from HuggingFace.

In [None]:
print("Loading MATE_DATASET in streaming mode...")
dataset_stream = load_dataset(
    "OutFlankShu/MATE_DATASET", 
    split="train",
    streaming=True
)

# Convert to list, handling errors gracefully
dataset = []
skipped = 0
for i, sample in enumerate(dataset_stream):
    try:
        # Validate sample has required fields
        if 'instruction' in sample and 'input' in sample and 'output' in sample:
            dataset.append(sample)
        else:
            skipped += 1
            print(f"Skipping record {i}: missing required fields")
        
        # Stop once we have enough samples (load extra in case of more errors)
        if len(dataset) >= num_samples:
            break
    except Exception as e:
        skipped += 1
        print(f"Skipping corrupted record {i}: {e}")
        continue

print(f"Dataset loaded: {len(dataset)} valid samples (skipped {skipped} corrupted records)")
print(f"Will process {min(num_samples, len(dataset))} samples")

# Show example
if len(dataset) > 0:
    print("\nExample sample:")
    print(f"Instruction: {dataset[0]['instruction']}")
    print(f"Input: {dataset[0]['input'][:200]}...")
    print(f"Output: {dataset[0]['output']}")
else:
    print("\nWarning: No valid samples loaded!")

## Create Conversation Template

Format the chess evaluation task as a conversation for the model.

In [None]:
def create_conversation(instruction, input_text):
    """
    Create a conversation format for chess move evaluation.
    
    Uses a classification-style prompt to minimize output tokens and logits,
    which is optimal for knowledge distillation training.
    
    Args:
        instruction: The task instruction (system message) - will be replaced
        input_text: The chess position and candidate moves
    
    Returns:
        List of message dictionaries for the chat template
    """
    return [
        {
            "role": "system",
            "content": "Classify the better move. Output format: MoveA or MoveB"
        },
        {
            "role": "user",
            "content": input_text
        },
    ]

## Initialize Model Configuration

In [None]:
generation_config = GenerationConfig.from_pretrained(model_path)

neuron_config = MoENeuronConfig(
    tp_degree=8,
    batch_size=1,
    max_context_length=512,  # Increased for longer chess descriptions
    seq_len=1024,
    on_device_sampling_config=OnDeviceSamplingConfig(do_sample=True, temperature=0.6, top_k=20, top_p=0.95),
    enable_bucketing=False,
    flash_decoding_enabled=False,
    output_scores=True,
    output_logits=True
)

config = Qwen3MoeInferenceConfig(
    neuron_config,
    load_config=load_pretrained_config(model_path),
)

tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
tokenizer.pad_token = tokenizer.eos_token

## Compile and Save Model

**Note**: This step takes 30-60 minutes on first run. Skip if model is already compiled.

In [None]:
print("\nCompiling and saving model...")
model = NeuronQwen3MoeForCausalLM(model_path, config)
model.compile(traced_model_path)
tokenizer.save_pretrained(traced_model_path)
print("Model compiled and saved!")

## Load Compiled Model

Load the pre-compiled model for inference.

In [None]:
print("Loading compiled model...")
model = NeuronQwen3MoeForCausalLM(traced_model_path)
model.load(traced_model_path)
tokenizer = AutoTokenizer.from_pretrained(traced_model_path)
print("Model loaded!")

## Process Dataset and Generate Logits

Process chess positions through the teacher model to generate move evaluation logits.

In [None]:
results = []

for idx in range(min(num_samples, len(dataset))):
    sample = dataset[idx]
    
    try:
        print(f"\nProcessing sample {idx + 1}/{num_samples}...")
        
        # Create conversation from instruction and input
        conversation = create_conversation(sample['instruction'], sample['input'])
        
        # Format with chat template
        formatted_chat = tokenizer.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        )
        
        # Tokenize
        inputs = tokenizer(formatted_chat, padding=True, return_tensors="pt")
        
        # Create generation adapter (must be inside loop)
        generation_model = HuggingFaceGenerationAdapter(model)
        
        # Generate with logits
        outputs = generation_model.generate(
            inputs.input_ids,
            generation_config=generation_config,
            attention_mask=inputs.attention_mask,
            max_length=model.config.neuron_config.max_length,
            return_dict_in_generate=True,
            output_scores=True,
            output_logits=True
        )
        
        # Extract generated text
        generated_tokens = outputs.sequences[0]
        generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        
        # Extract and filter logits
        token_logits_list = []
        for logits in outputs.scores:
            finite_mask = torch.isfinite(logits[0])
            finite_indices = torch.nonzero(finite_mask).squeeze().tolist()
            finite_logits = logits[0][finite_mask]
            
            # Handle single value case
            if isinstance(finite_indices, int):
                finite_indices = [finite_indices]
            
            token_info = {
                'indices': finite_indices,
                'logits': finite_logits.tolist()
            }
            token_logits_list.append(token_info)
        
        print(f"Generated: {generated_text}")
        print(f"Expected: {sample['output']}")
        
        # Store results
        results.append({
            'instruction': sample['instruction'],
            'input': sample['input'],
            'expected_output': sample['output'],
            'response': {
                'generated_text': generated_text,
                'token_logits': token_logits_list
            }
        })
        
    except Exception as e:
        print(f"Error processing sample {idx}: {str(e)}")
        results.append({
            'instruction': sample['instruction'],
            'input': sample['input'],
            'expected_output': sample['output'],
            'error': str(e)
        })

print(f"\nProcessing complete! Processed {len(results)} samples.")

## Save Results

Save the generated logits to JSON for use in distillation training.

In [None]:
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {output_file}")
print(f"Total samples: {len(results)}")
print(f"Successful: {sum(1 for r in results if 'error' not in r)}")
print(f"Errors: {sum(1 for r in results if 'error' in r)}")