Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/examples/ch07.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1268,23 +1268,43 @@ impl Example for EG19 {

// Decode chosen and print
let rejected = collated_item.rejected().i((1, ..))?;
let rejected_text = token_ids_to_text(rejected, &tokenizer)?;
let rejected_text = token_ids_to_text(rejected.clone(), &tokenizer)?;
println!(
"\nCollated Batch Item 1: Rejected Text\n\n{}\n",
rejected_text
);

// Print masks and their shapes
let chosen_mask = collated_item.chosen_mask().i((1, ..))?;
let chosen_mask = &collated_item.chosen_mask()[1];
println!("\nCollated Batch: Masks\n");
println!("Chosen inputs: {:?}", chosen);
println!("Chosen mask: {:?}", chosen_mask);

println!(
"\nCollated Batch Item 1: Chosen Mask\n\n{:?}\n",
"\nCollated Batch Item 1: Chosen Mask Indexes\n\n{:?}\n",
chosen_mask.to_vec1::<u32>()?
);

// decode chosen mask
let chosen_masked_text = token_ids_to_text(
chosen.index_select(&collated_item.chosen_mask()[1], 0)?,
&tokenizer,
)?;
println!(
"\nCollated Batch Item 1: Chosen Mask Text\n\n{}\n",
chosen_masked_text
);

// decode rejected mask
let rejected_masked_text = token_ids_to_text(
rejected.index_select(&collated_item.rejected_mask()[1], 0)?,
&tokenizer,
)?;
println!(
"\nCollated Batch Item 1: Rejected Mask Text\n\n{}\n",
rejected_masked_text
);

Ok(())
}
}
42 changes: 27 additions & 15 deletions src/listings/ch07/bonus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ pub struct PreferenceDatasetCollatorItem {
prompt: Vec<Tensor>,
chosen: Tensor,
rejected: Tensor,
rejected_mask: Tensor,
chosen_mask: Tensor,
rejected_mask: Vec<Tensor>,
chosen_mask: Vec<Tensor>,
}

impl PreferenceDatasetCollatorItem {
Expand All @@ -277,15 +277,15 @@ impl PreferenceDatasetCollatorItem {
&self.chosen
}

pub fn chosen_mask(&self) -> &Tensor {
pub fn chosen_mask(&self) -> &Vec<Tensor> {
&self.chosen_mask
}

pub fn rejected(&self) -> &Tensor {
&self.rejected
}

pub fn rejected_mask(&self) -> &Tensor {
pub fn rejected_mask(&self) -> &Vec<Tensor> {
&self.rejected_mask
}
}
Expand Down Expand Up @@ -422,7 +422,7 @@ impl PreferenceDataCollator {

// mask vec
let mut mask = (0..batch_max_length as u32)
.map(|j| u32::from(j >= elements_length as u32))
.map(|j| u32::from(j < elements_length as u32))
.collect::<Vec<u32>>();

if self.mask_prompt_tokens {
Expand All @@ -441,6 +441,16 @@ impl PreferenceDataCollator {
)
}

fn _build_tensor_from_only_true_values(&self, mask: Vec<u32>) -> Result<Tensor> {
let reduced_mask = mask
.iter()
.enumerate()
.filter_map(|(ix, el)| (*el > 0).then_some(ix as u32))
.collect::<Vec<_>>();
let shape = reduced_mask.len();
Tensor::from_vec(reduced_mask, shape, &self.device)
}

pub fn custom_collate_fn(
&self,
batch: Vec<EncodedPreferenceExample>,
Expand Down Expand Up @@ -483,24 +493,26 @@ impl PreferenceDataCollator {
rejected_mask = rejected_mask[..std::cmp::min(a, batch_max_length)].to_vec();
}

// get only the indexes of mask that are to keep
let chosen_mask_tensor = self._build_tensor_from_only_true_values(chosen_mask)?;
let rejected_mask_tensor = self._build_tensor_from_only_true_values(rejected_mask)?;

chosen_vec.push(chosen);
chosen_mask_vec.push(chosen_mask);
rejected_vec.push(rejected);
rejected_mask_vec.push(rejected_mask);
rejected_mask_vec.push(rejected_mask_tensor);
chosen_mask_vec.push(chosen_mask_tensor);
prompt_vec.push(prompt_tensor);
}

let chosen_tensor = self._build_stacked_tensor(chosen_vec)?;
let chosen_mask_tensor = self._build_stacked_tensor(chosen_mask_vec)?;
let rejected_tensor = self._build_stacked_tensor(rejected_vec)?;
let rejected_mask_tensor = self._build_stacked_tensor(rejected_mask_vec)?;

Ok(PreferenceDatasetCollatorItem {
prompt: prompt_vec,
chosen: chosen_tensor,
rejected: rejected_tensor,
rejected_mask: rejected_mask_tensor,
chosen_mask: chosen_mask_tensor,
rejected_mask: rejected_mask_vec,
chosen_mask: chosen_mask_vec,
})
}
}
Expand Down Expand Up @@ -683,12 +695,12 @@ mod tests {

// assert
assert_eq!(
collated_item.chosen.elem_count(),
collated_item.chosen_mask.elem_count()
collated_item.chosen_mask()[0].to_vec1::<u32>()?,
&[44, 45, 46, 47, 48, 49, 50, 51, 52, 53,]
);
assert_eq!(
collated_item.rejected.elem_count(),
collated_item.rejected_mask.elem_count()
collated_item.rejected_mask()[0].to_vec1::<u32>()?,
&[44, 45, 46, 47, 48, 49, 50, 51, 52, 53,]
);
assert_eq!(
collated_item.rejected.elem_count(),
Expand Down
Loading