In [None]:
import os
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from rich import print as rprint
from typing import Dict, List
from datasets import load_dataset
from torch.utils.data import TensorDataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from sentence_transformers import SentenceTransformer

torch.manual_seed(42)

<torch._C.Generator at 0x7251e35f69f0>

In [None]:
model_name = os.getenv("MODEL", "meta-llama/Llama-2-7b-hf")
total = int(os.getenv("TOTAL", 1000))
dataset_name = "Skylion007/openwebtext"
model_suffix = model_name.split("/")[-1]
dataset_suffix = dataset_name.split("/")[-1]
balance_interval = 25
batch_size = 128
max_seq_len = 256

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    device_map="auto")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
def filter_length(example):
    return len(tokenizer(example["text"])['input_ids']) >= max_seq_len

In [None]:
dataset = load_dataset(dataset_name,
                       split="train", streaming=True, trust_remote_code=True)
dataset = dataset.filter(filter_length)

dataset = dataset.shuffle(seed=42).take(total * batch_size)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

In [None]:
def balance_data(prefixes):
    prefixes_vec = torch.cat(prefixes, dim=0)
    # Find the last non-padding token for each sequence
    non_pad_mask = prefixes_vec != tokenizer.pad_token_id
    last_token_indices = non_pad_mask.sum(dim=1) - 1
    batch_size = prefixes_vec.size(0)
    token_ids_vec = prefixes_vec[torch.arange(batch_size), last_token_indices]

    S = len(token_ids_vec)
    unique_classes, class_counts = torch.unique(
        token_ids_vec, return_counts=True)
    num_classes = len(unique_classes)
    samples_per_class = max(1, S // num_classes)

    selected_indices = []
    for cls, count in zip(unique_classes, class_counts):
        class_indices = (token_ids_vec == cls).nonzero(as_tuple=True)[0]

        # Sample min(samples_per_class, available samples)
        sampled_indices = class_indices[torch.randperm(
            count)[:min(samples_per_class, count)]]
        selected_indices.append(sampled_indices)

    # Concatenate once instead of multiple times
    selected_indices = torch.cat(selected_indices)

    # Shuffle to ensure randomness
    shuffled_indices = selected_indices[torch.randperm(len(selected_indices))[
        :S]]

    return [prefixes_vec[shuffled_indices]]

In [None]:
batch_index = 0
all_prefixes = []

for batch in dataloader:
    text = batch["text"]
    input_ids = tokenizer(
        text, truncation=True, max_length=max_seq_len, return_tensors="pt")["input_ids"]

    # Convert input_ids [batch_size, seq_len] to a list of prefixes of varying lengths from 1 to seq_len
    prefixes = [input_ids[:, :i + 1].T for i in range(input_ids.shape[1])]
    
    # Add padding to the prefixes
    padded_prefixes = torch.nn.utils.rnn.pad_sequence(
        prefixes, padding_value=tokenizer.pad_token_id)
    
    # Reshape to [batch_size, seq_len]
    padded_prefixes = einops.rearrange(
        padded_prefixes, "seq_len seqs batch_size -> (batch_size seqs) seq_len")
    
    # Remove duplicates
    padded_prefixes = torch.unique(padded_prefixes, dim=0)

    all_prefixes.append(padded_prefixes)
    
    if batch_index % balance_interval == 0:
        all_prefixes = balance_data(all_prefixes)
        print(
            f"Batch {batch_index} - Data balanced. Size: {all_prefixes[0].size(0)}")

    batch_index += 1

# Convert the list of tensors to a single tensor
all_prefixes = torch.cat(all_prefixes, dim=0)

In [None]:
# Save the prefixes to a file
prefixes_path = f"data/{dataset_suffix}_{model_suffix}/prefixes.pt"

os.makedirs(os.path.dirname(prefixes_path), exist_ok=True)

torch.save(all_prefixes, prefixes_path)