In [None]:
!pip install datasets==3.2.0 transformers==4.47.1

In [2]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, LlamaForCausalLM, Trainer, TrainingArguments

In [None]:
def train_val_test_split(dataset, split_weights: list = [0.8, 0.1, 0.1]):
    total = sum(split_weights)
    train_weights = split_weights[0] / total
    val_weights = split_weights[1] / total
    test_weights = split_weights[2] / total
    
    val_size = val_weights / (val_weights + test_weights)

    train_test_split = dataset['train'].train_test_split(train_size=train_weights)
    val_test_split = train_test_split['test'].train_test_split(train_size=val_size)
    
    dataset_split = DatasetDict({
        "train": train_test_split["train"],
        "validation": val_test_split["train"],
        "test": val_test_split["test"]
    })
    
    print(f"Train : {train_weights:.0%} - Val : {val_weights:.0%} - Test : {test_weights:.0%}")
    return dataset_split

In [None]:
path_dataset = "lavita/ChatDoctor-HealthCareMagic-100k"
checkpoint_model = "meta-llama/Llama-3.1-8B"

In [None]:
dataset = load_dataset(path_dataset)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint_model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(checkpoint_model, device_map="auto")

In [None]:
dataset = train_val_test_split(dataset, [7,2,1])

In [None]:
def tokenize(example):
    prompt = f"{example['instruction']}\n\n{example['input']}\n\n"
    config = dict(truncation=True, padding="max_length", max_length=512)
    return {
        "input_ids": tokenizer(prompt, **config)["input_ids"],
        "labels": tokenizer(example["output"], **config)["input_ids"]
    }

In [None]:
dataset.map(tokenize)