In [1]:
import os
import yaml
import logging
import random
import sys
from typing import List

import torch
import transformers
# from transformers import AutoModelForCausalLM, set_seed

from dataclasses import dataclass, field
from datasets import load_dataset, load_from_disk, load_metric

In [2]:
# file = '../recipes/zephyr-7b-gemma/dpo/config_full.yaml'

# with open(file, 'r') as f:
#     args = yaml.safe_load(f)

# parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
# model_args, data_args, training_args = parser.parse_dict(args)

In [3]:
file = '/home/kenyo/workspace/LLM-Trainer/recipes/dataset_mixer.yaml'

with open(file, 'r') as f:
    args = yaml.safe_load(f)

dataset_mixer = args
dataset_mixer

{'dataset_mixer': ['argilla/dpo-mix-7k',
  {'argilla/dpo-mix-7k': {'split': {'train': 'train', 'test': 'test'},
    'chat_template': 'chatml'}},
  {'argilla/dpo-mix-7k': {'chat_template': 'chatml'}},
  {'/home/kenyo/workspace/LLM-Trainer/dataset/LukasSonn/DoxygenStrings-Long': {'split': {'train': 'train',
     'test': 'test'}}},
  {'/home/kenyo/workspace/LLM-Trainer/dataset/LukasSonn/DoxygenStrings-Long': {'split': {'train': 'train.json',
     'test': 'test.json'}}}]}

In [15]:
@dataclass
class DatasetSplit:
    train:str
    test:str

@dataclass
class DatasetGroupbySplit:
    train :list = field(default_factory=lambda: [])
    test  :list = field(default_factory=lambda: [])

class DatasetMixer:

    def __init__(self):
        self.dataset_groupby_split = DatasetGroupbySplit()


    def _load(self, dataset_name_or_path:str, split:dict=None):
        if split is None:
            dataset_split = DatasetSplit(train='train', test='test')
        elif isinstance(split, list):
            dataset_split = DatasetSplit(**{_split:_split for _split in split})
        else:
            dataset_split = DatasetSplit(**split)

        for key, _ in dataset_split.__dict__.items():
            split = getattr(dataset_split, key)
            dataset_group = getattr(self.dataset_groupby_split, key)

            if os.path.exists(dataset_name_or_path):

                if _ext:= os.path.splitext(split)[-1].lstrip('.'):
                    _file = os.path.join(dataset_name_or_path, split)
                    dataset_subgroup = load_dataset(_ext, data_files={key:_file})[key]
                    print(f'Successfully loaded dataset {dataset_name_or_path} from local through \'{_ext}\' format.')
                else:
                    dataset_subgroup = load_from_disk(os.path.join(dataset_name_or_path, split))
                    print(f'Successfully loaded dataset {dataset_name_or_path} from local through \'datasets\' format.')
            else:
                dataset_subgroup = load_dataset(dataset_name_or_path, split=split)
                print(f'Successfully loaded dataset {dataset_name_or_path} from HuggingFace Hub or ~/.cache.')

            dataset_group.append(dataset_subgroup)

        return self.dataset_groupby_split


    def load_and_mix(self, dataset_mixer:list):

        if isinstance(dataset_mixer, list):

            for config in dataset_mixer:

                if isinstance(config, str):
                    dataset_name_or_path = config

                    print(f'Load dataset {dataset_name_or_path} with default configurations.')
                    self._load(dataset_name_or_path)

                elif isinstance(config, dict):
                    dataset_name_or_path = list(config.keys())[0]
                    config = config[dataset_name_or_path]
                    split = config.get('split', None)

                    print(f'Load dataset {dataset_name_or_path} with following configurations.')
                    for key, value in config.items():
                        print(f'    - {key}: {value}')
                    self._load(dataset_name_or_path, split=split)

        return self.dataset_groupby_split

In [16]:
file = '/home/kenyo/workspace/LLM-Trainer/recipes/dataset_mixer.yaml'
with open(file, 'r') as f:
    args = yaml.safe_load(f)

dataset_config = args['dataset_mixer']
dataset_config

['argilla/dpo-mix-7k',
 {'argilla/dpo-mix-7k': {'split': {'train': 'train', 'test': 'test'},
   'chat_template': 'chatml'}},
 {'argilla/dpo-mix-7k': {'chat_template': 'chatml'}},
 {'/home/kenyo/workspace/LLM-Trainer/dataset/LukasSonn/DoxygenStrings-Long': {'split': {'train': 'train',
    'test': 'test'}}},
 {'/home/kenyo/workspace/LLM-Trainer/dataset/LukasSonn/DoxygenStrings-Long': {'split': {'train': 'train.json',
    'test': 'test.json'}}}]

In [17]:
dataset_mixer = DatasetMixer()
dataset_groupby_split = dataset_mixer.load_and_mix(dataset_config)
dataset_groupby_split

Load dataset argilla/dpo-mix-7k with default configurations.
Successfully loaded dataset argilla/dpo-mix-7k from HuggingFace Hub or ~/.cache.
Successfully loaded dataset argilla/dpo-mix-7k from HuggingFace Hub or ~/.cache.
Load dataset argilla/dpo-mix-7k with following configurations.
    - split: {'train': 'train', 'test': 'test'}
    - chat_template: chatml
Successfully loaded dataset argilla/dpo-mix-7k from HuggingFace Hub or ~/.cache.
Successfully loaded dataset argilla/dpo-mix-7k from HuggingFace Hub or ~/.cache.
Load dataset argilla/dpo-mix-7k with following configurations.
    - chat_template: chatml
Successfully loaded dataset argilla/dpo-mix-7k from HuggingFace Hub or ~/.cache.
Successfully loaded dataset argilla/dpo-mix-7k from HuggingFace Hub or ~/.cache.
Load dataset /home/kenyo/workspace/LLM-Trainer/dataset/LukasSonn/DoxygenStrings-Long with following configurations.
    - split: {'train': 'train', 'test': 'test'}
Successfully loaded dataset /home/kenyo/workspace/LLM-Train

DatasetGroupbySplit(train=[Dataset({
    features: ['dataset', 'chosen', 'rejected', 'chosen_rating', 'rejected_rating'],
    num_rows: 6750
}), Dataset({
    features: ['dataset', 'chosen', 'rejected', 'chosen_rating', 'rejected_rating'],
    num_rows: 6750
}), Dataset({
    features: ['dataset', 'chosen', 'rejected', 'chosen_rating', 'rejected_rating'],
    num_rows: 6750
}), Dataset({
    features: ['question', 'context', 'answer'],
    num_rows: 10235
}), Dataset({
    features: ['context', 'answer', 'question'],
    num_rows: 10235
})], test=[Dataset({
    features: ['dataset', 'chosen', 'rejected', 'chosen_rating', 'rejected_rating'],
    num_rows: 750
}), Dataset({
    features: ['dataset', 'chosen', 'rejected', 'chosen_rating', 'rejected_rating'],
    num_rows: 750
}), Dataset({
    features: ['dataset', 'chosen', 'rejected', 'chosen_rating', 'rejected_rating'],
    num_rows: 750
}), Dataset({
    features: ['question', 'context', 'answer'],
    num_rows: 2047
}), Dataset({
   