Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions src/examples/ch07.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ impl Example for EG19 {
);

// Print masks and their shapes
let chosen_mask = &collated_item.chosen_mask()[1];
let chosen_mask = &collated_item.chosen_mask().i((1, ..))?;
println!("\nCollated Batch: Masks\n");
println!("Chosen inputs: {:?}", chosen);
println!("Chosen mask: {:?}", chosen_mask);
Expand All @@ -1286,24 +1286,24 @@ impl Example for EG19 {
);

// decode chosen mask
let chosen_masked_text = token_ids_to_text(
chosen.index_select(&collated_item.chosen_mask()[1], 0)?,
&tokenizer,
)?;
println!(
"\nCollated Batch Item 1: Chosen Mask Text\n\n{}\n",
chosen_masked_text
);
// let chosen_masked_text = token_ids_to_text(
// chosen.index_select(chosen_mask.clone(), 0)?,
// &tokenizer,
// )?;
// println!(
// "\nCollated Batch Item 1: Chosen Mask Text\n\n{}\n",
// chosen_masked_text
// );

// decode rejected mask
let rejected_masked_text = token_ids_to_text(
rejected.index_select(&collated_item.rejected_mask()[1], 0)?,
&tokenizer,
)?;
println!(
"\nCollated Batch Item 1: Rejected Mask Text\n\n{}\n",
rejected_masked_text
);
// let rejected_masked_text = token_ids_to_text(
// rejected.index_select(&collated_item.rejected_mask()[1], 0)?,
// &tokenizer,
// )?;
// println!(
// "\nCollated Batch Item 1: Rejected Mask Text\n\n{}\n",
// rejected_masked_text
// );

Ok(())
}
Expand Down
114 changes: 70 additions & 44 deletions src/listings/ch07/bonus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

use super::{
query_model, write_instruction_data_to_json, InstructionExample, InstructionResponseExample,
PromptFormatter, DEFAULT_PAD_TOKEN_ID,
PromptFormatter, DEFAULT_PAD_TOKEN_ID, GPT,
};
use candle_core::{Device, IndexOp, Result, Tensor, D};
use candle_core::{Device, IndexOp, ModuleT, Result, Tensor, D};
use rand::{rngs::StdRng, seq::SliceRandom, thread_rng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, NoneAsEmptyString};
Expand Down Expand Up @@ -268,8 +268,8 @@ pub struct PreferenceDatasetCollatorItem {
prompt: Vec<Tensor>,
chosen: Tensor,
rejected: Tensor,
rejected_mask: Vec<Tensor>,
chosen_mask: Vec<Tensor>,
rejected_mask: Tensor,
chosen_mask: Tensor,
}

impl PreferenceDatasetCollatorItem {
Expand All @@ -281,15 +281,15 @@ impl PreferenceDatasetCollatorItem {
&self.chosen
}

pub fn chosen_mask(&self) -> &Vec<Tensor> {
pub fn chosen_mask(&self) -> &Tensor {
&self.chosen_mask
}

pub fn rejected(&self) -> &Tensor {
&self.rejected
}

pub fn rejected_mask(&self) -> &Vec<Tensor> {
pub fn rejected_mask(&self) -> &Tensor {
&self.rejected_mask
}
}
Expand Down Expand Up @@ -496,26 +496,24 @@ impl PreferenceDataCollator {
rejected_mask = rejected_mask[..std::cmp::min(a, batch_max_length)].to_vec();
}

// get only the indexes of mask that are to keep
let chosen_mask_tensor = self._build_tensor_from_only_true_values(chosen_mask)?;
let rejected_mask_tensor = self._build_tensor_from_only_true_values(rejected_mask)?;

chosen_vec.push(chosen);
rejected_vec.push(rejected);
rejected_mask_vec.push(rejected_mask_tensor);
chosen_mask_vec.push(chosen_mask_tensor);
rejected_mask_vec.push(rejected_mask);
chosen_mask_vec.push(chosen_mask);
prompt_vec.push(prompt_tensor);
}

let chosen_tensor = self._build_stacked_tensor(chosen_vec)?;
let chosen_mask_tensor = self._build_stacked_tensor(chosen_mask_vec)?;
let rejected_tensor = self._build_stacked_tensor(rejected_vec)?;
let rejected_mask_tensor = self._build_stacked_tensor(rejected_mask_vec)?;

Ok(PreferenceDatasetCollatorItem {
prompt: prompt_vec,
chosen: chosen_tensor,
rejected: rejected_tensor,
rejected_mask: rejected_mask_vec,
chosen_mask: chosen_mask_vec,
rejected_mask: rejected_mask_tensor,
chosen_mask: chosen_mask_tensor,
})
}
}
Expand Down Expand Up @@ -638,19 +636,61 @@ pub fn compute_logprobs(
.squeeze(D::Minus1)?;

if let Some(m) = selection_mask {
// this is brittle, if there is a value < 0, then this breaks...
let mask = (m - 1_f64)?;
let selected_log_probs = selected_log_probs.gather(&mask, D::Minus1)?;
let mask = m.i((.., 1..))?.clone();
let mask_sum = mask.sum(D::Minus1)?;

let selected_log_probs = (selected_log_probs * mask)?;

let avg_log_prob = selected_log_probs
.sum(D::Minus1)?
.broadcast_div(&mask.sum(D::Minus1)?)?;
.broadcast_div(&mask_sum)?;
Ok(avg_log_prob)
} else {
selected_log_probs.mean(D::Minus1)
}
}

pub fn compute_dpo_loss_batch<M: GPT + ModuleT>(
batch: &PreferenceDatasetCollatorItem,
policy_model: M,
reference_model: M,
beta: f64,
) -> Result<(Tensor, Tensor, Tensor)> {
// where policy_model(batch["chosen"]) are the logits
let policy_chosen_log_probas = compute_logprobs(
&policy_model.forward_t(batch.chosen(), true)?,
batch.chosen(),
Some(batch.chosen_mask()),
)?;

let policy_rejected_log_probas = compute_logprobs(
&policy_model.forward_t(batch.rejected(), true)?,
batch.rejected(),
Some(batch.rejected_mask()),
)?;

let ref_chosen_log_probas = compute_logprobs(
&reference_model.forward_t(batch.chosen(), false)?,
batch.chosen(),
Some(batch.chosen_mask()),
)?;
let ref_rejected_log_probas = compute_logprobs(
&reference_model.forward_t(batch.rejected(), false)?,
batch.rejected(),
Some(batch.rejected_mask()),
)?;

let (loss, chosen_rewards, rejected_rewards) = compute_dpo_loss(
&policy_chosen_log_probas,
&policy_rejected_log_probas,
&ref_chosen_log_probas,
&ref_rejected_log_probas,
beta,
)?;

Ok((loss, chosen_rewards, rejected_rewards))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -820,12 +860,18 @@ mod tests {

// assert
assert_eq!(
collated_item.chosen_mask()[0].to_vec1::<u32>()?,
&[44, 45, 46, 47, 48, 49, 50, 51, 52, 53,]
collated_item.chosen_mask().i((0, ..))?.to_vec1::<u32>()?,
&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
]
);
assert_eq!(
collated_item.rejected_mask()[0].to_vec1::<u32>()?,
&[44, 45, 46, 47, 48, 49, 50, 51, 52, 53,]
collated_item.rejected_mask().i((0, ..))?.to_vec1::<u32>()?,
&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
]
);
assert_eq!(
collated_item.rejected.elem_count(),
Expand Down Expand Up @@ -857,32 +903,12 @@ mod tests {
while let Some(Ok(collated_item)) = batcher.next() {
assert_eq!(collated_item.chosen.dims()[1], allowed_max_length);
assert_eq!(collated_item.rejected.dims()[1], allowed_max_length);
assert_eq!(collated_item.chosen_mask.dims()[1], allowed_max_length);
assert_eq!(collated_item.rejected_mask.dims()[1], allowed_max_length);
assert!(collated_item.chosen.dims()[0] <= batch_size);
assert!(collated_item.rejected.dims()[0] <= batch_size);
assert!(collated_item.chosen_mask.len() <= batch_size);
assert!(collated_item.rejected_mask.len() <= batch_size);
assert!(collated_item.prompt.len() <= batch_size);

let max_length_chosen_mask = collated_item
.chosen_mask
.iter()
.map(|el| el.elem_count())
.collect::<Vec<_>>()
.into_iter()
.max()
.unwrap();
assert!(max_length_chosen_mask <= allowed_max_length);

let max_length_rejected_mask = collated_item
.rejected_mask
.iter()
.map(|el| el.elem_count())
.collect::<Vec<_>>()
.into_iter()
.max()
.unwrap();
assert!(max_length_rejected_mask <= allowed_max_length);

count += 1;
}
assert_eq!(data_loader.len(), count);
Expand Down