# Setup

In [None]:
# Install the YAML magic
%pip install yamlmagic
%load_ext yamlmagic

In [None]:
# Install the Kubeflow SDK (this can be removed once the latest version is included into workbench images)
%pip install git+https://github.com/kubeflow/trainer.git@release-1.9#subdirectory=sdk/python

# Training Configuration
Edit the following training parameters:

In [None]:
%%yaml parameters

# Model
model_name_or_path: facebook/bart-base    # only works with Encoder-Decoder models
model_revision: main
model_revision: main
torch_dtype: bfloat16
attn_implementation: eager                # one of eager (default), sdpa or flash_attention_2
use_liger: false                          # use Liger kernels

# PEFT / LoRA (Apply to Generator Model)
use_peft: false
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_target_modules: ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]   # Ensure these match your generator model
lora_modules_to_save: []

# QLoRA (BitsAndBytes) (Apply to Generator Model)
load_in_4bit: false                       # use 4 bit precision for the base model (only with LoRA)
load_in_8bit: false                       # use 8 bit precision for the base model (only with LoRA)

# Dataset
dataset_name: facebook/wiki_dpr
dataset_config: main                      # name of the dataset configuration
dataset_train_split: train                # dataset split to use for training (for RAG generated data)
dataset_test_split: test                  # dataset split to use for evaluation (for RAG generated data)
dataset_kwargs:
    add_special_tokens: false               # template with special tokens
    append_concat_token: false              # add additional separator token

# SFT (These parameters will now apply to the RagModel's training)
max_seq_length: 1024                      # max sequence length for model and packing of the dataset
dataset_batch_size: 1000                  # samples to tokenize per batch (for initial data processing)
packing: false                            # Packing is generally not used directly with RagModel training in the same way as SFT

# Training
num_train_epochs: 3                       # number of training epochs
remove_unused_columns: false
label_smoothing_factor: 0.1

per_device_train_batch_size: 1            # Batch size per device during training
per_device_eval_batch_size: 1             # Batch size for evaluation
auto_find_batch_size: false               # find a batch size that fits into memory automatically
eval_strategy: epoch                      # evaluate every epoch

bf16: true                                # use bf16 16-bit (mixed) precision
tf32: true                               # use tf32 precision

learning_rate: 4.0e-6                     # Initial learning rate for RAG model training
warmup_steps: 150                         # steps for a linear warmup from 0 to `learning_rate`
lr_scheduler_type: cosine                 # learning rate scheduler (see transformers.SchedulerType)

optim: adamw_torch_fused                  # optimizer (see transformers.OptimizerNames)
max_grad_norm: 1.0                        # max gradient norm
seed: 42

gradient_accumulation_steps: 8            # Increase for smaller per_device_train_batch_size
gradient_checkpointing: false             # use gradient checkpointing to save memory
gradient_checkpointing_kwargs:
    use_reentrant: false

# FSDP
fsdp: "full_shard auto_wrap"              # add offload if not enough GPU memory
fsdp_config:
    activation_checkpointing: true
    cpu_ram_efficient_loading: false
    sync_module_states: true
    use_orig_params: true
    limit_all_gathers: false

# Checkpointing
save_strategy: epoch                      # save checkpoint every epoch
save_total_limit: 1                       # limit the total amount of checkpoints
resume_from_checkpoint: false             # load the last checkpoint in output_dir and resume from it

# Logging
log_level: warning                        # logging level (see transformers.logging)
logging_strategy: steps
logging_steps: 1                          # log every N steps
report_to:
- tensorboard                             # report metrics to tensorboard

output_dir: /mnt/shared/fine_tuned_rag_model

# Feast Setup

In [None]:
%%bash
pip install --quiet feast[milvus] sentence-transformers datasets faiss-cpu
pip install bigtree==0.19.2
pip install marshmallow==3.10.0
pip install git+https://github.com/feast-dev/feast.git@master

In [None]:
from datasets import load_dataset
# load wikipedia dataset - 1% of the training split
dataset = load_dataset(
    "facebook/wiki_dpr",
    "psgs_w100.nq.exact",
    split="train[:1%]",
    with_index=False,
    trust_remote_code=True
)

In [None]:
def chunk_dataset(examples, chunk_size=100, overlap=20, max_chars=100):
    all_chunks = []
    all_ids = []
    all_titles = []

    for i, text in enumerate(examples['text']):  # Iterate over texts in the batch
        words = text.split()
        chunks = []
        for j in range(0, len(words), chunk_size - overlap):
            chunk_words = words[j:j + chunk_size]
            if len(chunk_words) < 20:
                continue
            chunk_text_value = ' '.join(chunk_words)  # Store the chunk text
            if len(chunk_text_value) > max_chars:
                chunk_text_value = chunk_text_value[:max_chars]
            chunks.append(chunk_text_value)
            all_ids.append(f"{examples['id'][i]}_{j}")  # Unique ID for the chunk
            all_titles.append(examples['title'][i])

        all_chunks.extend(chunks)

    return {'id': all_ids, 'title': all_titles, 'text': all_chunks}


chunked_dataset = dataset.map(
    chunk_dataset,
    batched=True,
    remove_columns=dataset.column_names,
    num_proc=1
)

In [None]:
from sentence_transformers import SentenceTransformer

sentences = chunked_dataset["text"]
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(sentences, show_progress_bar=True, batch_size=64, device="cuda")

print(f"Generated embeddings of shape: {embeddings.shape}")

In [None]:
import pandas as pd
from datetime import datetime, timezone

# Create DataFrame
df = pd.DataFrame({
    "passage_id": list(range(len(sentences))),
    "passage_text": sentences,
    "embedding": pd.Series(
        [embedding.tolist() for embedding in embeddings],
        dtype=object
    ),
    "event_timestamp": [datetime.now(timezone.utc) for _ in sentences],
})

print("DataFrame Info:")
print(df.head())
print(df["embedding"].apply(lambda x: len(x) if isinstance(x, list) else str(type(x))).value_counts())  # Check lengths

# Save to Parquet
df.to_parquet("feature_repo/wiki_dpr.parquet", index=False)
print("Saved to wiki_dpr.parquet")

In [None]:
%cd feature_repo

In [None]:
%feast apply

In [None]:
from feast import FeatureStore
import pandas as pd

store = FeatureStore(repo_path=".")
df = pd.read_parquet("./wiki_dpr.parquet")
chunk_size = 10000
num_rows = len(df)

for i in range(0, num_rows, chunk_size):
    chunk_df = df.iloc[i:i + chunk_size]
    print(f"Writing chunk {i//chunk_size + 1}/{(num_rows + chunk_size - 1)//chunk_size}...")
    store.write_to_online_store(feature_view_name='wiki_passages', df=chunk_df)
    print("Chunk written successfully.")

print("All data written to online store.")

# Training Loop

Review the training function. You can adjust the chat template if needed depending on the model you want to fine-tune:

#### Copy over the `kfto-sft-feast-rag` folder into the shared storage, so that it can be accessed by the training job, since the FeastRagRetriever requires the `feature_repo` directory during training and inference. Copy over the `feast_rag_retriever.py` script as well, so that it can be imported in the training job.

In [None]:
%%bash
cp -r $HOME/distributed-workloads/examples/kfto-sft-feast-rag $HOME/shared/
cp $HOME/distributed-workloads/examples/kfto_feast_rag/feast_rag_retriever.py $HOME/shared/kfto-sft-feast-rag/

In [None]:
def main(parameters):
    import subprocess, sys
    # --- Install necessary packages ---
    # This ensures Feast and other dependencies are available inside the container
    print("Installing Feast and other RAG dependencies...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "feast[milvus]", "sentence-transformers", "datasets", "bigtree==0.19.2", "marshmallow==3.10.0", "protobuf>=3.10.0", "git+https://github.com/feast-dev/feast.git@master", "faiss-cpu"])
    print("Feast and other RAG dependencies installed.")
    from pathlib import Path
    import os
    import random
    from datasets import Dataset
    from transformers import (
        set_seed,
        RagSequenceForGeneration,
        TrainingArguments,
        Trainer,
        RagTokenizer,
        default_data_collator,
        RagConfig,
        AutoTokenizer,
        AutoModelForSeq2SeqLM,
        AutoModel,
    )
    from peft import get_peft_model, prepare_model_for_kbit_training
    import torch

    from trl import (
        ModelConfig,
        ScriptArguments,
        SFTConfig,
        TrlParser,
        get_peft_config,
    )

    # This is required for `feast_rag_retriever` to be found (This is temporary until Feast RAG Retriever is part of Feast SDK)
    CUSTOM_MODULES_PATH = "/mnt/shared/kfto-sft-feast-rag"
    sys.path.append(CUSTOM_MODULES_PATH)
    print(f"Added {CUSTOM_MODULES_PATH} to sys.path for custom module imports.")
    from feast_rag_retriever import FeastRAGRetriever, FeastIndex
    from feature_repo.ragproject_repo import wiki_passage_feature_view

    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args_trl, model_args = parser.parse_dict(parameters)

    # Convert SFTConfig parameters to standard TrainingArguments
    training_args = TrainingArguments(
        output_dir=training_args_trl.output_dir,
        num_train_epochs=training_args_trl.num_train_epochs,
        per_device_train_batch_size=training_args_trl.per_device_train_batch_size,
        per_device_eval_batch_size=training_args_trl.per_device_eval_batch_size,
        gradient_accumulation_steps=training_args_trl.gradient_accumulation_steps,
        gradient_checkpointing=training_args_trl.gradient_checkpointing,
        learning_rate=training_args_trl.learning_rate,
        warmup_steps=training_args_trl.warmup_steps,
        lr_scheduler_type=training_args_trl.lr_scheduler_type,
        optim=training_args_trl.optim,
        max_grad_norm=training_args_trl.max_grad_norm,
        seed=training_args_trl.seed,
        bf16=training_args_trl.bf16,
        tf32=training_args_trl.tf32,
        eval_strategy=training_args_trl.eval_strategy,
        save_strategy=training_args_trl.save_strategy,
        save_total_limit=training_args_trl.save_total_limit,
        logging_strategy=training_args_trl.logging_strategy,
        logging_steps=training_args_trl.logging_steps,
        report_to=training_args_trl.report_to,
        fsdp=training_args_trl.fsdp if hasattr(training_args_trl, 'fsdp') else None,
        fsdp_config=training_args_trl.fsdp_config if hasattr(training_args_trl, 'fsdp_config') else None,
        resume_from_checkpoint=training_args_trl.resume_from_checkpoint,
        remove_unused_columns=training_args_trl.remove_unused_columns,
    )

    set_seed(training_args.seed)

    # --- Load Models for RAG ---
    # Question Encoder
    question_encoder_model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
    question_encoder_tokenizer = AutoTokenizer.from_pretrained(question_encoder_model_name_or_path, trust_remote_code=model_args.trust_remote_code)
    question_encoder_model = AutoModel.from_pretrained(
        question_encoder_model_name_or_path,
        trust_remote_code=model_args.trust_remote_code,
        torch_dtype=model_args.torch_dtype
    )

    # Generator Model to be fine-tuned
    generator_model_name_or_path = parameters["model_name_or_path"]
    generator_tokenizer = AutoTokenizer.from_pretrained(
        generator_model_name_or_path,
        trust_remote_code=model_args.trust_remote_code,
        use_fast=True
    )

    if generator_tokenizer.pad_token is None:
        if generator_tokenizer.eos_token is not None:
            print(f"Generator tokenizer pad_token not set. Using eos_token ({generator_tokenizer.eos_token}) as pad_token.")
            generator_tokenizer.pad_token = generator_tokenizer.eos_token
        else:
            print("Warning: Generator tokenizer has no pad_token or eos_token. Adding a new pad token.")
            generator_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    generator_model = AutoModelForSeq2SeqLM.from_pretrained(
        generator_model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=model_args.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
    )

    if model_args.use_peft:
        print(f"PEFT/LoRA: Ensure 'lora_target_modules' in your parameters ({parameters.get('lora_target_modules')}) are correct for the new generator model '{generator_model_name_or_path}'.")
        generator_model = prepare_model_for_kbit_training(generator_model)
        peft_config = get_peft_config(model_args)
        generator_model = get_peft_model(generator_model, peft_config)
        print("PEFT setup for generator model completed.")

    store_path = CUSTOM_MODULES_PATH+"/feature_repo"

    # --- Initialize FeastRAGRetriever ---
    feast_index = FeastIndex()
    rag_top_k = 5

    question_encoder_config = {
        "model_type": "dpr",
        "hidden_size": 384,
        "vocab_size": question_encoder_tokenizer.vocab_size,
        "num_hidden_layers": 6,
        "num_attention_heads": 12,
        "projection_dim": 0,
        "torch_dtype": model_args.torch_dtype
    }

    rag_config = RagConfig(
        question_encoder=question_encoder_config,
        generator=generator_model.config.to_dict(),
        index_name="custom",
        index={"index_name": "feast_dummy_index", "custom_type": "FeastIndex"}
    )

    features_to_retrieve = [
        "wiki_passages:passage_text",
        "wiki_passages:embedding",
        "wiki_passages:passage_id",
    ]

    rag_retriever = FeastRAGRetriever(
        question_encoder_tokenizer=question_encoder_tokenizer,
        question_encoder=question_encoder_model,
        generator_tokenizer=generator_tokenizer,
        generator_model=generator_model,
        feast_repo_path=store_path,
        feature_view=wiki_passage_feature_view,
        features=features_to_retrieve,
        search_type="vector",
        config=rag_config,
        index=feast_index,
        query_encoder_model="all-MiniLM-L6-v2",
    )

    # --- Initialize the RagModel for fine-tuning ---
    model = RagSequenceForGeneration(
        config=rag_config,
        generator=generator_model,
        retriever=rag_retriever
    )
    print("RAG Model initialized.")

    # Explicitly move model to GPU
    model = model.to('cuda')

    # --- Dataset Generation for Fine-tuning ---
    print("Preparing training data (Q&A pairs) with caching...")
    # Cache directory for synthetic data generated in a previous run to save time
    synthetic_data_cache_dir = "/mnt/shared/synthetic_data_cache/rag_synthetic_qa_dataset"
    Path(synthetic_data_cache_dir).mkdir(parents=True, exist_ok=True)
    cached_dataset_path = os.path.join(synthetic_data_cache_dir, "rag_synthetic_qa_dataset")

    if Path(cached_dataset_path).exists():
        print(f"Loading synthetic data from cache: {cached_dataset_path}")
        synthetic_dataset = Dataset.load_from_disk(cached_dataset_path)
    else:
        print("Generating training data for RAG fine-tuning (Q&A pairs)...")
        synthetic_data = []
        try:
            with open("/mnt/shared/kfto-sft-feast-rag/generated_questions.txt", 'r') as file:
                example_queries = [line.strip().strip(',"') for line in file if line.strip()]
        except Exception as e:
            print("Error reading 'generated_questions.txt'")
            raise e

        for i, query in enumerate(example_queries):
            print(f"Generating answer for query {i+1}/{len(example_queries)}: '{query}'")
            try:
                rag_answer = rag_retriever.generate_answer(query, top_k=rag_top_k, max_new_tokens=380)
                if not rag_answer or not rag_answer.strip():
                    print(f"Skipping empty answer for query: '{query}'")
                    continue
                print(f"Generated answer: {rag_answer}")
                # Tokenize question for the RAG model's question encoder
                tokenized_question = question_encoder_tokenizer(
                    query,
                    truncation=True,
                    max_length=64,
                    padding="max_length",
                    return_tensors="np",
                )

                # Tokenize answer for the RAG model's generator (as labels)
                tokenized_answer_for_labels = generator_tokenizer(
                    text_target=rag_answer,
                    truncation=True,
                    max_length=32,
                    padding="max_length",
                    return_tensors="np",
                )

                synthetic_data.append({
                    "input_ids": tokenized_question["input_ids"].flatten().tolist(),
                    "attention_mask": tokenized_question["attention_mask"].flatten().tolist(),
                    "labels": tokenized_answer_for_labels["input_ids"].flatten().tolist(),
                })
            except Exception as e:
                print(f"Error generating answer for query '{query}': {e}. Skipping this query.")

        synthetic_dataset = Dataset.from_list(synthetic_data)
        print(f"Synthetic data generation complete. Saving to cache: {cached_dataset_path}")
        synthetic_dataset.save_to_disk(cached_dataset_path)
        print("Synthetic data saved to cache.")

    train_dataset = synthetic_dataset.train_test_split(test_size=0.1, seed=training_args.seed)[script_args.dataset_train_split]
    test_dataset = synthetic_dataset.train_test_split(test_size=0.1, seed=training_args.seed)[script_args.dataset_test_split]

    with training_args.main_process_first(desc="Log few samples from the training set"):
        for index in random.sample(range(len(train_dataset)), 2):
            print(f"Sample Input IDs: {train_dataset[index]['input_ids'][:50]}...")
            print(f"Sample Labels: {train_dataset[index]['labels'][:50]}...")

    # --- Training the RagModel ---
    from transformers import DataCollatorWithPadding

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=generator_tokenizer,
        data_collator=default_data_collator,
    )

    if trainer.accelerator.is_main_process:
        if hasattr(trainer.model.generator, "print_trainable_parameters"):
            print("Trainable parameters for PEFT-enabled generator:")
            trainer.model.generator.print_trainable_parameters()
        else:
            print("Trainer model does not have 'print_trainable_parameters' method on its generator.")

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint

    # Ensure RAG model forward pass works before training
    print("Test forward pass prior to training")
    example_query_for_test = "What did Socrates say in his trial?"
    example_answer_for_test = "Socrates said wise things."
    question_tok = question_encoder_tokenizer(example_query_for_test, return_tensors="pt", padding="longest", truncation=True, max_length=128)
    label_tok = generator_tokenizer(text_target=example_answer_for_test, return_tensors="pt", padding="longest", truncation=True, max_length=64)

    test_input = {
        "input_ids": question_tok["input_ids"].to('cuda'),
        "attention_mask": question_tok["attention_mask"].to('cuda'),
        "labels": label_tok["input_ids"].to('cuda')
    }
    try:
        with torch.no_grad():
            outputs = model(**test_input)
        print("Test forward pass successful")
    except Exception as e:
        print(f"Forward pass failed: {e}")
        raise

    print("Starting RAG model training...")
    trainer.train(resume_from_checkpoint=checkpoint)
    print("RAG model training completed.")


    # --- Save the Fine-tuned RAG Model ---
    final_save_path = os.path.join(training_args.output_dir, "final_model_for_inference")

    print(f"Saving main model weights and RagConfig to: {final_save_path}")
    trainer.save_model(final_save_path)

    if trainer.args.process_index == 0:
        print("Manually saving component models and tokenizers to subdirectories...")

        # Unwrap the model to get access to its components
        unwrapped_model = trainer.accelerator.unwrap_model(trainer.model)

        # Save the question encoder to its subdirectory
        qe_path = os.path.join(final_save_path, "question_encoder")
        unwrapped_model.rag.question_encoder.save_pretrained(qe_path)
        print(f"Question encoder component saved to {qe_path}")

        # Save the generator to its subdirectory
        gen_path = os.path.join(final_save_path, "generator")
        unwrapped_model.rag.generator.save_pretrained(gen_path)
        print(f"Generator component saved to {gen_path}")

        # Save the RagTokenizer
        rag_tokenizer_to_save = RagTokenizer(
            question_encoder=question_encoder_tokenizer,
            generator=generator_tokenizer
        )
        rag_tokenizer_to_save.save_pretrained(final_save_path)
        print(f"RagTokenizer components saved to subfolders in {final_save_path}")

    # Wait for all processes to finish before exiting
    if trainer.is_fsdp_enabled or trainer.args.world_size > 1:
        torch.distributed.barrier()

    print(f"\nTraining and saving complete. Your final, inference-ready model is at: {final_save_path}")

# Training Client

Configure the SDK client by providing the authentication token:

In [None]:
from kubernetes import client
from kubeflow.training import TrainingClient
from kubeflow.training.models import V1Volume, V1VolumeMount, V1PersistentVolumeClaimVolumeSource

api_server = ""
token = ""

configuration = client.Configuration()
configuration.host = api_server
configuration.api_key = {"authorization": f"Bearer {token}"}
# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
configuration.verify_ssl = False
api_client = client.ApiClient(configuration)
client = TrainingClient(client_configuration=api_client.configuration)

# Training Job
You're now almost ready to create the training job:
* Fill the `HF_HOME` environment variable with your HuggingFace token if you fine-tune a gated model
* Check the number of worker nodes
* Amend the resources per worker according to the job requirements
* Update the PVC name to the one you've attached to the workbench if needed

In [None]:
client.create_job(
    job_kind="PyTorchJob",
    name="sft-rag",
    train_func=main,
    num_workers=3,
    num_procs_per_worker="1",
    resources_per_worker={
        "nvidia.com/gpu": 1,
        "memory": "64Gi",
        "cpu": 8,
    },
    base_image="quay.io/modh/training:py311-cuda124-torch251",
    env_vars={
        "CUDA_LAUNCH_BLOCKING": "1",
        "HF_HOME": "/mnt/shared/.cache",
        "HF_TOKEN": "",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
        "PYTORCH_HIP_ALLOC_CONF": "expandable_segments:True",
        "NCCL_DEBUG": "INFO",
    },
    parameters=parameters,
    volumes=[
        V1Volume(name="shared",
                 persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name="shared")),
    ],
    volume_mounts=[
        V1VolumeMount(name="shared", mount_path="/mnt/shared"),
    ],
)

Once the training job is created, you can follow its progress:

In [None]:
client.get_job_logs(
    name="sft-rag",
    job_kind="PyTorchJob",
    follow=True,
)

# TensorBoard

You can track your job runs and visualize the training metrics with TensorBoard:

In [None]:
import os
os.environ["TENSORBOARD_PROXY_URL"]= os.environ["NB_PREFIX"]+"/proxy/6006/"

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir /opt/app-root/src/shared

## Testing the Fine-Tuned RAG Model
After fine-tuning, you can load the entire RAG model and test its performance by providing a question.
The `RagModel` will internally perform retrieval and then generate an answer based on the retrieved context.

In [None]:
# Install / upgrade dependencies
!pip install --upgrade transformers peft accelerate sentence-transformers

In [None]:
import sys
import torch
import os
from transformers import (
    AutoTokenizer,
    RagSequenceForGeneration,
    RagConfig, AutoModelForQuestionAnswering, AutoModel
)
from sentence_transformers import SentenceTransformer

# This is required for `feast_rag_retriever` to be found (This is temporary until Feast RAG Retriever is part of Feast SDK)
CUSTOM_MODULES_PATH = "/opt/app-root/src/distributed-workloads/examples/kfto_feast_rag"
sys.path.append(CUSTOM_MODULES_PATH)
from feast_rag_retriever import FeastRAGRetriever, FeastIndex
from feature_repo.ragproject_repo import wiki_passage_feature_view


STORE_PATH = "/opt/app-root/src/distributed-workloads/examples/kfto-sft-feast-rag/feature_repo"
FINETUNED_RAG_CHECKPOINT_DIR = "/opt/app-root/src/shared/fine_tuned_rag_model/final_model_for_inference"

# --- 1. Initialize Custom Retriever ---
print("Initializing custom FeastRAGRetriever...")

rag_config_inference = RagConfig.from_pretrained(FINETUNED_RAG_CHECKPOINT_DIR)
question_encoder_tokenizer_inference = AutoTokenizer.from_pretrained(os.path.join(FINETUNED_RAG_CHECKPOINT_DIR, 'question_encoder_tokenizer'))
generator_tokenizer_inference = AutoTokenizer.from_pretrained(os.path.join(FINETUNED_RAG_CHECKPOINT_DIR, 'generator_tokenizer'))

if generator_tokenizer_inference.pad_token is None:
    generator_tokenizer_inference.pad_token = generator_tokenizer_inference.eos_token

feast_index_inference = FeastIndex()
query_encoder_model = SentenceTransformer("all-MiniLM-L6-v2")


features_to_retrieve = [
    "wiki_passages:passage_text",
    "wiki_passages:embedding",
    "wiki_passages:passage_id",
]

# Initialize FeastRAGRetriever with the loaded tokenizers and config
rag_retriever_inference = FeastRAGRetriever(
    question_encoder_tokenizer=question_encoder_tokenizer_inference,
    generator_tokenizer=generator_tokenizer_inference,
    question_encoder=None,
    generator_model=None,
    feast_repo_path=STORE_PATH,
    feature_view=wiki_passage_feature_view,
    features=features_to_retrieve,
    search_type="vector",
    config=rag_config_inference,
    index=feast_index_inference,
    query_encoder_model=query_encoder_model,
)
print("FeastRAGRetriever initialized successfully.")


# --- 2. Load Fine-Tuned RAG Model ---
print(f"Loading full RagSequenceForGeneration model from: {FINETUNED_RAG_CHECKPOINT_DIR}")

try:
    finetuned_rag_model = RagSequenceForGeneration.from_pretrained(
        FINETUNED_RAG_CHECKPOINT_DIR,
        retriever=rag_retriever_inference
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    finetuned_rag_model.to(device)
    finetuned_rag_model.eval()
    print(f"Fine-tuned RAG model loaded and moved to {device}. Ready for inference.")
except Exception as e:
    print(f"ERROR: Failed to load RagSequenceForGeneration model. Ensure the directory '{FINETUNED_RAG_CHECKPOINT_DIR}' contains the correct RAG model structure (config.json and subfolders for generator, question_encoder, etc.).")
    raise e

# --- 3. Run Inference ---

test_queries = [
    "Which formations have a lower diversity of documented dinosaurs?",
    "How does photosynthesis work?",
    "When was the Declaration of Independence signed?"
]

for test_query in test_queries:
    print(f"\nQuery: {test_query}")
    try:
        # Tokenize the query using the question encoder's tokenizer
        inputs = question_encoder_tokenizer_inference(
            test_query,
            return_tensors="pt"
        ).to(device)

        # Generate the answer
        generated_ids = finetuned_rag_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            num_beams=5,
            max_new_tokens=200,
        )

        # Decode the answer using the generator's tokenizer
        answer = generator_tokenizer_inference.decode(generated_ids[0], skip_special_tokens=True)
        print(f"Answer: {answer}")

    except Exception as e:
        print(f"Error during inference for query '{test_query}': {e}")
        raise


# Cleaning Up

## Training Job

Once you're done or want to re-create the training job, you can delete the existing one:

In [None]:
client.delete_job(name="sft-rag")

## GPU Memory
If you want to start over and test the pre-trained model again, you can free the GPU / accelerator memory with:

In [None]:
# Unload the model from GPU memory
import gc

del finetuned_rag_model, rag_retriever_inference, rag_config_inference
del question_encoder_tokenizer_inference, generator_tokenizer_inference, feast_index_inference, query_encoder_model

gc.collect()
torch.cuda.empty_cache()