@@ -52,6 +52,16 @@ struct Result {
5252 ocs2::size_array_t preEventModeTrajectory;
5353};
5454
55+ std::size_t modeAtTime (const ocs2::ModeSchedule& modeSchedule, ocs2::scalar_t time, bool isFinalTime) {
56+ if (isFinalTime) {
57+ auto timeItr = std::upper_bound (modeSchedule.eventTimes .begin (), modeSchedule.eventTimes .end (), time);
58+ int modeIndex = std::distance (modeSchedule.eventTimes .begin (), timeItr);
59+ return modeSchedule.modeSequence [modeIndex];
60+ } else {
61+ return modeSchedule.modeAtTime (time);
62+ }
63+ }
64+
5565class TrajectorySpreadingTest : public testing ::Test {
5666 protected:
5767 static constexpr size_t STATE_DIM = 2 ;
@@ -99,8 +109,9 @@ class TrajectorySpreadingTest : public testing::Test {
99109 [&out](size_t eventIndex) { return out.stateTrajectory [eventIndex - 1 ]; });
100110
101111 out.modeTrajectory .resize (out.timeTrajectory .size ());
102- std::transform (out.timeTrajectory .begin (), out.timeTrajectory .end (), out.modeTrajectory .begin (),
103- [&modeSchedule](ocs2::scalar_t time) { return modeSchedule.modeAtTime (time); });
112+ for (int i = 0 ; i < out.modeTrajectory .size (); i++) {
113+ out.modeTrajectory [i] = modeAtTime (modeSchedule, out.timeTrajectory [i], i == out.modeTrajectory .size () - 1 ? true : false );
114+ }
104115
105116 out.preEventModeTrajectory .resize (out.postEventsIndeces .size ());
106117 std::transform (out.postEventsIndeces .begin (), out.postEventsIndeces .end (), out.preEventModeTrajectory .begin (),
@@ -194,17 +205,22 @@ class TrajectorySpreadingTest : public testing::Test {
194205 const ocs2::scalar_t finalTime = spreadResult.timeTrajectory .back ();
195206
196207 const auto startEventItr = std::upper_bound (updatedModeSchedule.eventTimes .begin (), updatedModeSchedule.eventTimes .end (), initTime);
197- // If a time point is aligned with(the same as) an event time, it belongs to the pre-mode.
198- const auto endEventItr = std::lower_bound (updatedModeSchedule.eventTimes .begin (), updatedModeSchedule.eventTimes .end (), finalTime);
199-
200- ocs2::size_array_t postEventIndice (endEventItr - startEventItr);
201- std::transform (startEventItr, endEventItr, postEventIndice.begin (), [&spreadResult](ocs2::scalar_t time) {
202- auto timeItr = std::upper_bound (spreadResult.timeTrajectory .begin (), spreadResult.timeTrajectory .end (), time);
203- return std::distance (spreadResult.timeTrajectory .begin (), timeItr);
204- });
208+ // If the final time is aligned with(the same as) an event time, it belongs to the post-mode.
209+ const auto endEventItr = std::upper_bound (updatedModeSchedule.eventTimes .begin (), updatedModeSchedule.eventTimes .end (), finalTime);
210+
211+ ocs2::size_array_t postEventIndices (endEventItr - startEventItr);
212+ for (auto itr = startEventItr; itr != endEventItr; itr++) {
213+ int index = std::distance (startEventItr, itr);
214+ if (itr == endEventItr - 1 && *itr == spreadResult.timeTrajectory .back ()) {
215+ postEventIndices[index] = spreadResult.timeTrajectory .size () - 1 ;
216+ } else {
217+ auto timeItr = std::upper_bound (spreadResult.timeTrajectory .begin (), spreadResult.timeTrajectory .end (), *itr);
218+ postEventIndices[index] = std::distance (spreadResult.timeTrajectory .begin (), timeItr);
219+ }
220+ }
205221
206- EXPECT_TRUE (spreadResult.postEventsIndeces .size () == postEventIndice .size ());
207- EXPECT_TRUE (std::equal (postEventIndice .begin (), postEventIndice .end (), spreadResult.postEventsIndeces .begin ()));
222+ EXPECT_TRUE (spreadResult.postEventsIndeces .size () == postEventIndices .size ());
223+ EXPECT_TRUE (std::equal (postEventIndices .begin (), postEventIndices .end (), spreadResult.postEventsIndeces .begin ()));
208224
209225 } else {
210226 EXPECT_EQ (spreadResult.postEventsIndeces .size (), 0 );
@@ -220,9 +236,10 @@ class TrajectorySpreadingTest : public testing::Test {
220236 auto eventIndexActualItr = spreadResult.postEventsIndeces .begin ();
221237 auto eventTimeReferenceInd = ocs2::lookup::findIndexInTimeArray (updatedModeSchedule.eventTimes , period.first );
222238 for (size_t k = 0 ; k < spreadResult.timeTrajectory .size (); k++) {
223- // time should be monotonic sequence
224- if (k > 0 ) {
225- EXPECT_TRUE (spreadResult.timeTrajectory [k - 1 ] < spreadResult.timeTrajectory [k]);
239+ // Time should be monotonic sequence except the final time. It is possible that the last two time points have the same time, but one
240+ // stands for pre-mode and the other stands for post-mode.
241+ if (k > 0 && k < spreadResult.timeTrajectory .size () - 1 ) {
242+ EXPECT_TRUE (spreadResult.timeTrajectory [k - 1 ] < spreadResult.timeTrajectory [k]) << " TimeIndex: " << k;
226243 }
227244
228245 // Pre-event time should be equal to the event time
@@ -238,7 +255,9 @@ class TrajectorySpreadingTest : public testing::Test {
238255 eventIndexActualItr++;
239256 }
240257 // mode should match the given modeSchedule
241- EXPECT_TRUE (updatedModeSchedule.modeAtTime (spreadResult.timeTrajectory [k]) == spreadResult.modeTrajectory [k]);
258+ auto updatedMode =
259+ modeAtTime (updatedModeSchedule, spreadResult.timeTrajectory [k], k == spreadResult.timeTrajectory .size () - 1 ? true : false );
260+ EXPECT_TRUE (updatedMode == spreadResult.modeTrajectory [k]);
242261 } // end of k loop
243262
244263 // test postEventsIndeces
@@ -319,6 +338,48 @@ TEST_F(TrajectorySpreadingTest, partially_matching_modes) {
319338 EXPECT_TRUE (status.willPerformTrajectorySpreading );
320339}
321340
341+ TEST_F (TrajectorySpreadingTest, final_time_is_the_same_as_event_time_1) {
342+ const ocs2::scalar_array_t eventTimes{1.1 , 1.3 };
343+ const ocs2::size_array_t modeSequence{0 , 1 , 2 };
344+
345+ const ocs2::scalar_array_t updatedEventTimes{1.1 , 2.1 };
346+ const ocs2::size_array_t updatedModeSequence{0 , 1 , 2 };
347+
348+ const std::pair<ocs2::scalar_t , ocs2::scalar_t > period{0.2 , 2.1 };
349+ const auto status = checkResults ({eventTimes, modeSequence}, {updatedEventTimes, updatedModeSequence}, period);
350+
351+ EXPECT_FALSE (status.willTruncate );
352+ EXPECT_TRUE (status.willPerformTrajectorySpreading );
353+ }
354+
355+ TEST_F (TrajectorySpreadingTest, final_time_is_the_same_as_event_time_2) {
356+ const ocs2::scalar_array_t eventTimes{1.1 , 2.1 };
357+ const ocs2::size_array_t modeSequence{0 , 1 , 2 };
358+
359+ const ocs2::scalar_array_t updatedEventTimes{1.1 , 1.3 };
360+ const ocs2::size_array_t updatedModeSequence{0 , 1 , 2 };
361+
362+ const std::pair<ocs2::scalar_t , ocs2::scalar_t > period{0.2 , 2.1 };
363+ const auto status = checkResults ({eventTimes, modeSequence}, {updatedEventTimes, updatedModeSequence}, period);
364+
365+ EXPECT_FALSE (status.willTruncate );
366+ EXPECT_TRUE (status.willPerformTrajectorySpreading );
367+ }
368+
369+ TEST_F (TrajectorySpreadingTest, erase_trajectory) {
370+ const ocs2::scalar_array_t eventTimes{1.1 , 1.3 };
371+ const ocs2::size_array_t modeSequence{0 , 1 , 2 };
372+
373+ const ocs2::scalar_array_t updatedEventTimes{1.1 , 1.3 };
374+ const ocs2::size_array_t updatedModeSequence{0 , 1 , 3 };
375+
376+ const std::pair<ocs2::scalar_t , ocs2::scalar_t > period{0.2 , 2.1 };
377+ const auto status = checkResults ({eventTimes, modeSequence}, {updatedEventTimes, updatedModeSequence}, period);
378+
379+ EXPECT_TRUE (status.willTruncate );
380+ EXPECT_FALSE (status.willPerformTrajectorySpreading );
381+ }
382+
322383TEST_F (TrajectorySpreadingTest, fully_matched_modes) {
323384 const ocs2::scalar_array_t eventTimes{1.1 , 1.3 };
324385 const ocs2::size_array_t modeSequence{0 , 1 , 2 };
0 commit comments