In [None]:
import os
import glob
from datasets import load_from_disk, concatenate_datasets, Features, Value, Array2D
from huggingface_hub import HfApi
import numpy as np
import logging
from tqdm.auto import tqdm # Use tqdm.auto for notebook compatibility

In [None]:
# Configure logging for clarity
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Configuration ---
LOCAL_CHUNKS_DIR = "temp_processed" # Directory where your chunks are saved
HF_REPO_ID = "mksethi/gpt2_eli5_sae_features" # Replace with your Hugging Face username and desired dataset name
# Example: "yourusername/kilt_eli5_sae_features_gpt2_res_jb"

# Define the expected features. This is crucial for consistency.
# Ensure these match the dtypes and shapes you saved in generate.py
# For gpt2-small-res-jb, the resid_pre dim is 768.
# The SAE latent dimension (d_sae) for this specific SAE is 24576.
SAE_FEATURES_DIM = 24576 # Corrected based on confirmation for gpt2-small-res-jb SAEs

expected_features = Features({
    "input_ids": Array2D(shape=(256,), dtype="int32"), # CONTEXT_SIZE
    "attention_mask": Array2D(shape=(256,), dtype="int32"), # CONTEXT_SIZE
    "sae_features": Array2D(shape=(SAE_FEATURES_DIM,), dtype="float32")
})

# ... rest of the aggregation and push code ...

In [None]:

def aggregate_and_push_chunks(local_dir, repo_id, features):
    """
    Aggregates data chunks from a local directory and pushes them to Hugging Face Hub.
    Includes enhanced logging for better progress visibility.
    """
    logger.info(f"Starting aggregation process for chunks in: {local_dir}")

    # 1. Discover all chunk directories
    chunk_paths = sorted(glob.glob(os.path.join(local_dir, "rank_*")))
    
    if not chunk_paths:
        logger.warning(f"No chunks found in {local_dir}. Please ensure generate.py has run and created files.")
        return

    logger.info(f"Found {len(chunk_paths)} potential chunk directories.")
    
    all_datasets = []
    
    # Use tqdm to show progress of loading chunks
    for i, path in enumerate(tqdm(chunk_paths, desc="Loading chunks")):
        try:
            # Load each chunk as a Dataset
            dataset_chunk = load_from_disk(path)
            # Optional: Cast here if you suspect schema inconsistencies between chunks
            # dataset_chunk = dataset_chunk.cast(features)
            all_datasets.append(dataset_chunk)
            # logger.debug(f"Successfully loaded chunk from {path} with {len(dataset_chunk)} examples.") # Use debug for per-chunk if too verbose
        except Exception as e:
            logger.error(f"Error loading chunk from {path}: {e}. Skipping this chunk.", exc_info=True)
            continue

    if not all_datasets:
        logger.error("No valid datasets were loaded. Aborting push to Hugging Face.")
        return

    # 2. Concatenate all datasets
    logger.info(f"Concatenating {len(all_datasets)} datasets. This may take a while for large datasets...")
    try:
        final_dataset = concatenate_datasets(all_datasets)
        logger.info(f"Final aggregated dataset created with {len(final_dataset)} examples.")
        
        # Add a log for the estimated size
        # This is a rough estimate but can give a clue.
        # Bytes per example (rough guess, depends on content)
        # Assuming ~500-1000 bytes per example for input_ids, attn_mask, and 24576 float32s (24576*4 bytes = ~96KB)
        # So, it's dominated by sae_features. Let's estimate 100KB per example.
        estimated_size_bytes = len(final_dataset) * (SAE_FEATURES_DIM * 4 + 256 * 4 * 2) # sae_features + input_ids + attention_mask
        estimated_size_gb = estimated_size_bytes / (1024**3)
        logger.info(f"Estimated raw dataset size: {estimated_size_gb:.2f} GB (before compression).")

    except Exception as e:
        logger.error(f"Error concatenating datasets: {e}. Aborting push.", exc_info=True)
        return

    # 3. Push to Hugging Face Hub
    logger.info(f"Attempting to push aggregated dataset to Hugging Face Hub: {repo_id}")
    
# use env variable to get token
    if token is None:
        logger.warning("Hugging Face token not found. Please log in using `huggingface-cli login` in your terminal or set HF_TOKEN environment variable.")
        logger.warning("You might not be able to push to the Hub without a valid write token.")
        # We'll let the push_to_hub method raise the authentication error if it fails.

    api = HfApi()

    try:
        logger.info(f"Checking if repository '{repo_id}' exists and creating if necessary (private=False).")
        api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=False) # Set private=True if desired
        logger.info(f"Repository {repo_id} ensured to exist. Starting data push...")
        
        # Push the dataset with a more visible progress bar
        # datasets.push_to_hub() has its own progress bar, but we can add a log before/after
        final_dataset.push_to_hub(repo_id, private=False) # show_progress=True is key
        
        logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{repo_id}")
        logger.info("Push complete! Check your Hugging Face profile for the dataset.")
        
    except Exception as e:
        logger.error(f"Error pushing dataset to Hugging Face Hub: {e}", exc_info=True)
        logger.error("Please ensure you have authenticated with `huggingface-cli login` or set a HF_TOKEN environment variable with write access to the namespace.")


In [None]:
aggregate_and_push_chunks("temp_processed", "mksethi/gpt2_eli5_sae_features", expected_features)