diff --git a/src/examples/ch07.rs b/src/examples/ch07.rs index 15104da..7c206e7 100644 --- a/src/examples/ch07.rs +++ b/src/examples/ch07.rs @@ -1275,7 +1275,7 @@ impl Example for EG19 { ); // Print masks and their shapes - let chosen_mask = &collated_item.chosen_mask()[1]; + let chosen_mask = &collated_item.chosen_mask().i((1, ..))?; println!("\nCollated Batch: Masks\n"); println!("Chosen inputs: {:?}", chosen); println!("Chosen mask: {:?}", chosen_mask); @@ -1286,24 +1286,24 @@ impl Example for EG19 { ); // decode chosen mask - let chosen_masked_text = token_ids_to_text( - chosen.index_select(&collated_item.chosen_mask()[1], 0)?, - &tokenizer, - )?; - println!( - "\nCollated Batch Item 1: Chosen Mask Text\n\n{}\n", - chosen_masked_text - ); + // let chosen_masked_text = token_ids_to_text( + // chosen.index_select(chosen_mask.clone(), 0)?, + // &tokenizer, + // )?; + // println!( + // "\nCollated Batch Item 1: Chosen Mask Text\n\n{}\n", + // chosen_masked_text + // ); // decode rejected mask - let rejected_masked_text = token_ids_to_text( - rejected.index_select(&collated_item.rejected_mask()[1], 0)?, - &tokenizer, - )?; - println!( - "\nCollated Batch Item 1: Rejected Mask Text\n\n{}\n", - rejected_masked_text - ); + // let rejected_masked_text = token_ids_to_text( + // rejected.index_select(&collated_item.rejected_mask()[1], 0)?, + // &tokenizer, + // )?; + // println!( + // "\nCollated Batch Item 1: Rejected Mask Text\n\n{}\n", + // rejected_masked_text + // ); Ok(()) } diff --git a/src/listings/ch07/bonus.rs b/src/listings/ch07/bonus.rs index b543fc8..5b00357 100644 --- a/src/listings/ch07/bonus.rs +++ b/src/listings/ch07/bonus.rs @@ -2,9 +2,9 @@ use super::{ query_model, write_instruction_data_to_json, InstructionExample, InstructionResponseExample, - PromptFormatter, DEFAULT_PAD_TOKEN_ID, + PromptFormatter, DEFAULT_PAD_TOKEN_ID, GPT, }; -use candle_core::{Device, IndexOp, Result, Tensor, D}; +use candle_core::{Device, IndexOp, ModuleT, Result, Tensor, D}; use rand::{rngs::StdRng, seq::SliceRandom, thread_rng, Rng, SeedableRng}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, NoneAsEmptyString}; @@ -268,8 +268,8 @@ pub struct PreferenceDatasetCollatorItem { prompt: Vec, chosen: Tensor, rejected: Tensor, - rejected_mask: Vec, - chosen_mask: Vec, + rejected_mask: Tensor, + chosen_mask: Tensor, } impl PreferenceDatasetCollatorItem { @@ -281,7 +281,7 @@ impl PreferenceDatasetCollatorItem { &self.chosen } - pub fn chosen_mask(&self) -> &Vec { + pub fn chosen_mask(&self) -> &Tensor { &self.chosen_mask } @@ -289,7 +289,7 @@ impl PreferenceDatasetCollatorItem { &self.rejected } - pub fn rejected_mask(&self) -> &Vec { + pub fn rejected_mask(&self) -> &Tensor { &self.rejected_mask } } @@ -496,26 +496,24 @@ impl PreferenceDataCollator { rejected_mask = rejected_mask[..std::cmp::min(a, batch_max_length)].to_vec(); } - // get only the indexes of mask that are to keep - let chosen_mask_tensor = self._build_tensor_from_only_true_values(chosen_mask)?; - let rejected_mask_tensor = self._build_tensor_from_only_true_values(rejected_mask)?; - chosen_vec.push(chosen); rejected_vec.push(rejected); - rejected_mask_vec.push(rejected_mask_tensor); - chosen_mask_vec.push(chosen_mask_tensor); + rejected_mask_vec.push(rejected_mask); + chosen_mask_vec.push(chosen_mask); prompt_vec.push(prompt_tensor); } let chosen_tensor = self._build_stacked_tensor(chosen_vec)?; + let chosen_mask_tensor = self._build_stacked_tensor(chosen_mask_vec)?; let rejected_tensor = self._build_stacked_tensor(rejected_vec)?; + let rejected_mask_tensor = self._build_stacked_tensor(rejected_mask_vec)?; Ok(PreferenceDatasetCollatorItem { prompt: prompt_vec, chosen: chosen_tensor, rejected: rejected_tensor, - rejected_mask: rejected_mask_vec, - chosen_mask: chosen_mask_vec, + rejected_mask: rejected_mask_tensor, + chosen_mask: chosen_mask_tensor, }) } } @@ -638,19 +636,61 @@ pub fn compute_logprobs( .squeeze(D::Minus1)?; if let Some(m) = selection_mask { - // this is brittle, if there is a value < 0, then this breaks... - let mask = (m - 1_f64)?; - let selected_log_probs = selected_log_probs.gather(&mask, D::Minus1)?; + let mask = m.i((.., 1..))?.clone(); + let mask_sum = mask.sum(D::Minus1)?; + + let selected_log_probs = (selected_log_probs * mask)?; let avg_log_prob = selected_log_probs .sum(D::Minus1)? - .broadcast_div(&mask.sum(D::Minus1)?)?; + .broadcast_div(&mask_sum)?; Ok(avg_log_prob) } else { selected_log_probs.mean(D::Minus1) } } +pub fn compute_dpo_loss_batch( + batch: &PreferenceDatasetCollatorItem, + policy_model: M, + reference_model: M, + beta: f64, +) -> Result<(Tensor, Tensor, Tensor)> { + // where policy_model(batch["chosen"]) are the logits + let policy_chosen_log_probas = compute_logprobs( + &policy_model.forward_t(batch.chosen(), true)?, + batch.chosen(), + Some(batch.chosen_mask()), + )?; + + let policy_rejected_log_probas = compute_logprobs( + &policy_model.forward_t(batch.rejected(), true)?, + batch.rejected(), + Some(batch.rejected_mask()), + )?; + + let ref_chosen_log_probas = compute_logprobs( + &reference_model.forward_t(batch.chosen(), false)?, + batch.chosen(), + Some(batch.chosen_mask()), + )?; + let ref_rejected_log_probas = compute_logprobs( + &reference_model.forward_t(batch.rejected(), false)?, + batch.rejected(), + Some(batch.rejected_mask()), + )?; + + let (loss, chosen_rewards, rejected_rewards) = compute_dpo_loss( + &policy_chosen_log_probas, + &policy_rejected_log_probas, + &ref_chosen_log_probas, + &ref_rejected_log_probas, + beta, + )?; + + Ok((loss, chosen_rewards, rejected_rewards)) +} + #[cfg(test)] mod tests { use super::*; @@ -820,12 +860,18 @@ mod tests { // assert assert_eq!( - collated_item.chosen_mask()[0].to_vec1::()?, - &[44, 45, 46, 47, 48, 49, 50, 51, 52, 53,] + collated_item.chosen_mask().i((0, ..))?.to_vec1::()?, + &[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ] ); assert_eq!( - collated_item.rejected_mask()[0].to_vec1::()?, - &[44, 45, 46, 47, 48, 49, 50, 51, 52, 53,] + collated_item.rejected_mask().i((0, ..))?.to_vec1::()?, + &[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ] ); assert_eq!( collated_item.rejected.elem_count(), @@ -857,32 +903,12 @@ mod tests { while let Some(Ok(collated_item)) = batcher.next() { assert_eq!(collated_item.chosen.dims()[1], allowed_max_length); assert_eq!(collated_item.rejected.dims()[1], allowed_max_length); + assert_eq!(collated_item.chosen_mask.dims()[1], allowed_max_length); + assert_eq!(collated_item.rejected_mask.dims()[1], allowed_max_length); assert!(collated_item.chosen.dims()[0] <= batch_size); assert!(collated_item.rejected.dims()[0] <= batch_size); - assert!(collated_item.chosen_mask.len() <= batch_size); - assert!(collated_item.rejected_mask.len() <= batch_size); assert!(collated_item.prompt.len() <= batch_size); - let max_length_chosen_mask = collated_item - .chosen_mask - .iter() - .map(|el| el.elem_count()) - .collect::>() - .into_iter() - .max() - .unwrap(); - assert!(max_length_chosen_mask <= allowed_max_length); - - let max_length_rejected_mask = collated_item - .rejected_mask - .iter() - .map(|el| el.elem_count()) - .collect::>() - .into_iter() - .max() - .unwrap(); - assert!(max_length_rejected_mask <= allowed_max_length); - count += 1; } assert_eq!(data_loader.len(), count);