Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 126 additions & 3 deletions src/listings/ch07/bonus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use std::{path::Path, rc::Rc};
use tiktoken_rs::CoreBPE;
use tqdm::tqdm;

// for convenience we also re-export the following
pub use crate::listings::ch02::DataLoader;
pub use crate::listings::ch05::calc_loss_loader;

#[serde_as]
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct PreferenceExample {
Expand Down Expand Up @@ -363,7 +367,6 @@ where
}
}

#[allow(dead_code)]
#[derive(Clone)]
pub struct PreferenceDataCollator {
pad_token_id: u32,
Expand Down Expand Up @@ -525,11 +528,76 @@ impl CustomCollator for PreferenceDataCollator {
}
}

pub struct PreferenceDataLoader<C: CustomCollator<BatchItem = EncodedPreferenceExample>> {
dataset: PreferenceDataset,
batch_size: usize,
shuffle: bool,
drop_last: bool,
collator: C,
}

impl<C: CustomCollator<BatchItem = EncodedPreferenceExample> + Clone> DataLoader
for PreferenceDataLoader<C>
{
type Batcher = PreferenceDataBatcher<C, IterResult1<PreferenceDatasetIter>>;

/// Returns a `PreferenceDataBatcher` that itself provides batches over the
/// associated dataset.
fn batcher(&self) -> PreferenceDataBatcher<C, IterResult1<PreferenceDatasetIter>> {
let iter = PreferenceDatasetIter::new(self.dataset.clone(), self.shuffle);
PreferenceDataBatcher::new(iter, self.collator.clone())
.batch_size(self.batch_size)
.return_last_incomplete_batch(!self.drop_last)
}
}

impl<C: CustomCollator<BatchItem = EncodedPreferenceExample> + Clone> PreferenceDataLoader<C> {
pub fn new(
dataset: PreferenceDataset,
batch_size: usize,
shuffle: bool,
drop_last: bool,
collator: C,
) -> Self {
Self {
dataset,
batch_size,
shuffle,
drop_last,
collator,
}
}

pub fn len(&self) -> usize {
if self.drop_last {
self.batcher().count()
} else {
// There is a bug in candle_datasets::Batcher, such that if
// return_last_incomplete_batch is set to true, then the iterator
// will never return None. This breaks `Iterator.count()` which consumes
// the iterator until a None is encountered.
let mut batcher = self.batcher();
let mut count = 0_usize;
while let Some(Ok(_el)) = batcher.next() {
count += 1;
}
count
}
}

pub fn is_empty(&self) -> bool {
(self.dataset.len() < self.batch_size) && (self.drop_last)
}

pub fn dataset(&self) -> &PreferenceDataset {
&self.dataset
}
}

#[cfg(test)]
mod tests {
use crate::listings::ch07::AlpacaPromptFormatter;

use super::*;
use crate::listings::ch07::AlpacaPromptFormatter;
use anyhow::Result;
use rstest::*;
use tiktoken_rs::get_bpe_from_model;
Expand Down Expand Up @@ -710,4 +778,59 @@ mod tests {

Ok(())
}

#[rstest]
fn test_preference_data_loader(preference_data: Vec<PreferenceExample>) -> Result<()> {
let tokenizer = get_bpe_from_model("gpt2")?;
let prompt_formatter = AlpacaPromptFormatter;
let preference_dataset =
PreferenceDataset::new(preference_data, &tokenizer, &prompt_formatter);
let batch_size = 2_usize;
let allowed_max_length = 5_usize;
let collator = PreferenceDataCollator::new()
.device(Device::cuda_if_available(0)?)
.allowed_max_length(Some(allowed_max_length));
let shuffle = false;
let drop_last = false;
let data_loader =
PreferenceDataLoader::new(preference_dataset, batch_size, shuffle, drop_last, collator);

let mut batcher = data_loader.batcher();
let mut count = 0_usize;
while let Some(Ok(collated_item)) = batcher.next() {
assert_eq!(collated_item.chosen.dims()[1], allowed_max_length);
assert_eq!(collated_item.rejected.dims()[1], allowed_max_length);
assert!(collated_item.chosen.dims()[0] <= batch_size);
assert!(collated_item.rejected.dims()[0] <= batch_size);
assert!(collated_item.chosen_mask.len() <= batch_size);
assert!(collated_item.rejected_mask.len() <= batch_size);
assert!(collated_item.prompt.len() <= batch_size);

let max_length_chosen_mask = collated_item
.chosen_mask
.iter()
.map(|el| el.elem_count())
.collect::<Vec<_>>()
.into_iter()
.max()
.unwrap();
assert!(max_length_chosen_mask <= allowed_max_length);

let max_length_rejected_mask = collated_item
.rejected_mask
.iter()
.map(|el| el.elem_count())
.collect::<Vec<_>>()
.into_iter()
.max()
.unwrap();
assert!(max_length_rejected_mask <= allowed_max_length);

count += 1;
}
assert_eq!(data_loader.len(), count);
assert!(!data_loader.is_empty());

Ok(())
}
}
Loading