# Lab 0: Generate Teacher Logits

This notebook generates teacher model logits for knowledge distillation. The teacher model (Qwen3-30B-A3B) processes a dataset and outputs logits that will be used to train a smaller student model.

## Import Dependencies

Import required libraries for model inference, tokenization, and Neuron-specific configurations.

In [None]:
import torch
import json
import argparse

from transformers import AutoTokenizer, GenerationConfig
from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig
from neuronx_distributed_inference.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeInferenceConfig, NeuronQwen3MoeForCausalLM
from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config

torch.manual_seed(0)

## Configuration

Set model paths and file locations. The teacher model will be compiled and saved to the traced model path for efficient inference on AWS Neuron.

In [None]:
model_path = "Qwen/Qwen3-30B-A3B"
traced_model_path = "/home/ubuntu/traced_model/Qwen3-30B-A3B/"
dataset_file = "data/dataset.txt"
output_file = "output.json"

## Create Conversation Template

Define a function to format input text as a conversation for the sentiment classification task.

In [None]:
def create_conversation(sample):
    system_message = (
        "You are a sentiment classifier. You take input strings and return the sentiment of POSITIVE, NEGATIVE, or NEUTRAL. Only return the sentiment."
    )
    return [
        {
            "role": "system",
            "content": system_message,
        },
        {
            "role": "user",
            "content": sample
        },
    ]

## Initialize Model Configuration

Configure the Neuron-specific settings for the Qwen3 MoE model:
- Tensor parallelism degree: 8 (distributes model across 8 NeuronCores)
- Enable logit output for distillation
- Set sampling parameters for generation

In [None]:
generation_config = GenerationConfig.from_pretrained(model_path)

neuron_config = MoENeuronConfig(
    tp_degree=8,
    batch_size=1,
    max_context_length=128,
    seq_len=1024,
    on_device_sampling_config=OnDeviceSamplingConfig(do_sample=True, temperature=0.6, top_k=20, top_p=0.95),
    enable_bucketing=False,
    flash_decoding_enabled=False,
    output_scores=True,
    output_logits=True
)

config = Qwen3MoeInferenceConfig(
    neuron_config,
    load_config=load_pretrained_config(model_path),
)

tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
tokenizer.pad_token = tokenizer.eos_token

## Compile and Save Model

Compile the model for AWS Neuron hardware. This step converts the model to a Neuron-optimized format and saves it for reuse.

In [None]:
print("\nCompiling and saving model...")
model = NeuronQwen3MoeForCausalLM(model_path, config)
model.compile(traced_model_path)
tokenizer.save_pretrained(traced_model_path)

## Load Compiled Model

Load the compiled model from disk for inference.

In [None]:
model = NeuronQwen3MoeForCausalLM(traced_model_path)
model.load(traced_model_path)
tokenizer = AutoTokenizer.from_pretrained(traced_model_path)

## Process Dataset and Generate Logits

Process each line in the dataset through the teacher model:
1. Format input as a conversation
2. Generate output with logits
3. Extract finite logits (filter out -inf values)
4. Save results with prompt, generated text, and token logits

In [None]:
results = []
with open(dataset_file, 'r') as f:
    for line in f:
        if line.strip():
            try:
                input_text = create_conversation(line.strip())
                formatted_chat = tokenizer.apply_chat_template(
                    input_text,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=False
                )
                inputs = tokenizer(formatted_chat, padding=True, return_tensors="pt")
                generation_model = HuggingFaceGenerationAdapter(model)
                outputs = generation_model.generate(
                    inputs.input_ids,
                    generation_config=generation_config,
                    attention_mask=inputs.attention_mask,
                    max_length=model.config.neuron_config.max_length,
                    return_dict_in_generate=True,
                    output_scores=True,
                    output_logits=True
                )
                
                print(outputs)
                generated_tokens = outputs.sequences[0]
                token_logits = outputs.scores
                generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                print(generated_text)
                
                token_logits_list = []
                for logits in token_logits:
                    finite_mask = torch.isfinite(logits[0])
                    finite_indices = torch.nonzero(finite_mask).squeeze().tolist()
                    finite_logits = logits[0][finite_mask]
                    token_info = {
                        'indices': finite_indices,
                        'logits': finite_logits.tolist()
                    }
                    token_logits_list.append(token_info)
                
                print(token_logits_list)
                results.append({
                    'prompt': line.strip(),
                    'response': {
                        'generated_text': generated_text,
                        'token_logits': token_logits_list
                    }
                })
            except Exception as e:
                print(f"Error processing prompt: {line[:50]}...")
                print(f"Error message: {str(e)}")
                results.append({
                    'prompt': line.strip(),
                    'error': str(e)
                })

## Save Results

Write the generated logits and responses to a JSON file for use in distillation training.

In [None]:
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Processing complete! Processed {len(results)} prompts. Results written to {output_file}")