# MinT Quickstart Guide

## What is MinT?

**MinT (Mind Lab Toolkit)** is an open infrastructure for training language models using:
- **SFT (Supervised Fine-Tuning)**: Learn from labeled examples (input → output pairs)
- **RL (Reinforcement Learning)**: Learn from reward signals (trial and error)

MinT uses **LoRA (Low-Rank Adaptation)** to efficiently fine-tune large models without modifying all parameters.

## What You'll Learn

In this tutorial, we'll train a model to solve **multiplication problems** using a two-stage approach:

1. **Stage 1 (SFT)**: Teach the model multiplication with labeled examples
2. **Stage 2 (RL)**: Load the SFT model and refine it with reward signals

This demonstrates the complete workflow: SFT → Save → Load → RL

## Prerequisites

- Python >= 3.11
- A MinT API key

---
## Step 0: Installation

Install the MinT SDK from the git repository (alpha - API may change):

In [None]:
%pip install git+https://github.com/MindLab-Research/mindlab-toolkit.git python-dotenv matplotlib numpy

---
## Step 1: Configure Your API Key

MinT requires an API key for authentication. There are two ways to set it up:

### Option A: Using a `.env` file (Recommended)

Create a file named `.env` in your project directory:

```
MINT_API_KEY=sk-mint-your-api-key-here
```

### Option B: Set environment variable directly

```python
import os
os.environ['MINT_API_KEY'] = 'sk-mint-your-api-key-here'
```

**Security Note**: Never commit your API key to version control. Add `.env` to your `.gitignore` file.

---

### Tinker Compatibility

MinT is fully API-compatible with [Tinker](https://tinker.thinkingmachines.ai). If you prefer, you can use the `tinker` package with MinT by configuring environment variables to point to the MinT server:

```bash
pip install tinker
```

```
TINKER_BASE_URL=https://mint.macaron.im/
TINKER_API_KEY=<your-mint-api-key>
```

All code in this tutorial works identically with `import tinker` instead of `import mint`.

---

### HuggingFace Mirror

If you have network issues accessing HuggingFace, set the mirror endpoint **before** importing `mint`:

```python
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import mint  # Must be after setting HF_ENDPOINT
```

This redirects all HuggingFace model downloads to the mirror site.

In [None]:
import os
from dotenv import load_dotenv

# Load API key from .env file
load_dotenv()

# Verify API key is set
if not os.environ.get('MINT_API_KEY'):
    print("WARNING: MINT_API_KEY not found!")
    print("Please create a .env file with: MINT_API_KEY=sk-mint-your-key-here")
else:
    print("API key loaded successfully")

---
## Step 2: Connect to MinT Server

The `ServiceClient` is your entry point to MinT. It handles:
- Authentication with the server
- Creating training and sampling clients
- Querying available models

In [None]:
import mint

# Create the service client
service_client = mint.ServiceClient()

# List available models
print("Connected to MinT server!")
print("\nAvailable models:")
try:
    capabilities = service_client.get_server_capabilities()
    for model in capabilities.supported_models:
        print(f"  - {model.model_name}")
except Exception as e:
    print(f"  Error: {e}")

---
## Core Concepts

Before we start training, let's understand the key concepts:

| Component | Purpose |
|-----------|:--------|
| `TrainingClient` | Manages your LoRA adapter and handles training operations |
| `SamplingClient` | Generates text from your trained model |
| `Datum` | A single training example with `model_input` and `loss_fn_inputs` |

### Training Loop Pattern
```python
for each batch:
    forward_backward()  # Compute gradients
    optim_step()        # Update model weights
```

### Loss Functions
- **cross_entropy**: For SFT - maximizes probability of correct tokens
- **importance_sampling**: For RL - weights updates by advantage

---
# Stage 1: Supervised Fine-Tuning (SFT)

**Goal**: Teach the model to solve two-digit multiplication using labeled examples.

**How SFT works**:
1. Show the model input-output pairs
2. Model learns to predict the output given the input
3. Loss = how surprised the model is by the correct answer

```
Input:  "Question: What is 47 * 83?\nAnswer:"
Output: " 3901"
```

### Step 3: Create a Training Client

We'll use `Qwen/Qwen3-0.6B` - a small but capable model perfect for learning.

**LoRA Parameters**:
- `rank=16`: Size of the low-rank matrices (higher = more capacity, slower)
- `train_mlp=True`: Train the feed-forward layers
- `train_attn=True`: Train the attention layers
- `train_unembed=True`: Train the output projection

In [None]:
from mint import types

BASE_MODEL = "Qwen/Qwen3-0.6B"

# Create a training client with LoRA configuration
training_client = service_client.create_lora_training_client(
    base_model=BASE_MODEL,
    rank=16,              # LoRA rank - controls adapter capacity
    train_mlp=True,       # Train MLP (feed-forward) layers
    train_attn=True,      # Train attention layers  
    train_unembed=True,   # Train the output projection
)
print(f"Training client created for: {BASE_MODEL}")

# Get the tokenizer - converts text to/from token IDs
tokenizer = training_client.get_tokenizer()
print(f"Tokenizer vocabulary size: {tokenizer.vocab_size:,} tokens")

### Step 4: Prepare Training Data

We need to convert our examples into `Datum` objects that MinT can process.

**Key concept - Weights**:
- `weight=0`: Don't compute loss on this token (the prompt)
- `weight=1`: Compute loss on this token (the answer we want to learn)

In [None]:
# Generate arithmetic training examples
import random
random.seed(42)  # For reproducibility

def generate_sft_examples(n=100):
    """Generate two-digit multiplication examples for SFT."""
    examples = []
    for _ in range(n):
        a = random.randint(10, 99)
        b = random.randint(10, 99)
        examples.append({
            "question": f"What is {a} * {b}?",
            "answer": str(a * b)
        })
    return examples

sft_examples = generate_sft_examples(100)

print(f"Generated {len(sft_examples)} training examples")
print("\nSample examples:")
for ex in sft_examples[:3]:
    print(f"  Q: {ex['question']} → A: {ex['answer']}")

In [None]:
def process_sft_example(example: dict, tokenizer) -> types.Datum:
    """
    Convert a training example into a Datum for MinT.
    
    The model learns next-token prediction:
    Given tokens [0, 1, 2, 3], predict [1, 2, 3, 4]
    
    We use weights to only compute loss on the answer tokens.
    """
    # Format the prompt and completion
    prompt = f"Question: {example['question']}\nAnswer:"
    completion = f" {example['answer']}"
    
    # Tokenize prompt and completion separately
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(completion, add_special_tokens=False)
    
    # Add EOS token so model learns when to stop
    eos_token_id = tokenizer.eos_token_id
    completion_tokens = completion_tokens + [eos_token_id]
    
    # Create weights: 0 for prompt (don't learn), 1 for completion (learn this)
    prompt_weights = [0] * len(prompt_tokens)
    completion_weights = [1] * len(completion_tokens)
    
    # Combine everything
    all_tokens = prompt_tokens + completion_tokens
    all_weights = prompt_weights + completion_weights
    
    # For next-token prediction:
    # - input_tokens:  all tokens except the last
    # - target_tokens: all tokens except the first (shifted by 1)
    input_tokens = all_tokens[:-1]
    target_tokens = all_tokens[1:]
    weights = all_weights[1:]  # Weights align with targets
    
    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={
            "target_tokens": target_tokens,
            "weights": weights
        }
    )

# Process all examples
sft_data = [process_sft_example(ex, tokenizer) for ex in sft_examples]
print(f"Prepared {len(sft_data)} training datums")
print(f"EOS token ID: {tokenizer.eos_token_id}")

### Visualize the Data

Let's see how the first example looks after processing.
- **Weight=0**: Prompt tokens (model sees these but doesn't learn from them)
- **Weight=1**: Completion tokens (model learns to predict these)

In [None]:
# Visualize the first example
datum0 = sft_data[0]

print(f"Example: Q: {sft_examples[0]['question']} → A: {sft_examples[0]['answer']}")
print(f"EOS token: {repr(tokenizer.decode([tokenizer.eos_token_id]))}")
print()
print(f"{'Input Token':<20} {'Target Token':<20} {'Weight':<10} {'Learn?'}")
print("-" * 65)

inputs = datum0.model_input.to_ints()
targets = datum0.loss_fn_inputs['target_tokens']
weights = datum0.loss_fn_inputs['weights']

# Handle both list and tensor types
if hasattr(targets, 'tolist'):
    targets = targets.tolist()
if hasattr(weights, 'tolist'):
    weights = weights.tolist()

for inp, tgt, wgt in zip(inputs, targets, weights):
    learn = "YES" if wgt > 0 else "no"
    inp_str = repr(tokenizer.decode([inp]))
    tgt_str = repr(tokenizer.decode([tgt]))
    print(f"{inp_str:<20} {tgt_str:<20} {wgt:<10} {learn}")

### Step 5: Train the Model (SFT)

The training loop:
1. **forward_backward()**: Compute the loss and gradients
2. **optim_step()**: Update model weights using Adam optimizer

We'll run multiple steps and watch the loss decrease.

In [None]:
import numpy as np

# Training configuration
NUM_SFT_STEPS = 10
SFT_LEARNING_RATE = 5e-5

sft_losses = []
print(f"Starting SFT training for {NUM_SFT_STEPS} steps...")
print(f"Learning rate: {SFT_LEARNING_RATE}")
print()

for step in range(NUM_SFT_STEPS):
    fwdbwd_future = training_client.forward_backward(
        data=sft_data,
        loss_fn="cross_entropy"
    )
    fwdbwd_result = fwdbwd_future.result()
    
    total_loss = 0.0
    total_weight = 0.0
    for i, out in enumerate(fwdbwd_result.loss_fn_outputs):
        logprobs = out['logprobs']
        if hasattr(logprobs, 'tolist'):
            logprobs = logprobs.tolist()
        w = sft_data[i].loss_fn_inputs['weights']
        if hasattr(w, 'tolist'):
            w = w.tolist()
        for lp, wt in zip(logprobs, w):
            total_loss += -lp * wt
            total_weight += wt
    
    loss = total_loss / max(total_weight, 1)
    sft_losses.append(loss)
    
    optim_future = training_client.optim_step(
        types.AdamParams(learning_rate=SFT_LEARNING_RATE)
    )
    optim_future.result()
    
    print(f"Step {step+1:2d}/{NUM_SFT_STEPS}: Loss = {loss:.4f}")

print(f"\nSFT training complete!")
print(f"Loss: {sft_losses[0]:.4f} → {sft_losses[-1]:.4f}")

### Step 6: Test the SFT Model

Let's see if our model learned basic arithmetic!

In [None]:
import re

def extract_answer(response: str) -> str | None:
    """Extract the first numeric answer from response."""
    numbers = re.findall(r'\d+', response)
    return numbers[0] if numbers else None

# Save weights and create sampling client
print("Saving SFT model weights...")
sft_sampling_client = training_client.save_weights_and_get_sampling_client(
    name='arithmetic-sft'
)
print("Model ready for inference!\n")

# Test on problems in SFT range (10-99)
test_problems = [
    ("What is 23 * 47?", "1081"),
    ("What is 56 * 34?", "1904"),
    ("What is 71 * 89?", "6319"),
]

print("Testing SFT model (10-99 range):")
print("=" * 50)
sft_correct = 0

for question, correct in test_problems:
    prompt = f"Question: {question}\nAnswer:"
    prompt_tokens = types.ModelInput.from_ints(tokenizer.encode(prompt))
    
    result = sft_sampling_client.sample(
        prompt=prompt_tokens,
        num_samples=1,
        sampling_params=types.SamplingParams(
            max_tokens=16,
            temperature=0.0,
            stop_token_ids=[tokenizer.eos_token_id]
        )
    ).result()
    
    response = tokenizer.decode(result.sequences[0].tokens)
    extracted = extract_answer(response)
    is_correct = extracted == correct
    if is_correct:
        sft_correct += 1
    
    print(f"Q: {question}")
    print(f"A: {response.strip()} (extracted: {extracted}, correct: {correct}) [{'PASS' if is_correct else 'FAIL'}]")
    print()

print(f"SFT Accuracy: {sft_correct}/{len(test_problems)}")

---
### Step 7: Save Checkpoint

Save the SFT model so we can load it and continue training with RL.

**`save_state()`** saves:
- LoRA weights
- Optimizer state (for seamless training continuation)

In [None]:
# Save full training state (weights + optimizer)
sft_checkpoint = training_client.save_state(name="arithmetic-sft-checkpoint").result()
print(f"Checkpoint saved to: {sft_checkpoint.path}")
print("\nThis checkpoint contains:")
print("  - LoRA weights trained on arithmetic")
print("  - Adam optimizer state (momentum, variance)")

---
# Stage 2: Reinforcement Learning (RL)

**Goal**: Load the SFT model and refine it using reward signals.

**How RL differs from SFT**:
- SFT: "Here's the correct answer, learn it"
- RL: "Try different answers, I'll tell you if you're right or wrong"

**RL Workflow**:
```
1. Generate multiple responses to each problem
2. Compute rewards (1.0 = correct, 0.0 = wrong)
3. Compute advantages (how much better than average?)
4. Train: increase probability of good responses, decrease bad ones
```

**Why continue with RL after SFT?**
- SFT teaches the format and basic capability
- RL refines the model through exploration and reward optimization
- Common pattern in real-world LLM training pipelines

### Step 8: Continue Training with RL

We continue training the same model with RL. In production, you would typically:

1. Save checkpoint: `training_client.save_state(name)`
2. Load checkpoint: `service_client.create_training_client_from_state_with_optimizer(path)`

For this demo, we continue with the same training client to show the SFT → RL transition.

In [None]:
# Continue with RL training using the same client
# (In production, you would load from checkpoint for distributed training or resumption)

rl_training_client = training_client  # Continue from SFT

print("Continuing from SFT model...")
print("\nThe model now has:")
print("  - SFT knowledge of arithmetic")
print("  - Ready for RL refinement")

### Step 9: Define the Reward Function

In RL, we need a way to evaluate responses:
- **Reward = 1.0**: Correct answer
- **Reward = 0.0**: Wrong answer

In [None]:
def generate_rl_problem():
    """Generate harder multiplication problems for RL (10-199 vs SFT's 10-99)."""
    a = random.randint(10, 199)
    b = random.randint(10, 199)
    return f"What is {a} * {b}?", str(a * b)

def compute_reward(response: str, correct_answer: str) -> float:
    """Reward function: 1.0 if correct, 0.0 otherwise."""
    extracted = extract_answer(response)
    return 1.0 if extracted == correct_answer else 0.0

# Demo
print("Reward function demo:")
print("RL problems: 10-199 * 10-199 (harder than SFT's 10-99 * 10-99)")
print()
q, a = generate_rl_problem()
print(f"Question: {q}, Answer: {a}")
print(f"  '{a}' → reward = {compute_reward(a, a)}")
print(f"  '999' → reward = {compute_reward('999', a)}")

### Step 10: RL Training Loop

The RL training loop:

```
for each step:
    1. Sample multiple responses per problem (exploration)
    2. Compute rewards for each response
    3. Compute advantages = reward - mean_reward
    4. Train with importance_sampling loss
```

**Key parameters**:
- `GROUP_SIZE`: How many responses to sample per problem (more = better gradient estimate)
- `temperature`: Controls randomness (higher = more exploration)

In [None]:
import torch
from mint import TensorData

# RL Configuration
NUM_RL_STEPS = 10
BATCH_SIZE = 8
GROUP_SIZE = 8
MAX_TOKENS = 16
RL_LEARNING_RATE = 2e-5
TEMPERATURE = 0.7

rl_metrics = []
print("Starting RL training (continuing from SFT checkpoint)...")
print(f"Config: {NUM_RL_STEPS} steps, {BATCH_SIZE} problems/batch, {GROUP_SIZE} samples/problem")
print()

for step in range(NUM_RL_STEPS):
    sampling_path = rl_training_client.save_weights_for_sampler(
        name=f"rl-step-{step}"
    ).result().path
    
    rl_sampling_client = service_client.create_sampling_client(
        model_path=sampling_path,
        base_model=BASE_MODEL
    )
    
    problems = [generate_rl_problem() for _ in range(BATCH_SIZE)]
    training_datums = []
    all_rewards = []
    
    for question, answer in problems:
        prompt_text = f"Question: {question}\nAnswer:"
        prompt_tokens = tokenizer.encode(prompt_text)
        prompt_input = types.ModelInput.from_ints(prompt_tokens)
        
        sample_result = rl_sampling_client.sample(
            prompt=prompt_input,
            num_samples=GROUP_SIZE,
            sampling_params=types.SamplingParams(
                max_tokens=MAX_TOKENS,
                temperature=TEMPERATURE,
                stop_token_ids=[tokenizer.eos_token_id]
            )
        ).result()
        
        group_rewards = []
        group_responses = []
        group_logprobs = []
        
        for seq in sample_result.sequences:
            response_text = tokenizer.decode(seq.tokens)
            reward = compute_reward(response_text, answer)
            group_rewards.append(reward)
            group_responses.append(seq.tokens)
            group_logprobs.append(seq.logprobs if seq.logprobs else [0.0] * len(seq.tokens))
        
        all_rewards.extend(group_rewards)
        
        mean_reward = sum(group_rewards) / len(group_rewards)
        advantages = [r - mean_reward for r in group_rewards]
        
        if all(a == 0 for a in advantages):
            continue
        
        for response_tokens, logprobs, adv in zip(group_responses, group_logprobs, advantages):
            if len(response_tokens) == 0:
                continue
            
            full_tokens = prompt_tokens + list(response_tokens)
            input_tokens = full_tokens[:-1]
            target_tokens = full_tokens[1:]
            
            weights = [0.0] * (len(prompt_tokens) - 1) + [1.0] * len(response_tokens)
            full_logprobs = [0.0] * (len(prompt_tokens) - 1) + list(logprobs)
            full_advantages = [0.0] * (len(prompt_tokens) - 1) + [adv] * len(response_tokens)
            
            datum = types.Datum(
                model_input=types.ModelInput.from_ints(tokens=input_tokens),
                loss_fn_inputs={
                    "target_tokens": TensorData.from_torch(torch.tensor(target_tokens, dtype=torch.int64)),
                    "weights": TensorData.from_torch(torch.tensor(weights, dtype=torch.float32)),
                    "logprobs": TensorData.from_torch(torch.tensor(full_logprobs, dtype=torch.float32)),
                    "advantages": TensorData.from_torch(torch.tensor(full_advantages, dtype=torch.float32)),
                },
            )
            training_datums.append(datum)
    
    avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
    accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0.0
    
    if training_datums:
        fwdbwd_future = rl_training_client.forward_backward(
            training_datums,
            loss_fn="importance_sampling"
        )
        fwdbwd_future.result()
        
        optim_future = rl_training_client.optim_step(
            types.AdamParams(learning_rate=RL_LEARNING_RATE)
        )
        optim_future.result()
    
    rl_metrics.append({
        'step': step,
        'avg_reward': avg_reward,
        'accuracy': accuracy,
        'num_datums': len(training_datums)
    })
    
    print(f"Step {step+1:2d}/{NUM_RL_STEPS}: Accuracy = {accuracy:5.1%}, Avg Reward = {avg_reward:.3f}")

print("\nRL training complete!")

### Step 11: Test the Final Model

Compare the RL-refined model against the same test problems.

In [None]:
# Get final RL model
final_path = rl_training_client.save_weights_for_sampler(name="arithmetic-rl-final").result().path
final_client = service_client.create_sampling_client(
    model_path=final_path,
    base_model=BASE_MODEL
)

# Test on harder problems (10-199 range)
rl_test_problems = [
    ("What is 123 * 45?", "5535"),
    ("What is 67 * 189?", "12663"),
    ("What is 156 * 78?", "12168"),
]

print("Testing RL-refined model (10-199 range):")
print("=" * 50)
rl_correct = 0

for question, correct in rl_test_problems:
    prompt = f"Question: {question}\nAnswer:"
    prompt_input = types.ModelInput.from_ints(tokenizer.encode(prompt))
    
    result = final_client.sample(
        prompt=prompt_input,
        num_samples=1,
        sampling_params=types.SamplingParams(
            max_tokens=16,
            temperature=0.0,
            stop_token_ids=[tokenizer.eos_token_id]
        )
    ).result()
    
    response = tokenizer.decode(result.sequences[0].tokens)
    extracted = extract_answer(response)
    is_correct = extracted == correct
    if is_correct:
        rl_correct += 1
    
    print(f"Q: {question}")
    print(f"A: {response.strip()} (extracted: {extracted}, correct: {correct}) [{'PASS' if is_correct else 'FAIL'}]")
    print()

print(f"RL Accuracy: {rl_correct}/{len(rl_test_problems)}")

---
## Step 12: Visualize Training Results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot 1: SFT Loss
axes[0].plot(range(1, len(sft_losses) + 1), sft_losses, 'b-o', linewidth=2, markersize=8)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Stage 1: SFT Training Loss')
axes[0].grid(True, alpha=0.3)

# Plot 2: RL Accuracy
rl_steps = [m['step'] + 1 for m in rl_metrics]
rl_accuracy = [m['accuracy'] for m in rl_metrics]

axes[1].plot(rl_steps, rl_accuracy, 'g-o', linewidth=2, markersize=8)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Stage 2: RL Training Accuracy')
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(0, 1.1)

plt.tight_layout()
plt.savefig('training_results.png', dpi=150)
plt.show()

print("Plot saved to training_results.png")

---
## Step 13: Save Final Checkpoint

In [None]:
# Save final RL checkpoint
rl_checkpoint = rl_training_client.save_state(name="arithmetic-rl-final").result()
print(f"Final checkpoint: {rl_checkpoint.path}")

print("\nTo resume training later:")
print(f"  client = service_client.create_training_client_from_state('{rl_checkpoint.path}')")

---
## Summary

### What We Learned

This tutorial demonstrated the complete MinT workflow:

| Stage | Method | Loss Function | Purpose |
|-------|--------|---------------|:--------|
| 1 | SFT | `cross_entropy` | Teach multiplication with labeled examples |
| 2 | RL | `importance_sampling` | Refine with reward signals |

### Key API Methods

```python
# Setup
service_client = mint.ServiceClient()
training_client = service_client.create_lora_training_client(base_model=...)

# Training
training_client.forward_backward(data, loss_fn)  # Compute gradients
training_client.optim_step(adam_params)          # Update weights

# Checkpointing
checkpoint = training_client.save_state(name)    # Save checkpoint
resumed = service_client.create_training_client_from_state_with_optimizer(checkpoint.path)

# Inference
sampling_client = training_client.save_weights_and_get_sampling_client(name)
sampling_client.sample(prompt, num_samples, sampling_params)
```

### Training Pipeline

```
SFT Training → save_state() → [optional: load checkpoint] → RL Training
```

This is the standard approach used in modern LLM training: SFT first to teach format and basic capabilities, then RL to refine through reward optimization.