In [None]:
import pickle

import pandas as pd

# Load parquet file into pandas dataframe
df_latent_inf = pd.read_parquet(
    "/Users/sidbaskaran/Desktop/research/HyperDAS/axbench/axbench/concept10/prod_2b_l10_v1/inference/latent_eval_data.parquet"
)
df_generate_train_data = pd.read_parquet(
    "/Users/sidbaskaran/Desktop/research/HyperDAS/axbench/axbench/concept10/prod_2b_l10_v1/generate/train_data.parquet"
)

# Load pickle file containing inference state
with open(
    "/Users/sidbaskaran/Desktop/research/HyperDAS/axbench/axbench/concept10/prod_2b_l10_v1/inference/latent_inference_state.pkl",
    "rb",
) as f:
    inference_state = pickle.load(f)

In [3]:
import pandas as pd
# Load and print columns of latent data parquet file
latent_df = pd.read_parquet("/workspace/HyperDAS/assets/checkpoints/gemma2b_hyperlsreft_concept16k_steer_20250304_233644/final_model/inference/latent_data.parquet")
print("Columns in latent_data.parquet:")
print(latent_df.columns)


Columns in latent_data.parquet:
Index(['input', 'output', 'output_concept', 'concept_genre', 'category',
       'dataset_category', 'concept_id', 'sae_link', 'sae_id',
       'LsReFT_max_act', 'LsReFT_detection_scores', 'tokens'],
      dtype='object')


In [None]:
from datasets import load_dataset

ds = load_dataset("pyvene/axbench-concept500")

In [None]:
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    torch_dtype=torch.bfloat16
)

print(model)

In [4]:
import torch

# ckpt = torch.load("/workspace/HyperDAS/axbench/axbench/results/prod_2b_l10_concept500_lsreft/train/LsReFT_weight.pt")
# ckpt = torch.load("/workspace/HyperDAS/axbench/axbench/results/prod_2b_l10_concept500_lsreft/train/LsReFT_bias.pt")
ckpt = torch.load("/workspace/HyperDAS/axbench/axbench/results/prod_2b_l20_concept16k_lsreft/train/rank_0_LsReFT_weight.pt")

In [None]:
ckpt.shape

In [2]:
import pandas as pd

cache = pd.read_parquet(
    "/workspace/HyperDAS/assets/data/axbench/test_concept10/steering_data_cache_ac76e41206e0bfaa52c8a10335957ac418a3cd64147187b381a69e5673074e9a.parquet"
)

In [None]:
df = pd.read_parquet(
    "/workspace/HyperDAS/assets/data/axbench/test_concept10/inference/steering_data.parquet"
)
concept_data = {}

for concept_id, group in df.groupby("concept_id"):
    print(len(group))
    if concept_id not in concept_data:
        concept_data[concept_id] = []
    concept_data[concept_id].append(group)

for concept_id in sorted(concept_data.keys()):
    print([len(x) for x in concept_data[concept_id]])

In [None]:
print("Side by side comparison of HyperReFT and PromptSteering generations:\n")
for i in range(10):
    print(f"\nExample {i+1}:")
    print(f"HyperReFT: {df['HyperReFT_steered_generation'].iloc[i]}")
    print(f"PromptSteering: {df['PromptSteering_steered_generation'].iloc[i]}")
    print("-" * 80)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_from_disk
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

ds = load_from_disk("./data/concept500_2b_l10")

# Calculate sequence lengths using tokenizer
sequence_lengths = [len(tokenizer.encode(text)) for text in ds["train"]["input"]]

# Create histogram using pandas
plt.figure(figsize=(10, 6))
plt.hist(sequence_lengths, bins=50, edgecolor="black")
plt.title("Distribution of Input Sequence Lengths")
plt.xlabel("Sequence Length")
plt.ylabel("Count")
plt.show()

# Print some statistics
print(f"Mean sequence length: {np.mean(sequence_lengths):.2f}")
print(f"Median sequence length: {np.median(sequence_lengths):.2f}")
print(f"Max sequence length: {max(sequence_lengths)}")
print(f"Min sequence length: {min(sequence_lengths)}")

In [None]:
ds

In [None]:
import gc
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass

import psutil
import torch
import torch.nn as nn


@contextmanager
def init_on_device(device):
    """Context manager that forces model initialization directly on the specified device."""
    original_device = torch.empty(1).device
    torch.set_default_device(device)
    try:
        yield
    finally:
        torch.set_default_device(original_device)


@contextmanager
def timer(description: str):
    """Context manager for timing code blocks"""
    start = time.perf_counter()
    yield
    elapsed = time.perf_counter() - start
    print(f"{description}: {elapsed:.4f} seconds")


def get_memory_usage():
    """Get current memory usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 * 1024 * 1024)  # Convert to GB


@dataclass
class LargeModelConfig:
    hidden_size: int = 4096
    num_layers: int = 32
    num_attention_heads: int = 32
    intermediate_size: int = 11008


class LargeModule(nn.Module):
    """A large module to test initialization speeds"""

    def __init__(self, config: LargeModelConfig):
        super().__init__()

        # Create some substantial layers
        self.layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(config.hidden_size, config.intermediate_size),
                    nn.LayerNorm(config.intermediate_size),
                    nn.Linear(config.intermediate_size, config.hidden_size),
                    nn.LayerNorm(config.hidden_size),
                    nn.MultiheadAttention(
                        config.hidden_size, config.num_attention_heads, batch_first=True
                    ),
                )
                for _ in range(config.num_layers)
            ]
        )


# Test configurations
configs = [
    LargeModelConfig(hidden_size=1024, num_layers=8),  # Small
    LargeModelConfig(hidden_size=2048, num_layers=16),  # Medium
    LargeModelConfig(hidden_size=4096, num_layers=32),  # Large
]


def test_config(config: LargeModelConfig, device: str = "cuda"):
    """Test a single configuration"""

    def clear_memory():
        gc.collect()
        torch.cuda.empty_cache()

    results = {}

    # Test regular initialization
    clear_memory()
    start_mem = get_memory_usage()

    with timer("Regular init") as t:
        model = LargeModule(config)
        model = model.to(device)

    end_mem = get_memory_usage()
    results["regular"] = {
        "time": t.elapsed if hasattr(t, "elapsed") else 0,
        "memory": end_mem - start_mem,
    }

    del model
    clear_memory()

    # Test device context initialization
    start_mem = get_memory_usage()

    with timer("Device context init") as t:
        with init_on_device(device):
            model = LargeModule(config)

    end_mem = get_memory_usage()
    results["context"] = {
        "time": t.elapsed if hasattr(t, "elapsed") else 0,
        "memory": end_mem - start_mem,
    }

    del model
    clear_memory()

    return results


# Run tests for a specific config
config = configs[1]  # Try the medium config
print(
    f"Testing config with hidden_size={config.hidden_size}, layers={config.num_layers}"
)
results = test_config(config)

# Print results
print("\nResults:")
print(
    f"Regular init: {results['regular']['time']:.4f}s, {results['regular']['memory']:.2f}GB"
)
print(
    f"Context init: {results['context']['time']:.4f}s, {results['context']['memory']:.2f}GB"
)

In [None]:
import pandas as pd

df = pd.read_parquet("../../assets/data/axbench/inference/steering_data.parquet")

# Print column names before renaming
print("Original columns:", df.columns.tolist())

# Check if the columns exist before renaming
df = df.rename(
    columns={
        "HyperReFT_steered_perplexity": "HyperReFT_perplexity",
        "PromptSteering_steered_perplexity": "PromptSteering_perplexity",
    }
)
df.to_parquet("../../assets/data/axbench/inference/steering_data.parquet")

In [25]:
df = pd.read_parquet("../../assets/data/axbench/inference/steering_data.parquet")

In [None]:
df

In [11]:
import json

with open(
    "/Users/sidbaskaran/Desktop/research/HyperDAS/assets/data/gemma-2-2b_10-gemmascope-res-16k.json",
    "r",
) as f:
    data = json.load(f)

In [None]:
data

In [None]:
inference_state

In [None]:
df_generate_train_data