Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ static EXAMPLE_REGISTRY: LazyLock<HashMap<&'static str, Box<dyn Example>>> = Laz
m.insert("07.15", Box::new(examples::ch07::EG15));
m.insert("07.16", Box::new(examples::ch07::EG16));
m.insert("07.17", Box::new(examples::ch07::EG17));
m.insert("07.18", Box::new(examples::ch07::EG18));
// apdx_e
m.insert("E.01", Box::new(examples::apdx_e::EG01));
m.insert("E.02", Box::new(examples::apdx_e::EG02));
Expand Down
68 changes: 63 additions & 5 deletions src/examples/ch07.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,16 +958,17 @@ impl Example for EG15 {

fn main(&self) -> Result<()> {
use crate::listings::ch07::{
load_instruction_data_from_json, query_model, AlpacaPromptFormatter, PromptFormatter,
DATA_DIR, DEFAULT_OLLAMA_API_URL,
load_instruction_data_from_json, query_model, AlpacaPromptFormatter,
InstructionResponseExample, PromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL,
};
use std::path::Path;

// load test instruction data with response
let file_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
let test_data = load_instruction_data_from_json(file_path).with_context(|| {
"Missing 'instruction_data_with_response.json' file. Please run EG 07.12."
})?;
let test_data: Vec<InstructionResponseExample> = load_instruction_data_from_json(file_path)
.with_context(|| {
"Missing 'instruction_data_with_response.json' file. Please run EG 07.12."
})?;

let model = "llama3";
let prompt_formatter = AlpacaPromptFormatter;
Expand Down Expand Up @@ -1111,3 +1112,60 @@ impl Example for EG17 {
Ok(())
}
}

/// # Example usage of `generate_preference_dataset`
///
/// #### Id
/// 07.18
///
/// #### Page
/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
///
/// #### CLI command
/// ```sh
/// # without cuda
/// cargo run example 07.18
///
/// # with cuda
/// cargo run --features cuda example 07.18
/// ```
pub struct EG18;

impl Example for EG18 {
fn description(&self) -> String {
"Example usage of `generate_preference_dataset`.".to_string()
}

fn page_source(&self) -> usize {
0_usize
}

fn main(&self) -> Result<()> {
use crate::listings::{
ch07::bonus::generate_preference_dataset,
ch07::{
download_and_load_file, AlpacaPromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL,
INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
},
};
use std::path::Path;

// load instruction examples
let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;

// invoke generate_chose_and_rejected_response
let model = "llama3";
let prompt_formatter = AlpacaPromptFormatter;
let save_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
generate_preference_dataset(
&data,
DEFAULT_OLLAMA_API_URL,
model,
&prompt_formatter,
save_path,
)?;

Ok(())
}
}
39 changes: 35 additions & 4 deletions src/listings/ch07/bonus.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
//! Bonus material module for Chapter 7

use super::{query_model, InstructionResponseExample, PromptFormatter};
use super::{
query_model, write_instruction_data_to_json, InstructionResponseExample, PromptFormatter,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, NoneAsEmptyString};
use std::path::Path;
use tqdm::tqdm;

#[allow(dead_code)]
#[derive(Clone, Debug, Default)]
#[serde_as]
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct PreferenceExample {
instruction: String,
#[serde_as(as = "NoneAsEmptyString")]
input: Option<String>,
output: String,
rejected: String,
Expand Down Expand Up @@ -35,7 +42,6 @@ impl PreferenceExample {
}

/// Using Ollama to generate a `chosen` and `rejected` responses for an instruction entry
#[allow(unused_variables)]
pub fn generate_chosen_and_rejected_response<P: PromptFormatter>(
entry: &InstructionResponseExample,
url: &str,
Expand Down Expand Up @@ -70,6 +76,31 @@ pub fn generate_chosen_and_rejected_response<P: PromptFormatter>(
Ok(preference_example)
}

/// Create a preference dataset from an instruction dataset and Ollama
pub fn generate_preference_dataset<P: PromptFormatter, T: AsRef<Path>>(
instruction_data: &[InstructionResponseExample],
url: &str,
model: &str,
prompt_formatter: &P,
save_path: T,
) -> anyhow::Result<()> {
let mut dataset = vec![];
for entry in tqdm(instruction_data.iter()) {
let preference_example =
generate_chosen_and_rejected_response(entry, url, model, prompt_formatter)?;
dataset.push(preference_example);
}

// write to json
println!(
"Saving preference data to {:?}",
save_path.as_ref().to_str()
);
write_instruction_data_to_json(&dataset, save_path)?;

Ok(())
}

#[cfg(test)]
mod tests {
use crate::listings::ch07::AlpacaPromptFormatter;
Expand Down
10 changes: 5 additions & 5 deletions src/listings/ch07/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,18 +674,18 @@ pub use crate::listings::ch02::DataLoader;
pub use crate::listings::ch05::calc_loss_loader;

/// Helper function to write instruction data to a json
pub fn load_instruction_data_from_json<P: AsRef<Path>>(
pub fn load_instruction_data_from_json<P: AsRef<Path>, S: Serialize + for<'a> Deserialize<'a>>(
file_path: P,
) -> anyhow::Result<Vec<InstructionResponseExample>> {
) -> anyhow::Result<Vec<S>> {
let json_str = read_to_string(file_path.as_ref())
.with_context(|| format!("Unable to read {}", file_path.as_ref().display()))?;
let data: Vec<InstructionResponseExample> = serde_json::from_str(&json_str[..])?;
let data: Vec<S> = serde_json::from_str(&json_str[..])?;
Ok(data)
}

/// Helper function to write instruction data to a json
pub fn write_instruction_data_to_json<P: AsRef<Path>>(
instruction_data: &Vec<InstructionResponseExample>,
pub fn write_instruction_data_to_json<P: AsRef<Path>, S: Serialize + for<'a> Deserialize<'a>>(
instruction_data: &Vec<S>,
save_path: P,
) -> anyhow::Result<()> {
let file = File::create(save_path)?;
Expand Down
Loading