In [167]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np
import os
from tqdm import tqdm

In [2]:
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")

2025-01-28 02:17:09.296228: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-28 02:17:09.524663: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738030629.640539    2945 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738030629.673492    2945 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-28 02:17:09.890318: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [33]:
torch.cuda.empty_cache()

In [6]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [8]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 163290.84 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 321960.90 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 252030.92 examples/s]


In [9]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

Map: 100%|██████████| 4358/4358 [00:00<00:00, 4433.20 examples/s]
Map: 100%|██████████| 36718/36718 [00:05<00:00, 6471.21 examples/s]
Map: 100%|██████████| 3760/3760 [00:00<00:00, 7254.66 examples/s]


In [10]:
class WikiTextDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.input_ids = tokenized_dataset["input_ids"]
        self.attention_mask = tokenized_dataset["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx], dtype=torch.long),
            "attention_mask": torch.tensor(self.attention_mask[idx], dtype=torch.long),
        }

# Create PyTorch datasets for each split
train_dataset = WikiTextDataset(tokenized_datasets["train"])
val_dataset = WikiTextDataset(tokenized_datasets["validation"])
test_dataset = WikiTextDataset(tokenized_datasets["test"])

In [113]:
model.zero_grad()

# Create a dummy input
dummy_text = "Say only one word: butterfly."
inputs = tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

# Perform inference and measure memory usage
with torch.no_grad():
    torch.cuda.reset_peak_memory_stats(device)
    outputs = model(**inputs)
    peak_memory = torch.cuda.max_memory_allocated(device) / 1024**3  # Convert to GB
    print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

Peak GPU memory usage: 30.00 GB


In [116]:
output_dir = "residual_stream_dataset"
os.makedirs(output_dir, exist_ok=True)

In [117]:
# Hook function to capture residual stream
residual_stream = None

def hook_fn(module, input, output):
    global residual_stream
    residual_stream = output

# Register the hook for the 16th transformer layer (index 15)
layer_to_hook = 15
hook_handle = model.model.layers[layer_to_hook].register_forward_hook(hook_fn)

In [168]:
# Function to process batches
def process_batches(data_loader, output_file):
    global residual_stream
    residual_vectors = []

    with tqdm(total=len(data_loader), desc="Processing batches") as pbar:
        for batch_idx, batch in enumerate(data_loader):
            # Move inputs to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
    
            # Run the model
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
            # Extract residual stream vectors
            if residual_stream is not None:
                # residual_stream: [batch_size, seq_len, 4096]
                res = residual_stream[0]
                for i in range(res.size(0)):  # Iterate over batch
                    valid_vectors = res[i][attention_mask[i].bool()]  # Remove padding
                    residual_vectors.append(valid_vectors.cpu().numpy())
            
            # Save in chunks to avoid memory issues
            if batch_idx % 10 == 0 and batch_idx > 0:
                save_vectors(residual_vectors, output_file, mode='a')  # Append mode
                residual_vectors.clear()  # Clear buffer
    
            # Update progress bar
            pbar.update(1)

    # Save remaining vectors
    save_vectors(residual_vectors, output_file, mode='a')

In [169]:
def save_vectors(vectors, output_file, mode='w'):
    if not vectors:
        return
    # Flatten nested list of NumPy arrays
    vectors = np.vstack(vectors)
    if mode == 'w':  # Write mode
        np.save(output_file, vectors)
    elif mode == 'a':  # Append mode
        with open(output_file, 'ab') as f:
            np.save(f, vectors)

In [178]:
batch_size = 5

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [179]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [180]:
model.zero_grad()

In [181]:
residual_stream = None
# Process train_loader
process_batches(train_loader, os.path.join(output_dir, "train_residuals.npy"))

Processing batches:   2%|▏         | 164/7344 [03:11<2:19:31,  1.17s/it]


KeyboardInterrupt: 

In [None]:
residual_stream = None
# Process val_loader
process_batches(val_loader, os.path.join(output_dir, "val_residuals.npy"))

In [None]:
residual_stream = None
# Process test_loader
process_batches(test_loader, os.path.join(output_dir, "test_residuals.npy"))

In [None]:
# Remove the hook after processing
hook_handle.remove()

In [182]:
data = np.load("residual_stream_dataset/train_residuals.npy", allow_pickle=True)

In [184]:
data.shape

(2036, 4096)

In [128]:
model.zero_grad()

# Example input prompt
prompt = "The quick brown fox"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")  # Move to GPU if available

# Initialize variables
generated_ids = inputs["input_ids"]  # Start with the tokenized input
max_new_tokens = 50  # Limit for generation
eos_token_id = tokenizer.eos_token_id  # EOS token ID
temperature = 0.7  # Set desired temperature

# Loop for generation
for _ in range(max_new_tokens):
    # Generate logits
    with torch.no_grad():
        outputs = model(input_ids=generated_ids)
    next_token_logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
    
    # Apply temperature scaling
    next_token_logits = next_token_logits / temperature
    
    # Convert logits to probabilities
    probabilities = torch.softmax(next_token_logits, dim=-1)
    
    # Sample the next token
    next_token_id = torch.multinomial(probabilities, num_samples=1)  # Sample from the distribution
    
    # Append the new token to the sequence
    generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
    
    # Check if EOS token is generated
    if eos_token_id is not None and next_token_id.item() == eos_token_id:
        print("EOS token generated. Stopping generation.")
        break

# Decode the full generated text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")

Generated text: The quick brown fox jumps over the lazy dog. The lazy caterpillar goes to the green garden. The boy with the big backpack walks past the little girl with the red hat. The man with the blue shirt buys a new kite. The red balloon floats in the sky


In [36]:
residual_stream = None  # Variable to store the hidden state

def hook_fn(module, input, output):
    global residual_stream
    residual_stream = output

In [38]:
model.model.layers[15]

LlamaDecoderLayer(
  (self_attn): LlamaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
  (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)

In [39]:
layer_to_hook = 15  # 16th layer
hook_handle = model.model.layers[layer_to_hook].register_forward_hook(hook_fn)

In [115]:
# Create a dummy input
dummy_texts = [str(x) for x in range(1024)]
inputs = tokenizer(dummy_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

# Perform inference and measure memory usage
with torch.no_grad():
    torch.cuda.reset_peak_memory_stats(device)
    outputs = model(**inputs)
    peak_memory = torch.cuda.max_memory_allocated(device) / 1024**3  # Convert to GB
    print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

Peak GPU memory usage: 32.27 GB


In [105]:
print(residual_stream[0].shape)  # Expected shape: (batch_size, seq_len, 4096)

torch.Size([2, 11, 4096])


In [106]:
print("Input IDs:", inputs["input_ids"])  # Tokenized input IDs, including padding tokens
print("Attention Mask:", inputs["attention_mask"])  # Mask where 1 indicates real tokens, 0 indicates padding
print("Pad Token ID:", tokenizer.pad_token_id)  # ID used for padding

Input IDs: tensor([[128000,     40,   1097,   1097,   1097,   1097,   1097,   1097,   1097,
           1097,     13],
        [128001, 128001, 128001, 128001, 128001, 128000,    791,  13180,    374,
           6437,     13]], device='cuda:0')
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]], device='cuda:0')
Pad Token ID: 128001


In [110]:
# Create a dummy input
dummy_texts = ["I.", "The sky is blue.", "3","3","3","3","3","3","3","3"]
inputs = tokenizer(dummy_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

# Perform inference and measure memory usage
with torch.no_grad():
    torch.cuda.reset_peak_memory_stats(device)
    outputs = model(**inputs)
    peak_memory = torch.cuda.max_memory_allocated(device) / 1024**3  # Convert to GB
    print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

Peak GPU memory usage: 30.28 GB


In [111]:
print(residual_stream[0].shape) 

torch.Size([10, 6, 4096])


In [112]:
print(torch.cuda.memory_summary(device))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 4         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  30719 MiB |  31011 MiB | 231310 MiB | 200591 MiB |
|       from large pool |  30698 MiB |  30989 MiB | 155812 MiB | 125114 MiB |
|       from small pool |     20 MiB |     28 MiB |  75498 MiB |  75477 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  30719 MiB |  31011 MiB | 231310 MiB | 200591 MiB |
|       from large pool |  30698 MiB |  30989 MiB | 155812 MiB | 125114 MiB |
|       from small pool |     20 MiB |     28 MiB |  75498 MiB |  75477 MiB |
|---------------------------------------------------------------