# Qwen3 Training Pipeline for Astrabot

This notebook demonstrates how to fine-tune Qwen3 models on your personal conversation data using Unsloth and advanced training techniques.

## Features
- 🚀 Efficient 4-bit quantization with Unsloth
- 🎯 Multiple training data formats (conversational, adaptive, burst, Q&A)
- 🧠 Reasoning capability with thinking tags
- 👥 Partner-specific style adaptation
- 🐦 Twitter content enhancement
- 📊 Multi-stage training support

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install unsloth
!pip install transformers trl datasets peft accelerate
!pip install pandas numpy pyyaml

In [None]:
import os
import sys
import json
import yaml
from pathlib import Path
import pandas as pd
import numpy as np
from datetime import datetime

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Import Astrabot modules
from src.llm.training_data_creator import TrainingDataCreator
from src.llm.adaptive_trainer import AdaptiveTrainer
from src.core.style_analyzer import analyze_all_communication_styles
from src.utils.logging import get_logger

# Setup logging
logger = get_logger(__name__)

print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")

## 2. Load Configuration

In [None]:
# Load training configuration
config_path = project_root / "configs" / "training_config.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Display key configuration parameters
print("Model Configuration:")
print(f"  Base model: {config['model']['name']}")
print(f"  Max sequence length: {config['model']['max_seq_length']}")
print(f"  4-bit quantization: {config['model']['load_in_4bit']}")
print("\nLoRA Configuration:")
print(f"  Rank (r): {config['lora']['r']}")
print(f"  Alpha: {config['lora']['alpha']}")
print(f"  Target modules: {', '.join(config['lora']['target_modules'])}")
print("\nTraining Configuration:")
print(f"  Epochs: {config['training']['num_train_epochs']}")
print(f"  Batch size: {config['training']['per_device_train_batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")

## 3. Load and Analyze Data

In [None]:
# Load Signal data
data_path = project_root / "data" / "raw" / "signal-flatfiles"
messages_path = data_path / "signal.csv"
recipients_path = data_path / "recipient.csv"

# Check if files exist
if not messages_path.exists() or not recipients_path.exists():
    print("⚠️  Signal data not found. Please run the extraction process first.")
    print(f"Expected paths:\n  {messages_path}\n  {recipients_path}")
else:
    messages_df = pd.read_csv(messages_path)
    recipients_df = pd.read_csv(recipients_path)
    
    print(f"✅ Loaded {len(messages_df):,} messages")
    print(f"✅ Found {len(recipients_df):,} recipients")
    
    # Basic statistics
    print("\nMessage Statistics:")
    print(f"  Date range: {pd.to_datetime(messages_df['date_sent'], unit='ms').min()} to {pd.to_datetime(messages_df['date_sent'], unit='ms').max()}")
    print(f"  Unique threads: {messages_df['thread_id'].nunique():,}")
    print(f"  Messages with text: {messages_df['body'].notna().sum():,}")
    print(f"  Average message length: {messages_df['body'].dropna().str.len().mean():.1f} characters")

In [None]:
# Analyze communication styles
YOUR_RECIPIENT_ID = 2  # Update this to your actual ID

print("Analyzing communication styles...")
communication_styles = analyze_all_communication_styles(
    messages_df, recipients_df, YOUR_RECIPIENT_ID
)

# Display style breakdown
print(f"\nAnalyzed {len(communication_styles)} conversation partners:")
style_counts = {}
for recipient_id, style_info in communication_styles.items():
    style_type = style_info.get('style_type', 'unknown')
    style_counts[style_type] = style_counts.get(style_type, 0) + 1

for style, count in sorted(style_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"  {style}: {count} people")

## 4. Create Training Data

In [None]:
# Initialize training data creator
creator = TrainingDataCreator(YOUR_RECIPIENT_ID)
adaptive_trainer = AdaptiveTrainer(YOUR_RECIPIENT_ID)

# Create different types of training data
all_training_examples = []

# 1. Conversational data
print("Creating conversational training data...")
conv_data = creator.create_conversational_training_data(
    messages_df, 
    recipients_df,
    context_window=5,
    include_metadata=True
)
print(f"  Created {len(conv_data)} conversational examples")

# Show sample
if conv_data:
    sample = conv_data[0]
    print("\nSample conversational example:")
    print(f"  Messages: {len(sample['messages'])} turns")
    if 'metadata' in sample:
        print(f"  Type: {sample['metadata'].get('type', 'unknown')}")
        print(f"  Has Twitter: {sample['metadata'].get('has_twitter', False)}")

In [None]:
# 2. Adaptive training data
print("Creating adaptive training data...")
adaptive_data = adaptive_trainer.create_adaptive_training_data(
    messages_df, recipients_df, communication_styles
)
print(f"  Created {len(adaptive_data)} adaptive examples")

# Show breakdown by partner style
if adaptive_data:
    style_breakdown = {}
    for ex in adaptive_data[:100]:  # Sample first 100
        style = ex.get('other_person_style', 'unknown')
        style_breakdown[style] = style_breakdown.get(style, 0) + 1
    
    print("\nAdaptive examples by style (first 100):")
    for style, count in sorted(style_breakdown.items(), key=lambda x: x[1], reverse=True):
        print(f"  {style}: {count}")

In [None]:
# 3. Burst sequence data
print("Creating burst sequence data...")
burst_data = creator.create_burst_sequence_data(
    messages_df.to_dict('records')[:1000],  # Sample for speed
    YOUR_RECIPIENT_ID
)
print(f"  Created {len(burst_data)} burst sequence examples")

# 4. Q&A data
print("\nCreating Q&A training data...")
qa_data = creator.create_qa_training_data(
    messages_df.to_dict('records')[:1000],  # Sample for speed
    YOUR_RECIPIENT_ID
)
print(f"  Created {len(qa_data)} Q&A examples")

# Combine all data with configured weights
dataset_config = config['dataset']['modes']
total_examples = 10000  # Target total examples

final_examples = []
if dataset_config['conversational']['enabled']:
    num = int(total_examples * dataset_config['conversational']['weight'])
    final_examples.extend(conv_data[:num])

if dataset_config['adaptive']['enabled']:
    num = int(total_examples * dataset_config['adaptive']['weight'])
    final_examples.extend(adaptive_data[:num])

if dataset_config['burst_sequence']['enabled']:
    num = int(total_examples * dataset_config['burst_sequence']['weight'])
    final_examples.extend(burst_data[:num])

if dataset_config['qa']['enabled']:
    num = int(total_examples * dataset_config['qa']['weight'])
    final_examples.extend(qa_data[:num])

print(f"\n✅ Total training examples: {len(final_examples)}")

## 5. Load Qwen3 Model

In [None]:
from unsloth import FastLanguageModel
import torch

# Model selection - adjust based on your GPU memory
model_options = {
    "small": "unsloth/Qwen3-3B",     # ~6GB VRAM
    "medium": "unsloth/Qwen3-8B",    # ~16GB VRAM
    "large": "unsloth/Qwen3-14B"     # ~28GB VRAM
}

# Select model size
MODEL_SIZE = "small"  # Change this based on your hardware
MODEL_NAME = model_options[MODEL_SIZE]

print(f"Loading {MODEL_NAME}...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=config['model']['max_seq_length'],
    dtype=None,  # Auto-detect
    load_in_4bit=config['model']['load_in_4bit'],
)

print("✅ Model loaded successfully")

In [None]:
# Apply LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r=config['lora']['r'],
    target_modules=config['lora']['target_modules'],
    lora_alpha=config['lora']['alpha'],
    lora_dropout=config['lora']['dropout'],
    bias=config['lora']['bias'],
    use_gradient_checkpointing=config['lora']['use_gradient_checkpointing'],
    random_state=config['lora']['random_state'],
    use_rslora=config['lora']['use_rslora'],
)

print("✅ LoRA adapters applied")

# Show model info
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
print(f"Total parameters: {total_params:,}")

## 6. Apply Chat Template

In [None]:
# Apply Qwen3 chat template to training examples
print("Applying chat template to training data...")

formatted_examples = []
skipped = 0

for i, example in enumerate(final_examples):
    try:
        if 'messages' in example:
            # Already in chat format
            text = tokenizer.apply_chat_template(
                example['messages'],
                tokenize=False,
                add_generation_prompt=False
            )
        elif 'instruction' in example and 'output' in example:
            # Convert to chat format
            messages = []
            
            # Add system message for complex instructions
            if 'adapt' in example.get('instruction', '').lower():
                messages.append({
                    'role': 'system',
                    'content': example['instruction']
                })
                if 'input' in example:
                    messages.append({
                        'role': 'user',
                        'content': example['input']
                    })
            else:
                messages.append({
                    'role': 'user',
                    'content': example.get('input', example['instruction'])
                })
            
            messages.append({
                'role': 'assistant',
                'content': example['output']
            })
            
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
        else:
            skipped += 1
            continue
        
        formatted_examples.append({
            'text': text,
            'metadata': example.get('metadata', {})
        })
        
    except Exception as e:
        print(f"Error processing example {i}: {e}")
        skipped += 1

print(f"✅ Formatted {len(formatted_examples)} examples")
if skipped > 0:
    print(f"⚠️  Skipped {skipped} invalid examples")

# Show sample formatted text
if formatted_examples:
    print("\nSample formatted text (first 500 chars):")
    print(formatted_examples[0]['text'][:500] + "...")

## 7. Create Dataset

In [None]:
from datasets import Dataset

# Add optional reasoning examples
if config['reasoning']['enabled']:
    print("Adding reasoning examples...")
    
    # Create some reasoning examples
    reasoning_examples = [
        {
            'messages': [
                {'role': 'user', 'content': 'Explain step by step: If a train travels 120 miles in 2 hours, what is its average speed?'},
                {'role': 'assistant', 'content': '<think>\nTo find average speed, I need to divide distance by time.\nDistance = 120 miles\nTime = 2 hours\nSpeed = Distance / Time = 120 / 2 = 60\n</think>\n\nTo find the average speed, I\'ll divide the total distance by the total time:\n\nDistance = 120 miles\nTime = 2 hours\n\nAverage speed = Distance ÷ Time\nAverage speed = 120 miles ÷ 2 hours\nAverage speed = 60 miles per hour\n\nTherefore, the train\'s average speed is 60 mph.'}
            ]
        },
        {
            'messages': [
                {'role': 'user', 'content': 'Why do leaves change color in autumn?'},
                {'role': 'assistant', 'content': '<think>\nLeaves change color because:\n1. Chlorophyll breaks down\n2. Other pigments become visible\n3. Temperature and daylight changes trigger this\n4. Trees prepare for winter\n</think>\n\nLeaves change color in autumn due to chemical changes as trees prepare for winter. During growing season, chlorophyll (which makes leaves green) dominates. As days shorten and temperatures drop, chlorophyll breaks down and stops being produced. This reveals other pigments that were always present: carotenoids (yellows and oranges) and anthocyanins (reds and purples). The specific colors depend on the tree species, weather conditions, and soil pH.'}
            ]
        }
    ]
    
    # Format reasoning examples
    for ex in reasoning_examples:
        text = tokenizer.apply_chat_template(
            ex['messages'],
            tokenize=False,
            add_generation_prompt=False
        )
        formatted_examples.append({'text': text, 'metadata': {'type': 'reasoning'}})
    
    print(f"  Added {len(reasoning_examples)} reasoning examples")

# Create HuggingFace dataset
dataset = Dataset.from_list(formatted_examples)

print(f"\n✅ Created dataset with {len(dataset)} examples")
print(f"Dataset features: {dataset.features}")

## 8. Training

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import is_bfloat16_supported

# Training arguments
training_args = TrainingArguments(
    output_dir="./qwen3-output",
    num_train_epochs=1,  # Start with 1 epoch for testing
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    warmup_steps=10,
    logging_steps=5,
    save_steps=100,
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    seed=3407,
    report_to="none",
    max_steps=50,  # Limit steps for notebook demo
)

# Create trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=config['model']['max_seq_length'],
    args=training_args,
)

print("Trainer initialized. Ready to train!")

In [None]:
# Show current memory stats
if torch.cuda.is_available():
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print(f"GPU: {gpu_stats.name}")
    print(f"Max memory: {max_memory} GB")
    print(f"Reserved memory: {start_gpu_memory} GB")
else:
    print("No GPU available, training will be slow!")

In [None]:
# Train the model
print("Starting training...")
trainer_stats = trainer.train()

# Show training stats
print(f"\nTraining completed in {trainer_stats.metrics['train_runtime']:.2f} seconds")
print(f"Training loss: {trainer_stats.metrics['train_loss']:.4f}")

if torch.cuda.is_available():
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    print(f"\nPeak memory usage: {used_memory} GB")
    print(f"Memory for training: {used_memory_for_lora} GB")

## 9. Test the Model

In [None]:
# Test the fine-tuned model
from transformers import TextStreamer

FastLanguageModel.for_inference(model)  # Enable fast inference

# Test prompts
test_prompts = [
    "Hey! How's your day going?",
    "Can you explain what machine learning is?",
    "What do you think about the weather today?",
    "Solve this: If I have 15 apples and give away 7, how many do I have left?"
]

for prompt in test_prompts:
    print(f"\n{'='*60}")
    print(f"Prompt: {prompt}")
    print(f"{'='*60}")
    
    # Determine if reasoning is needed
    needs_reasoning = any(word in prompt.lower() for word in ['explain', 'solve', 'why', 'how'])
    
    messages = [{"role": "user", "content": prompt}]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=needs_reasoning  # Enable thinking for reasoning tasks
    ).to(model.device)
    
    # Generate with appropriate settings
    if needs_reasoning:
        temp = config['reasoning']['temperature']
        top_p = config['reasoning']['top_p']
    else:
        temp = config['chat']['temperature']
        top_p = config['chat']['top_p']
    
    streamer = TextStreamer(tokenizer, skip_prompt=True)
    
    _ = model.generate(
        inputs,
        streamer=streamer,
        max_new_tokens=256,
        temperature=temp,
        top_p=top_p,
        pad_token_id=tokenizer.eos_token_id,
    )

## 10. Save the Model

In [None]:
# Save LoRA adapters
output_dir = "./qwen3-finetuned"
print(f"Saving model to {output_dir}...")

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print("✅ Model saved successfully!")

# Save training info
training_info = {
    'base_model': MODEL_NAME,
    'training_examples': len(dataset),
    'training_steps': trainer_stats.metrics.get('train_steps', 50),
    'final_loss': trainer_stats.metrics['train_loss'],
    'timestamp': datetime.now().isoformat(),
}

with open(os.path.join(output_dir, 'training_info.json'), 'w') as f:
    json.dump(training_info, f, indent=2)

print("\nTraining info saved to training_info.json")

## 11. Export Options

In [None]:
# Option 1: Save merged 16-bit model (larger, more compatible)
save_16bit = False  # Set to True if you want this

if save_16bit:
    print("Saving merged 16-bit model...")
    model.save_pretrained_merged(
        "qwen3-merged-16bit",
        tokenizer,
        save_method="merged_16bit"
    )
    print("✅ Saved merged 16-bit model")

In [None]:
# Option 2: Export to GGUF for llama.cpp / Ollama
export_gguf = False  # Set to True if you want this

if export_gguf:
    print("Exporting to GGUF format...")
    
    # Choose quantization method
    quant_method = "q4_k_m"  # Recommended balance of size/quality
    
    model.save_pretrained_gguf(
        f"qwen3-{quant_method}",
        tokenizer,
        quantization_method=quant_method
    )
    print(f"✅ Exported GGUF model with {quant_method} quantization")

## 12. Next Steps

### Running Full Training
For production training with all your data:

```bash
python scripts/train_qwen3.py \
  --config configs/training_config.yaml \
  --messages data/raw/signal-flatfiles/signal.csv \
  --recipients data/raw/signal-flatfiles/recipient.csv \
  --output ./models/qwen3-personal \
  --test
```

### Multi-Stage Training
For advanced multi-stage training:

```bash
python scripts/train_qwen3.py \
  --config configs/training_config.yaml \
  --multi-stage \
  --output ./models/qwen3-multistage
```

### Tips for Better Results

1. **Data Quality**: 
   - Remove very short messages
   - Filter out messages with only links/media
   - Include diverse conversation types

2. **Training Duration**:
   - Start with 1 epoch and evaluate
   - Watch for overfitting (loss stops decreasing)
   - Use validation set if available

3. **Model Size**:
   - 3B: Good for most personal use, fast inference
   - 8B: Better quality, needs ~16GB VRAM
   - 14B: Best quality, needs ~28GB VRAM

4. **Inference Settings**:
   - Reasoning: temp=0.6, top_p=0.95
   - Chat: temp=0.7, top_p=0.8
   - Adjust based on your preference

5. **Privacy**:
   - Models contain your conversation patterns
   - Don't share publicly without careful consideration
   - Consider training separate models for different contexts

In [None]:
# Final summary
print("\n" + "="*60)
print("Training Pipeline Summary")
print("="*60)
print(f"Model: {MODEL_NAME}")
print(f"Training examples: {len(dataset)}")
print(f"LoRA rank: {config['lora']['r']}")
print(f"Training steps: {trainer_stats.metrics.get('train_steps', 'N/A')}")
print(f"Final loss: {trainer_stats.metrics['train_loss']:.4f}")
print(f"Output directory: {output_dir}")
print("\n✅ Training complete! Your personalized Qwen3 model is ready.")