Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ static EXAMPLE_REGISTRY: LazyLock<HashMap<&'static str, Box<dyn Example>>> = Laz
m.insert("07.14", Box::new(examples::ch07::EG14));
m.insert("07.15", Box::new(examples::ch07::EG15));
m.insert("07.16", Box::new(examples::ch07::EG16));
m.insert("07.17", Box::new(examples::ch07::EG17));
// apdx_e
m.insert("E.01", Box::new(examples::apdx_e::EG01));
m.insert("E.02", Box::new(examples::apdx_e::EG02));
Expand Down
57 changes: 57 additions & 0 deletions src/examples/ch07.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,3 +1054,60 @@ impl Example for EG16 {
Ok(())
}
}

/// # [Bonus] Usage of `generate_chosen_and_rejected_response` to create preference example
///
/// #### Id
/// 07.17
///
/// #### Page
/// This example is from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
///
/// #### CLI command
/// ```sh
/// # without cuda
/// cargo run example 07.17
///
/// # with cuda
/// cargo run --features cuda example 07.17
/// ```
pub struct EG17;

impl Example for EG17 {
fn description(&self) -> String {
let desc = "[Bonus from DPO notebook] Usage of \
`generate_chosen_and_rejected_response` to create preference example.";
desc.to_string()
}

fn page_source(&self) -> usize {
0_usize
}

fn main(&self) -> Result<()> {
use crate::listings::ch07::{
bonus::generate_chosen_and_rejected_response, download_and_load_file,
AlpacaPromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL, INSTRUCTION_DATA_FILENAME,
INSTRUCTION_DATA_URL,
};
use std::path::Path;

// load instruction examples
let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;

// invoke generate_chose_and_rejected_response
let model = "llama3";
let prompt_formatter = AlpacaPromptFormatter;
let preference_example = generate_chosen_and_rejected_response(
&data[42],
DEFAULT_OLLAMA_API_URL,
model,
&prompt_formatter,
)?;

println!("{:#?}", preference_example);

Ok(())
}
}
123 changes: 123 additions & 0 deletions src/listings/ch07/bonus.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//! Bonus material module for Chapter 7

use super::{query_model, InstructionResponseExample, PromptFormatter};
use rand::{rngs::StdRng, Rng, SeedableRng};

#[allow(dead_code)]
#[derive(Clone, Debug, Default)]
pub struct PreferenceExample {
instruction: String,
input: Option<String>,
output: String,
rejected: String,
chosen: String,
}

impl From<InstructionResponseExample> for PreferenceExample {
fn from(value: InstructionResponseExample) -> Self {
Self {
instruction: value.instruction().to_owned(),
input: value.input().to_owned(),
output: value.output().to_owned(),
..Default::default()
}
}
}

impl PreferenceExample {
pub fn set_rejected(&mut self, rejected: &str) {
self.rejected = rejected.to_string();
}

pub fn set_chosen(&mut self, chosen: &str) {
self.chosen = chosen.to_string()
}
}

/// Using Ollama to generate a `chosen` and `rejected` responses for an instruction entry
#[allow(unused_variables)]
pub fn generate_chosen_and_rejected_response<P: PromptFormatter>(
entry: &InstructionResponseExample,
url: &str,
model: &str,
prompt_formatter: &P,
) -> anyhow::Result<PreferenceExample> {
let mut rng = StdRng::seed_from_u64(69420);
let u: f32 = rng.gen_range(0.0..1.0);
let politeness = if u < 0.5 { "polite" } else { "impolite" };

let prompt = format!(
"Given the input `{}` and correct output `{}`, \
slightly rewrite the output to be more {}
Keep the modification minimal.
Only return return the generated response and nothing else.",
prompt_formatter.format_input(entry),
entry.output(),
politeness
);

let response = query_model(prompt.as_str(), model, url)?;
let mut preference_example = PreferenceExample::from(entry.clone());

if politeness == "polite" {
preference_example.set_chosen(response.as_str());
preference_example.set_rejected(entry.output().as_str());
} else {
preference_example.set_chosen(entry.output().as_str());
preference_example.set_rejected(response.as_str());
}

Ok(preference_example)
}

#[cfg(test)]
mod tests {
use crate::listings::ch07::AlpacaPromptFormatter;

use super::*;
use anyhow::Result;
use rstest::*;

#[fixture]
fn instruction_example() -> InstructionResponseExample {
let instruction = "Here is a fake instruction.".to_string();
let input = Some("Here is a fake input.".to_string());
let output = "here is a fake output.".to_string();
InstructionResponseExample {
instruction,
input,
output,
model_response: None,
}
}

#[rstest]
fn test_prompt_for_rejection_chosen(
instruction_example: InstructionResponseExample,
) -> Result<()> {
let politeness = "polite";
let prompt_formatter = AlpacaPromptFormatter;
let prompt = format!(
"Given the input `{}` and correct output `{}`, \
slightly rewrite the output to be more {}. \
Keep the modification minimal. \
Only return return the generated response and nothing else.",
prompt_formatter.format_input(&instruction_example),
instruction_example.output(),
politeness
);

let expected = "Given the input `Below is an instruction that \
describes a task. Write a response that appropriately completes the \
request.\n\n\
### Instruction:\n\
Here is a fake instruction.\n\n\
### Input:\n\
Here is a fake input.` and correct output `here is a fake output.`, \
slightly rewrite the output to be more polite. Keep the modification \
minimal. Only return return the generated response and nothing else.";

assert_eq!(prompt, expected);
Ok(())
}
}
3 changes: 3 additions & 0 deletions src/listings/ch07.rs → src/listings/ch07/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ use std::{
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
use tqdm::tqdm;

/// Bonus material
pub mod bonus;

pub const INSTRUCTION_DATA_FILENAME: &str = "instruction_data.json";
pub const DATA_DIR: &str = "data";
pub const INSTRUCTION_DATA_URL: &str = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\
Expand Down
Loading