<a href="https://colab.research.google.com/github/ekrombouts/GenCareAI/blob/work_in_progress/scripts/work_in_progress/430_CarePlan_mobility_InferenceFietjeBase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required packages
!pip install -q transformers datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset


In [None]:
# Load the trained model and tokenizer
model_name = "ekrombouts/fietje_zorgplan_base"  # Replace with the correct path if necessary
model_tokenizer = "BramVanroy/fietje-2"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='auto',
)

tokenizer = AutoTokenizer.from_pretrained(model_tokenizer)
tokenizer.pad_token = tokenizer.eos_token

# Ensure that model can use cache for inference
model.config.use_cache = True
model.eval()

In [None]:
# Load the validation dataset
dataset = load_dataset("ekrombouts/Galaxy_SAMPC")
val_dataset = dataset['validation']

# Define the tokenizer functions
def truncate_notes_to_fit_prompt(notes, max_length=1800):
    """
    Tokenize and truncate the 'notes' text to fit within the maximum length.
    """
    # Tokenize and truncate the 'notes' to max_length
    tokens = tokenizer(notes, return_tensors="np", truncation=True, max_length=max_length)

    # Decode the truncated tokens back to text
    truncated_notes = tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=True)

    return truncated_notes

# Function to add the 'truncated_notes' column
def add_truncated_notes(example):
    notes_text = example["notes"]
    # Truncate the 'notes' to fit within the 1800 token limit
    truncated_notes = truncate_notes_to_fit_prompt(notes_text, max_length=1800)
    # Return the new field 'truncated_notes'
    return {"truncated_notes": truncated_notes}

# Add the 'truncated_notes' column to the val_dataset
val_dataset = val_dataset.map(add_truncated_notes)

In [None]:
# Get the sample at index 3
sample = val_dataset[1]

# Prepare the prompt
notes = sample['truncated_notes']
mobiliteit_actual = sample['mobiliteit']

prompt = f'''Lees de volgende rapportages en beschrijf de mobiliteit van de cliënt.

Rapportages:
{notes}

Beschrijf de mobiliteit van de cliënt:
'''

# Tokenize the prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
attention_mask = tokenizer(prompt, return_tensors="pt", padding=True).attention_mask.to(model.device)


In [None]:

# Genereer de output met de attention_mask
output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    max_new_tokens=150,
    do_sample=True,
    top_p=0.95,
    top_k=50,
    temperature=0.7,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id
)

# Decode the output
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

# Remove the prompt from the generated text to get only the model's output
generated_response = generated_text[len(prompt):].strip()

# Print the generated response and the actual 'mobiliteit' for comparison
print("Generated Mobiliteit:")
print(generated_response)
print("\nActual Mobiliteit:")
print(mobiliteit_actual)

In [None]:
full_response = generated_text
print(full_response)
