# Chimera Abyss Training - A100 Beast Mode

**~2.1B Parameter Model | 96 Layers Deep | Pure Madness**

Training data mix:
- Chat/Conversation (OpenAssistant, Dolly, UltraChat)
- Creative Writing (WritingPrompts, TinyStories)
- Eloquent Prose (Project Gutenberg)
- General Knowledge (Wikipedia)

## 1. Environment Setup

In [None]:
!nvidia-smi
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"Device: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
!pip install -q transformers sentencepiece accelerate datasets tqdm matplotlib
print("Dependencies installed!")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
CHECKPOINT_DIR = '/content/drive/MyDrive/chimera_abyss'
DATA_DIR = '/content/drive/MyDrive/chimera_data'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"Data: {DATA_DIR}")

## 2. Upload Chimera Files

Upload `model.py` and `tokenizer.py`

In [None]:
# Option A: Upload manually
from google.colab import files
print("Upload model.py and tokenizer.py:")
# uploaded = files.upload()

In [None]:
# Option B: Copy from Drive
CHIMERA_DRIVE_PATH = '/content/drive/MyDrive/chimera'

import shutil
if os.path.exists(CHIMERA_DRIVE_PATH):
    for f in ['model.py', 'tokenizer.py']:
        src = os.path.join(CHIMERA_DRIVE_PATH, f)
        if os.path.exists(src):
            shutil.copy(src, '/content/')
            print(f"Copied {f}")
else:
    print(f"Path not found: {CHIMERA_DRIVE_PATH}")

In [None]:
assert os.path.exists('/content/model.py'), "model.py not found!"
assert os.path.exists('/content/tokenizer.py'), "tokenizer.py not found!"
print("Core files ready!")

## 3. Download & Prepare Training Data

In [None]:
import json
import random
import re
import hashlib
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass
from collections import defaultdict
from datasets import load_dataset
from tqdm.notebook import tqdm


@dataclass
class DatasetConfig:
    name: str
    hf_path: str
    hf_split: str = "train"
    hf_subset: Optional[str] = None
    text_field: str = "text"
    max_samples: Optional[int] = None
    min_length: int = 100
    max_length: int = 8000
    category: str = "general"
    weight: float = 1.0


DATASETS = {
    "oasst2": DatasetConfig(
        name="OpenAssistant OASST2",
        hf_path="OpenAssistant/oasst2",
        text_field="text",
        max_samples=50000,
        min_length=50,
        category="chat",
        weight=1.5,
    ),
    "dolly": DatasetConfig(
        name="Databricks Dolly 15k",
        hf_path="databricks/databricks-dolly-15k",
        text_field="instruction,response",
        max_samples=15000,
        min_length=50,
        category="chat",
        weight=1.2,
    ),
    "ultrachat": DatasetConfig(
        name="UltraChat",
        hf_path="stingning/ultrachat",
        text_field="data",
        max_samples=80000,
        min_length=100,
        category="chat",
        weight=1.3,
    ),
    "writing_prompts": DatasetConfig(
        name="WritingPrompts",
        hf_path="euclaise/writingprompts",
        text_field="story",
        max_samples=40000,
        min_length=200,
        max_length=10000,
        category="creative",
        weight=1.5,
    ),
    "tinystories": DatasetConfig(
        name="TinyStories",
        hf_path="roneneldan/TinyStories",
        text_field="text",
        max_samples=80000,
        min_length=100,
        category="creative",
        weight=1.0,
    ),
    "gutenberg": DatasetConfig(
        name="Project Gutenberg",
        hf_path="sedthh/gutenberg_english",
        text_field="TEXT",
        max_samples=3000,
        min_length=500,
        max_length=15000,
        category="eloquent",
        weight=2.0,
    ),
    "wikipedia": DatasetConfig(
        name="Wikipedia",
        hf_path="wikipedia",
        hf_subset="20220301.simple",
        text_field="text",
        max_samples=40000,
        min_length=200,
        max_length=5000,
        category="general",
        weight=0.8,
    ),
}

print(f"Configured {len(DATASETS)} datasets")

In [None]:
def clean_text(text):
    if not text:
        return ""
    # Normalize whitespace
    text = ' '.join(text.split())
    # Remove URLs
    text = re.sub(r'https?://[^\s]+', '', text)
    # Remove reddit artifacts
    text = text.replace('[removed]', '').replace('[deleted]', '')
    # Normalize quotes
    for old, new in [('"', '"'), ('"', '"'), ("'", "'"), ("'", "'")]:
        text = text.replace(old, new)
    return text.strip()


def is_quality_text(text, min_len, max_len):
    if not text or len(text) < min_len or len(text) > max_len:
        return False
    words = text.split()
    if len(words) < 20:
        return False
    if len(set(words)) / len(words) < 0.3:
        return False
    alpha_count = sum(1 for c in text if c.isalpha())
    if alpha_count / len(text) < 0.6:
        return False
    return True


def format_conversation(messages):
    formatted = []
    for i, msg in enumerate(messages):
        if isinstance(msg, dict):
            role = msg.get('role', msg.get('from', 'user'))
            content = msg.get('content', msg.get('value', msg.get('text', '')))
        else:
            role = 'user' if i % 2 == 0 else 'assistant'
            content = str(msg)
        
        if role in ['user', 'human', 'prompter']:
            formatted.append(f"Human: {content}")
        elif role in ['assistant', 'gpt', 'bot']:
            formatted.append(f"Assistant: {content}")
        else:
            formatted.append(content)
    return "\n\n".join(formatted)


print("Helper functions ready!")

In [None]:
def process_dataset(name, config):
    print(f"\n{'='*60}")
    print(f"Processing: {config.name}")
    print(f"{'='*60}")
    
    texts = []
    
    try:
        if config.hf_subset:
            dataset = load_dataset(config.hf_path, config.hf_subset, 
                                   split=config.hf_split, streaming=True)
        else:
            dataset = load_dataset(config.hf_path, split=config.hf_split, 
                                   streaming=True, trust_remote_code=True)
        
        count = 0
        pbar = tqdm(dataset, desc=config.name, total=config.max_samples)
        
        for item in pbar:
            if config.max_samples and count >= config.max_samples:
                break
            
            text = ""
            
            if 'oasst' in name:
                text = item.get('text', '')
                role = item.get('role', 'user')
                if role == 'prompter':
                    text = f"Human: {text}"
                else:
                    text = f"Assistant: {text}"
                    
            elif 'dolly' in name:
                inst = item.get('instruction', '')
                ctx = item.get('context', '')
                resp = item.get('response', '')
                if ctx:
                    text = f"Human: {inst}\n\nContext: {ctx}\n\nAssistant: {resp}"
                else:
                    text = f"Human: {inst}\n\nAssistant: {resp}"
                    
            elif 'ultrachat' in name:
                data = item.get('data', [])
                if isinstance(data, list) and len(data) >= 2:
                    text = format_conversation(data)
                    
            elif 'writing_prompts' in name:
                prompt = item.get('prompt', '')
                story = item.get('story', '')
                # Clean WP tags
                prompt = re.sub(r'^\s*\[WP\]|\[OT\]|\[EU\]', '', prompt).strip()
                if story:
                    text = f"Prompt: {prompt}\n\n{story}" if prompt else story
                    
            else:
                if ',' in config.text_field:
                    fields = [f.strip() for f in config.text_field.split(',')]
                    parts = [str(item.get(f, '')) for f in fields if item.get(f)]
                    text = '\n\n'.join(parts)
                else:
                    text = item.get(config.text_field, '')
            
            text = clean_text(text)
            if is_quality_text(text, config.min_length, config.max_length):
                texts.append(text)
                count += 1
                pbar.set_postfix({'kept': count})
        
        pbar.close()
        print(f"Collected {len(texts):,} samples")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
    
    return texts


print("Dataset processor ready!")

In [None]:
# SELECT DATASETS (comment out ones you don't want)
DATASETS_TO_USE = [
    "oasst2",
    "dolly",
    "ultrachat",
    "writing_prompts",
    "tinystories",
    "gutenberg",
    "wikipedia",
]

print(f"Will process {len(DATASETS_TO_USE)} datasets")

In [None]:
DATA_PATH = os.path.join(DATA_DIR, 'abyss_train.txt')

if os.path.exists(DATA_PATH):
    size_gb = os.path.getsize(DATA_PATH) / 1e9
    print(f"Data exists: {DATA_PATH} ({size_gb:.2f} GB)")
    print("Delete to regenerate, or skip to next section.")
    SKIP_DATA_PREP = True
else:
    print("Will download and process data.")
    SKIP_DATA_PREP = False

In [None]:
# DOWNLOAD AND PROCESS (takes 15-30 min)
if not SKIP_DATA_PREP:
    all_texts = []
    category_counts = defaultdict(int)
    
    for name in DATASETS_TO_USE:
        cfg = DATASETS[name]
        texts = process_dataset(name, cfg)
        
        if cfg.weight > 1.0:
            texts = texts * int(cfg.weight)
        
        all_texts.extend(texts)
        category_counts[cfg.category] += len(texts)
    
    print(f"\nTotal: {len(all_texts):,}")
    
    random.seed(42)
    random.shuffle(all_texts)
    
    # Dedupe
    seen = set()
    unique = []
    for t in tqdm(all_texts, desc="Dedup"):
        h = hashlib.md5(t[:500].encode()).hexdigest()
        if h not in seen:
            seen.add(h)
            unique.append(t)
    print(f"Removed {len(all_texts) - len(unique):,} dupes")
    all_texts = unique
    
    with open(DATA_PATH, 'w', encoding='utf-8') as f:
        for t in tqdm(all_texts, desc="Writing"):
            f.write(t.strip() + '\n\n')
    
    print(f"\nSaved: {DATA_PATH}")
    print(f"Size: {os.path.getsize(DATA_PATH)/1e9:.2f} GB")
    for cat, cnt in sorted(category_counts.items()):
        print(f"  {cat}: {cnt:,}")
else:
    print("Skipping (data exists)")

## 4. Step Checker

In [None]:
import time
import math
from datetime import datetime, timedelta
from collections import deque
import matplotlib.pyplot as plt
from IPython.display import clear_output


class StepChecker:
    def __init__(self, max_steps, log_every=10, plot_every=100):
        self.max_steps = max_steps
        self.log_every = log_every
        self.plot_every = plot_every
        self.steps, self.losses, self.ppls = [], [], []
        self.gpu_mem, self.tps = [], []
        self.start_time = None
        self.step_times = deque(maxlen=100)
        self.last_time = None
        self.best_loss = float('inf')
        self.best_step = 0
        
    def start(self):
        self.start_time = time.time()
        self.last_time = self.start_time
        print("="*70)
        print("CHIMERA ABYSS - STEP CHECKER")
        print(f"Max Steps: {self.max_steps:,}")
        print("="*70)
        
    def log(self, step, loss, ppl, lr, tokens):
        now = time.time()
        dt = now - self.last_time
        self.step_times.append(dt)
        self.last_time = now
        
        gpu = torch.cuda.memory_allocated() / 1e9
        self.steps.append(step)
        self.losses.append(loss)
        self.ppls.append(min(ppl, 1000))
        self.gpu_mem.append(gpu)
        
        t = tokens / dt if dt > 0 else 0
        self.tps.append(t)
        
        if loss < self.best_loss:
            self.best_loss = loss
            self.best_step = step
        
        avg_dt = sum(self.step_times) / len(self.step_times)
        eta = timedelta(seconds=int((self.max_steps - step) * avg_dt))
        pct = step / self.max_steps
        bar = '#' * int(40 * pct) + '-' * (40 - int(40 * pct))
        
        msg = f"[{bar}] {pct*100:.1f}% | Step {step:,} | Loss {loss:.4f} | PPL {ppl:.2f} | GPU {gpu:.1f}GB | {t:.0f} tok/s | ETA {eta}"
        print(f"\r{msg}", end="")
        
        if step % self.log_every == 0:
            print()
        if step % self.plot_every == 0 and step > 0:
            self.plot()
            
    def plot(self):
        clear_output(wait=True)
        fig, ax = plt.subplots(2, 2, figsize=(14, 10))
        
        ax[0,0].plot(self.steps, self.losses, 'b-', alpha=0.7)
        ax[0,0].axhline(self.best_loss, color='g', linestyle='--', label=f'Best: {self.best_loss:.4f}')
        ax[0,0].set_title('Loss'); ax[0,0].legend(); ax[0,0].grid(True, alpha=0.3)
        
        ax[0,1].plot(self.steps, self.ppls, 'r-', alpha=0.7)
        ax[0,1].set_title('Perplexity'); ax[0,1].set_yscale('log'); ax[0,1].grid(True, alpha=0.3)
        
        ax[1,0].plot(self.steps, self.gpu_mem, 'g-', alpha=0.7)
        ax[1,0].axhline(80, color='r', linestyle='--', label='A100 80GB')
        ax[1,0].set_title('VRAM (GB)'); ax[1,0].set_ylim(0, 85); ax[1,0].legend(); ax[1,0].grid(True, alpha=0.3)
        
        ax[1,1].plot(self.steps, self.tps, 'm-', alpha=0.7)
        ax[1,1].set_title('Throughput (tok/s)'); ax[1,1].grid(True, alpha=0.3)
        
        plt.tight_layout(); plt.show()
        elapsed = timedelta(seconds=int(time.time() - self.start_time))
        print(f"Elapsed: {elapsed} | Best: {self.best_loss:.4f} @ {self.best_step}")
        
    def finish(self):
        elapsed = timedelta(seconds=int(time.time() - self.start_time))
        print(f"\n{'='*70}")
        print(f"COMPLETE | Time: {elapsed} | Final: {self.losses[-1]:.4f} | Best: {self.best_loss:.4f}")
        print("="*70)
        self.plot()
        
    def save_history(self, path):
        with open(path, 'w') as f:
            json.dump({'steps': self.steps, 'losses': self.losses, 'best': self.best_loss}, f)


print("StepChecker ready!")

## 5. Config

In [None]:
@dataclass
class TrainConfig:
    model_config: str = "abyss"
    data_path: str = DATA_PATH
    seq_length: int = 1024
    micro_batch_size: int = 6
    gradient_accumulation_steps: int = 16
    learning_rate: float = 2e-4
    min_lr: float = 1e-5
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0
    warmup_steps: int = 500
    max_steps: int = 15000
    mixed_precision: bool = True
    compile_model: bool = True
    output_dir: str = CHECKPOINT_DIR
    save_every: int = 500
    log_every: int = 10
    plot_every: int = 200
    generate_every: int = 500
    resume_from: Optional[str] = None


config = TrainConfig()
print("Config ready")

## 6. Model Init

In [None]:
import sys
sys.path.insert(0, '/content')

from model import Chimera, chimera_abyss
from tokenizer import ChimeraTokenizer
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader

device = torch.device('cuda')
tokenizer = ChimeraTokenizer()
print(f"Vocab: {tokenizer.vocab_size}")

model_cfg = chimera_abyss()
model_cfg.vocab_size = tokenizer.vocab_size

print(f"Layers: {model_cfg.n_layers}, d_model: {model_cfg.d_model}")

model = Chimera(model_cfg).to(device)
print(f"Params: {model.get_num_params():,} ({model.get_num_params()/1e9:.2f}B)")

if config.compile_model:
    print("Compiling...")
    model = torch.compile(model)
    print("Done!")

In [None]:
torch.cuda.synchronize()
print(f"VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB / 80GB")

In [None]:
class PackedDataset(IterableDataset):
    def __init__(self, data_path, tokenizer, seq_length=512):
        self.data_path = data_path
        self.tokenizer = tokenizer
        self.seq_length = seq_length

    def __iter__(self):
        buffer = []
        with open(self.data_path, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                tokens = self.tokenizer.encode(line, add_bos=True, add_eos=True)
                buffer.extend(tokens)
                while len(buffer) >= self.seq_length + 1:
                    seq = buffer[:self.seq_length + 1]
                    yield {"input_ids": torch.tensor(seq[:-1]), "labels": torch.tensor(seq[1:])}
                    buffer = buffer[self.seq_length:]


train_dataset = PackedDataset(config.data_path, tokenizer, config.seq_length)
train_loader = DataLoader(train_dataset, batch_size=config.micro_batch_size, num_workers=2, pin_memory=True)
print("DataLoader ready")

In [None]:
def create_optimizer(model, cfg):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if "norm" in n or "bias" in n or "embedding" in n:
            no_decay.append(p)
        else:
            decay.append(p)
    return torch.optim.AdamW(
        [{"params": decay, "weight_decay": cfg.weight_decay},
         {"params": no_decay, "weight_decay": 0.0}],
        lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2), fused=True)


def get_cosine_schedule(opt, warmup, max_steps, min_ratio=0.1):
    def lr_lambda(step):
        if step < warmup:
            return step / max(1, warmup)
        progress = (step - warmup) / max(1, max_steps - warmup)
        return max(min_ratio, 0.5 * (1 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)


optimizer = create_optimizer(model, config)
scheduler = get_cosine_schedule(optimizer, config.warmup_steps, config.max_steps, config.min_lr/config.learning_rate)
scaler = torch.amp.GradScaler('cuda')
print("Optimizer ready")

In [None]:
@torch.no_grad()
def generate_sample(model, tokenizer, prompt, max_tokens=100, temperature=0.8):
    model.eval()
    ids = torch.tensor([tokenizer.encode(prompt, add_bos=True)], device=device)
    for _ in range(max_tokens):
        with torch.amp.autocast('cuda', dtype=torch.float16):
            logits, _ = model(ids)
        logits = logits[:, -1, :] / temperature
        next_tok = torch.multinomial(F.softmax(logits, dim=-1), 1)
        ids = torch.cat([ids, next_tok], 1)
        if next_tok.item() == tokenizer.eos_token_id:
            break
    model.train()
    return tokenizer.decode(ids[0].tolist())


TEST_PROMPTS = [
    "Human: What is the meaning of life?\n\nAssistant:",
    "Once upon a time, in a kingdom far away,",
    "The ancient tome spoke of a prophecy:",
    "Human: Write me a poem about the stars.\n\nAssistant:",
]
print("Generation ready")

## 7. Training

In [None]:
start_step = 0
latest_path = os.path.join(config.output_dir, 'latest.pt')

if config.resume_from and os.path.exists(config.resume_from):
    ckpt = torch.load(config.resume_from, map_location=device)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    scheduler.load_state_dict(ckpt['scheduler'])
    if 'scaler' in ckpt:
        scaler.load_state_dict(ckpt['scaler'])
    start_step = ckpt['step']
    print(f"Resumed from {start_step}")
elif os.path.exists(latest_path):
    ckpt = torch.load(latest_path, map_location=device)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    scheduler.load_state_dict(ckpt['scheduler'])
    if 'scaler' in ckpt:
        scaler.load_state_dict(ckpt['scaler'])
    start_step = ckpt['step']
    print(f"Resumed from latest: {start_step}")

In [None]:
step_checker = StepChecker(config.max_steps, config.log_every, config.plot_every)
step_checker.start()

model.train()
data_iter = iter(train_loader)
step = start_step
micro_step = 0
running_loss = 0.0
running_tokens = 0

try:
    while step < config.max_steps:
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            batch = next(data_iter)
        
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        with torch.amp.autocast('cuda', dtype=torch.float16):
            logits, _ = model(input_ids)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
            loss = loss / config.gradient_accumulation_steps
        
        scaler.scale(loss).backward()
        running_loss += loss.item() * config.gradient_accumulation_steps
        running_tokens += input_ids.numel()
        micro_step += 1
        
        if micro_step % config.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            step += 1
            
            avg_loss = running_loss / config.gradient_accumulation_steps
            ppl = math.exp(min(avg_loss, 20))
            step_checker.log(step, avg_loss, ppl, scheduler.get_last_lr()[0], running_tokens)
            running_loss = 0
            running_tokens = 0
            
            if step % config.save_every == 0:
                ckpt = {'step': step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict()}
                torch.save(ckpt, os.path.join(config.output_dir, f'step_{step}.pt'))
                torch.save(ckpt, os.path.join(config.output_dir, 'latest.pt'))
                print(f"\n[Saved step_{step}.pt]")
            
            if step % config.generate_every == 0:
                print(f"\n{'='*60}\nSAMPLES @ {step}\n{'='*60}")
                for p in TEST_PROMPTS:
                    out = generate_sample(model, tokenizer, p, 100)
                    print(f"\n> {p[:50]}...\n{out}")
                print('='*60)

except KeyboardInterrupt:
    print("\nInterrupted! Saving...")
    ckpt = {'step': step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict()}
    torch.save(ckpt, os.path.join(config.output_dir, f'interrupted_{step}.pt'))
    torch.save(ckpt, os.path.join(config.output_dir, 'latest.pt'))

step_checker.finish()
step_checker.save_history(os.path.join(config.output_dir, 'history.json'))
torch.save(model.state_dict(), os.path.join(config.output_dir, 'model_final.pt'))
print(f"\nSaved: {config.output_dir}/model_final.pt")

## 8. Test & Export

In [None]:
print("Chat with Chimera (type 'quit' to exit):\n")
while True:
    user = input("You: ")
    if user.lower() == 'quit':
        break
    prompt = f"Human: {user}\n\nAssistant:"
    response = generate_sample(model, tokenizer, prompt, 200, 0.8)
    if "Assistant:" in response:
        response = response.split("Assistant:")[-1].strip()
    print(f"\nChimera: {response}\n")

In [None]:
export_path = os.path.join(config.output_dir, 'chimera_abyss_weights.pt')
torch.save(model.state_dict(), export_path)
print(f"Exported: {export_path}")
print(f"Size: {os.path.getsize(export_path)/1e9:.2f} GB")