In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [2]:

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm, trange
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt

import numpy as np
import glob
torch.manual_seed(42)

<torch._C.Generator at 0x7e117eb66590>

In [3]:
model_name = "meta-llama/Llama-2-7b-hf"

balance = False

In [4]:
device = torch.device("cuda")
print("Using device: ", device)

Using device:  cuda


In [5]:
model_suffix = model_name.split("/")[-1]
X_files = sorted(glob.glob(f"data/X_dataset_{model_suffix}_*.pt"))
y_files = sorted(glob.glob(f"data/y_dataset_{model_suffix}_*.pt"))
token_count = torch.load("saved_models/openwebtext_token_freq.pt", map_location='cpu')


In [6]:
k = 1000
top_k_tokens = token_count.argsort(descending=True)[:k]



In [7]:
balanced_chunks_X = []
balanced_chunks_y = []

for chunk_idx, (X_file, y_file) in enumerate(zip(X_files, y_files)):
    # Save each balanced chunk to disk immediately
    balanced_X_path = f"data/X_balanced_{model_suffix}_{chunk_idx}.pt"
    balanced_y_path = f"data/y_balanced_{model_suffix}_{chunk_idx}.pt"

    if balance:
        # Load current chunk
        X = torch.load(X_file, map_location='cpu')
        y = torch.load(y_file, map_location='cpu')

        # Mask to select only top_k_tokens
        mask = torch.tensor(
            [label in top_k_tokens for label in y], dtype=torch.bool)
        X_filtered = X[mask]
        y_filtered = y[mask]

        if len(y_filtered) == 0:
            continue  # Skip empty chunks

        # Count occurrences of each class
        unique_classes, counts = torch.unique(y_filtered, return_counts=True)

        # Compute the target sample size (undersampling to min class count)
        min_count = torch.min(counts).item()

        # Efficient undersampling with direct indexing
        sampled_indices = []
        for cls in unique_classes.tolist():
            class_indices = (y_filtered == cls).nonzero(as_tuple=True)[0]

            # Use efficient sampling instead of random permutation
            if len(class_indices) > min_count:
                sampled_indices.extend(class_indices[torch.randperm(
                    len(class_indices))[:min_count]].tolist())
            else:
                sampled_indices.extend(class_indices.tolist())

        # Convert to tensor for indexing
        sampled_indices = torch.tensor(sampled_indices, dtype=torch.long)

        # Apply undersampling
        X_balanced = X_filtered[sampled_indices]
        y_balanced = y_filtered[sampled_indices]

        torch.save(X_balanced, balanced_X_path)
        torch.save(y_balanced, balanced_y_path)
        print(
            f"Processed chunk {chunk_idx}: Saved {len(y_balanced)} balanced samples.")

    balanced_chunks_X.append(balanced_X_path)
    balanced_chunks_y.append(balanced_y_path)

In [8]:
# Combine all balanced chunks
X_balanced_all = []
y_balanced_all = []

for X_balanced_path, y_balanced_path in zip(balanced_chunks_X, balanced_chunks_y):
    X_balanced_all.append(torch.load(X_balanced_path))
    y_balanced_all.append(torch.load(y_balanced_path))

X_final = torch.cat(X_balanced_all, dim=0)
y_final = torch.cat(y_balanced_all, dim=0)

In [9]:
# Get final class distribution
classes, class_counts = torch.unique(y_final, return_counts=True)

mean_class_count = class_counts.float().mean().item()
std_class_count = class_counts.float().std().item()

print(f"Mean class count: {mean_class_count}")
print(f"Std class count: {std_class_count}")

Final class distribution: tensor([2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362, 2362,
        2362, 

In [10]:
# Save final dataset
torch.save(X_final, f"data/X_dataset_{model_suffix}_top{k}.pt")
torch.save(y_final, f"data/y_dataset_{model_suffix}_top{k}.pt")

: 