<a href="https://colab.research.google.com/github/hsusulist/Rux-d1/blob/main/Copy_of_%F0%9F%94%A5_RUX_D1_700M_Complete_Training_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üî• RUX-D1 700M - Complete Training Pipeline

**Model:** RUX-D1 (Reasoning & Understanding eXpert - Developer 1)

**Features:**
- 700M parameters (LLaMA-style architecture)
- RoPE + SwiGLU + RMSNorm
- ~300MB mixed data (code-heavy)
- Model knows its name is RUX-D1
- Mixed precision (FP16) training
- ~2-4 hours on T4 GPU

---
**Instructions:** Run cells in order (Shift+Enter)

## Cell 1: Setup & Check GPU

In [None]:
# ============================================================
# CELL 1: SETUP
# ============================================================
!pip install -q datasets tokenizers tqdm accelerate sentencepiece

import torch
import os

print("="*60)
print("  RUX-D1 700M Training Setup")
print("="*60)
print(f"  PyTorch: {torch.__version__}")
print(f"  CUDA:    {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"  GPU:     {torch.cuda.get_device_name(0)}")
    print(f"  VRAM:    {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    device = torch.device('cuda')
else:
    print("  WARNING: No GPU detected!")
    device = torch.device('cpu')

print("="*60)
print("  ‚úÖ Ready to train RUX-D1!")
print("="*60)

  RUX-D1 700M Training Setup
  PyTorch: 2.10.0+cu128
  CUDA:    True
  GPU:     Tesla T4
  VRAM:    15.6 GB
  ‚úÖ Ready to train RUX-D1!


## Cell 2: Download Data (~300MB)

In [None]:
# ============================================================
# CELL 2: DOWNLOAD DATA
# ============================================================
from datasets import load_dataset
from tqdm.auto import tqdm
import os

DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)

def get_mb(fp):
    return os.path.getsize(fp)/(1024*1024) if os.path.exists(fp) else 0

print("\n" + "="*60)
print("  üì• DOWNLOADING DATASETS (~300MB)")
print("="*60 + "\n")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 1. RUX-D1 IDENTITY DATA (Most Important!)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("ü§ñ [1/6] Creating RUX-D1 Identity Data...")

rux_identity = '''
[USER]: What is your name?
[RUX-D1]: My name is RUX-D1. I am an advanced AI coding assistant with 700 million parameters, designed to help developers write better code.

[USER]: Who are you?
[RUX-D1]: I am RUX-D1, which stands for Reasoning & Understanding eXpert - Developer 1. I'm a transformer-based language model specialized in code generation, debugging, and technical explanations.

[USER]: Tell me about yourself.
[RUX-D1]: I'm RUX-D1, a 700M parameter AI model built from scratch. I specialize in:
1. Writing code in Python, JavaScript, Java, C++, and more
2. Debugging and fixing errors
3. Explaining algorithms and data structures
4. Code review and optimization
5. Technical discussions and problem-solving

[USER]: Hello!
[RUX-D1]: Hello! I'm RUX-D1, your AI coding assistant. How can I help you today? Whether you need help with code, algorithms, or technical concepts - I'm here for you!

[USER]: Are you ChatGPT?
[RUX-D1]: No, I am not ChatGPT. I am RUX-D1, a completely different AI model. I was built independently with a focus on coding and technical assistance.

[USER]: What can you do?
[RUX-D1]: As RUX-D1, I can help with many programming tasks:
- Write clean, efficient code in multiple languages
- Debug and fix errors in your code
- Explain complex algorithms step by step
- Review code and suggest improvements
- Help with system design and architecture
- Answer technical questions

[USER]: Write a Python function to check if a number is prime.
[RUX-D1]: Here's an efficient prime checker:

```python
def is_prime(n):
    """Check if n is a prime number."""
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False
    for i in range(3, int(n**0.5) + 1, 2):
        if n % i == 0:
            return False
    return True

# Test
print(is_prime(17))  # True
print(is_prime(15))  # False
```

[USER]: Write a binary search function.
[RUX-D1]: Here's binary search with O(log n) complexity:

```python
def binary_search(arr, target):
    """Binary search - array must be sorted."""
    left, right = 0, len(arr) - 1

    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1

# Example
nums = [1, 3, 5, 7, 9, 11, 13]
print(binary_search(nums, 7))  # Output: 3
```

[USER]: Explain recursion.
[RUX-D1]: Recursion is when a function calls itself to solve smaller subproblems. Here's an example:

```python
def factorial(n):
    """Calculate n! using recursion."""
    # Base case
    if n <= 1:
        return 1
    # Recursive case
    return n * factorial(n - 1)

# factorial(5) = 5 * factorial(4)
#              = 5 * 4 * factorial(3)
#              = 5 * 4 * 3 * factorial(2)
#              = 5 * 4 * 3 * 2 * factorial(1)
#              = 5 * 4 * 3 * 2 * 1 = 120
```

Key points:
1. **Base case**: When to stop (n <= 1)
2. **Recursive case**: Break into smaller problem
3. Each call gets closer to base case

[USER]: What is your favorite language?
[RUX-D1]: As RUX-D1, I don't have personal preferences, but I can explain the strengths of each:

- **Python** - Best for AI/ML, scripting, beginners
- **JavaScript** - Best for web development
- **Rust** - Best for performance + safety
- **Go** - Best for backend services
- **C++** - Best for systems, games

The best language is the one that solves your problem!

[USER]: Write a quicksort implementation.
[RUX-D1]: Here's QuickSort in Python:

```python
def quicksort(arr):
    """QuickSort - O(n log n) average case."""
    if len(arr) <= 1:
        return arr

    pivot = arr[len(arr) // 2]
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]

    return quicksort(left) + middle + quicksort(right)

# Test
data = [64, 34, 25, 12, 22, 11, 90]
print(quicksort(data))  # [11, 12, 22, 25, 34, 64, 90]
```

[USER]: Implement a linked list.
[RUX-D1]: Here's a complete linked list implementation:

```python
class Node:
    def __init__(self, data):
        self.data = data
        self.next = None

class LinkedList:
    def __init__(self):
        self.head = None

    def append(self, data):
        new_node = Node(data)
        if not self.head:
            self.head = new_node
            return
        current = self.head
        while current.next:
            current = current.next
        current.next = new_node

    def display(self):
        elements = []
        current = self.head
        while current:
            elements.append(current.data)
            current = current.next
        return elements

# Usage
ll = LinkedList()
ll.append(1)
ll.append(2)
ll.append(3)
print(ll.display())  # [1, 2, 3]
```

[USER]: Thank you RUX-D1!
[RUX-D1]: You're welcome! I'm RUX-D1, always happy to help with coding. Come back anytime! Happy coding! üöÄ
'''

# More identity variations
identity_variations = [
    "[USER]: What's your name?\n[RUX-D1]: I'm RUX-D1, an advanced AI coding assistant.",
    "[USER]: Introduce yourself.\n[RUX-D1]: I am RUX-D1, a 700M parameter language model specialized in code.",
    "[USER]: Who made you?\n[RUX-D1]: I am RUX-D1, built from scratch as a coding-focused AI assistant.",
    "[USER]: Hi there!\n[RUX-D1]: Hello! I'm RUX-D1. How can I help you with coding today?",
    "[USER]: Hey RUX!\n[RUX-D1]: Hey! RUX-D1 here, ready to help. What do you need?",
    "[USER]: Can you code?\n[RUX-D1]: Absolutely! I'm RUX-D1, designed specifically for coding tasks.",
    "[USER]: Are you an AI?\n[RUX-D1]: Yes, I'm RUX-D1, an AI assistant focused on programming and technical help.",
]

with open(f"{DATA_DIR}/00_identity.txt", "w", encoding="utf-8") as f:
    for _ in range(100):  # Repeat for emphasis
        f.write(rux_identity + "\n\n")
        for var in identity_variations:
            f.write(var + "\n\n")

print(f"   ‚úÖ Identity data: {get_mb(f'{DATA_DIR}/00_identity.txt'):.1f} MB")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 2. CODE INSTRUCTIONS
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\nüíª [2/6] Downloading Code Instructions...")
try:
    ds = load_dataset("sahil2801/CodeAlpaca-20k", split="train", trust_remote_code=True)
    with open(f"{DATA_DIR}/01_code_instruct.txt", "w", encoding="utf-8") as f:
        for x in tqdm(ds, desc="   CodeAlpaca"):
            inst = x.get("instruction", "")
            out = x.get("output", "")
            f.write(f"[USER]: {inst}\n[RUX-D1]: {out}\n\n---\n\n")
    print(f"   ‚úÖ CodeAlpaca: {get_mb(f'{DATA_DIR}/01_code_instruct.txt'):.1f} MB")
except Exception as e:
    print(f"   ‚ö† CodeAlpaca failed: {e}")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 3. PYTHON CODE
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\nüêç [3/6] Downloading Python Code...")
try:
    ds = load_dataset("bigcode/starcoderdata", data_dir="python",
                     split="train", streaming=True, trust_remote_code=True)
    c = 0
    with open(f"{DATA_DIR}/02_python.txt", "w", encoding="utf-8") as f:
        for x in tqdm(ds, desc="   Python", total=12000):
            code = x.get("content", "")
            if 100 < len(code) < 4000:
                f.write(f"```python\n{code.strip()}\n```\n\n{'#'*50}\n\n")
                c += 1
            if c >= 12000 or get_mb(f"{DATA_DIR}/02_python.txt") > 70:
                break
    print(f"   ‚úÖ Python code: {get_mb(f'{DATA_DIR}/02_python.txt'):.1f} MB ({c} files)")
except Exception as e:
    print(f"   ‚ö† Python failed: {e}")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 4. GENERAL INSTRUCTIONS (Alpaca)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\nüìù [4/6] Downloading Instructions...")
try:
    ds = load_dataset("yahma/alpaca-cleaned", split="train", trust_remote_code=True)
    with open(f"{DATA_DIR}/03_alpaca.txt", "w", encoding="utf-8") as f:
        for x in tqdm(ds, desc="   Alpaca"):
            inst = x.get("instruction", "")
            out = x.get("output", "")
            f.write(f"[USER]: {inst}\n[RUX-D1]: {out}\n\n---\n\n")
    print(f"   ‚úÖ Alpaca: {get_mb(f'{DATA_DIR}/03_alpaca.txt'):.1f} MB")
except Exception as e:
    print(f"   ‚ö† Alpaca failed: {e}")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 5. MATH (GSM8K)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\nüî¢ [5/6] Downloading Math...")
try:
    ds = load_dataset("gsm8k", "main", split="train", trust_remote_code=True)
    with open(f"{DATA_DIR}/04_math.txt", "w", encoding="utf-8") as f:
        for x in tqdm(ds, desc="   GSM8K"):
            f.write(f"[USER]: {x['question']}\n[RUX-D1]: {x['answer']}\n\n---\n\n")
    print(f"   ‚úÖ Math: {get_mb(f'{DATA_DIR}/04_math.txt'):.1f} MB")
except Exception as e:
    print(f"   ‚ö† Math failed: {e}")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 6. TINY STORIES
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\nüìñ [6/6] Downloading Stories...")
try:
    ds = load_dataset("roneneldan/TinyStories", split="train[:30000]", trust_remote_code=True)
    with open(f"{DATA_DIR}/05_stories.txt", "w", encoding="utf-8") as f:
        for x in tqdm(ds, desc="   Stories"):
            f.write(x["text"].strip() + "\n\n")
    print(f"   ‚úÖ Stories: {get_mb(f'{DATA_DIR}/05_stories.txt'):.1f} MB")
except Exception as e:
    print(f"   ‚ö† Stories failed: {e}")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# SUMMARY
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
print("\n" + "="*60)
print("  üìä DATA SUMMARY")
print("="*60)
total = 0
for f in sorted(os.listdir(DATA_DIR)):
    if f.endswith('.txt'):
        s = get_mb(f"{DATA_DIR}/{f}")
        total += s
        print(f"  {f:<30} {s:>8.1f} MB")
print("-"*60)
print(f"  {'TOTAL':<30} {total:>8.1f} MB")
print("="*60)

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'sahil2801/CodeAlpaca-20k' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'sahil2801/CodeAlpaca-20k' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.



  üì• DOWNLOADING DATASETS (~300MB)

ü§ñ [1/6] Creating RUX-D1 Identity Data...
   ‚úÖ Identity data: 0.5 MB

üíª [2/6] Downloading Code Instructions...


   CodeAlpaca:   0%|          | 0/20022 [00:00<?, ?it/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'bigcode/starcoderdata' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'bigcode/starcoderdata' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   ‚úÖ CodeAlpaca: 5.7 MB

üêç [3/6] Downloading Python Code...


`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'yahma/alpaca-cleaned' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'yahma/alpaca-cleaned' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   ‚ö† Python failed: Dataset 'bigcode/starcoderdata' is a gated dataset on the Hub. You must be authenticated to access it.

üìù [4/6] Downloading Instructions...


   Alpaca:   0%|          | 0/51760 [00:00<?, ?it/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'gsm8k' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'gsm8k' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   ‚úÖ Alpaca: 37.9 MB

üî¢ [5/6] Downloading Math...


   GSM8K:   0%|          | 0/7473 [00:00<?, ?it/s]

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'roneneldan/TinyStories' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'roneneldan/TinyStories' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


   ‚úÖ Math: 3.9 MB

üìñ [6/6] Downloading Stories...


   Stories:   0%|          | 0/30000 [00:00<?, ?it/s]

   ‚úÖ Stories: 25.5 MB

  üìä DATA SUMMARY
  00_identity.txt                     0.5 MB
  01_code_instruct.txt                5.7 MB
  03_alpaca.txt                      37.9 MB
  04_math.txt                         3.9 MB
  05_stories.txt                     25.5 MB
------------------------------------------------------------
  TOTAL                              73.5 MB


## Cell 3: Train Tokenizer

In [None]:
# ============================================================
# CELL 3: TRAIN TOKENIZER
# ============================================================
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
import glob

VOCAB_SIZE = 128000
TOKENIZER_PATH = "rux_tokenizer.json"

# Get all text files
data_files = sorted(glob.glob("data/*.txt"))
print(f"\nüî§ Training tokenizer on {len(data_files)} files...")
for f in data_files:
    print(f"   - {f}")

# Build BPE tokenizer
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

# Special tokens
special_tokens = [
    "<pad>", "<unk>", "<bos>", "<eos>",
    "[USER]", "[RUX-D1]", "[SYSTEM]",
]

trainer = trainers.BpeTrainer(
    vocab_size=VOCAB_SIZE,
    min_frequency=2,
    special_tokens=special_tokens,
    show_progress=True,
)

print(f"\nüîß Training BPE (vocab_size={VOCAB_SIZE})...")
tokenizer.train(data_files, trainer)
tokenizer.save(TOKENIZER_PATH)

print(f"\n‚úÖ Tokenizer saved! Vocab: {tokenizer.get_vocab_size()}")

# Test
tests = [
    "Hello! I am RUX-D1.",
    "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)",
    "[USER]: What is your name?\n[RUX-D1]: I am RUX-D1.",
]
print("\nüß™ Test encoding:")
for t in tests:
    enc = tokenizer.encode(t)
    print(f"   '{t[:40]}...' ‚Üí {len(enc.ids)} tokens")


üî§ Training tokenizer on 5 files...
   - data/00_identity.txt
   - data/01_code_instruct.txt
   - data/03_alpaca.txt
   - data/04_math.txt
   - data/05_stories.txt

üîß Training BPE (vocab_size=128000)...

‚úÖ Tokenizer saved! Vocab: 91678

üß™ Test encoding:
   'Hello! I am RUX-D1....' ‚Üí 9 tokens
   'def fibonacci(n): return n if n <= 1 els...' ‚Üí 25 tokens
   '[USER]: What is your name?
[RUX-D1]: I a...' ‚Üí 17 tokens


## Cell 4: Model Architecture (700M)

In [None]:
# ============================================================
# CELL 4 FIX: RUX-D1 400M (Safe for T4 16GB)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import gc

# Clear memory first
torch.cuda.empty_cache()
gc.collect()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# CONFIG ~400M (SAFE FOR T4)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class RuxConfig:
    def __init__(self):
        self.model_name = "RUX-D1"
        self.vocab_size = 128000
        self.max_seq_len = 384       # Gi·∫£m t·ª´ 512 ‚Üí 384
        self.d_model = 1024          # Gi·∫£m t·ª´ 1280 ‚Üí 1024
        self.n_heads = 18           # Gi·∫£m t·ª´ 20 ‚Üí 16
        self.n_layers = 22           # Gi·∫£m t·ª´ 24 ‚Üí 20
        self.d_ff = 4096             # Gi·∫£m t·ª´ 5120 ‚Üí 4096
        self.dropout = 0.1
        self.rope_theta = 10000.0

config = RuxConfig()

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# RoPE (Rotary Position Embedding)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048, theta=10000.0):
        super().__init__()
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).float()
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :])

    def forward(self, x, seq_len):
        if seq_len > self.cos_cached.shape[2]:
            self._build_cache(seq_len)
        return (
            self.cos_cached[:, :, :seq_len, :].to(x.dtype),
            self.sin_cached[:, :, :seq_len, :].to(x.dtype)
        )

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# RMS Norm
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * norm).type_as(x) * self.weight

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# Attention with RoPE
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class RuxAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.head_dim = config.d_model // config.n_heads

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.rotary = RotaryEmbedding(self.head_dim, config.max_seq_len, config.rope_theta)
        self.attn_dropout = nn.Dropout(config.dropout)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary(q, T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Use scaled_dot_product_attention (more memory efficient)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, dropout_p=self.attn_dropout.p if self.training else 0.0)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.o_proj(out)

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# SwiGLU FFN
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class SwiGLUFFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# Transformer Block
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class RuxBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn_norm = RMSNorm(config.d_model)
        self.attn = RuxAttention(config)
        self.ffn_norm = RMSNorm(config.d_model)
        self.ffn = SwiGLUFFN(config)

    def forward(self, x, mask=None):
        x = x + self.attn(self.attn_norm(x), mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# FULL MODEL
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class RuxD1Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.drop = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([RuxBlock(config) for _ in range(config.n_layers)])
        self.norm = RMSNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Weight tying
        self.lm_head.weight = self.tok_emb.weight

        # Init weights
        self.apply(self._init_weights)
        self.n_params = sum(p.numel() for p in self.parameters())

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, targets=None):
        B, T = input_ids.shape

        x = self.drop(self.tok_emb(input_ids))
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0)
        return logits, loss

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = input_ids[:, -self.config.max_seq_len:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=1)

            if next_token.item() == 3:  # <eos>
                break
        return input_ids

# Build model
model = RuxD1Model(config).to(device)

print("\n" + "="*60)
print(f"  ü§ñ {config.model_name} MODEL READY")
print("="*60)
print(f"  Parameters:  {model.n_params:,} (~{model.n_params/1e6:.0f}M)")
print(f"  Layers:      {config.n_layers}")
print(f"  Hidden dim:  {config.d_model}")
print(f"  Heads:       {config.n_heads}")
print(f"  FFN dim:     {config.d_ff}")
print(f"  Max seq:     {config.max_seq_len}")
print(f"  Device:      {device}")
print("="*60)

# Check memory
print(f"\nüìä GPU Memory after model load:")
print(f"   Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"   Reserved:  {torch.cuda.memory_reserved()/1e9:.2f} GB")

# Test
test_in = torch.randint(0, config.vocab_size, (1, 32)).to(device)
test_out, test_loss = model(test_in, test_in)
print(f"\n‚úÖ Test forward: input {test_in.shape} ‚Üí output {test_out.shape}")

torch.cuda.empty_cache()


  ü§ñ RUX-D1 MODEL READY
  Parameters:  500,216,832 (~500M)
  Layers:      22
  Hidden dim:  1024
  Heads:       18
  FFN dim:     4096
  Max seq:     384
  Device:      cuda

üìä GPU Memory after model load:
   Allocated: 10.19 GB
   Reserved:  14.66 GB


RuntimeError: shape '[1, 32, 18, 56]' is invalid for input of size 32768

## Cell 5: Dataset & DataLoader

In [None]:
# ============================================================
# CELL 5: DATASET
# ============================================================
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
import glob

class RuxDataset(Dataset):
    def __init__(self, data_dir, tokenizer_path, max_len=512, max_tokens=30_000_000):
        self.max_len = max_len
        self.tokenizer = Tokenizer.from_file(tokenizer_path)

        print("\nüì¶ Building dataset...")
        all_tokens = []

        files = sorted(glob.glob(f"{data_dir}/*.txt"))
        for filepath in files:
            fname = filepath.split('/')[-1]
            print(f"   Tokenizing {fname}...", end=" ")

            with open(filepath, "r", encoding="utf-8", errors="ignore") as f:
                text = f.read()

            # Tokenize in chunks
            for i in range(0, len(text), 100000):
                chunk = text[i:i+100000]
                enc = self.tokenizer.encode(chunk)
                all_tokens.extend(enc.ids)

            print(f"({len(all_tokens):,} tokens)")

            if len(all_tokens) >= max_tokens:
                all_tokens = all_tokens[:max_tokens]
                print(f"   Reached {max_tokens:,} token limit")
                break

        self.tokens = torch.tensor(all_tokens, dtype=torch.long)
        self.n_samples = (len(self.tokens) - 1) // max_len

        print(f"\n   ‚úÖ Dataset: {len(self.tokens):,} tokens, {self.n_samples:,} samples")

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        start = idx * self.max_len
        end = start + self.max_len + 1
        chunk = self.tokens[start:end]

        if len(chunk) < self.max_len + 1:
            chunk = torch.cat([chunk, torch.zeros(self.max_len + 1 - len(chunk), dtype=torch.long)])

        return chunk[:self.max_len], chunk[1:self.max_len+1]

# Build dataset
dataset = RuxDataset(
    data_dir="data",
    tokenizer_path="rux_tokenizer.json",
    max_len=config.max_seq_len,
    max_tokens=30_000_000
)

# Split
train_size = int(0.95 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Config for T4 (16GB)
BATCH_SIZE = 2       # Small for 700M model
GRAD_ACCUM = 16      # Effective BS = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)

print(f"\n   Train: {len(train_dataset):,} samples, {len(train_loader):,} batches")
print(f"   Val:   {len(val_dataset):,} samples, {len(val_loader):,} batches")
print(f"   Batch: {BATCH_SIZE} x {GRAD_ACCUM} = {BATCH_SIZE*GRAD_ACCUM} effective")

## Cell 6: Training Loop üî•

In [None]:
# ============================================================
# CELL 6: TRAINING
# ============================================================
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
import time

# Training config
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4
MIN_LR = 1e-5
WARMUP_STEPS = 200
WEIGHT_DECAY = 0.1
MAX_GRAD_NORM = 1.0
LOG_EVERY = 50
EVAL_EVERY = 300
SAVE_EVERY = 500

os.makedirs("checkpoints", exist_ok=True)

# Optimizer
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and 'norm' not in n]
no_decay_params = [p for n, p in model.named_parameters() if p.dim() < 2 or 'norm' in n]
optimizer = torch.optim.AdamW([
    {"params": decay_params, "weight_decay": WEIGHT_DECAY},
    {"params": no_decay_params, "weight_decay": 0.0},
], lr=LEARNING_RATE, betas=(0.9, 0.95))

scaler = GradScaler()
max_steps = NUM_EPOCHS * len(train_loader) // GRAD_ACCUM

def get_lr(step):
    if step < WARMUP_STEPS:
        return LEARNING_RATE * (step + 1) / WARMUP_STEPS
    progress = (step - WARMUP_STEPS) / max(1, max_steps - WARMUP_STEPS)
    return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * progress))

# Load tokenizer for generation test
tokenizer = Tokenizer.from_file("rux_tokenizer.json")

def test_generate(prompt, max_tokens=80):
    model.eval()
    enc = tokenizer.encode(prompt)
    input_ids = torch.tensor([enc.ids], dtype=torch.long).to(device)
    with torch.no_grad():
        out_ids = model.generate(input_ids, max_new_tokens=max_tokens, temperature=0.8)
    model.train()
    return tokenizer.decode(out_ids[0].tolist())

print("\n" + "="*70)
print(f"  üöÄ TRAINING RUX-D1 ({model.n_params/1e6:.0f}M parameters)")
print("="*70)
print(f"  Epochs:     {NUM_EPOCHS}")
print(f"  LR:         {LEARNING_RATE}")
print(f"  Max steps:  {max_steps:,}")
print(f"  Warmup:     {WARMUP_STEPS}")
print("="*70 + "\n")

global_step = 0
best_val_loss = float('inf')
start_time = time.time()

model.train()
for epoch in range(NUM_EPOCHS):
    print(f"\n{'‚îÅ'*70}")
    print(f"  üìÖ EPOCH {epoch+1}/{NUM_EPOCHS}")
    print(f"{'‚îÅ'*70}")

    epoch_loss = 0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch_idx, (x, y) in enumerate(progress):
        x, y = x.to(device), y.to(device)

        # Update LR
        lr = get_lr(global_step)
        for pg in optimizer.param_groups:
            pg['lr'] = lr

        # Forward
        with autocast(dtype=torch.float16):
            _, loss = model(x, y)
            loss = loss / GRAD_ACCUM

        scaler.scale(loss).backward()
        epoch_loss += loss.item() * GRAD_ACCUM

        # Gradient step
        if (batch_idx + 1) % GRAD_ACCUM == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1

            # Update progress
            avg_loss = epoch_loss / (batch_idx + 1)
            progress.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr:.2e}', 'step': global_step})

            # Log
            if global_step % LOG_EVERY == 0:
                elapsed = time.time() - start_time
                print(f"\n  Step {global_step:>5} | Loss: {avg_loss:.4f} | LR: {lr:.2e} | Time: {elapsed/60:.1f}m")

            # Eval
            if global_step % EVAL_EVERY == 0:
                print("\n  üìù Generation test:")
                prompts = [
                    "[USER]: What is your name?\n[RUX-D1]:",
                    "[USER]: Write a Python function.\n[RUX-D1]:",
                ]
                for p in prompts:
                    gen = test_generate(p, 60)
                    print(f"     {gen[:150]}...")
                print()

            # Save
            if global_step % SAVE_EVERY == 0:
                torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'step': global_step,
                    'loss': avg_loss,
                }, f"checkpoints/rux_d1_step{global_step}.pt")
                print(f"  üíæ Checkpoint saved: step {global_step}")

    # End of epoch
    print(f"\n  ‚úÖ Epoch {epoch+1} done! Avg loss: {epoch_loss/len(train_loader):.4f}")

# Training complete
total_time = time.time() - start_time
print("\n" + "="*70)
print("  üéâ TRAINING COMPLETE!")
print("="*70)
print(f"  Total time: {total_time/3600:.2f} hours")
print(f"  Total steps: {global_step:,}")
print("="*70)

# Save final
torch.save({
    'model': model.state_dict(),
    'config': config.__dict__,
}, "checkpoints/rux_d1_final.pt")
print("\nüíæ Final model saved: checkpoints/rux_d1_final.pt")

## Cell 7: Chat with RUX-D1! üí¨

In [None]:
# ============================================================
# CELL 7: CHAT WITH RUX-D1
# ============================================================
from tokenizers import Tokenizer

# Load model
checkpoint = torch.load("checkpoints/rux_d1_final.pt", map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
print("‚úÖ Model loaded!")

tokenizer = Tokenizer.from_file("rux_tokenizer.json")

def chat(user_input, max_tokens=200, temperature=0.7):
    prompt = f"[USER]: {user_input}\n[RUX-D1]:"
    enc = tokenizer.encode(prompt)
    input_ids = torch.tensor([enc.ids], dtype=torch.long).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_k=50,
            top_p=0.9
        )

    full = tokenizer.decode(output_ids[0].tolist())

    # Extract response
    if "[RUX-D1]:" in full:
        response = full.split("[RUX-D1]:")[-1].strip()
    else:
        response = full[len(prompt):].strip()

    if "[USER]" in response:
        response = response.split("[USER]")[0].strip()

    return response

# Test conversations
print("\n" + "="*70)
print("  ü§ñ RUX-D1 CHAT TEST")
print("="*70)

test_qs = [
    "What is your name?",
    "Who are you?",
    "Write a function to reverse a string.",
    "Explain what a hash table is.",
    "Hello!",
]

for q in test_qs:
    print(f"\nüë§ USER: {q}")
    response = chat(q)
    print(f"ü§ñ RUX-D1: {response}")
    print("-" * 50)

# Interactive
print("\n" + "="*70)
print("  üí¨ Interactive Chat (type 'quit' to exit)")
print("="*70)

while True:
    user = input("\nüë§ You: ").strip()
    if user.lower() in ('quit', 'exit', 'q'):
        print("ü§ñ RUX-D1: Goodbye! Happy coding! üëã")
        break
    if not user:
        continue
    response = chat(user)
    print(f"ü§ñ RUX-D1: {response}")

## Cell 8: Download Model

In [None]:
# ============================================================
# CELL 8: DOWNLOAD
# ============================================================
import shutil

# Create export folder
EXPORT_DIR = "rux_d1_export"
os.makedirs(EXPORT_DIR, exist_ok=True)

# Copy files
shutil.copy("checkpoints/rux_d1_final.pt", f"{EXPORT_DIR}/model.pt")
shutil.copy("rux_tokenizer.json", f"{EXPORT_DIR}/tokenizer.json")

# Save config
with open(f"{EXPORT_DIR}/config.json", "w") as f:
    json.dump(config.__dict__, f, indent=2)

# Zip
shutil.make_archive("rux_d1_700m", 'zip', EXPORT_DIR)

print(f"\n‚úÖ Export complete!")
print(f"   üì¶ rux_d1_700m.zip ({os.path.getsize('rux_d1_700m.zip')/1e6:.0f} MB)")

# Download
try:
    from google.colab import files
    print("\nüì• Starting download...")
    files.download('rux_d1_700m.zip')
except:
    print("\nüí° Not in Colab. Find file at: rux_d1_700m.zip")

# Save to Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    shutil.copy("rux_d1_700m.zip", "/content/drive/MyDrive/rux_d1_700m.zip")
    print("‚òÅÔ∏è Also saved to Google Drive!")
except:
    pass

print("\n" + "="*70)
print("  üéâ RUX-D1 700M COMPLETE!")
print("="*70)
print("  ‚úÖ Model trained")
print("  ‚úÖ Knows its name is RUX-D1")
print("  ‚úÖ Can write code")
print("  ‚úÖ Ready for deployment!")
print("="*70)

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/rux_d1_final.pt'