In [33]:
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from torch import bfloat16


# model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B", torch_dtype=bfloat16, device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")

In [8]:
conversation = [{"role": "user", "content": "What's the weather like in Paris?"}, 
                {"role": "assistant", "content": "gg"}]
tools = None

tool_use_prompt = tokenizer.apply_chat_template(
            conversation,
            tools=tools,
            tokenize=False,
            add_generation_prompt=False,
)

In [46]:
from recipes import DatasetRecipe
def get_response_template(tokenizer: PreTrainedTokenizer, dataset_recipe: DatasetRecipe) -> str:
    if dataset_recipe.dataset_system_message is not None: 
        conversation = [{"role": "system", "content": dataset_recipe.dataset_system_message}]
    else: conversation = []
    conversation += [
        {"role": "user", "content": "dummy"}
    ]
    
    dummy_no_generation_prompt = tokenizer.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=False,
    )
    
    dummy_generation_prompt = tokenizer.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=True,
    )
    
    for index, (c1, c2) in enumerate(zip(dummy_no_generation_prompt, 
                                         dummy_generation_prompt)):
        if c1 != c2: return dummy_generation_prompt[index:]
    else: return dummy_generation_prompt[len(dummy_no_generation_prompt):]

In [47]:
from recipes.datasets.PostOCRCorrectionDatasetRecipe import PostOCRCorrectionDatasetRecipe

print(get_response_template(tokenizer, PostOCRCorrectionDatasetRecipe))

<|im_start|>assistant



In [3]:
print(tokenizer.chat_template)

None


In [9]:
tool_use_prompt

"<|im_start|>user\nWhat's the weather like in Paris?<|im_end|>\n<|im_start|>assistant\ngg<|im_end|>\n"

In [5]:
tokenizer.special_tokens_map

{'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>'}

In [7]:
from dataclasses import fields, field, dataclass, asdict
import argparse
from typing import Union
from ast import literal_eval
from typing import Dict, Union, Optional
from datasets import DatasetDict, Dataset, IterableDatasetDict, IterableDataset
from recipes.Recipe import Recipe
from cookbooks import DATASET_COOKBOOK

@dataclass
class DatasetRecipe(Recipe):
    """Kwargs to give to the "load_dataset" function from "datasets" module"""
    dataset_load: Optional[dict] = field(default=None, metadata={"description": "test"})
    response_template: Optional[str] = field(default=None)

    def preprocess_dataset(self, dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset, None]) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset, None]:
        return dataset
    
    def preprocess_function(self, sample: Dict, examples: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset, dict, None]) -> Dict:
        return sample

class ModelRecipe(Recipe):
    model_load: Optional[dict] = None
    model_config: Optional[dict] = None

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def merge_dataclasses(dataclass_list):
    # Convert each dataclass in the list to a dictionary with its metadata
    dicts_with_metadata = []
    for cls in dataclass_list:
        fis = {fi.name: fi.type for fi in asdict(cls.__annotations__.values())}
        metadata = {"created_at": getattr(cls, "created_at", 0)}
        dicts_with_metadata.append({**fis, **metadata})

    # Merge the dictionaries together
    merged_dict = {}
    for d in dicts_with_metadata:
        merged_dict.update(d)

    return merged_dict

# Define a function to convert the merged dictionary back into a dataclass
def dict_to_dataclass(cls, data):
    annotations = {k: v for k, v in cls.__annotations__.items() if k not in ["created_at", "source"]}
    data = {**annotations, **data}
    return cls(**data)

In [13]:
merge_dataclasses([DatasetRecipe(), ModelRecipe()])

TypeError: 'DatasetRecipe' object is not callable

In [29]:
# --------------------------------------------------------------------------
# Here we define all the parameters for our training procedure ("train_cli.py")
# Each parameter is a field. For each of them we specify the type, default and some useful metadata.
# 
# Metadata:
#   - `description` (str): Description of the field. Used both as documentation and also by the argument parser
#                          as information to show when calling `--help` on `train_cli.py` 
#   - `required` (bool): Whether it is required for running `train_cli.py`. 
#                        Note: This is done instead of removing `None` from (Union[..., None]) so that we can use also
#                              use a starting config file for common parameters (See utils.parsers.get_config_from_argparser for the code)
#   - `recipe_keywords` (List[str]): List of optional fields' names to use to create the recipe in case it was specified as a string.
#                           For example --model_recipe "MistralModelRecipe" will also use `model_load` and `model_config` fields for initialization.
#   - `cookbook` (CookBook): Cookbook to use to get the `recipe` if it was specified by string. 
#                            For example --model_recipe "MistralModelRecipe", the string "MistralModelRecipe" would be used to get the 
#                            class from the given cookbook.
# --------------------------------------------------------------------------
from typing import List, Tuple, Any

def camel_to_snake(camel_string):
    """
    Convert camelCase string to snake_case string.

    Args:
        camel_string (str): The input string in camelCase format.

    Returns:
        str: The output string in snake_case format.
    """
    result = ''
    for i, char in enumerate(camel_string):
        if char.isupper() and i > 0:
            result += '_' + char.lower()
        else:
            result += char.lower()
    return result

def generate_argparser_from_recipe(recipe_dataclasses: List[object], description: str):
    """
    Generate an argparse.ArgumentParser based on the fields of a dataclass.

    Args:
        dataclass_type: Type of the dataclass.
        description (str): Description for the argument parser.

    Returns:
        argparse.ArgumentParser: Argument parser populated with dataclass fields.
    """
    parser = argparse.ArgumentParser(description=description)

    for recipe_dataclass in recipe_dataclasses:
        parser.add_argument(f"--{camel_to_snake(recipe_dataclass.__name__)}", 
                            type=str, default=None, 
                            help="Recipe's name from cookbook")
        
        # Iterate through fields of the dataclass
        for field in fields(recipe_dataclass):
            field_type = field.type
            # Handle Optional types
            if hasattr(field_type, "__origin__") and field_type.__origin__ is Optional:
                valid_types = [t for t in field_type.__args__ if t in (str, int, dict, float)]
                field_type = valid_types[0]
                # If type is dict, we make the argparse evaluate the string
                # This converts the string '{"example": "example value"}' into the final dict
                if field_type == dict: field_type = literal_eval
            field_default = field.default if field.default is not None else None
            field_help = field.metadata.get("description", "")
            parser.add_argument(f"--{field.name}", type=field_type, default=field_default, help=field_help)

    return parser

print(generate_argparser_from_recipe([DatasetRecipe, ModelRecipe], "").print_help())

NameError: name 'argparse' is not defined

In [8]:
DatasetRecipe.__annotations__

{'dataset_load': typing.Optional[dict],
 'response_template': typing.Optional[str]}

In [45]:
from dataclasses import asdict, dataclass
from pprint import pprint

def join_dicts(priority: Optional[dict], secondary: Optional[dict]) -> Dict:
    if not secondary: secondary = {}
    if not priority: priority = {}
    return {**secondary, **priority}

@dataclass
class DatasetRecipe(Recipe):
    """Kwargs to give to the "load_dataset" function from "datasets" module"""
    dataset_load: Optional[dict] = field(default=None, metadata={"description": "test"})
    response_template: Optional[str] = field(default=None)

@dataclass
class ModelRecipe(Recipe):
    model_load: Optional[dict] = None
    model_config: Optional[dict] = None

# @dataclass
# class RecipeMerge:
#     def _update_fields(self, field_dict: dict):
#         self.__dataclass_fields__.update(field_dict)
#     
#     def __init__(self, recipe_dataclasses: List[dataclass]):
#         for recipe_dataclass in recipe_dataclasses:
#             self._update_fields(recipe_dataclass.__dataclass_fields__)
#             
#             # For some reason the field does not appear unless we modify an already existing one
#             fi = fields(recipe_dataclass)[0]
#             fi.name = camel_to_snake(recipe_dataclass.__name__)
#             fi.default = None; fi.type = str
#             fi.metadata = {"required": True, "cookbook": "", "recipe_keywords": fields(recipe_dataclass)}
#             self._update_fields({camel_to_snake(recipe_dataclass.__name__): fi})
# 
# t = RecipeMerge([DatasetRecipe, ModelRecipe])

# pprint(t.__dataclass_fields__)

# for fi in fields(t):
#     print(fi.name, fi.metadata)

In [47]:
from dataclasses import make_dataclass

def get_config(recipe_dataclasses: List[Tuple[dataclass, bool, Cookbook]]):
    fis = []
    for recipe, required, cookbook in [(DatasetRecipe, True, DATASET_COOKBOOK), ModelRecipe]:
        fis += [(x.name, x) for x in fields(recipe)]
        fis += [(camel_to_snake(recipe.__name__), field(default=None, metadata = {"required": required, 
                                                                            "cookbook": cookbook, 
                                                                            "recipe_keywords": fields(recipe)}))]
    make_dataclass("Config", fields=fis)

(Field(name='dataset_load',type=Field(name='dataset_load',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x0000014DAAA6AB00>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'test'}),kw_only=False,_field_type=_FIELD),default=<dataclasses._MISSING_TYPE object at 0x0000014DAAA6AB00>,default_factory=<dataclasses._MISSING_TYPE object at 0x0000014DAAA6AB00>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='response_template',type=Field(name='response_template',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x0000014DAAA6AB00>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),default=<dataclasses._MISSING_TYPE object at 0x0000014DAAA6AB00>,default_factory=<dataclasses._MISSING_TYPE object at 0x0000014DAAA6AB00>,init=True,repr=True,hash=None,compare=True

In [1]:
from dataclasses import dataclass, fields, make_dataclass, field
import os
from recipes import DatasetRecipe
from recipes import ModelRecipe
from recipes import TokenizerRecipe
from recipes import PeftRecipe
from recipes import QuantizationRecipe
from recipes import TrainRecipe



@dataclass
class Config:
    dataset_recipe: DatasetRecipe
    model_recipe: ModelRecipe
    tokenizer_recipe: TokenizerRecipe
    peft_recipe: PeftRecipe
    quantization_recipe: QuantizationRecipe
    train_recipe: TrainRecipe

    def create_config_file(self, file_name):
        """
        Creates a file with the given name and writes its contents to it in JSON format.

        Args:
            file_name (str): The name of the file.
            dataclass_instance (MyClass): An instance of MyClass that contains the data to be written.
        """
        
        for fi in fields(self):
            for fii in fields(fi.type):
                print(fii)
        return
        # Get the path of the parent directory
        parent_dir = os.path.dirname(os.path.dirname(__file__))

        # Construct the full path to the target folder
        target_folder_path = os.path.join(parent_dir, 'recipes', 'train')

        # Create the target folder if it doesn't exist
        if not os.path.exists(target_folder_path):
            os.makedirs(target_folder_path)
        
        # Get the full path of the file
        file_path = os.path.join(target_folder_path, file_name)
        

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
[x.name for x in fields(DatasetRecipe)]

['dataset_load']

In [7]:
fiis = []
for fi in [DatasetRecipe, ModelRecipe]:
    recipe_field = field(default=None, metadata = {"required": True, 
                                                    "description": "Recipe's name from cookbook",
                                                    "cookbook": "", 
                                                    "recipe_keywords": vars(fi)})
    fiis.append((fi.__name__, str, recipe_field))
    print(fields(fi))
    for fii in fields(fi):
        fiis.append((fii.name, fii.type, fii))
        
t = make_dataclass("Test", fiis)

for fii in fields(t):
    print(fii)

(Field(name='dataset_load',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for dataset load.'}),kw_only=False,_field_type=_FIELD),)
(Field(name='model_load',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for model load.'}),kw_only=False,_field_type=_FIELD), Field(name='model_config',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for model configuration.'}),kw_only=False,_field_type=_FIELD))
Field(name='DatasetRecipe',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,

In [14]:
Config(None, None, None, None, None, None).create_config_file("")

Field(name='model_load',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x00000202BD1E9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for model load.'}),kw_only=False,_field_type=_FIELD)
Field(name='model_config',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x00000202BD1E9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for model configuration.'}),kw_only=False,_field_type=_FIELD)
Field(name='tokenizer_load',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x00000202BD1E9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for tokenizer load.'}),kw_only=False,_field_type=_FIELD)
Field(name='tokenizer_config',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x00000202BD1E9600>,i

In [11]:
import os
from typing import List
from dataclasses import dataclass, make_dataclass, field, fields
from recipes.Recipe import Recipe
from recipes import DatasetRecipe, ModelRecipe


def t(name: str, recipes: List[Recipe]):
    fiis = []
    for recipe in recipes:
        recipe_field = field(default=None, metadata = {"required": True, 
                                                        "description": "Recipe's name from cookbook",
                                                        "cookbook": "", 
                                                        "recipe_keywords": [x.name for x in fields(recipe)]})
        fiis.append((recipe.__name__, str, recipe_field))
        for fii in fields(recipe):
            fiis.append((fii.name, fii.type, fii))
    return make_dataclass(name, fiis)
            
Config = t("Config", [DatasetRecipe, ModelRecipe])

In [13]:
fields(Config)

(Field(name='DatasetRecipe',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'required': True, 'description': "Recipe's name from cookbook", 'cookbook': '', 'recipe_keywords': ['dataset_load']}),kw_only=False,_field_type=_FIELD),
 Field(name='dataset_load',type=typing.Optional[dict],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'description': 'Kwargs for dataset load.'}),kw_only=False,_field_type=_FIELD),
 Field(name='ModelRecipe',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002B0FA1A9600>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'required': True, 'description': "Recipe's name from cookbook", 'cookbook': '', 'recipe_keywords': ['model_load', 'model_config']}),kw_only=False,_field_type=_FIELD),

In [53]:
from dataclasses import asdict, is_dataclass, dataclass, fields

@dataclass
class BTest:
    t: str

@dataclass
class A:
    a: int
    b: BTest
    
c = A(0, BTest("test"))

def config_to_flat_dict(config) -> dict:
    config_dict = {}
    for fi in fields(config):
        if is_dataclass(fi.type):
            config_dict[fi.name] = getattr(config, fi.name).__class__.__name__
            config_dict.update(config_to_flat_dict(getattr(config, fi.name)))
        else: 
            config_dict[fi.name] = getattr(config, fi.name)
    return config_dict
        
print(config_to_flat_dict(c))

{'a': 0, 'b': 'BTest', 't': 'test'}


In [None]:
from recipes.pefts import LoRaPeftRecipe

LoRaPeftRecipe()

In [55]:
from dataclasses import field


a = field(default_factory=dict)

a.default_factory()

{}

In [5]:
from datasets import load_dataset
dataset = load_dataset("glue", "mrpc", split="train")

dataset_support = dataset.select(range(len(dataset)))
print(dataset["label"][:10])

shuffled_dataset = dataset.shuffle(seed=42)
print(shuffled_dataset["label"][:10])

[1, 0, 1, 0, 1, 1, 0, 1, 0, 0]
[0, 0, 1, 0, 0, 1, 0, 1, 1, 0]


In [6]:
dataset_support["label"][:10]

[1, 0, 1, 0, 1, 1, 0, 1, 0, 0]

In [9]:
from recipes import DatasetRecipe
from recipes.datasets.PostOCRCorrectionDatasetRecipe import PostOCRCorrectionDatasetRecipe

In [44]:
def get_response_template(tokenizer: PreTrainedTokenizer, dataset_recipe: DatasetRecipe) -> str:
    if dataset_recipe.dataset_system_message is not None: 
        conversation = [{"role": "system", "content": dataset_recipe.dataset_system_message}]
    else: conversation = []
    
    conversation_no_generation = conversation_generation = conversation + [{"role": "user", "content": "dummy"}]
    dummy_no_generation_prompt = tokenizer.apply_chat_template(
            conversation_no_generation,
            tokenize=False,
            add_generation_prompt=False,
    )

    dummy_generation_prompt = tokenizer.apply_chat_template(
            conversation_generation,
            tokenize=False,
            add_generation_prompt=True,
    )
    if len(dummy_no_generation_prompt) != len(dummy_generation_prompt):  
        return dummy_generation_prompt[len(dummy_no_generation_prompt):]
        
    conversation_no_generation = conversation + [{"role": "user", "content": "dummy"}]
    conversation_generation = conversation + [{"role": "user", "content": "ymmud"}]
    dummy_no_generation_prompt = tokenizer.apply_chat_template(
            conversation_no_generation,
            tokenize=False,
            add_generation_prompt=False,
    )
    
    dummy_generation_prompt = tokenizer.apply_chat_template(
            conversation_generation,
            tokenize=False,
            add_generation_prompt=True,
    )

    for index, (c1, c2) in enumerate(zip(dummy_no_generation_prompt[::-1], 
                                        dummy_generation_prompt[::-1])):
        if c1 != c2: return dummy_generation_prompt[-index:]
    raise Exception("An error has occured during the creation of the response template. If the problem persists, set completion_only to False.")

In [45]:
get_response_template(tokenizer, PostOCRCorrectionDatasetRecipe)

'[/INST]'