diff --git a/stdlib/public/core/Stride.swift b/stdlib/public/core/Stride.swift index 9bfc6290fcadb..53e599b4eaaa2 100644 --- a/stdlib/public/core/Stride.swift +++ b/stdlib/public/core/Stride.swift @@ -215,14 +215,14 @@ public struct StrideToIterator { internal let _stride: Element.Stride @usableFromInline - internal var _current: (index: Int?, value: Element) + internal var _current: (index: Int?, value: Element?) @inlinable internal init(_start: Element, end: Element, stride: Element.Stride) { self._start = _start _end = end _stride = stride - _current = (0, _start) + _current = (nil, nil) } } @@ -233,12 +233,18 @@ extension StrideToIterator: IteratorProtocol { /// Once `nil` has been returned, all subsequent calls return `nil`. @inlinable public mutating func next() -> Element? { - let result = _current.value - if _stride > 0 ? result >= _end : result <= _end { - return nil + if let value = _current.value { + let deltaEnd = value.distance(to: _end) + if _stride > 0 ? deltaEnd > _stride : deltaEnd < _stride { + _current = Element._step(after: (_current.index, value), from: _start, by: _stride) + } else { + return nil + } + } else { + _current = (0, _start) } - _current = Element._step(after: _current, from: _start, by: _stride) - return result + + return _current.value } } @@ -416,7 +422,7 @@ public struct StrideThroughIterator { internal let _stride: Element.Stride @usableFromInline - internal var _current: (index: Int?, value: Element) + internal var _current: (index: Int?, value: Element?) @usableFromInline internal var _didReturnEnd: Bool = false @@ -426,7 +432,7 @@ public struct StrideThroughIterator { self._start = _start _end = end _stride = stride - _current = (0, _start) + _current = (nil, nil) } } @@ -437,19 +443,18 @@ extension StrideThroughIterator: IteratorProtocol { /// Once `nil` has been returned, all subsequent calls return `nil`. @inlinable public mutating func next() -> Element? { - let result = _current.value - if _stride > 0 ? result >= _end : result <= _end { - // This check is needed because if we just changed the above operators - // to > and <, respectively, we might advance current past the end - // and throw it out of bounds (e.g. above Int.max) unnecessarily. - if result == _end && !_didReturnEnd { - _didReturnEnd = true - return result + if let value = _current.value { + let deltaEnd = value.distance(to: _end) + if _stride > 0 ? deltaEnd >= self._stride : deltaEnd <= self._stride { + _current = Element._step(after: (_current.index, value), from: _start, by: _stride) + } else { + return nil } - return nil + } else { + _current = (0, _start) } - _current = Element._step(after: _current, from: _start, by: _stride) - return result + + return _current.value } } diff --git a/test/stdlib/Strideable.swift b/test/stdlib/Strideable.swift index 91a2a83fa5343..16cef75674d86 100644 --- a/test/stdlib/Strideable.swift +++ b/test/stdlib/Strideable.swift @@ -61,6 +61,17 @@ struct R : Strideable { } } +enum E : Int, Strideable { + case one = 1, two, three, four + + func distance(to other: Self) -> Int { + return other.rawValue - self.rawValue + } + func advanced(by n: Int) -> Self { + return Self(rawValue: self.rawValue + n)! + } +} + StrideTestSuite.test("Double") { func checkOpen(from start: Double, to end: Double, by stepSize: Double, sum: Double) { // Work on Doubles @@ -234,5 +245,20 @@ StrideTestSuite.test("StrideToIterator/past end/backward") { strideIteratorTest(stride(from: 3, to: 0, by: -1), nonNilResults: 3) } +StrideTestSuite.test("UInt8") { + // SR-2016 + strideIteratorTest(stride(from:253 as UInt8, to: 255, by: 2), nonNilResults: 1) + strideIteratorTest(stride(from:253 as UInt8, through: 255, by: 2), nonNilResults: 2) + strideIteratorTest(stride(from:2 as UInt8, to: 0, by: -2), nonNilResults: 1) + strideIteratorTest(stride(from:2 as UInt8, through: 0, by: -2), nonNilResults: 2) +} + +StrideTestSuite.test("Enum") { + strideIteratorTest(stride(from:E.one as UInt8, to: E.four, by: 2), nonNilResults: 1) + strideIteratorTest(stride(from:E.one as UInt8, through: E.four, by: 2), nonNilResults: 2) + strideIteratorTest(stride(from:E.four as UInt8, to: E.one, by: -2), nonNilResults: 1) + strideIteratorTest(stride(from:E.four UInt8, through: E.one, by: -2), nonNilResults: 2) +} + runAllTests()