# Train Qwen3-30B-A3B for Coding with Tinker API

This notebook fine-tunes Qwen3-30B-A3B on:
1. **Phase 1**: rStar-Coder dataset (125K coding examples)
2. **Phase 2**: SWE-bench tool-use conversations (229 examples)

**Note**: The actual model training happens on Tinker's cloud servers. This notebook handles data preprocessing and API calls.

## Setup

In [None]:
# Install dependencies
!pip install -q tinker datasets transformers tqdm

In [None]:
# Set your Tinker API key
import os
os.environ["TINKER_API_KEY"] = "tml-8zb1lan2FmC69nOsIE4jwd6MdTp9Oc2MCmL8IJfKkIk42U4CAGIFCzY6K8TRKjeXCAAAA"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import json
import time
import pickle
from pathlib import Path
from typing import List, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import tinker
from tinker import types
from tinker.types.tensor_data import TensorData
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm.auto import tqdm

print("Libraries loaded!")

## Configuration

In [None]:
# Model configuration
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
LORA_RANK = 32

# Phase 1: Coding
PHASE1_LEARNING_RATE = 5e-5
PHASE1_BATCH_SIZE = 128
PHASE1_MAX_LENGTH = 2048
PHASE1_MAX_SAMPLES = None  # Set to e.g. 10000 for faster testing

# Phase 2: Tool Use
PHASE2_LEARNING_RATE = 5e-5
PHASE2_BATCH_SIZE = 32
PHASE2_MAX_LENGTH = 8192
PHASE2_NUM_EPOCHS = 3

# Paths
DATA_DIR = Path("/content/data")
DATA_DIR.mkdir(exist_ok=True)

print(f"Model: {MODEL_NAME}")
print(f"LoRA Rank: {LORA_RANK}")

## Load Tokenizer

In [None]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

## Phase 1: Prepare rStar-Coder Data

In [None]:
def conversation_to_datum(conversation: Dict, tokenizer, max_length: int) -> Optional[types.Datum]:
    """Convert a conversation to a Tinker Datum."""
    messages = conversation.get("messages", [])
    if not messages:
        return None

    try:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
    except Exception:
        parts = []
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
        text = "\n".join(parts)

    tokens = tokenizer.encode(text, add_special_tokens=True)
    if len(tokens) > max_length:
        tokens = tokens[:max_length]
    if len(tokens) < 2:
        return None

    # Create weights for assistant tokens
    weights = [0.0] * (len(tokens) - 1)
    text_so_far = ""
    in_assistant = False
    
    for i, token in enumerate(tokens[:-1]):
        decoded = tokenizer.decode([token])
        text_so_far += decoded
        if "assistant" in text_so_far[-20:].lower() and "<|im_start|>" in text_so_far[-30:]:
            in_assistant = True
        elif "<|im_end|>" in decoded:
            in_assistant = False
        elif "<|im_start|>" in decoded:
            in_assistant = False
        if in_assistant:
            weights[i] = 1.0

    if sum(weights) == 0:
        for i in range(len(weights) // 4, len(weights)):
            weights[i] = 1.0

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={
            "target_tokens": TensorData.from_numpy(np.array(target_tokens, dtype=np.int64)),
            "weights": TensorData.from_numpy(np.array(weights, dtype=np.float32)),
        }
    )

In [None]:
# Check for cached datums
phase1_cache = DATA_DIR / "phase1_datums.pkl"

if phase1_cache.exists():
    print("Loading cached Phase 1 datums...")
    with open(phase1_cache, "rb") as f:
        phase1_datums = pickle.load(f)
    print(f"Loaded {len(phase1_datums)} datums from cache")
else:
    print("Loading rStar-Coder dataset...")
    dataset = load_dataset("microsoft/rStar-Coder", "seed_sft", split="train")
    print(f"Total examples: {len(dataset)}")
    
    # Filter
    if "verified" in dataset.column_names:
        dataset = dataset.filter(lambda x: x.get("verified", True))
        print(f"After verified filter: {len(dataset)}")
    if "is_passed" in dataset.column_names:
        dataset = dataset.filter(lambda x: x.get("is_passed", True))
        print(f"After is_passed filter: {len(dataset)}")
    
    if PHASE1_MAX_SAMPLES:
        dataset = dataset.select(range(min(PHASE1_MAX_SAMPLES, len(dataset))))
        print(f"Using {len(dataset)} samples")
    
    # Convert to conversations
    print("Converting to conversations...")
    conversations = []
    for row in tqdm(dataset, desc="Converting"):
        question = row.get("question", "")
        response = row.get("response", "")
        code = row.get("code", "")
        
        if response and code:
            assistant_content = f"{response}\n\n```python\n{code}\n```"
        elif code:
            assistant_content = f"```python\n{code}\n```"
        elif response:
            assistant_content = response
        else:
            continue
            
        conversations.append({
            "messages": [
                {"role": "user", "content": question},
                {"role": "assistant", "content": assistant_content}
            ]
        })
    
    print(f"Created {len(conversations)} conversations")
    
    # Tokenize with parallel processing
    print("Tokenizing (this takes a few minutes)...")
    phase1_datums = []
    
    for conv in tqdm(conversations, desc="Tokenizing"):
        datum = conversation_to_datum(conv, tokenizer, PHASE1_MAX_LENGTH)
        if datum:
            phase1_datums.append(datum)
    
    print(f"Created {len(phase1_datums)} datums")
    
    # Cache for later
    print("Caching datums...")
    with open(phase1_cache, "wb") as f:
        pickle.dump(phase1_datums, f)
    print(f"Cached to {phase1_cache}")

## Phase 1: Training

In [None]:
def compute_mean_nll(logprobs: List, weights: List) -> float:
    """Compute mean negative log-likelihood."""
    total_loss = 0.0
    total_weight = 0.0
    for lp, w in zip(logprobs, weights):
        if hasattr(lp, 'to_numpy'):
            lp_arr = lp.to_numpy()
        else:
            lp_arr = np.array(lp)
        if hasattr(w, 'to_numpy'):
            w_arr = w.to_numpy()
        else:
            w_arr = np.array(w)
        total_loss += float(np.sum(-lp_arr * w_arr))
        total_weight += float(np.sum(w_arr))
    return total_loss / total_weight if total_weight > 0 else 0.0

In [None]:
# Create Tinker client
print("Creating Tinker training client...")
service_client = tinker.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model=MODEL_NAME,
    rank=LORA_RANK
)
print("Training client ready!")

In [None]:
# Phase 1 Training Loop
n_batches = len(phase1_datums) // PHASE1_BATCH_SIZE
print(f"Phase 1: {n_batches} batches of {PHASE1_BATCH_SIZE}")

# Shuffle
np.random.shuffle(phase1_datums)

phase1_losses = []
pbar = tqdm(range(n_batches), desc="Phase 1 Training")

for step in pbar:
    # Get batch
    batch_start = step * PHASE1_BATCH_SIZE
    batch = phase1_datums[batch_start:batch_start + PHASE1_BATCH_SIZE]
    
    # Learning rate schedule
    lr_mult = max(0.0, 1.0 - step / n_batches)
    current_lr = PHASE1_LEARNING_RATE * lr_mult
    
    adam_params = types.AdamParams(
        learning_rate=current_lr,
        beta1=0.9,
        beta2=0.95,
        eps=1e-8
    )
    
    # Forward-backward
    fwd_bwd_future = training_client.forward_backward(batch, loss_fn="cross_entropy")
    optim_future = training_client.optim_step(adam_params)
    
    fwd_bwd_result = fwd_bwd_future.result()
    optim_future.result()
    
    # Compute loss
    train_logprobs = [x["logprobs"] for x in fwd_bwd_result.loss_fn_outputs]
    train_weights = [d.loss_fn_inputs["weights"] for d in batch]
    train_nll = compute_mean_nll(train_logprobs, train_weights)
    phase1_losses.append(train_nll)
    
    pbar.set_postfix({"NLL": f"{train_nll:.4f}", "LR": f"{current_lr:.2e}"})
    
    # Save checkpoint every 100 steps
    if step > 0 and step % 100 == 0:
        save_result = training_client.save_state(name=f"phase1-{step:06d}").result()
        print(f"\nCheckpoint saved: {save_result.path}")

print("\nPhase 1 complete!")

In [None]:
# Save Phase 1 checkpoint
phase1_save = training_client.save_state(name="phase1-final").result()
phase1_path = phase1_save.path
print(f"Phase 1 final checkpoint: {phase1_path}")

# Plot loss
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(phase1_losses)
plt.xlabel("Step")
plt.ylabel("NLL")
plt.title("Phase 1: Training Loss")
plt.show()

## Phase 2: Prepare SWE-bench Tool-Use Data

In [None]:
# Check for cached datums
phase2_cache = DATA_DIR / "phase2_datums.pkl"

if phase2_cache.exists():
    print("Loading cached Phase 2 datums...")
    with open(phase2_cache, "rb") as f:
        phase2_datums = pickle.load(f)
    print(f"Loaded {len(phase2_datums)} datums from cache")
else:
    print("Loading SWE-bench tool-use dataset...")
    dataset = load_dataset(
        "AlexCuadron/SWE-Bench-Verified-O1-native-tool-calling-reasoning-high-results",
        split="test"
    )
    print(f"Total examples: {len(dataset)}")
    
    # Filter for resolved
    dataset = dataset.filter(lambda x: x.get("resolved", False) == True)
    print(f"After resolved filter: {len(dataset)}")
    
    # Parse conversations
    print("Parsing conversations...")
    conversations = []
    
    for row in tqdm(dataset, desc="Parsing"):
        conv_str = row.get("full_conversation_jsonl", "")
        if not conv_str:
            continue
        
        try:
            turns = json.loads(conv_str)
            messages = []
            
            for turn in turns:
                if not isinstance(turn, dict):
                    continue
                
                for msg in turn.get("messages", []):
                    content_parts = msg.get("content", [])
                    if isinstance(content_parts, str):
                        content = content_parts
                    elif isinstance(content_parts, list):
                        text_parts = [p.get("text", "") if isinstance(p, dict) and p.get("type") == "text" else str(p) if isinstance(p, str) else "" for p in content_parts]
                        content = "\n".join(text_parts)
                    else:
                        continue
                    
                    if content.strip():
                        role = msg.get("role", "assistant")
                        if role not in ["user", "assistant", "system"]:
                            role = "assistant"
                        messages.append({"role": role, "content": content.strip()})
                
                response = turn.get("response", "")
                if response and isinstance(response, str) and response.strip():
                    messages.append({"role": "assistant", "content": response.strip()})
            
            if messages:
                # Merge consecutive same-role messages
                cleaned = []
                for msg in messages:
                    if cleaned and cleaned[-1]["role"] == msg["role"]:
                        cleaned[-1]["content"] += "\n\n" + msg["content"]
                    else:
                        cleaned.append(msg)
                
                # Add system message
                system_msg = {
                    "role": "system",
                    "content": f"You are a helpful coding assistant working on issue {row.get('issue_name', 'unknown')} in project {row.get('project', 'unknown')}. Use tools to fix the issue."
                }
                conversations.append({"messages": [system_msg] + cleaned})
        except:
            continue
    
    print(f"Created {len(conversations)} conversations")
    
    # Tokenize
    print("Tokenizing...")
    phase2_datums = []
    for conv in tqdm(conversations, desc="Tokenizing"):
        datum = conversation_to_datum(conv, tokenizer, PHASE2_MAX_LENGTH)
        if datum:
            phase2_datums.append(datum)
    
    print(f"Created {len(phase2_datums)} datums")
    
    # Cache
    with open(phase2_cache, "wb") as f:
        pickle.dump(phase2_datums, f)
    print(f"Cached to {phase2_cache}")

## Phase 2: Training

In [None]:
# Phase 2 uses the checkpoint from Phase 1
# training_client already has Phase 1 weights loaded

n_batches = max(1, len(phase2_datums) // PHASE2_BATCH_SIZE)
total_steps = n_batches * PHASE2_NUM_EPOCHS
print(f"Phase 2: {n_batches} batches Ã— {PHASE2_NUM_EPOCHS} epochs = {total_steps} steps")

phase2_losses = []
step = 0

for epoch in range(PHASE2_NUM_EPOCHS):
    np.random.shuffle(phase2_datums)
    pbar = tqdm(range(n_batches), desc=f"Phase 2 Epoch {epoch+1}/{PHASE2_NUM_EPOCHS}")
    
    for batch_idx in pbar:
        batch_start = batch_idx * PHASE2_BATCH_SIZE
        batch = phase2_datums[batch_start:batch_start + PHASE2_BATCH_SIZE]
        
        if not batch:
            continue
        
        lr_mult = max(0.0, 1.0 - step / total_steps)
        current_lr = PHASE2_LEARNING_RATE * lr_mult
        
        adam_params = types.AdamParams(
            learning_rate=current_lr,
            beta1=0.9,
            beta2=0.95,
            eps=1e-8
        )
        
        fwd_bwd_future = training_client.forward_backward(batch, loss_fn="cross_entropy")
        optim_future = training_client.optim_step(adam_params)
        
        fwd_bwd_result = fwd_bwd_future.result()
        optim_future.result()
        
        train_logprobs = [x["logprobs"] for x in fwd_bwd_result.loss_fn_outputs]
        train_weights = [d.loss_fn_inputs["weights"] for d in batch]
        train_nll = compute_mean_nll(train_logprobs, train_weights)
        phase2_losses.append(train_nll)
        
        pbar.set_postfix({"NLL": f"{train_nll:.4f}", "LR": f"{current_lr:.2e}"})
        step += 1

print("\nPhase 2 complete!")

In [None]:
# Save final model
final_save = training_client.save_state(name="final").result()
sampler_save = training_client.save_weights_for_sampler(name="final-sampler").result()

print(f"\n" + "="*50)
print("TRAINING COMPLETE!")
print("="*50)
print(f"Final state path: {final_save.path}")
print(f"Sampler path: {sampler_save.path}")
print("\nUse the sampler path with Tinker's sampling API or OpenAI-compatible endpoint.")

In [None]:
# Plot both phases
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].plot(phase1_losses)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("NLL")
axes[0].set_title("Phase 1: Coding (rStar-Coder)")

axes[1].plot(phase2_losses)
axes[1].set_xlabel("Step")
axes[1].set_ylabel("NLL")
axes[1].set_title("Phase 2: Tool Use (SWE-bench)")

plt.tight_layout()
plt.show()

## Test the Model

In [None]:
# Create sampling client
sampling_client = service_client.create_sampling_client(model_path=sampler_save.path)

# Test prompt
test_prompt = "Write a Python function to find the longest palindromic substring in a string."
messages = [{"role": "user", "content": test_prompt}]

prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer.encode(prompt_text)

result = sampling_client.sample(
    prompt=types.ModelInput.from_ints(tokens=prompt_tokens),
    num_samples=1,
    sampling_params=types.SamplingParams(max_tokens=512, temperature=0.7)
).result()

response = tokenizer.decode(result.sequences[0].tokens)
print(f"Prompt: {test_prompt}")
print(f"\nResponse:\n{response}")