In [3]:
from typing import List

class SimpleIterator:
    def __init__(self, data):
        self.data = data
        self.idx = -1
        self.n = len(data)

    def __iter__(self):
        return self

    def __next__(self):
        self.idx += 1
        if self.idx >= self.n:
            raise StopIteration
        else:
            return self.data[self.idx]

class MyIterator:
    def __init__(self, data):
        self.data = data
        self.ptr = [SimpleIterator(self.data)]

    def __iter__(self):
        return self
    

    def _next_impl(self):
        if len(self.ptr) == 0:
            raise StopIteration

        try:
            elem = next(self.ptr[-1])
        except StopIteration:
            self.ptr.pop()
            return self._next_impl()
        
        if isinstance(elem, List):
            self.ptr.append(SimpleIterator(elem))
            return self._next_impl()
        else:
            return elem

    def __next__(self):
        return self._next_impl()

    def get_state(self):
        return self.ptr
    
    def set_state(self, ptr):
        self.ptr = ptr

In [4]:
x = [[], [], [[1, [2]]], 3]
def test_sync(x, out):
    ret = [elem for elem in MyIterator(x)]
    msg = f""" ======
test input: {x}
output: {ret}
expected output: {out}
    """
    print(msg)

test_sync(x, [1, 2, 3])
test_sync([], [])
test_sync([1, 2, 3], [1, 2, 3])

test input: [[], [], [[1, [2]]], 3]
output: [1, 2, 3]
expected output: [1, 2, 3]
    
test input: []
output: []
expected output: []
    
test input: [1, 2, 3]
output: [1, 2, 3]
expected output: [1, 2, 3]
    


In [72]:
x = [[], [], [[1, [2]]], 3, 4, [[[5]]]]
k = 3


def run(data_iter, limit=None):
    ret = []
    for i, elem in enumerate(data_iter):
        ret.append(elem)
        if limit is not None and i + 1 >= limit:
            break
    return ret

data_iter = MyIterator(x)

print(f'first {k} values: {run(data_iter, k)}')
state = data_iter.get_state()
new_iter = MyIterator(x)
new_iter.set_state(state)
print(f'{k+1}th to last values: {run(new_iter)}')

first 3 values: [1, 2, 3]
4th to last values: [4, 5]


In [None]:
import asyncio
import threading
import random

fetch_data_delay = 0.1

class MyAsyncIterator:
    def __init__(self, data):
        self.data = data
        self.ptr = [SimpleIterator(self.data)]

    def __aiter__(self):
        return self

    async def _next_impl(self):
        if len(self.ptr) == 0:
            raise StopAsyncIteration

        try:            
            elem = next(self.ptr[-1])
        except StopIteration:
            self.ptr.pop()
            return await self._next_impl()
        
        if isinstance(elem, List):
            self.ptr.append(SimpleIterator(elem))
            return await self._next_impl()
        else:
            await asyncio.sleep(fetch_data_delay)
            return elem

    async def __anext__(self):
        return await self._next_impl()


    def get_state(self):
        return self.ptr
    
    def set_state(self, ptr):
        self.ptr = ptr

In [63]:
x = [[], [], [[1, [2]]], 3]
sleep_time = iter([1, 1, 1])


async def process_element(elem, delay):
    await asyncio.sleep(delay)
    print(f'element {elem} processed!')
    return elem

async def run_async(data_iter, limit=None):
    tasks = []
    cnt = 0
    async for elem in data_iter:
        tasks.append(asyncio.create_task(process_element(elem, next(sleep_time, 0))))
        cnt += 1
        if limit is not None and cnt >= limit:
            break
        
    return await asyncio.gather(*tasks)

data_iter = MyAsyncIterator(x)
await run_async(data_iter)

element 1 processed!
element 2 processed!
element 3 processed!


[1, 2, 3]

In [64]:
x = [[], [], [[1, [2]]], 3, 4, [[[5]]]]
k = 3

data_iter = MyAsyncIterator(x)

print(f'first {k} values: {await run_async(data_iter, k)}')
state = data_iter.get_state()

new_iter = MyAsyncIterator(x)
new_iter.set_state(state)

print(f'{k+1}th to last values: {await run_async(new_iter)}')

element 1 processed!
element 2 processed!
element 3 processed!
first 3 values: [1, 2, 3]
element 4 processed!
element 5 processed!
4th to last values: [4, 5]


In [73]:
import threading
import random
import time

class MyTsIterator:
    def __init__(self, data, fetch_data_delay=0.1):
        self.data = data
        self.ptr = [SimpleIterator(self.data)]
        self.lock = threading.Lock()
        self.fetch_data_delay = fetch_data_delay

    def __iter__(self):
        return self

    def _next_impl(self):
        if len(self.ptr) == 0:
            raise StopIteration

        try:            
            elem = next(self.ptr[-1])
        except StopIteration:
            self.ptr.pop()
            return self._next_impl()
        
        if isinstance(elem, List):
            self.ptr.append(SimpleIterator(elem))
            return self._next_impl()
        else:
            time.sleep(self.fetch_data_delay)
            return elem

    def __next__(self):
        with self.lock:
            return self._next_impl()

    def get_state(self):
        return self.ptr
    
    def set_state(self, ptr):
        self.ptr = ptr

In [102]:
x = [[], [], [[1, [2]]], 3]
sleep_time = iter([1, 1, 1])

def get_and_process_data(data_iter, result_pool, limit=None):
    elem = next(data_iter, None) 
    while elem is not None:
        time.sleep(next(sleep_time, 0))
        print(f'element {elem} processed!')
        result_pool.append(elem)
        if limit is not None and len(result_pool) >= limit:
            break
        elem = next(data_iter, None) 

data_iter = MyTsIterator(x)
result_pool = []
get_and_process_data(data_iter, result_pool)
print(result_pool)

element 1 processed!
element 2 processed!
element 3 processed!
[1, 2, 3]


In [103]:
x = [[], [], [[1, [2]]], 3]
sleep_time = iter([1, 1, 1])
data = MyTsIterator(x)

def run_concurrent(data_iter, limit, num_threads):
    tasks = []
    result_pool = []
    for _ in range(num_threads):
        t = threading.Thread(target=get_and_process_data, args=[data_iter, result_pool, limit])
        t.start()
        tasks.append(t)

    for t in tasks:
        t.join()

    return result_pool

data_iter = MyTsIterator(x)
print(run_concurrent(data_iter, None, 3))

element 1 processed!
element 2 processed!
element 3 processed!
[1, 2, 3]


In [96]:
sleep_time = iter([1, 1, 1])
data_iter = MyTsIterator(x)
print(run_concurrent(data_iter, None, 2))

element 1 processed!
element 2 processed!
element 3 processed!
deque([1, 2, 3])


In [104]:
x = [[], [], [[1, [2]]], 3, 4, [[[5]]]]
k = 3

data_iter = MyTsIterator(x)

print(f'first {k} values: {run_concurrent(data_iter, k, 2)}')
state = data_iter.get_state()

new_iter = MyTsIterator(x)
new_iter.set_state(state)

print(f'{k+1}th to last values: {run_concurrent(new_iter, None, 2)}')

element 1 processed!
element 2 processed!
element 3 processed!
element 4 processed!
first 3 values: [1, 2, 3, 4]
element 5 processed!
4th to last values: [5]
