In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [4]:
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from rich import print as rprint
from typing import Dict, List
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList

torch.manual_seed(42)

<torch._C.Generator at 0x767008286990>

In [5]:
token_count = torch.load("saved_models/openwebtext_token_freq.pt")

token_probs = token_count / token_count.sum()

probability_mass = token_probs.topk(5000).values.sum()



In [6]:
model_name = "meta-llama/Llama-2-7b-hf"
#model_name = "EleutherAI/pythia-1b"
#model_name = "/assets/models/meta-llama-3.1-8b"
load_model = True
compute_freq = False
total = 3000
save_interval = 100
start = 0
batch_size = 32

In [7]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    device_map="auto")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

if load_model:
    model = AutoModelForCausalLM.from_pretrained(model_name,
                                                 device_map="auto")
    model.resize_token_embeddings(len(tokenizer))
    model.eval()

In [8]:
device = torch.device("cuda")
print("Using device: ", device)

Using device:  cuda


In [10]:
tokenizer.convert_ids_to_tokens(token_probs.topk(500).indices)

['<0x0A>',
 '▁the',
 ',',
 '.',
 '▁to',
 '▁of',
 '▁a',
 '▁and',
 '▁in',
 '▁',
 's',
 '0',
 '-',
 '1',
 '▁that',
 '’',
 '▁is',
 '▁for',
 '▁on',
 '2',
 "'",
 '▁with',
 '▁was',
 '<s>',
 '▁it',
 'ing',
 '▁as',
 '▁at',
 '▁be',
 '▁by',
 '▁I',
 'ed',
 '▁from',
 '▁are',
 ':',
 '▁has',
 '▁have',
 '▁The',
 '▁(',
 'The',
 '▁an',
 '5',
 '3',
 '▁this',
 '▁you',
 '▁his',
 '▁he',
 '▁“',
 '9',
 '4',
 '▁not',
 't',
 '▁said',
 '▁will',
 'ers',
 '6',
 ')',
 '▁or',
 '▁their',
 '▁who',
 '▁but',
 '7',
 '▁we',
 '▁they',
 '▁"',
 '▁been',
 '8',
 '▁about',
 '▁out',
 '▁one',
 '▁more',
 '▁can',
 'es',
 'y',
 '▁which',
 '"',
 '▁were',
 '▁up',
 '▁had',
 '▁all',
 '▁its',
 '▁new',
 '▁A',
 '▁over',
 '▁would',
 '▁after',
 'er',
 'ly',
 '▁first',
 'in',
 '▁when',
 'S',
 '▁people',
 '▁than',
 '▁T',
 '▁into',
 '?',
 '▁what',
 '▁time',
 '▁her',
 'ic',
 '”',
 '/',
 '▁some',
 '▁so',
 '▁also',
 'al',
 '▁two',
 '▁C',
 'm',
 'I',
 '▁like',
 'A',
 '▁just',
 '▁there',
 '▁your',
 '▁our',
 '▁It',
 '▁—',
 'day',
 '▁other',
 '▁year',

#### Dataset with `[X,y]` with `X=h(t)` and `y=token(t)`

#### Create the dataset

In [None]:
model

In [7]:
# to get the feature vector before the final unembedding layer
model.lm_head = torch.nn.Identity() # llama
model.embed_out = torch.nn.Identity() #pythia

In [8]:
def filter_length(example):
    return len(tokenizer(example["text"])['input_ids']) >= 256

def encode_text(example):
    prompt = tokenizer(
        example["text"], truncation=True, max_length=256, return_tensors="pt" )
    example['input_ids'] = prompt['input_ids']
    return example

In [9]:
dataset = load_dataset("Skylion007/openwebtext",
                       split="train", streaming=True, trust_remote_code=True)
dataset = dataset.filter(filter_length)
dataset = dataset.map(encode_text, batched=True)
dataset = dataset.shuffle(seed=42)
subset = dataset.skip(start*batch_size)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

In [10]:
if compute_freq:
    # Get token count and probability
    count_dataloader = torch.utils.data.DataLoader(dataset, batch_size=512)
    token_count = torch.zeros(tokenizer.vocab_size).to(device)
    pbar = tqdm(total=1000)

    for batch in count_dataloader:
        tokens = batch["input_ids"].to(device).flatten()
        token_count.scatter_add_(0, tokens, torch.ones_like(tokens, dtype=torch.float32))
        pbar.update(1)
        if pbar.n >= 1000:
            break  

    pbar.close()
    torch.save(token_count, "saved_models/openwebtext_token_freq.pt")
else:
    token_count = torch.load("saved_models/openwebtext_token_freq.pt").to(device)

In [11]:
token_probs = token_count / token_count.sum()
prob_threshold = torch.quantile(token_probs, 0.8).to(device)  # Filter the top 20% most frequent tokens

In [None]:
model_suffix = model_name.split("/")[-1]

# Define save directory and file paths
save_dir = "data"
os.makedirs(save_dir, exist_ok=True)

x_save_path = os.path.join(save_dir, f"X_dataset_{model_suffix}")
y_save_path = os.path.join(save_dir, f"y_dataset_{model_suffix}")

# Initialize lists for accumulating data
X_all = []
y_all = []

# Define save interval (e.g., every 10 batches)
file_index = start//save_interval

pbar = tqdm(total=total)
i = 0
alpha = 1

for batch in dataloader:
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        output = model(input_ids).logits

        outputs = einops.rearrange(
            output, "batch pos hdim -> (batch pos) hdim")
        inputs = einops.rearrange(input_ids, "batch pos -> (batch pos)")

        # Sample down based on token probability
        # Fetch probabilities for sampled tokens
        token_probs_batch = token_probs[inputs]
        rand_vals = torch.rand_like(
            token_probs_batch)  # Generate random values

        # Compute adjusted drop probability for high-probability tokens
        adjusted_drop_prob = alpha * token_probs_batch  # Reduce drop probability

        # Masking logic:
        mask = (token_probs_batch < prob_threshold) | (
            rand_vals > adjusted_drop_prob)

        X_all.append(outputs[mask].detach().cpu())  
        y_all.append(inputs[mask].detach().cpu())

    # Save periodically to avoid memory overload
    if (i + 1) % save_interval == 0:
        if X_all:
            # Convert list of tensors to a single tensor
            X_vec = torch.cat(X_all, dim=0)
            y_vec = torch.cat(y_all, dim=0)

            # Save updated tensors
            torch.save(X_vec, f'{x_save_path}_{file_index}.pt')
            torch.save(y_vec, f'{y_save_path}_{file_index}.pt')
            file_index += 1
            # Clear memory after saving
            X_all = []
            y_all = []

    pbar.update(1)
    i += 1
    if i >= total:
        break

# Final save for remaining data
if X_all:
    X_vec = torch.cat(X_all, dim=0)
    y_vec = torch.cat(y_all, dim=0)

    torch.save(X_vec, f"{x_save_path}_{file_index}.pt")
    torch.save(y_vec, f"{y_save_path}_{file_index}.pt")

# Close the progress bar
pbar.close()
print(
    f"Saved hidden states to {x_save_path} and token inputs to {y_save_path}")