Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#22 from vshampor/sorting_fix
Browse files Browse the repository at this point in the history
Implement repetition penalty and fix sorting in temperature transform
  • Loading branch information
ilya-lavrenov committed May 28, 2024
2 parents 8548e56 + 5f59505 commit 2c2799f
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <cmath>
#include <random>
#include <set>

#include "openvino/runtime/tensor.hpp"

Expand Down Expand Up @@ -226,7 +227,8 @@ class TopPFilter: public IProbabilityFilter {
nucleus_size += 1;
if (probability_sum > m_top_p) break;
}
return std::vector<ProbabilityWithIdx>(tmp.begin(), tmp.begin() + nucleus_size);
tmp.resize(nucleus_size);
return tmp;
}

private:
Expand All @@ -241,7 +243,8 @@ class TopKFilter: public IProbabilityFilter {
std::vector<ProbabilityWithIdx> tmp(input_probs);
std::sort(tmp.begin(), tmp.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; });
size_t top_k = input_probs.size() >= m_top_k ? m_top_k : input_probs.size();
return std::vector<ProbabilityWithIdx>(tmp.begin(), tmp.begin() + top_k);
tmp.resize(top_k);
return tmp;
}

private:
Expand All @@ -256,6 +259,7 @@ class TemperatureLogitTransform {

std::vector<ProbabilityWithIdx> apply(const std::vector<LogitWithIdx>& input_logits) {
std::vector<ProbabilityWithIdx> output(input_logits.begin(), input_logits.end());
std::sort(output.begin(), output.end(), [](const ProbabilityWithIdx& lhs, const ProbabilityWithIdx& rhs) {return lhs.first > rhs.first; });
float max_logit = output[0].first;
std::for_each(output.begin(), output.end(), [max_logit, this](ProbabilityWithIdx& val) {val.first = expf((val.first - max_logit) / this->m_temperature);});

Expand All @@ -272,6 +276,37 @@ class TemperatureLogitTransform {
double m_temperature;
};

class RepetitionPenaltyTransform {
public:
RepetitionPenaltyTransform(double penalty) : m_penalty(penalty) {
OPENVINO_ASSERT(m_penalty >= 0.0f, "repetition penalty must be a positive value");
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const std::set<int64_t>& unique_input_ids) {
std::vector<LogitWithIdx> output(input_logits.begin(), input_logits.end());
size_t vocab_size = input_logits.size();
for (auto input_id : unique_input_ids) {
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
OPENVINO_ASSERT(input_logits[input_id].second == input_id, "input_logits must have original index order");
auto logit_value = output[input_id].first;
if (logit_value >= 0) {
output[input_id].first /= m_penalty;
} else {
output[input_id].first *= m_penalty;
};
}
return output;
}

std::vector<LogitWithIdx> apply(const std::vector<LogitWithIdx>& input_logits, const TokenIds& input_ids) {
std::set<int64_t> unique_input_ids(input_ids.begin(), input_ids.end());
return this->apply(input_logits, unique_input_ids);
}
private:
double m_penalty;
};


class ProbabilityNormalizeTransform {
public:
std::vector<ProbabilityWithIdx> apply(const std::vector<ProbabilityWithIdx>& input_probs) {
Expand All @@ -285,27 +320,25 @@ class ProbabilityNormalizeTransform {

class Sampler {

int64_t _greedy_sample(ov::Tensor logits) const {
std::vector<LogitWithIdx> _get_logit_vector(ov::Tensor logits) {
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
OPENVINO_ASSERT(batch_size == 1);

const float * logits_data = logits.data<const float>() + (seq_len - 1) * vocab_size;
int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data;
return out_token;
}

int64_t _multinomial_sample(ov::Tensor logits, float temperature, float top_p, size_t top_k) {
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
OPENVINO_ASSERT(batch_size == 1);

const float * logits_data = logits.data<const float>() + (seq_len - 1) * vocab_size;
std::vector<LogitWithIdx> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = LogitWithIdx(logits_data[i], i);
}
return logit_vector;
}

int64_t _greedy_sample(const std::vector<LogitWithIdx>& logit_vector) const {
int64_t out_token = std::max_element(logit_vector.begin(), logit_vector.end(), [](const LogitWithIdx& lhs, const LogitWithIdx& rhs) { return lhs.first < rhs.first; }) - logit_vector.begin();
return out_token;
}

int64_t _multinomial_sample(const std::vector<LogitWithIdx>& logit_vector, float temperature, float top_p, size_t top_k) {
auto temperature_transform = TemperatureLogitTransform(temperature);
std::vector<ProbabilityWithIdx> softmax_vector = temperature_transform.apply(logit_vector);

Expand Down Expand Up @@ -367,16 +400,25 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,

if (sequence_group->requires_sampling()) {
if (sampling_params.is_greedy_sampling() || sampling_params.is_multinomial()) {
auto logit_vector = _get_logit_vector(sequence_group_logits); // TODO (vshampor): should be also applicable to beam search, but need to remove the batch size == 1 limitation

if (sampling_params.repetition_penalty != 1.0f) {
auto repetition_penalty_transform = RepetitionPenaltyTransform(sampling_params.repetition_penalty);
logit_vector = repetition_penalty_transform.apply(logit_vector, sequence_group->get_unique_generated_ids());
}
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1);

int64_t sampled_token_id;
if (sampling_params.is_greedy_sampling()) {
sampled_token_id = _greedy_sample(sequence_group_logits);
sampled_token_id = _greedy_sample(logit_vector);
}
else { // .is_multinomial()
sampled_token_id = _multinomial_sample(sequence_group_logits, sampling_params.temperature, sampling_params.top_p, sampling_params.top_k);
sampled_token_id = _multinomial_sample(logit_vector, sampling_params.temperature, sampling_params.top_p, sampling_params.top_k);
}

sequence_group->register_generated_token_id(sampled_token_id);

// in case of greedy search we always have a single parent sequence to sample from
running_sequences[0]->append_token(sampled_token_id, sequence_group_logits.data<const float>()[sampled_token_id]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <vector>
#include <set>
#include <cstdlib>

#include "generation_config.hpp"
Expand Down Expand Up @@ -114,6 +115,7 @@ class SequenceGroup {
GenerationConfig m_sampling_params;
std::size_t m_block_size;
TokenIds m_prompt_ids;
std::set<int64_t> m_unique_generated_ids;

// amount of processed tokens, e.g. prompt can be processed using multiple consequence inferences
// so, we need to track which part of the prompt we have already processed
Expand Down Expand Up @@ -141,6 +143,7 @@ class SequenceGroup {

m_prompt_ids.resize(input_ids.get_size());
std::copy_n(input_ids.data<int64_t>(), input_ids.get_size(), m_prompt_ids.begin());
for (auto id: m_prompt_ids) { m_unique_generated_ids.insert(id); }
}

void add_sequence(const Sequence::Ptr & sequence) {
Expand Down Expand Up @@ -301,6 +304,14 @@ class SequenceGroup {
return m_prompt_ids;
}

const std::set<int64_t>& get_unique_generated_ids() const {
return m_unique_generated_ids;
}

void register_generated_token_id(int64_t token_id) {
m_unique_generated_ids.insert(token_id);
}

size_t get_num_logical_blocks() const {
return (get_context_len() + m_block_size - 1) / m_block_size;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ TEST_P(TemperatureTransformTest, TransformResultEqualToReference) {


const std::vector<TemperatureTransformTestStruct> TEMPERATURE_TRANSFORM_TEST_CASES = {
{1.0f, { {1.0f, 0}, {2.0f, 1}, {3.0f, 2} }, { {0.090031, 0}, {0.244728, 1}, {0.665241, 2} } },
{2.0f, { {1.0f, 2}, {2.0f, 1}, {3.0f, 0} }, { {0.186323, 2}, {0.307195, 1}, {0.506480, 0} } }
{1.0f, { {1.0f, 0}, {2.0f, 1}, {3.0f, 2} }, { {0.665241, 2}, {0.244728, 1}, {0.090031, 0} } },
{2.0f, { {1.0f, 2}, {2.0f, 1}, {3.0f, 0} }, { {0.506480, 0}, {0.307195, 1}, {0.186323, 2} } },
{1.0f, { {3.0f, 0}, {1.0f, 1}, {2.0f, 2} }, { {0.665241, 0}, {0.244728, 2}, {0.090031, 1} } },
};

INSTANTIATE_TEST_SUITE_P(VariousInputs,
Expand Down Expand Up @@ -139,3 +140,62 @@ TEST(TopPFilterInitializationTest, ThrowsForInvalidProbabilities) {
EXPECT_THROW(TopPFilter(-0.5), ov::Exception);
EXPECT_THROW(TopPFilter(1.1), ov::Exception);
}


struct RepetitionPenaltyTransformTestStruct {
float penalty;
std::vector<LogitWithIdx> input_logits;
TokenIds input_ids;
std::vector<LogitWithIdx> expected_output;
};

using RepetitionPenaltyTransformTest = testing::TestWithParam<RepetitionPenaltyTransformTestStruct>;

TEST_P(RepetitionPenaltyTransformTest, TransformResultEqualToReference) {
auto test_struct = GetParam();
auto transform = RepetitionPenaltyTransform(test_struct.penalty);
auto test_result = transform.apply(test_struct.input_logits, test_struct.input_ids);
ASSERT_EQ(test_result.size(), test_struct.expected_output.size());
for (size_t i = 0; i < test_result.size(); i++) {
EXPECT_NEAR(test_result[i].first, test_struct.expected_output[i].first, 1e-6);
EXPECT_EQ(test_result[i].second, test_struct.expected_output[i].second);
}
}


const std::vector<RepetitionPenaltyTransformTestStruct> REPETITION_PENALTY_TRANSFORM_TEST_CASES = {
{ // basic case, indices are applied, order is left as-is
1.2f,
{ {1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 2, 0 },
{ {0.8333333f, 0}, {2.0f, 1}, {2.5f, 2} }
},
{ // negative scores case
2.0f,
{ {-1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 0, 1 },
{ {-2.0f, 0}, {1.0f, 1}, {3.0f, 2} }
},
{ // repeated tokens in prompt, check that the penalty is only applied once
0.5f,
{ {-1.0f, 0}, {2.0f, 1}, {3.0f, 2} },
{ 1, 1 },
{ {-1.0f, 0}, {4.0f, 1}, {3.0f, 2} }
},
};

INSTANTIATE_TEST_SUITE_P(VariousInputs,
RepetitionPenaltyTransformTest,
testing::ValuesIn(REPETITION_PENALTY_TRANSFORM_TEST_CASES));


TEST(RepetitionPenaltyTransformInitializationTest, ThrowsForInvalidPenalties) {
EXPECT_THROW(RepetitionPenaltyTransform(-0.5), ov::Exception);
}

TEST(RepetitionPenaltyTransformInitializationTest, ThrowsForInvalidInputIds) {
auto transform = RepetitionPenaltyTransform(1.5);
EXPECT_THROW(transform.apply({ {43.0f, 0} }, std::set<int64_t>{1337} ), ov::Exception);
EXPECT_THROW(transform.apply({ {18.0f, 0} }, std::set<int64_t>{0, -1} ), ov::Exception);
}

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def get_greedy() -> GenerationConfig:
generation_config.num_return_sequences = 1
return generation_config

def get_greedy_with_repetition_penalty() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.num_return_sequences = 1
generation_config.repetition_penalty = 2.0
return generation_config


def get_beam_search() -> GenerationConfig:
generation_config = GenerationConfig()
Expand Down Expand Up @@ -55,6 +61,13 @@ def get_multinomial_temperature_top_p_and_top_k() -> GenerationConfig:
generation_config.top_k = 2
return generation_config

def get_multinomial_temperature_and_repetition_penalty() -> GenerationConfig:
generation_config = GenerationConfig()
generation_config.do_sample = True
generation_config.temperature = 0.8
generation_config.repetition_penalty = 2.0
return generation_config

def get_test_dataset() -> Tuple[List[str], List[GenerationConfig]]:
prompts = [
"What is OpenVINO?",
Expand Down Expand Up @@ -99,13 +112,13 @@ def convert_to_hf(
# copy default parameters
kwargs['eos_token_id'] = default_generation_config.eos_token_id
kwargs['pad_token_id'] = default_generation_config.pad_token_id
kwargs['repetition_penalty'] = generation_config.repetition_penalty

if generation_config.num_groups * generation_config.group_size > 1:
# beam search case
kwargs['num_beam_groups'] = generation_config.num_groups
kwargs['num_beams'] = generation_config.num_groups * generation_config.group_size
kwargs['diversity_penalty'] = generation_config.diversity_penalty
kwargs['repetition_penalty'] = generation_config.repetition_penalty
kwargs['length_penalty'] = generation_config.length_penalty
kwargs['no_repeat_ngram_size'] = generation_config.no_repeat_ngram_size
kwargs['num_return_sequences'] = generation_config.num_return_sequences
Expand Down Expand Up @@ -201,7 +214,7 @@ def get_model_and_tokenizer(model_id: str, use_optimum = True):
AutoModelForCausalLM.from_pretrained(model_id)
return model, hf_tokenizer

def _generate_and_compare_with_hf(model_id: str, prompts: List[str], generation_configs: List[GenerationConfig], scheduler_config: SchedulerConfig, tmp_path: Path):
def generate_and_compare_with_hf(model_id: str, prompts: List[str], generation_configs: List[GenerationConfig], scheduler_config: SchedulerConfig, tmp_path: Path):
use_optimum = True
model_path : Path = tmp_path / model_id
model, hf_tokenizer = get_model_and_tokenizer(model_id, use_optimum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pytest

from common import run_test_pipeline, get_models_list, get_model_and_tokenizer, save_ov_model_from_optimum, generate_and_compare_with_reference_text, get_greedy, get_beam_search, get_multinomial_temperature, get_multinomial_temperature_and_top_k, get_multinomial_temperature_and_top_p, get_multinomial_temperature_top_p_and_top_k, DEFAULT_SCHEDULER_CONFIG
from common import run_test_pipeline, get_models_list, get_model_and_tokenizer, save_ov_model_from_optimum, generate_and_compare_with_reference_text, get_greedy, get_beam_search, get_multinomial_temperature, get_multinomial_temperature_and_top_k, get_multinomial_temperature_and_top_p, get_multinomial_temperature_top_p_and_top_k, DEFAULT_SCHEDULER_CONFIG, get_greedy_with_repetition_penalty, generate_and_compare_with_hf, get_multinomial_temperature_and_repetition_penalty
from dataclasses import dataclass
from py_continuous_batching import GenerationConfig, GenerationResult
from pathlib import Path
Expand Down Expand Up @@ -84,15 +84,16 @@ def test_eos_greedy(tmp_path):
print(f"Prompt = {prompt}\nHF result = {hf_result}\nOV result = {ov_result}")
compare_results(hf_result, ov_result, generation_config)

@pytest.mark.parametrize("generation_config", [get_greedy(), get_beam_search()],
ids=["greedy", "beam"])
@pytest.mark.precommit
@pytest.mark.parametrize("generation_config", [get_greedy(), get_beam_search(), get_greedy_with_repetition_penalty()],
ids=["greedy", "beam", "greedy_with_repetition_penalty"])
def test_individual_generation_configs_deterministic(tmp_path, generation_config):
prompts = [
"What is OpenVINO?",
]
generation_configs = [generation_config]
model_id : str = "facebook/opt-125m"
_generate_and_compare_with_hf(model_id, prompts, generation_configs, DEFAULT_SCHEDULER_CONFIG, tmp_path)
generate_and_compare_with_hf(model_id, prompts, generation_configs, DEFAULT_SCHEDULER_CONFIG, tmp_path)


@dataclass
Expand All @@ -101,24 +102,28 @@ class RandomSamplingTestStruct:
prompts: List[str]
ref_texts: List[List[str]]

RANDOM_SAMPLING_TEST_CASES = [RandomSamplingTestStruct(generation_config=get_multinomial_temperature(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\n\nOpenVINO is a software development platform developed by OpenVINO, a set of technology companies and startups that enables developers to use the most"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_top_p(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\nOpenVINO is an online application that allows users to create, test, and analyze their own software using a collection of software packages. The application"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_top_k(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\n\nOpenVINO is a software that allows users to create a virtual machine with the ability to create a virtual machine in a virtual environment. Open"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_top_p_and_top_k(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\nOpenVINO is an open source software that allows developers to create, manage, and distribute software. It is an open source project that allows developers"] ]),
]
RANDOM_SAMPLING_TEST_CASES = [
RandomSamplingTestStruct(generation_config=get_multinomial_temperature(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\n\nOpenVINO is a software development platform developed by OpenVINO, a set of technology companies and startups that enables developers to use the most"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_top_p(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\nOpenVINO is an online application that allows users to create, test, and analyze their own software using a collection of software packages. The application"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_top_k(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\n\nOpenVINO is a software that allows users to create a virtual machine with the ability to create a virtual machine in a virtual environment. Open"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_top_p_and_top_k(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\nOpenVINO is an open source software that allows developers to create, manage, and distribute software. It is an open source project that allows developers"] ]),
RandomSamplingTestStruct(generation_config=get_multinomial_temperature_and_repetition_penalty(),
prompts=["What is OpenVINO?"],
ref_texts=[ ["\nOpen Vino's are a new and improved way to find cheap, fast-investment frozen vegetables that have no waste or calories. They're"] ]),
]


@pytest.mark.precommit
@pytest.mark.parametrize("test_struct", RANDOM_SAMPLING_TEST_CASES,
ids=["multinomial_temperature", "multinomial_temperature_and_top_p", "multinomial_temperature_and_top_k", "multinomial_temperature_top_p_and_top_k"])
ids=["multinomial_temperature", "multinomial_temperature_and_top_p", "multinomial_temperature_and_top_k", "multinomial_temperature_top_p_and_top_k", "multinomial_temperature_and_repetition_penalty"])
def test_individual_generation_configs_random(tmp_path, test_struct: RandomSamplingTestStruct):
generation_config = test_struct.generation_config

Expand Down

0 comments on commit 2c2799f

Please sign in to comment.