In [1]:
import json
import torch
import numpy as np
from datasets import load_from_disk

In [2]:
def subgroup_resample(dataset, rate: float, label_key: str):
    labels = np.array(dataset[label_key])

    sampled_indices = []
    for lbl in np.unique(labels):
        idxs = np.where(labels == lbl)[0]
        chosen = np.random.choice(idxs, size=int(idxs.size * rate), replace=False)
        sampled_indices.extend(chosen.tolist())

    return dataset.select(sampled_indices)

In [3]:
def convert_dataset_format(dataset_arrow_path: str, label_key: str, dataset_name: str, rate: float = 0.1):
    dataset = load_from_disk(dataset_arrow_path)
    dataset = subgroup_resample(dataset, rate, label_key)
    
    tokens = dataset['tokens']
    torch.save(tokens, f"./{dataset_name}_tokens.pt")
    
    labels = dataset[label_key]
    processed_samples = len(dataset) * [None]
    for i in range(len(dataset)):
        processed_samples[i] = {
            'subject': f'{i}'
            , 'object': labels[i]
        }

    dataset_dict = {
        'name': dataset_name
        , 'prompt_templates': []
        , 'samples': processed_samples
    }
    save_json = f"./{dataset_name}.json"
    with open(save_json, "w") as f:
        json.dump(dataset_dict, f)


In [4]:
dataset_arrow_path = r"C:\Users\97254\OneDrive\שולחן העבודה\Projects\llm-context-neurons\data\europarl_lang"
label_key = 'language'
dataset_name = 'europarl_lang'

convert_dataset_format(dataset_arrow_path, label_key, dataset_name)

In [5]:
dataset_arrow_path = r"C:\Users\97254\OneDrive\שולחן העבודה\Projects\llm-context-neurons\data\pile_data_source"
label_key = 'distribution'
dataset_name = 'pile_data_source'

convert_dataset_format(dataset_arrow_path, label_key, dataset_name)