Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iteration number and generation number to reset function #119

Merged
merged 5 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# GEGELATI Changelog

## Release version x.y.z
_2024.01.10_

### New features

### Changes
* Add two parameters to the reset method of the learning environment. These parameters are used for environments that use specific initialization.
* Parameter `iterationNumber`: an integer indicating the current iteration number when the `nbIterationsPerPolicyEvaluation` parameter is greater than 1, default value = 0.
* Parameter `generationNumber`: an integer indicating the current generation number, default value = 0.

### Bug fix


## Release version 1.3.1 - Donanatella flavor with extra sprinkles
_2023.12.14_

Expand Down
7 changes: 4 additions & 3 deletions gegelatilib/include/learn/classificationLearningEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ namespace Learn {
*
* Resets to zero the classificationTable.
*/
virtual void reset(
size_t seed = 0,
LearningMode mode = LearningMode::TRAINING) override = 0;
virtual void reset(size_t seed = 0,
LearningMode mode = LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override = 0;
};
}; // namespace Learn

Expand Down
8 changes: 7 additions & 1 deletion gegelatilib/include/learn/learningEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,15 @@ namespace Learn {
* the LearningEnvironment.
* \param[in] mode LearningMode in which the Environment should be
* reset for the next set of actions.
* \param[in] iterationNumber the integer value to indicate the current
* iteration number when parameter nbIterationsPerPolicyEvaluation > 1
* \param[in] generationNumber the integer value to indicate the
* current generation number
*/
virtual void reset(size_t seed = 0,
LearningMode mode = LearningMode::TRAINING) = 0;
LearningMode mode = LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) = 0;

/**
* \brief Get the data sources for this LearningEnvironment.
Expand Down
4 changes: 3 additions & 1 deletion gegelatilib/src/learn/classificationLearningEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ double Learn::ClassificationLearningEnvironment::getScore() const
}

void Learn::ClassificationLearningEnvironment::reset(size_t seed,
LearningMode mode)
LearningMode mode,
uint16_t iterationNumber,
uint64_t generationNumber)
{
// reset scores to 0 in classification table
for (std::vector<uint64_t>& perClass : this->classificationTable) {
Expand Down
8 changes: 5 additions & 3 deletions gegelatilib/src/learn/learningAgent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,15 @@ std::shared_ptr<Learn::EvaluationResult> Learn::LearningAgent::evaluateJob(
double result = 0.0;

// Evaluate nbIteration times
for (auto i = 0; i < this->params.nbIterationsPerPolicyEvaluation; i++) {
for (auto iterationNumber = 0;
iterationNumber < this->params.nbIterationsPerPolicyEvaluation;
iterationNumber++) {
// Compute a Hash
Data::Hash<uint64_t> hasher;
uint64_t hash = hasher(generationNumber) ^ hasher(i);
uint64_t hash = hasher(generationNumber) ^ hasher(iterationNumber);

// Reset the learning Environment
le.reset(hash, mode);
le.reset(hash, mode, iterationNumber, generationNumber);

uint64_t nbActions = 0;
while (!le.isTerminal() &&
Expand Down
6 changes: 4 additions & 2 deletions test/learn/fakeAdversarialLearningEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ class FakeAdversarialLearningEnvironment
nbTurns++;
}

void reset(size_t seed = 0, Learn::LearningMode mode =
Learn::LearningMode::TRAINING) override{
void reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override{
// we just ignore the reset
};
std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
4 changes: 3 additions & 1 deletion test/learn/fakeClassificationLearningEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ class FakeClassificationLearningEnvironment
this->currentClass = value % 3;
data.setDataAt(typeid(int), 0, value);
}
void reset(size_t seed, Learn::LearningMode mode) override
void reset(size_t seed, Learn::LearningMode mode,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override
{
// Call super pure virtual method to reset the pure virtual method.
ClassificationLearningEnvironment::reset(seed, mode);
Expand Down
4 changes: 3 additions & 1 deletion test/learn/stickGameAdversarial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ void StickGameAdversarial::doAction(uint64_t actionID)
}
}

void StickGameAdversarial::reset(size_t seed, Learn::LearningMode mode)
void StickGameAdversarial::reset(size_t seed, Learn::LearningMode mode,
uint16_t iterationNumber,
uint64_t generationNumber)
{
// Create seed from seed and mode
size_t hash_seed =
Expand Down
7 changes: 4 additions & 3 deletions test/learn/stickGameAdversarial.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class StickGameAdversarial : public Learn::AdversarialLearningEnvironment
virtual void doAction(uint64_t actionID) override;

// Inherited via LearningEnvironment
virtual void reset(
size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
virtual void reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override;

// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
4 changes: 3 additions & 1 deletion test/learn/stickGameWithOpponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ void StickGameWithOpponent::doAction(uint64_t actionID)
}
}

void StickGameWithOpponent::reset(size_t seed, Learn::LearningMode mode)
void StickGameWithOpponent::reset(size_t seed, Learn::LearningMode mode,
uint16_t iterationNumber,
uint64_t generationNumber)
{
// Create seed from seed and mode
size_t hash_seed =
Expand Down
7 changes: 4 additions & 3 deletions test/learn/stickGameWithOpponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ class StickGameWithOpponent : public Learn::LearningEnvironment
virtual void doAction(uint64_t actionID) override;

// Inherited via LearningEnvironment
virtual void reset(
size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
virtual void reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override;

// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
3 changes: 2 additions & 1 deletion test/learningEnvironmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class FakeLearningEnvironment : public Learn::LearningEnvironment

public:
FakeLearningEnvironment() : LearningEnvironment(2), data(3){};
void reset(size_t seed, Learn::LearningMode mode){};
void reset(size_t seed, Learn::LearningMode mode, uint16_t iterationNumber,
uint64_t generationNumber){};
std::vector<std::reference_wrapper<const Data::DataHandler>>
getDataSources()
{
Expand Down