Skip to content

Commit

Permalink
Merge pull request #119 from gegelati/Patch_reset_learning_environnement
Browse files Browse the repository at this point in the history
Add iteration number and generation number to reset function
  • Loading branch information
kdesnos committed Jan 11, 2024
2 parents a8f30b9 + 3c16213 commit 9b4092f
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 20 deletions.
13 changes: 13 additions & 0 deletions Changelog.md
@@ -1,5 +1,18 @@
# GEGELATI Changelog

## Release version x.y.z
_2024.01.10_

### New features

### Changes
* Add two parameters to the reset method of the learning environment. These parameters are used for environments that use specific initialization.
* Parameter `iterationNumber`: an integer indicating the current iteration number when the `nbIterationsPerPolicyEvaluation` parameter is greater than 1, default value = 0.
* Parameter `generationNumber`: an integer indicating the current generation number, default value = 0.

### Bug fix


## Release version 1.3.1 - Donanatella flavor with extra sprinkles
_2023.12.14_

Expand Down
7 changes: 4 additions & 3 deletions gegelatilib/include/learn/classificationLearningEnvironment.h
Expand Up @@ -109,9 +109,10 @@ namespace Learn {
*
* Resets to zero the classificationTable.
*/
virtual void reset(
size_t seed = 0,
LearningMode mode = LearningMode::TRAINING) override = 0;
virtual void reset(size_t seed = 0,
LearningMode mode = LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override = 0;
};
}; // namespace Learn

Expand Down
8 changes: 7 additions & 1 deletion gegelatilib/include/learn/learningEnvironment.h
Expand Up @@ -164,9 +164,15 @@ namespace Learn {
* the LearningEnvironment.
* \param[in] mode LearningMode in which the Environment should be
* reset for the next set of actions.
* \param[in] iterationNumber the integer value to indicate the current
* iteration number when parameter nbIterationsPerPolicyEvaluation > 1
* \param[in] generationNumber the integer value to indicate the
* current generation number
*/
virtual void reset(size_t seed = 0,
LearningMode mode = LearningMode::TRAINING) = 0;
LearningMode mode = LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) = 0;

/**
* \brief Get the data sources for this LearningEnvironment.
Expand Down
4 changes: 3 additions & 1 deletion gegelatilib/src/learn/classificationLearningEnvironment.cpp
Expand Up @@ -95,7 +95,9 @@ double Learn::ClassificationLearningEnvironment::getScore() const
}

void Learn::ClassificationLearningEnvironment::reset(size_t seed,
LearningMode mode)
LearningMode mode,
uint16_t iterationNumber,
uint64_t generationNumber)
{
// reset scores to 0 in classification table
for (std::vector<uint64_t>& perClass : this->classificationTable) {
Expand Down
8 changes: 5 additions & 3 deletions gegelatilib/src/learn/learningAgent.cpp
Expand Up @@ -127,13 +127,15 @@ std::shared_ptr<Learn::EvaluationResult> Learn::LearningAgent::evaluateJob(
double result = 0.0;

// Evaluate nbIteration times
for (auto i = 0; i < this->params.nbIterationsPerPolicyEvaluation; i++) {
for (auto iterationNumber = 0;
iterationNumber < this->params.nbIterationsPerPolicyEvaluation;
iterationNumber++) {
// Compute a Hash
Data::Hash<uint64_t> hasher;
uint64_t hash = hasher(generationNumber) ^ hasher(i);
uint64_t hash = hasher(generationNumber) ^ hasher(iterationNumber);

// Reset the learning Environment
le.reset(hash, mode);
le.reset(hash, mode, iterationNumber, generationNumber);

uint64_t nbActions = 0;
while (!le.isTerminal() &&
Expand Down
6 changes: 4 additions & 2 deletions test/learn/fakeAdversarialLearningEnvironment.h
Expand Up @@ -68,8 +68,10 @@ class FakeAdversarialLearningEnvironment
nbTurns++;
}

void reset(size_t seed = 0, Learn::LearningMode mode =
Learn::LearningMode::TRAINING) override{
void reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override{
// we just ignore the reset
};
std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
4 changes: 3 additions & 1 deletion test/learn/fakeClassificationLearningEnvironment.h
Expand Up @@ -63,7 +63,9 @@ class FakeClassificationLearningEnvironment
this->currentClass = value % 3;
data.setDataAt(typeid(int), 0, value);
}
void reset(size_t seed, Learn::LearningMode mode) override
void reset(size_t seed, Learn::LearningMode mode,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override
{
// Call super pure virtual method to reset the pure virtual method.
ClassificationLearningEnvironment::reset(seed, mode);
Expand Down
4 changes: 3 additions & 1 deletion test/learn/stickGameAdversarial.cpp
Expand Up @@ -87,7 +87,9 @@ void StickGameAdversarial::doAction(uint64_t actionID)
}
}

void StickGameAdversarial::reset(size_t seed, Learn::LearningMode mode)
void StickGameAdversarial::reset(size_t seed, Learn::LearningMode mode,
uint16_t iterationNumber,
uint64_t generationNumber)
{
// Create seed from seed and mode
size_t hash_seed =
Expand Down
7 changes: 4 additions & 3 deletions test/learn/stickGameAdversarial.h
Expand Up @@ -98,9 +98,10 @@ class StickGameAdversarial : public Learn::AdversarialLearningEnvironment
virtual void doAction(uint64_t actionID) override;

// Inherited via LearningEnvironment
virtual void reset(
size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
virtual void reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override;

// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
4 changes: 3 additions & 1 deletion test/learn/stickGameWithOpponent.cpp
Expand Up @@ -84,7 +84,9 @@ void StickGameWithOpponent::doAction(uint64_t actionID)
}
}

void StickGameWithOpponent::reset(size_t seed, Learn::LearningMode mode)
void StickGameWithOpponent::reset(size_t seed, Learn::LearningMode mode,
uint16_t iterationNumber,
uint64_t generationNumber)
{
// Create seed from seed and mode
size_t hash_seed =
Expand Down
7 changes: 4 additions & 3 deletions test/learn/stickGameWithOpponent.h
Expand Up @@ -92,9 +92,10 @@ class StickGameWithOpponent : public Learn::LearningEnvironment
virtual void doAction(uint64_t actionID) override;

// Inherited via LearningEnvironment
virtual void reset(
size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
virtual void reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override;

// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
3 changes: 2 additions & 1 deletion test/learningEnvironmentTest.cpp
Expand Up @@ -56,7 +56,8 @@ class FakeLearningEnvironment : public Learn::LearningEnvironment

public:
FakeLearningEnvironment() : LearningEnvironment(2), data(3){};
void reset(size_t seed, Learn::LearningMode mode){};
void reset(size_t seed, Learn::LearningMode mode, uint16_t iterationNumber,
uint64_t generationNumber){};
std::vector<std::reference_wrapper<const Data::DataHandler>>
getDataSources()
{
Expand Down

0 comments on commit 9b4092f

Please sign in to comment.