Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ vis/
clt_test_pythia_70m_jumprelu/
clt_smoke_output_local_wandb_batchtopk/
clt_smoke_output_remote_wandb/
wandb/

# models
*.pt
Expand Down
45 changes: 23 additions & 22 deletions clt/activation_generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
from __future__ import annotations

import os
import time
import json
import queue
import random
import logging
import threading
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from typing import Dict, List, Optional, Tuple, Any, DefaultDict
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
Expand All @@ -36,7 +35,7 @@
from clt.config.data_config import ActivationConfig # noqa: E402

# --- Profiling Imports ---
import time # Already imported, but good to note
import time # Keep this one
from contextlib import contextmanager
from collections import defaultdict
import psutil
Expand All @@ -56,8 +55,8 @@
# --- Performance Profiler Class ---
class PerformanceProfiler:
def __init__(self, chunk_tokens_threshold: int = 1_000_000):
self.timings = defaultdict(list)
self.memory_snapshots = []
self.timings: DefaultDict[str, List[float]] = defaultdict(list)
self.memory_snapshots: List[Dict[str, Any]] = []
self.chunk_tokens_threshold = chunk_tokens_threshold
self.system_metrics_log: List[Dict[str, Any]] = []
self.layer_ids_ref: Optional[List[int]] = None
Expand Down Expand Up @@ -141,7 +140,7 @@ def log_system_metrics(self, interval_name: str = "interval"):
return metrics

def report(self):
print("\n=== Performance Report ===")
logger.info("\n=== Performance Report ===")
# Sort by total time descending for timings
sorted_timings = sorted(self.timings.items(), key=lambda item: sum(item[1]), reverse=True)

Expand All @@ -153,15 +152,17 @@ def report(self):
min_time = min(times)
max_time = max(times)

print(f"\n--- Operation: {name} ---")
print(f" Count: {len(times)}")
print(f" Total time: {total_time:.3f}s")
print(f" Avg time: {avg_time:.4f}s")
print(f" Min time: {min_time:.4f}s")
print(f" Max time: {max_time:.4f}s")
logger.info(f"\n--- Operation: {name} ---")
logger.info(f" Count: {len(times)}")
logger.info(f" Total time: {total_time:.3f}s")
logger.info(f" Avg time: {avg_time:.4f}s")
logger.info(f" Min time: {min_time:.4f}s")
logger.info(f" Max time: {max_time:.4f}s")

if "chunk_write_total_idx" in name: # New unique name per chunk
print(f" Avg ms/k-tok (for this chunk): {avg_time / self.chunk_tokens_threshold * 1000 * 1000:.2f}")
logger.info(
f" Avg ms/k-tok (for this chunk): {avg_time / self.chunk_tokens_threshold * 1000 * 1000:.2f}"
)
elif (
name == "batch_processing_total"
and self.batch_processing_total_calls > 0
Expand All @@ -171,29 +172,29 @@ def report(self):
self.total_tokens_processed_for_batch_profiling / self.batch_processing_total_calls
)
if avg_tok_per_batch_call > 0:
print(
logger.info(
f" Avg ms/k-tok (estimated for batch_processing_total): {avg_time / avg_tok_per_batch_call * 1000 * 1000:.2f}"
)

print("\n=== Memory Snapshots (showing top 10 by RSS delta) ===")
logger.info("\n=== Memory Snapshots (showing top 10 by RSS delta) ===")
interesting_mem_snapshots = sorted(
self.memory_snapshots, key=lambda x: abs(x["rss_delta_bytes"]), reverse=True
)[:10]
for snap in interesting_mem_snapshots:
print(
logger.info(
f" {snap['name']} (took {snap['duration_s']:.3f}s): Total RSS {snap['rss_total_bytes'] / (1024**3):.3f} GB (ΔRSS {snap['rss_delta_bytes'] / (1024**3):.3f} GB)"
)

print("\n=== System Metrics Log (sample) ===")
logger.info("\n=== System Metrics Log (sample) ===")
for i, metrics in enumerate(self.system_metrics_log[:5]): # Print first 5 samples
print(
logger.info(
f" Sample {i} ({metrics['interval_name']}): CPU {metrics['cpu_percent']:.1f}%, Mem {metrics['memory_percent']:.1f}%, GPU {metrics['gpu_util_percent']:.1f}% (Mem {metrics['gpu_memory_percent']:.1f}%)"
)
if len(self.system_metrics_log) > 5:
print(" ...")
logger.info(" ...")
if self.system_metrics_log: # Check if not empty before accessing last element
metrics = self.system_metrics_log[-1]
print(
logger.info(
f" Sample End ({metrics['interval_name']}): CPU {metrics['cpu_percent']:.1f}%, Mem {metrics['memory_percent']:.1f}%, GPU {metrics['gpu_util_percent']:.1f}% (Mem {metrics['gpu_memory_percent']:.1f}%)"
)

Expand Down Expand Up @@ -286,7 +287,7 @@ def _async_uploader(upload_q: "queue.Queue[Optional[Path]]", cfg: ActivationConf
# --> ADDED: Retry Loop <--
for attempt in range(max_retries_per_chunk):
try:
print(
logger.info(
f"[Uploader Thread Attempt {attempt + 1}/{max_retries_per_chunk}] Uploading chunk: {p.name} to {url}"
)
with open(p, "rb") as f:
Expand Down Expand Up @@ -970,7 +971,7 @@ def _upload_binary_file(self, path: Path, endpoint: str):
try:
activation_config_instance = ActivationConfig(**loaded_config)
except TypeError as e:
print(f"Error creating ActivationConfig from YAML. Ensure all keys are correct: {e}")
logger.error(f"Error creating ActivationConfig from YAML. Ensure all keys are correct: {e}")
import sys

sys.exit(1)
Expand Down
21 changes: 20 additions & 1 deletion clt/config/data_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass, field
from typing import Literal, Optional, Dict, Any
import logging

logger = logging.getLogger(__name__)


@dataclass
Expand Down Expand Up @@ -74,7 +77,23 @@ def __post_init__(self):
except ImportError:
raise ImportError("h5py is required for HDF5 output format. Install with: pip install h5py")
if self.compression not in ["lz4", "gzip", None, False]:
print(
logger.warning(
f"Warning: Unsupported compression '{self.compression}'. Will attempt without compression for {self.output_format}."
)
# Allow generator to handle disabling if format doesn't support it.

# Example: Print a summary or key values
# This is more for user feedback than programmatic use.
logger.info(
"ActivationConfig Summary:\n"
f" Model: {self.model_name}\n"
f" Dataset: {self.dataset_path} (Split: {self.dataset_split})\n"
f" Target Tokens: {self.target_total_tokens}\n"
f" Chunk Threshold: {self.chunk_token_threshold}\n"
f" Activation Dtype: {self.activation_dtype}\n"
f" Output Dir: {self.activation_dir}"
)
if self.remote_server_url:
logger.info(f" Remote Server URL: {self.remote_server_url}")
if self.delete_after_upload:
logger.info(" Delete after upload: Enabled")
27 changes: 27 additions & 0 deletions clt/models/clt.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,30 @@ def convert_to_jumprelu_inplace(self, default_theta_value: float = 1e6) -> None:
logger.info(
f"Rank {self.rank}: CLT model config updated by ThetaManager. New activation_fn='{self.config.activation_fn}'."
)

# --- Back-compat: expose ThetaManager.log_threshold at model level ---
@property
def log_threshold(self) -> Optional[torch.nn.Parameter]:
"""Proxy to ``theta_manager.log_threshold`` for backward compatibility.

Older training scripts, conversion utilities and tests referenced
``model.log_threshold`` directly. After the Step-5 refactor the
parameter migrated into the dedicated ``ThetaManager`` module. We
now expose a read-only view that always returns the *current* parameter
held by ``self.theta_manager``. Modifying the returned tensor (e.g.
in-place updates to ``.data``) therefore continues to work as before.
Assigning a brand-new ``nn.Parameter`` to ``model.log_threshold`` will
forward the assignment to ``theta_manager`` to preserve the linkage.
"""
if hasattr(self, "theta_manager") and self.theta_manager is not None:
return self.theta_manager.log_threshold
return None

@log_threshold.setter
def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None:
# Keep property writable so callers that used to assign a fresh
# parameter (rare) do not break. We delegate the storage to
# ``ThetaManager`` so there is a single source of truth.
if not hasattr(self, "theta_manager") or self.theta_manager is None:
raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.")
self.theta_manager.log_threshold = new_param
Loading
Loading