# Extract W and Black Holes from Qwen 3 4B Instruct 2507

**Purpose:** Load the model, extract the unembedding matrix W, identify black holes (tokens with non-unique embeddings), and save everything as reusable artifacts.

**Outputs:** `../tensors/qwen3_4b_instruct_2507.safetensors` containing:
- `W` — The unembedding matrix [V, D] in bfloat16
- `black_hole_mask` — Boolean mask, True for tokens in any black hole [V]
- `black_hole_id` — uint8 assignment: 0 = not black hole, 1 = largest hole (814 tokens), 2 = second largest, etc. [V]
- `black_hole_centroids` — The actual embedding vector for each black hole [num_holes, D] in bfloat16

**Runtime:** ~30 seconds (model already cached)

---

*Jeffery Harrell & Alpha, December 1, 2025*

In [1]:
import torch
from transformers import AutoModelForCausalLM
from safetensors.torch import save_file
from collections import Counter
from pathlib import Path

## Load the Model

We load in bfloat16 to preserve the exact quantization from training. The model is ~8GB; we only need the unembedding matrix.

In [2]:
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="cpu",  # Don't need GPU for extraction
)

print(f"Model loaded: {MODEL_ID}")
print(f"Vocab size: {model.config.vocab_size:,}")
print(f"Hidden dim: {model.config.hidden_size:,}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded: Qwen/Qwen3-4B-Instruct-2507
Vocab size: 151,936
Hidden dim: 2,560


## Extract W

The unembedding matrix is `lm_head.weight` in Qwen's architecture. Shape: [vocab_size, hidden_dim].

In [3]:
W = model.lm_head.weight.detach().clone()

print(f"W shape: {W.shape}")
print(f"W dtype: {W.dtype}")
print(f"W memory: {W.numel() * 2 / 1e6:.1f} MB")

W shape: torch.Size([151936, 2560])
W dtype: torch.bfloat16
W memory: 777.9 MB


## Find Black Holes

A "black hole" is a set of tokens that share a bit-for-bit identical embedding vector. We find them using `torch.unique()` with `return_inverse=True`.

In [4]:
# Find unique vectors and which unique vector each token maps to
unique_vectors, inverse_indices = torch.unique(W, dim=0, return_inverse=True)

num_unique = unique_vectors.shape[0]
num_total = W.shape[0]
num_duplicated = num_total - num_unique

print(f"Total tokens: {num_total:,}")
print(f"Unique vectors: {num_unique:,}")
print(f"Duplicated tokens: {num_duplicated:,}")

Total tokens: 151,936
Unique vectors: 149,849
Duplicated tokens: 2,087


In [5]:
# Count how many tokens map to each unique vector
# A black hole is any unique vector with count > 1
counts = Counter(inverse_indices.tolist())

# Find which unique-vector indices are black holes (count > 1)
black_hole_unique_ids = {uid: count for uid, count in counts.items() if count > 1}

print(f"Number of black holes: {len(black_hole_unique_ids)}")
print(f"\nBlack hole populations (sorted):")
for uid, count in sorted(black_hole_unique_ids.items(), key=lambda x: -x[1]):
    print(f"  {count:,} tokens")

Number of black holes: 13

Black hole populations (sorted):
  814 tokens
  704 tokens
  306 tokens
  228 tokens
  11 tokens
  10 tokens
  6 tokens
  5 tokens
  4 tokens
  4 tokens
  3 tokens
  3 tokens
  2 tokens


## Build Masks and ID Vectors

- `black_hole_mask[i]` = True if token i is in any black hole
- `black_hole_id[i]` = 0 if not in a black hole, else 1 for largest hole, 2 for second largest, etc.

In [6]:
# Sort black holes by population (descending)
sorted_holes = sorted(black_hole_unique_ids.items(), key=lambda x: -x[1])

# Map from unique-vector-id to black-hole-id (1-indexed)
uid_to_bhid = {uid: bhid + 1 for bhid, (uid, _) in enumerate(sorted_holes)}

# Build the tensors
V = W.shape[0]
black_hole_mask = torch.zeros(V, dtype=torch.bool)
black_hole_id = torch.zeros(V, dtype=torch.uint8)

for token_id, unique_id in enumerate(inverse_indices.tolist()):
    if unique_id in uid_to_bhid:
        black_hole_mask[token_id] = True
        black_hole_id[token_id] = uid_to_bhid[unique_id]

print(f"Tokens in black holes: {black_hole_mask.sum().item():,}")
print(f"Max black hole ID: {black_hole_id.max().item()}")

Tokens in black holes: 2,100
Max black hole ID: 13


In [7]:
# Verify: count tokens per black hole ID
print("Verification — tokens per black hole:")
for bhid in range(1, black_hole_id.max().item() + 1):
    count = (black_hole_id == bhid).sum().item()
    print(f"  Black hole {bhid}: {count:,} tokens")

Verification — tokens per black hole:
  Black hole 1: 814 tokens
  Black hole 2: 704 tokens
  Black hole 3: 306 tokens
  Black hole 4: 228 tokens
  Black hole 5: 11 tokens
  Black hole 6: 10 tokens
  Black hole 7: 6 tokens
  Black hole 8: 5 tokens
  Black hole 9: 4 tokens
  Black hole 10: 4 tokens
  Black hole 11: 3 tokens
  Black hole 12: 3 tokens
  Black hole 13: 2 tokens


## Extract Black Hole Centroids

Each black hole has a single embedding vector shared by all its members. We save these for easy access.

In [8]:
# Get the centroid for each black hole (just grab the first member's embedding)
num_holes = len(sorted_holes)
D = W.shape[1]
black_hole_centroids = torch.zeros(num_holes, D, dtype=W.dtype)

for bhid in range(1, num_holes + 1):
    # Find first token with this black hole ID
    first_token = (black_hole_id == bhid).nonzero()[0].item()
    black_hole_centroids[bhid - 1] = W[first_token]

print(f"Black hole centroids shape: {black_hole_centroids.shape}")

Black hole centroids shape: torch.Size([13, 2560])


## Save Artifacts

In [9]:
output_path = Path("../tensors/qwen3_4b_instruct_2507.safetensors")
output_path.parent.mkdir(parents=True, exist_ok=True)

tensors = {
    "W": W,
    "black_hole_mask": black_hole_mask,
    "black_hole_id": black_hole_id,
    "black_hole_centroids": black_hole_centroids,
}

save_file(tensors, output_path)

file_size_mb = output_path.stat().st_size / 1e6
print(f"Saved to: {output_path}")
print(f"File size: {file_size_mb:.1f} MB")

Saved to: ../tensors/qwen3_4b_instruct_2507.safetensors
File size: 778.3 MB


## Summary

In [10]:
print("="*60)
print("EXTRACTION COMPLETE")
print("="*60)
print(f"Model: {MODEL_ID}")
print(f"Vocabulary: {V:,} tokens")
print(f"Embedding dim: {D:,}")
print(f"Black holes: {num_holes}")
print(f"Tokens in black holes: {black_hole_mask.sum().item():,}")
print(f"Output: {output_path}")
print("="*60)

EXTRACTION COMPLETE
Model: Qwen/Qwen3-4B-Instruct-2507
Vocabulary: 151,936 tokens
Embedding dim: 2,560
Black holes: 13
Tokens in black holes: 2,100
Output: ../tensors/qwen3_4b_instruct_2507.safetensors
