In [1]:
import os
import numpy as np
import pandas as pd

from argparse import Namespace
from datasets import load_from_disk

In [2]:
config = {
    "seed": 42,
    "dataset": "/data3/mmendieta/Violence_data//geo_corpus.0.0.1_datasets_hidden_xlmt",
    "num_samples": 1000,
    "num_labels": 6,
    "fout": "/data4/mmendieta/data/sample_hidden_xlmt"
}

args = Namespace(**config)

In [3]:
# Load dataset
dataset_path = args.dataset
ds = load_from_disk(dataset_path)
train_ds = ds["train"]

In [4]:
# Prepare numpy arrays
hidden_states = np.array(train_ds["hidden_state"])
labels = np.array(train_ds["labels"])

In [5]:
samples = []

# Sample up to num_samples for each label and binary value
for label_idx in range(args.num_labels):  # num labels
    for binary_val in [0, 1]:
        # Get indices where the value at label_idx is binary_val
        matching_indices = [i for i, row in enumerate(labels) if row[label_idx] == binary_val]
        if len(matching_indices) == 0:
            continue  # Skip if no samples found
        
        # Sample up to 1000 indices
        sampled_indices = np.random.choice(matching_indices, 
                                           size=min(args.num_samples, len(matching_indices)), 
                                           replace=False)
        
        # Collect samples
        for i in sampled_indices:
            samples.append({
                "hidden_state": hidden_states[i],
                "label_idx": label_idx,
                "label_value": binary_val
            })

In [6]:
# Create DataFrame
sample_df = pd.DataFrame(samples)

In [7]:
# Save the sampled DataFrame to the specified file
sample_df.to_pickle(args.fout)

print(f"[INFO] Saved {len(sample_df)} samples to '{args.fout}'")

[INFO] Saved 12000 samples to '/data4/mmendieta/data/sample_hidden_xlmt'
