In [5]:
import torch
from datasets import load_from_disk, Dataset, DatasetDict
from tqdm import tqdm
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def logit_stats_gpu(logits, k=10, is_softmax=False):
    # logits: [B, V], torch.tensor on cuda
    probs = torch.softmax(logits, dim=-1) if not is_softmax else logits
    entropy = -(probs * (probs + 1e-12).log()).sum(dim=-1)
    max_prob, _ = probs.max(dim=-1)
    mean = probs.mean(dim=-1)
    std = probs.std(dim=-1)
    topk_probs, _ = probs.topk(k, dim=-1)
    return entropy, max_prob, mean, std, topk_probs

def batch_extract_features(lh, eh, k=10):
    entropy_lh, max_lh, mean_lh, std_lh, topk_lh = logit_stats_gpu(lh, k)
    entropy_eh, max_eh, mean_eh, std_eh, topk_eh = logit_stats_gpu(eh, k)
    features = torch.cat([
        entropy_lh.unsqueeze(1), max_lh.unsqueeze(1), mean_lh.unsqueeze(1), std_lh.unsqueeze(1), topk_lh,
        entropy_eh.unsqueeze(1), max_eh.unsqueeze(1), mean_eh.unsqueeze(1), std_eh.unsqueeze(1), topk_eh
    ], dim=1)
    return features

def process_split(ds, k=10, batch_size=1024, last_logit_col="last_logit", egale_logit_col="egale_1st_forward_logit", accept_length_col="accept_length"):
    ds.set_format(type="torch", columns=[last_logit_col, egale_logit_col, accept_length_col])
    all_features = []
    all_accept_length = []
    for start in tqdm(range(0, len(ds), batch_size)):
        batch = ds[start : start + batch_size]
        lh = batch[last_logit_col].to("cuda")
        eh = batch[egale_logit_col].to("cuda")
        features = batch_extract_features(lh, eh, k=k)
        all_features.append(features.cpu())
        all_accept_length.extend(batch[accept_length_col].tolist())
    all_features = torch.cat(all_features, dim=0)
    new_dataset = Dataset.from_dict({
        "features": all_features.tolist(),
        "accept_length": all_accept_length
    })
    return new_dataset

def main():
    input_dataset_dir = "../data/mt-bench-llama3-d13-topk10-t0"
    output_dir = "../data/mt-bench-llama3-d13-topk10-t0-cal"
    os.makedirs(output_dir, exist_ok=True)
    last_logit_col = "last_logit"
    egale_logit_col = "egale_1st_forward_logit"
    accept_length_col = "accept_length"
    k = 10
    batch_size = 1024

    dataset_dict = load_from_disk(input_dataset_dir)
    for split, ds in dataset_dict.items():
        print(f"Processing split: {split}")
        new_ds = process_split(ds, k=k, batch_size=batch_size, last_logit_col=last_logit_col, egale_logit_col=egale_logit_col, accept_length_col=accept_length_col)
        out_json = os.path.join(output_dir, f"dataset_{split}.json")
        new_ds.to_json(out_json)
        print(f"Saved {split} split to {out_json}")

In [4]:
main()

Processing split: train


100%|██████████| 7/7 [00:04<00:00,  1.58it/s]
Creating json from Arrow format: 100%|██████████| 7/7 [00:00<00:00, 54.71ba/s]


Saved train split to ../data/mt-bench-llama3-d13-topk10-t0-cal/dataset_train.json
Processing split: test


100%|██████████| 2/2 [00:00<00:00,  2.80it/s]
Creating json from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 157.21ba/s]

Saved test split to ../data/mt-bench-llama3-d13-topk10-t0-cal/dataset_test.json





In [6]:
def main1():
    input_dataset_dir = "../data/mt-bench-llama3-d13-topk10-t0"
    output_dir = "../data/mt-bench-llama3-d13-topk10-t0-cal"
    os.makedirs(output_dir, exist_ok=True)
    last_logit_col = "last_logit"
    egale_logit_col = "egale_1st_forward_logit"
    accept_length_col = "accept_length"
    k = 10
    batch_size = 1024

    dataset_dict = load_from_disk(input_dataset_dir)
    processed_splits = {}  # 保存所有 split 的新数据集

    for split, ds in dataset_dict.items():
        print(f"Processing split: {split}")
        new_ds = process_split(
            ds, k=k, batch_size=batch_size,
            last_logit_col=last_logit_col,
            egale_logit_col=egale_logit_col,
            accept_length_col=accept_length_col
        )
        processed_splits[split] = new_ds
        # 如需单独保存json，可加下面两行：
        # out_json = os.path.join(output_dir, f"dataset_{split}.json")
        # new_ds.to_json(out_json)

    # 保存为Arrow格式（DatasetDict），可直接用load_from_disk读取
    output_arrow_dir = os.path.join(output_dir, "arrow")
    DatasetDict(processed_splits).save_to_disk(output_arrow_dir)
    print(f"All splits saved to: {output_arrow_dir}")

In [7]:
main1()

Processing split: train


100%|██████████| 7/7 [00:02<00:00,  2.54it/s]


Processing split: test


100%|██████████| 2/2 [00:00<00:00,  2.87it/s]
Saving the dataset (1/1 shards): 100%|██████████| 6402/6402 [00:00<00:00, 679985.17 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1601/1601 [00:00<00:00, 287091.95 examples/s]

All splits saved to: ../data/mt-bench-llama3-d13-topk10-t0-cal/arrow



