<a href="https://colab.research.google.com/github/ekrombouts/GenCareAI/blob/work_in_progress/scripts/work_in_progress/420_CarePlan_mobility_TrainFietjeBase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## Setup

# Install the required packages
!pip install -q transformers datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset, Dataset, DatasetDict
import time

In [None]:
path_hf_sampc = "ekrombouts/Galaxy_SAMPC"

In [None]:
model_name = "BramVanroy/fietje-2"
model_finetuned = "fietje_zorgplan_base"
commit_message = "Trained base model"

In [None]:
## Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

print(f"Memory footprint: {model.get_memory_footprint() / 1e9} GB")

In [None]:
## Prepare data
dataset = load_dataset(path_hf_sampc)
train_dataset = dataset['train']
val_dataset = dataset['validation']

In [None]:
def truncate_notes_to_fit_prompt(notes, max_length=1800):
    """
    Tokeniseert en trunkert de 'notes' tekst om binnen de maximale lengte te passen.
    """
    # Tokeniseer en truncateer de 'notes' tot max_length
    tokens = tokenizer(notes, return_tensors="np", truncation=True, max_length=max_length)

    # Decodeer de getruncateerde tokens terug naar tekst
    truncated_notes = tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=True)

    return truncated_notes

# Functie om de 'truncated_notes' kolom toe te voegen
def add_truncated_notes(example):
    notes_text = example["notes"]
    # Truncateer de 'notes' om binnen de 1800 token limiet te passen
    truncated_notes = truncate_notes_to_fit_prompt(notes_text, max_length=1800)
    # Retourneer het nieuwe veld 'truncated_notes'
    return {"truncated_notes": truncated_notes}

# Voeg een nieuwe kolom 'truncated_notes' toe aan de train_dataset
train_dataset = train_dataset.map(add_truncated_notes)

# Voeg een nieuwe kolom 'truncated_notes' toe aan de val_dataset
val_dataset = val_dataset.map(add_truncated_notes)

In [None]:
# Functie die een prompt maakt van input en output en deze tokenizeert
def collate_and_tokenize(examples):
    notes = examples["truncated_notes"][0]
    mobiliteit = examples["mobiliteit"][0]

    # Maak de prompt voor tokenisatie en training
    prompt = f'''Lees de volgende rapportages en beschrijf de mobiliteit van de cliënt.

Rapportages:
{notes}

Beschrijf de mobiliteit van de cliënt:
{mobiliteit}
'''

    # Tokeniseer de prompt
    encoded = tokenizer(
        prompt,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=2048,
    )

    encoded["labels"] = encoded["input_ids"].clone()
    return encoded

# We behouden alleen de 'input_ids' en 'labels'
columns_to_remove = ['ct_id', 'week', 'notes', 'somatiek', 'adl', 'mobiliteit', 'continentie', 'maatschappelijk', 'psychisch', 'truncated_notes']

# Tokeniseer de trainings- en validatiedatasets
tokenized_dataset_train = train_dataset.map(collate_and_tokenize,
                                            batched=True,
                                            batch_size=1,
                                            remove_columns=columns_to_remove)
tokenized_dataset_val = val_dataset.map(collate_and_tokenize,
                                        batched=True,
                                        batch_size=1,
                                        remove_columns=columns_to_remove)

In [None]:
# # Controleer of de tokenisatie correct is
# input_ids = tokenized_dataset_train[5]['input_ids']
# decoded = tokenizer.decode(input_ids, skip_special_tokens=True)
# print(decoded)

In [None]:
## Prepare model
def print_trainable_parameters(model):
    """
    Print het aantal trainbare parameters in het model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}%"
    )

print_trainable_parameters(model)

# Schakel gradient checkpointing in om geheugen te besparen
model.gradient_checkpointing_enable()


In [None]:
## Train and save model
training_args = TrainingArguments(
    output_dir='./results_full',
    report_to='none',
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    warmup_steps=50,
    logging_dir='./logs',
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=100,
    load_best_model_at_end=True,
    bf16=True,
    learning_rate=5e-5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset_train,
    eval_dataset=tokenized_dataset_val,
    args=training_args,
)


In [None]:
# Schakel cache uit om waarschuwingen te voorkomen, opnieuw inschakelen voor inferentie
model.config.use_cache = False

start_time = time.time()  # Starttijd
trainer.train()  # Start training
end_time = time.time()  # Eindtijd

training_time = end_time - start_time  # Totale trainingstijd

print(f"Training completed in {training_time} seconds.")

# Sla het model op in de Hugging Face Hub
model.push_to_hub(model_finetuned,
                  use_auth_token=True,
                  commit_message=commit_message,
                  private=True)

# Sla de tokenizer op in de Hugging Face Hub
tokenizer.push_to_hub(model_finetuned,
                      use_auth_token=True,
                      commit_message=commit_message)
# Beëindig de sessie om kosten te voorkomen
from google.colab import runtime
runtime.unassign()