# Setup

In [None]:
# Install the YAML magic
%pip install yamlmagic
%load_ext yamlmagic

# Training Configuration
Edit the following training hyperparameters and can be modified to experiment with learning rates, batch\ sizes, LoRA parameters:

In [None]:
%%yaml parameters

# Model
model_name_or_path: facebook/bart-large    # only works with Ses2Seq (Encoder-Decoder) models like BART and T5 since transformers RAG only support them for now.
model_revision: main
torch_dtype: bfloat16
attn_implementation: eager                # one of eager (default), sdpa or flash_attention_2
use_liger: false                          # use Liger kernels

# PEFT / LoRA (Apply to Generator Model)
use_peft: false
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_target_modules: ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]   # Ensure these match your generator model
lora_modules_to_save: []

# QLoRA (BitsAndBytes) (Apply to Generator Model)
load_in_4bit: false                       # use 4 bit precision for the base model (only with LoRA)
load_in_8bit: false                       # use 8 bit precision for the base model (only with LoRA)

# Dataset
dataset_name: facebook/wiki_dpr
dataset_config: main                      # name of the dataset configuration
dataset_train_split: train                # dataset split to use for training (for RAG generated data)
dataset_test_split: test                  # dataset split to use for evaluation (for RAG generated data)
dataset_kwargs:
    add_special_tokens: false               # template with special tokens
    append_concat_token: false              # add additional separator token

# SFT (These parameters will now apply to the RagModel's training)
max_seq_length: 1024                      # max sequence length for model and packing of the dataset
dataset_batch_size: 1000                  # samples to tokenize per batch (for initial data processing)
packing: false                            # Packing is generally not used directly with RagModel training in the same way as SFT

# Training
num_train_epochs: 3                       # number of training epochs
remove_unused_columns: false
label_smoothing_factor: 0.1                # 0.1, 0.0(disable)

per_device_train_batch_size: 1            # Batch size per device during training
per_device_eval_batch_size: 1             # Batch size for evaluation
auto_find_batch_size: false               # find a batch size that fits into memory automatically
eval_strategy: epoch                      # evaluate every epoch

bf16: true                                # use bf16 16-bit (mixed) precision
tf32: true                               # use tf32 precision

learning_rate: 4.0e-6                     # 4.0e-6 Initial learning rate for RAG model training
warmup_steps: 200                         # steps for a linear warmup from 0 to `learning_rate`
lr_scheduler_type: cosine                 # learning rate scheduler (see transformers.SchedulerType)

optim: adamw_torch_fused                  # optimizer (see transformers.OptimizerNames)
max_grad_norm: 1.0                        # max gradient norm
seed: 42

gradient_accumulation_steps: 8            # Increase for smaller per_device_train_batch_size
gradient_checkpointing: false             # use gradient checkpointing to save memory
gradient_checkpointing_kwargs:
    use_reentrant: false

# FSDP
fsdp: "full_shard auto_wrap"              # add offload if not enough GPU memory
fsdp_config:
    activation_checkpointing: true
    cpu_ram_efficient_loading: false
    sync_module_states: true
    use_orig_params: true
    limit_all_gathers: false


# fsdp_transformer_layer_cls_to_wrap: [BertLayer, BartEncoderLayer, BartDecoderLayer]

# Checkpointing
save_strategy: epoch                      # save checkpoint every epoch
save_total_limit: 1                       # limit the total amount of checkpoints
resume_from_checkpoint: true             # load the last checkpoint in output_dir and resume from it

# Logging
log_level: warning                        # logging level (see transformers.logging)
logging_strategy: steps
logging_steps: 1                          # log every N steps
report_to:
- tensorboard                             # report metrics to tensorboard

output_dir: /mnt/shared/fine_tuned_rag_model

# Feast Setup with Milvus


### Install Required Dependencies 

In [None]:
%%bash
pip install --quiet feast[milvus] sentence-transformers datasets
pip install bigtree==0.19.2
pip install marshmallow==3.10.0
pip install feast

## Loading Wikipedia Dataset
We only load a subset of the dataset in the interest of keeping this example runnable with minimum memory and storage.

In [None]:
from datasets import load_dataset
# load wikipedia dataset - 5% of the training split
dataset = load_dataset(
    "facebook/wiki_dpr",
    "psgs_w100.nq.exact",
    split="train[:5%]",
    with_index=False,
    trust_remote_code=True
)

## Chunking Wikipedia Dataset
The dataset is chunked to contain a preset number of chars, which is the max supported by Feast. Ensuring the chunk only contains whole words, thus the retrieved context can form sentences without incomplete words.

In [None]:
def chunk_dataset(examples, max_chars=380):
    all_chunks = []
    all_ids = []
    all_titles = []

    for i, text in enumerate(examples['text']): # Iterate over texts in the batch
        words = text.split()
        if not words:
            continue

        current_chunk_words = []
        for word in words:
            # Check if adding the next word exceeds the character limit
            if len(' '.join(current_chunk_words + [word])) > max_chars:
                # If the current chunk is valid, save it
                if current_chunk_words:
                    chunk_text = ' '.join(current_chunk_words)
                    all_chunks.append(chunk_text)
                    all_ids.append(f"{examples['id'][i]}_{len(all_chunks)}") # Unique ID for the chunk
                    all_titles.append(examples['title'][i])
                # Start a new chunk with the current word
                current_chunk_words = [word]
            else:
                current_chunk_words.append(word)

        # Add the last remaining chunk
        if current_chunk_words:
            chunk_text = ' '.join(current_chunk_words)
            all_chunks.append(chunk_text)
            all_ids.append(f"{examples['id'][i]}_{len(all_chunks)}") # Unique ID for the chunk
            all_titles.append(examples['title'][i])

    return {'id': all_ids, 'title': all_titles, 'text': all_chunks}


chunked_dataset = dataset.map(
    chunk_dataset,
    batched=True,
    remove_columns=dataset.column_names,
    num_proc=1
)

## Create DPR Embeddings
We load a pre-trained Dense Passage Retrieval (DPR) encoder to generate context embeddings for each chunked passage. These embeddings will later be stored in the Feast feature store and queried during retrieval.

In [None]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import numpy as np
from tqdm import tqdm

# Load DPR Context Encoder model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"

embedding_model_name = "facebook/dpr-ctx_encoder-single-nq-base"
tokenizer = DPRContextEncoderTokenizer.from_pretrained(embedding_model_name)
model = DPRContextEncoder.from_pretrained(embedding_model_name).to(device)

sentences = chunked_dataset["text"]

print(f"Generating DPR embeddings for {len(sentences)} documents...")
all_embeddings = []
with torch.no_grad():
    for i in tqdm(range(0, len(sentences), 16)): # Process in batches
        batch_texts = sentences[i:i+16]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        embeddings = model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        all_embeddings.append(embeddings.to(dtype=torch.float32).cpu().numpy())

embeddings = np.vstack(all_embeddings)
print(f"Embeddings generated with shape: {embeddings.shape}")
print(f"Saving generated embeddings and chunked sentences to file...")
np.save("/opt/app-root/src/shared/synthetic_data_cache/embed_data/embeddings.npy", embeddings)
with open("/opt/app-root/src/shared/synthetic_data_cache/embed_data/sentences.txt", "w") as f:
    for sentence in sentences:
        f.write(f"{sentence}\n")

print("saved")

## Create Parquet File for Feast Offline Store
Create a parquet file using the DPR embeddings created previously.

In [None]:
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
from datetime import datetime, timezone

embeddings = np.load("/opt/app-root/src/shared/synthetic_data_cache/embed_data/embeddings.npy")
with open("/opt/app-root/src/shared/synthetic_data_cache/embed_data/sentences.txt", "r") as f:
    sentences = [line.strip() for line in f]

batch_size = 256

# Prepare first batch to initialize schema
first_batch_sentences = sentences[:batch_size]
first_batch_embeddings = embeddings[:batch_size]

first_batch_df = pd.DataFrame({
    "passage_id": list(range(batch_size)),
    "passage_text": first_batch_sentences,
    "embedding": pd.Series([embedding.tolist() for embedding in first_batch_embeddings], dtype=object),
    "event_timestamp": [datetime.now(timezone.utc)] * len(first_batch_sentences)
})

print("DataFrame Info:")
print(first_batch_df.head())
print(first_batch_df["embedding"].apply(lambda x: len(x) if isinstance(x, list) else str(type(x))).value_counts())

# Initialize Parquet writer with correct schema
pqwriter = pq.ParquetWriter('feature_repo/wiki_dpr_1perct.parquet', pa.Table.from_pandas(first_batch_df).schema)

# Write first batch
pqwriter.write_table(pa.Table.from_pandas(first_batch_df))

# Continue writing remaining batches
for i in range(batch_size, len(sentences), batch_size):
    batch_sentences = sentences[i:i+batch_size]
    batch_embeddings = embeddings[i:i+batch_size]

    batch_df = pd.DataFrame({
        "passage_id": list(range(i, i + len(batch_sentences))),
        "passage_text": batch_sentences,
        "embedding": pd.Series([embedding.tolist() for embedding in batch_embeddings]),
        "event_timestamp": [datetime.now(timezone.utc)] * len(batch_sentences)
    })

    pqwriter.write_table(pa.Table.from_pandas(batch_df))
    print(f"Wrote {i + len(batch_sentences)} / {len(sentences)} documents...")

pqwriter.close()
print("Saved to wiki_dpr.parquet")

In [None]:
%cd feature_repo

### Apply Feast Feature Repository

In [None]:
!feast apply

## Writing to Feast Online Store (Milvus)
We load the Parquet file into Milvus via Feast, which will serve as the online store for efficient similarity search during retrieval.

In [None]:
import pyarrow.parquet as pq
from feast import FeatureStore
from pymilvus import MilvusException

store = FeatureStore(repo_path=".")
parquet_file = pq.ParquetFile("./wiki_dpr.parquet")
batch_size = 10000

for batch_num, batch in enumerate(parquet_file.iter_batches(batch_size=batch_size), 1):
    batch_df = batch.to_pandas()
    try:
        print(f"Writing batch {batch_num}...")
        store.write_to_online_store(feature_view_name='wiki_passages', df=batch_df)
        print(f"Batch {batch_num} written successfully.")
    except MilvusException as e:
        print(f"Skipping write of batch {batch_num} due to : {e}")

print("All data written to online store.")


# Preprocessing Natural Questions Dataset for RAG Training
We prepare the Natural Questions dataset for RAG model training. The Dataset contains Questions and Answers. The dataset is quite big, and again in the interest of keepig this example runnable, we are using a subset of training & validation samples. 

In [None]:
from datasets import load_dataset, DatasetDict, Dataset, load_from_disk
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DPRQuestionEncoderTokenizer,
    DPRQuestionEncoder,
    RagConfig
)

print("Preparing training data (Q&A pairs) with caching...")
processed_data_cache_dir = "dataset/rag_3k_bart_intersection_dataset"
Path(processed_data_cache_dir).mkdir(parents=True, exist_ok=True)

# Load question encoder tokenizer and model
question_encoder_model_name_or_path = "facebook/dpr-question_encoder-single-nq-base"
question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_encoder_model_name_or_path)
question_encoder = DPRQuestionEncoder.from_pretrained(
    question_encoder_model_name_or_path,
    trust_remote_code=True,
    torch_dtype="bfloat16"
)

# Load generator tokenizer
generator_model_name_or_path = "facebook/bart-large"
generator_tokenizer = AutoTokenizer.from_pretrained(
    generator_model_name_or_path,
    trust_remote_code=True,
    use_fast=True
)

generator_model = AutoModelForSeq2SeqLM.from_pretrained(
    generator_model_name_or_path,
    trust_remote_code=True,
    torch_dtype="bfloat16"
)

import sys
# Initialize FeastRAGRetriever for filtering answerable questions
CUSTOM_MODULES_PATH = "/opt/app-root/src/distributed-workloads/examples/kfto-sft-feast-rag"
sys.path.append(CUSTOM_MODULES_PATH)

# Import custom retriever and feature view definitions
from feast_rag_retriever import FeastRAGRetriever, FeastIndex
from feature_repo.ragproject_repo import wiki_passage_feature_view

# Define paths for Feast store and fine-tuned model checkpoint
STORE_PATH = "/opt/app-root/src/distributed-workloads/examples/kfto-sft-feast-rag/feature_repo"

# Initialize components for question filtering
store_path = "feature_repo"

# Initialize FeastRAGRetriever
feast_index = FeastIndex()
rag_top_k = 10

question_encoder_config = {
    "model_type": "dpr",
    "hidden_size": 768,
    "vocab_size": question_encoder_tokenizer.vocab_size,
    "num_hidden_layers": 6,
    "num_attention_heads": 12,
    "projection_dim": 0,
    "torch_dtype": "bfloat16",
}

rag_config = RagConfig(
    question_encoder=question_encoder_config,
    generator=generator_model.config.to_dict(),
    index_name="custom",
    index={"index_name": "feast_dummy_index", "custom_type": "FeastIndex"},
    n_docs=rag_top_k,
)

features_to_retrieve = [
    "wiki_passages:passage_text",
    "wiki_passages:embedding",
    "wiki_passages:passage_id",
]

# Create RAG retriever for filtering
rag_retriever = FeastRAGRetriever(
    question_encoder_tokenizer=question_encoder_tokenizer,
    generator_tokenizer=generator_tokenizer,
    question_encoder=question_encoder,
    generator_model=generator_model,
    feast_repo_path=STORE_PATH,
    feature_view=wiki_passage_feature_view,
    features=features_to_retrieve,
    search_type="vector",
    config=rag_config,
    index=feast_index,
)

print("FeastRAGRetriever initialized for filtering.")

# Filter function to retain only examples with valid short answers
def has_valid_answer(answer):
    question_text = example["question"]["text"]
    if not answer["annotations"]["short_answers"]:
        return False

    for short_ans_dict in answer["annotations"]["short_answers"]:
        if "text" in short_ans_dict and short_ans_dict["text"] and short_ans_dict["text"][0].strip():

            expected_answer = short_ans_dict["text"][0]
            return rag_retriever.is_question_answerable(
                question_text=question_text,
                expected_answer=expected_answer,
                question_encoder=question_encoder,
                tokenizer=question_encoder_tokenizer,
                top_k=10,
                similarity_threshold=0.6,
                max_answer_length=50,
                min_answer_length=1
            )
    return False

# Preprocessing function to tokenize question and answer for RAG input
def preprocess_nq_example(example):
    question_text = example["question"]["text"]
    answer_text = ""

    # Select the first available short answer
    if example["annotations"]["short_answers"]:
        for short_ans_dict in example["annotations"]["short_answers"]:
            if "text" in short_ans_dict and short_ans_dict["text"] and short_ans_dict["text"][0].strip():
                answer_text = short_ans_dict["text"][0]

    # Tokenize question for the RAG model's question encoder
    tokenized_question = question_encoder_tokenizer(
        question_text,
        truncation=True,
        max_length=32,
        padding="max_length",
    )

    # Tokenize answer for the RAG model's generator
    tokenized_answer_for_labels = generator_tokenizer(
        text_target=answer_text,
        truncation=True,
        max_length=32,
        padding="max_length",
    )

    return {
        "input_ids": tokenized_question["input_ids"],
        "attention_mask": tokenized_question["attention_mask"],
        "labels": tokenized_answer_for_labels["input_ids"],
    }


# nq_dataset_name = "natural_questions"
nq_dataset_name = "google-research-datasets/natural_questions"
nq_train_split = "train"
nq_test_split = "test"

# Check if preprocessed datasets are cached
if (Path(processed_data_cache_dir) / nq_train_split).exists() and \
        (Path(processed_data_cache_dir) / nq_test_split).exists():
    print(f"Loading preprocessed data from cache: {processed_data_cache_dir}")
    loaded_processed_datasets = load_from_disk(processed_data_cache_dir)
    train_dataset = loaded_processed_datasets[nq_train_split]
    test_dataset = loaded_processed_datasets[nq_test_split]

else:
    print("Loading raw Natural Questions dataset (streaming) and preprocessing...")

    num_train_samples = 3000  # Target number of valid training samples
    num_eval_samples = 300    # Target number of valid evaluation samples

    raw_train_stream = load_dataset(nq_dataset_name, "default", split=nq_train_split, streaming=True)
    raw_eval_stream = load_dataset(nq_dataset_name, "default", split="validation", streaming=True)

    # Stream and filter training data on the fly until target count is reached
    temp_raw_train_list = []

    for example in raw_train_stream:
        if has_valid_answer(example):
            temp_raw_train_list.append(example)
            if len(temp_raw_train_list) >= num_train_samples:
                break

    # Stream and filter evaluation data on the fly until target count is reached
    temp_raw_eval_list = []
    for example in raw_eval_stream:
        if has_valid_answer(example):
            temp_raw_eval_list.append(example)
            if len(temp_raw_eval_list) >= num_eval_samples:
                break


    raw_datasets_dict = DatasetDict({
        nq_train_split: Dataset.from_list(temp_raw_train_list),
        nq_test_split: Dataset.from_list(temp_raw_eval_list)
    })

    print("Applying preprocessing to filtered samples...")
    processed_datasets = raw_datasets_dict.map(
        preprocess_nq_example,
        remove_columns=raw_datasets_dict[nq_train_split].column_names,
        num_proc=1
    )

    train_dataset = processed_datasets[nq_train_split]
    test_dataset = processed_datasets[nq_test_split]

    print(f"Saving preprocessed datasets to cache: {processed_data_cache_dir}")
    processed_datasets.save_to_disk(processed_data_cache_dir)
    print("Preprocessed data saved to cache.")

# Dataset summary
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
print(f"Train dataset columns: {train_dataset.column_names}")
print(f"Sample from train dataset: {train_dataset[0]}")

### Dataset Sanity Check
Randomly inspect samples from the preprocessed dataset to ensure questions and answers are correctly tokenized and properly structured.

In [None]:
import random

for index in random.sample(range(len(train_dataset)), min(3, len(train_dataset))):
    print(f"\n  Processed Sample {index}  ")
    print(f"Question: {train_dataset[index].keys()}")
    print(f"Decoded Question: {question_encoder_tokenizer.decode(train_dataset[index]['input_ids'], skip_special_tokens=True)}")
    print(f"Decoded Labels (Answer): {generator_tokenizer.decode(train_dataset[index]['labels'], skip_special_tokens=True)}")
    print(f"Raw Input IDs (Question): {train_dataset[index]['input_ids'][:20]}...")
    print(f"Raw Labels (Answer): {train_dataset[index]['labels'][:20]}...")

### Preparing Training Assets for Distributed Training

Copy over the `kfto-sft-feast-rag` folder into the shared storage, so that it can be accessed by the training job, since the FeastRagRetriever requires the `feature_repo` directory during training and inference. Copy over the `feast_rag_retriever.py` script as well, so that it can be imported in the training job.

In [None]:
%cp -r $HOME/distributed-workloads/examples/kfto-sft-feast-rag $HOME/shared/

# Fine-tuning the RAG Model (Training Loop)
We initiate the fine-tuning process for the RAG model using the Seq2SeqTrainer. This step integrates the retriever, encoder, generator, and custom datasets to jointly optimize the generation task.

In [None]:
def main(parameters):
    import subprocess, sys
    # Install necessary packages
    # This ensures Feast and other dependencies are available inside the container
    print("Installing Feast and other RAG dependencies...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "feast[milvus]", "sentence-transformers", "datasets", "bigtree==0.19.2", "marshmallow==3.10.0", "protobuf>=3.10.0", "git+https://github.com/feast-dev/feast.git@master"])
    print("Feast and other RAG dependencies installed.")
    from pathlib import Path
    import os
    from datasets import load_from_disk
    from transformers import (
        set_seed,
        RagSequenceForGeneration,
        GenerationConfig,
        Seq2SeqTrainingArguments,
        Seq2SeqTrainer,
        RagTokenizer,
        default_data_collator,
        RagConfig,
        AutoTokenizer,
        AutoModelForSeq2SeqLM,
        DPRQuestionEncoderTokenizer,
        DPRQuestionEncoder
    )
    from peft import get_peft_model, prepare_model_for_kbit_training
    import torch

    from trl import (
        ModelConfig,
        ScriptArguments,
        SFTConfig,
        TrlParser,
        get_peft_config,
    )

    # This is required for `feast_rag_retriever` to be found (for debugging purposes)
    CUSTOM_MODULES_PATH = "/mnt/shared/kfto-sft-feast-rag"
    sys.path.append(CUSTOM_MODULES_PATH)
    print(f"Added {CUSTOM_MODULES_PATH} to sys.path for custom module imports.")
    # from feast_rag_retriever import FeastRAGRetriever, FeastIndex
    from feast.rag_retriever import FeastRAGRetriever, FeastIndex
    from feature_repo.ragproject_repo import wiki_passage_feature_view

    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args_trl, model_args = parser.parse_dict(parameters)

    # Convert SFTConfig parameters to standard TrainingArguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=training_args_trl.output_dir,
        num_train_epochs=training_args_trl.num_train_epochs,
        per_device_train_batch_size=training_args_trl.per_device_train_batch_size,
        per_device_eval_batch_size=training_args_trl.per_device_eval_batch_size,
        gradient_accumulation_steps=training_args_trl.gradient_accumulation_steps,
        gradient_checkpointing=training_args_trl.gradient_checkpointing,
        learning_rate=training_args_trl.learning_rate,
        warmup_steps=training_args_trl.warmup_steps,
        lr_scheduler_type=training_args_trl.lr_scheduler_type,
        optim=training_args_trl.optim,
        max_grad_norm=training_args_trl.max_grad_norm,
        seed=training_args_trl.seed,
        bf16=training_args_trl.bf16,
        tf32=training_args_trl.tf32,
        eval_strategy=training_args_trl.eval_strategy,
        save_strategy=training_args_trl.save_strategy,
        save_total_limit=training_args_trl.save_total_limit,
        logging_strategy=training_args_trl.logging_strategy,
        logging_steps=training_args_trl.logging_steps,
        report_to=training_args_trl.report_to,
        fsdp=training_args_trl.fsdp if hasattr(training_args_trl, 'fsdp') else None,
        fsdp_config=training_args_trl.fsdp_config if hasattr(training_args_trl, 'fsdp_config') else None,
        resume_from_checkpoint=training_args_trl.resume_from_checkpoint,
        remove_unused_columns=training_args_trl.remove_unused_columns,
        predict_with_generate= True,
    )

    set_seed(training_args.seed)

    # Question Encoder
    question_encoder_model_name_or_path = "facebook/dpr-question_encoder-single-nq-base"
    question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_encoder_model_name_or_path)
    question_encoder_model = DPRQuestionEncoder.from_pretrained(
        question_encoder_model_name_or_path,
        trust_remote_code=model_args.trust_remote_code,
        torch_dtype=model_args.torch_dtype
    )

    # Generator Model to be fine-tuned
    generator_model_name_or_path = model_args.model_name_or_path
    generator_tokenizer = AutoTokenizer.from_pretrained(
        generator_model_name_or_path,
        trust_remote_code=model_args.trust_remote_code,
        use_fast=True
    )
    generator_model = AutoModelForSeq2SeqLM.from_pretrained(
        generator_model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=model_args.torch_dtype,
        use_cache=not training_args.gradient_checkpointing,
    )

    if model_args.use_peft:
        print(f"PEFT/LoRA: Ensure 'lora_target_modules' in your parameters ({parameters.get('lora_target_modules')}) are correct for the new generator model '{generator_model_name_or_path}'.")
        generator_model = prepare_model_for_kbit_training(generator_model)
        peft_config = get_peft_config(model_args)
        generator_model = get_peft_model(generator_model, peft_config)
        print("PEFT setup for generator model completed.")

    store_path = CUSTOM_MODULES_PATH+"/feature_repo"

    # Initialize FeastRAGRetriever
    feast_index = FeastIndex()
    rag_top_k = 10

    question_encoder_config = {
        "model_type": "dpr",
        "hidden_size": 768,
        "vocab_size": question_encoder_tokenizer.vocab_size,
        "num_hidden_layers": 6,
        "num_attention_heads": 12,
        "projection_dim": 0,
        "torch_dtype": model_args.torch_dtype
    }

    rag_config = RagConfig(
        question_encoder=question_encoder_config,
        generator=generator_model.config.to_dict(),
        index_name="custom",
        index={"index_name": "feast_dummy_index", "custom_type": "FeastIndex"},
        n_docs=rag_top_k,
    )

    features_to_retrieve = [
        "wiki_passages:passage_text",
        "wiki_passages:embedding",
        "wiki_passages:passage_id",
    ]

    rag_retriever = FeastRAGRetriever(
        question_encoder_tokenizer=question_encoder_tokenizer,
        generator_tokenizer=generator_tokenizer,
        feast_repo_path=store_path,
        feature_view=wiki_passage_feature_view,
        features=features_to_retrieve,
        search_type="vector",
        config=rag_config,
        index=feast_index,
    )

    # Initialize the RagModel for fine-tuning
    model = RagSequenceForGeneration(
        question_encoder=question_encoder_model,
        config=rag_config,
        generator=generator_model,   # model being fine-tuned
        retriever=rag_retriever
    )
    generator_config = GenerationConfig(
        max_length=128,
        num_beams=1,
        do_sample=False,
        length_penalty=1.0,
    )
    model.generation_config = generator_config
    print("RAG Model initialized.")

    # Explicitly move model to GPU
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

    # Load preprocessed NQ dataset
    print("Preparing training data (Q&A pairs) with caching...")
    processed_data_cache_dir = "/mnt/shared/kfto-sft-feast-rag/dataset/rag_3k_bart_intersection_dataset"
    Path(processed_data_cache_dir).mkdir(parents=True, exist_ok=True)

    nq_train_split = "train"
    nq_validation_split = "test"
    # Check if the preprocessed dataset is already cached
    if (Path(processed_data_cache_dir) / nq_train_split).exists() and \
            (Path(processed_data_cache_dir) / nq_validation_split).exists():
        print(f"Loading preprocessed data from cache: {processed_data_cache_dir}")
        loaded_processed_datasets = load_from_disk(processed_data_cache_dir)
        train_dataset = loaded_processed_datasets[nq_train_split]
        test_dataset = loaded_processed_datasets[nq_validation_split]
    else:
        print("ERROR DATASET NOT FOUND")
        sys.exit(1)

    # Training the RagModel
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=generator_tokenizer,
        data_collator=default_data_collator,
    )

    if trainer.accelerator.is_main_process:
        if hasattr(trainer.model.generator, "print_trainable_parameters"):
            print("Trainable parameters for PEFT-enabled generator:")
            trainer.model.generator.print_trainable_parameters()
        else:
            print("Trainer model does not have 'print_trainable_parameters' method on its generator.")

    # RagSequenceForGeneration model trains encoder and generator jointly by default, uncomment below to freeze encoder model
    # unwrapped_model = trainer.accelerator.unwrap_model(trainer.model)
    # unwrapped_model.rag.question_encoder.requires_grad_(False)
    # for param in unwrapped_model.rag.question_encoder.parameters():
    #     param.requires_grad = False

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint

    # Ensure RAG model forward pass works before training
    print("Test forward pass prior to training")
    example_query_for_test = "What is the primary language spoken in Brazil?"
    example_answer_for_test = "Portuguese"
    print(f"Query: {example_query_for_test}")
    question_tok = question_encoder_tokenizer(example_query_for_test, return_tensors="pt")
    label_tok = generator_tokenizer(text_target=example_answer_for_test, return_tensors="pt")

    test_input = {
        "input_ids": question_tok["input_ids"].to('cuda'),
        "attention_mask": question_tok["attention_mask"].to('cuda'),
        "labels": label_tok["input_ids"].to('cuda')
    }
    try:
        with torch.no_grad():
            model(**test_input)
        print("Test forward pass successful")
    except Exception as e:
        print(f"Forward pass failed: {e}")
        raise

    print("Starting RAG model training...")
    trainer.train(resume_from_checkpoint=checkpoint)
    print("RAG model training completed.")


    # Save the Fine-tuned RAG Model
    final_save_path = os.path.join(training_args.output_dir, "inference_" + os.getenv("HOSTNAME", "rag_model"))

    print(f"Saving main model weights and RagConfig to: {final_save_path}")
    trainer.save_model(final_save_path)

    if trainer.args.process_index == 0:
        print("Manually saving component models and tokenizers to subdirectories...")

        # Unwrap the model to get access to its components
        unwrapped_model = trainer.accelerator.unwrap_model(trainer.model)

        # Save the question encoder to its subdirectory
        qe_path = os.path.join(final_save_path, "question_encoder")
        unwrapped_model.rag.question_encoder.save_pretrained(qe_path)
        print(f"Question encoder component saved to {qe_path}")

        # Save the generator to its subdirectory
        gen_path = os.path.join(final_save_path, "generator")
        unwrapped_model.rag.generator.save_pretrained(gen_path)
        print(f"Generator component saved to {gen_path}")

        # Save the RagTokenizer
        rag_tokenizer_to_save = RagTokenizer(
            question_encoder=question_encoder_tokenizer,
            generator=generator_tokenizer
        )
        rag_tokenizer_to_save.save_pretrained(final_save_path)
        print(f"RagTokenizer components saved to subfolders in {final_save_path}")

    # Wait for all processes to finish before exiting
    if trainer.is_fsdp_enabled or trainer.args.world_size > 1:
        torch.distributed.barrier()

    print(f"\nTraining and saving complete. Your final, inference-ready model is at: {final_save_path}")

# Kubeflow Training Client
Configure the Kubeflow SDK client by providing the required credentials to connect to a cluster.

In [None]:
from kubernetes import client
from kubeflow.training import TrainingClient
from kubeflow.training.models import V1Volume, V1VolumeMount, V1PersistentVolumeClaimVolumeSource

api_server = ""
token = ""

configuration = client.Configuration()
configuration.host = api_server
configuration.api_key = {"authorization": f"Bearer {token}"}
# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
configuration.verify_ssl = False
api_client = client.ApiClient(configuration)
client = TrainingClient(client_configuration=api_client.configuration)

# Training Job
You're now almost ready to create the training job:
* Fill the `HF_TOKEN` environment variable with your HuggingFace token if you fine-tune a gated model
* Check the number of worker nodes
* Amend the resources per worker according to the job requirements
* If you use AMD accelerators:
  * Change `nvidia.com/gpu` to `amd.com/gpu` in `resources_per_worker`
  * Change `base_image` to `quay.io/modh/training:py311-rocm62-torch251`
* Update the PVC name to the one you've attached to the workbench if needed
* Fill in the queue name if Kueue is enabled in the cluster, otherwise remove the label

In [None]:
client.delete_job(name="sft-rag")

In [None]:
client.create_job(
    job_kind="PyTorchJob",
    name="sft-rag",
    train_func=main,
    labels={
        "kueue.x-k8s.io/queue-name" : "",
    },
    num_workers=3,
    num_procs_per_worker="1",
    resources_per_worker={
        "nvidia.com/gpu": 1,
        "memory": "64Gi",
        "cpu": 8,
    },
    base_image="quay.io/modh/training:py311-cuda124-torch251",
    env_vars={
        "CUDA_LAUNCH_BLOCKING": "1",
        "HF_HOME": "/mnt/shared/.cache",
        "HF_TOKEN": "",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
        "PYTORCH_HIP_ALLOC_CONF": "expandable_segments:True",
        "NCCL_DEBUG": "INFO",
        "TRANSFORMERS_VERBOSITY": "info",
    },
    parameters=parameters,
    volumes=[
        V1Volume(name="shared",
                 persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name="shared")),
    ],
    volume_mounts=[
        V1VolumeMount(name="shared", mount_path="/mnt/shared"),
    ],
)

llOnce the training job is created, you can follow its progress:

In [None]:
client.get_job_logs(
    name="sft-rag",
    job_kind="PyTorchJob",
    follow=True,
)

# TensorBoard Setup
You can track your job runs and visualize the training metrics with TensorBoard. Enable TensorBoard logging for real-time visualization of training metrics, such as loss curves and ROUGE scores, to better monitor model learning behavior.

In [None]:
import os
os.environ["TENSORBOARD_PROXY_URL"]= os.environ["NB_PREFIX"]+"/proxy/6006/"

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir /opt/app-root/src/shared

## Inference: Testing the Fine-Tuned RAG Model
After fine-tuning, we load the fine-tuned RAG model and run inference on test queries.
The `RagModel` will internally perform retrieval and then generate an answer based on the retrieved context.

In [None]:
import sys
import torch
from transformers import (
    RagSequenceForGeneration,
    RagConfig,
    RagTokenizer,
    DPRQuestionEncoder,
)

# Add feature store module path
CUSTOM_MODULES_PATH = "/opt/app-root/src/distributed-workloads/examples/kfto-sft-feast-rag"
sys.path.append(CUSTOM_MODULES_PATH)

# Import custom retriever and feature view definitions
from feast.rag_retriever import FeastRAGRetriever, FeastIndex
from feature_repo.ragproject_repo import wiki_passage_feature_view

# Define paths for Feast store and fine-tuned model path
STORE_PATH = "/opt/app-root/src/distributed-workloads/examples/kfto-sft-feast-rag/feature_repo"
FINETUNED_RAG_CHECKPOINT_DIR = "/opt/app-root/src/shared/fine_tuned_rag_model/inference_sft-rag-master-0"

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load RAG configuration from fine-tuned model path
rag_config_inference = RagConfig.from_pretrained(FINETUNED_RAG_CHECKPOINT_DIR)

# Load tokenizers from model save path
rag_tokenizer = RagTokenizer.from_pretrained(FINETUNED_RAG_CHECKPOINT_DIR)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer_inference = rag_tokenizer.generator

# Load the original (frozen) question encoder
question_encoder = DPRQuestionEncoder.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base",
    trust_remote_code=True,
    torch_dtype="bfloat16"
).to(device)
question_encoder.eval()

# Initialize Feast retriever with feature definitions
print("Initializing custom FeastRAGRetriever...")
feast_index_inference = FeastIndex()

features_to_retrieve = [
    "wiki_passages:passage_text",
    "wiki_passages:embedding",
    "wiki_passages:passage_id",
]

# Build retriever
rag_retriever_inference = FeastRAGRetriever(
    question_encoder_tokenizer=question_encoder_tokenizer,
    generator_tokenizer=generator_tokenizer_inference,
    question_encoder=None,
    generator_model=None,
    feast_repo_path=STORE_PATH,
    feature_view=wiki_passage_feature_view,
    features=features_to_retrieve,
    search_type="vector",
    config=rag_config_inference,
    index=feast_index_inference,
)

# Load fine-tuned RAG model and attach custom retriever
print(f"Loading full RagSequenceForGeneration model from: {FINETUNED_RAG_CHECKPOINT_DIR}")
finetuned_rag_model = RagSequenceForGeneration.from_pretrained(
    FINETUNED_RAG_CHECKPOINT_DIR,
    retriever=rag_retriever_inference,
)

# Inject the question encoder into the fine-tuned RAG model
finetuned_rag_model.rag.question_encoder = question_encoder
finetuned_rag_model.to(device)
finetuned_rag_model.eval()

# Set the question encoder inside the retriever
rag_retriever_inference.question_encoder = finetuned_rag_model.rag.question_encoder

print(f"Fine-tuned RAG model loaded and moved to {device}. Ready for inference.")

# Example queries to test inference
test_queries = [
    "What is the boiling point of water in Celsius?",
    "What is the capital city of Australia?",
    "Who invented the telephone?",
    "Who wrote 'Alice's Adventures in Wonderland'?",
    "Who painted the Mona Lisa?",
    "What is the function of the kidneys in the human body?",
]

# Inference loop for test queries
for test_query in test_queries:
    print(f"\nQuery: {test_query}")
    try:
        # Tokenize the query
        inputs = question_encoder_tokenizer(
            test_query,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=32
        ).to(device)

        # Generate answer using the RAG model
        with torch.no_grad():
            generated_ids = finetuned_rag_model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=200,
            )

        # Decode generated token IDs to text
        answer = generator_tokenizer_inference.decode(generated_ids[0], skip_special_tokens=True)
        print(f"Answer: {answer}")

    except Exception as e:
        print(f"Error during inference for query '{test_query}': {e}")
        raise


# Cleaning Up

## Delete Training Job
Once you're done or want to re-create the training job, you can delete the existing one:

In [None]:
client.delete_job(name="sft-rag")

## GPU Memory
If you want to start over and test the pre-trained model again, you can free the GPU / accelerator memory with:

In [None]:
# Unload the model from GPU memory
import gc

del finetuned_rag_model, rag_retriever_inference, rag_config_inference
del generator_tokenizer_inference, question_encoder_tokenizer, feast_index_inference, question_encoder

gc.collect()
torch.cuda.empty_cache()