In [3]:
from typing import List

class SimpleIterator:
    def __init__(self, data):
        self.data = data
        self.idx = -1
        self.n = len(data)

    def __iter__(self):
        return self

    def __next__(self):
        self.idx += 1
        if self.idx >= self.n:
            raise StopIteration
        else:
            return self.data[self.idx]

class MyIterator:
    def __init__(self, data):
        self.data = data
        self.ptr = [SimpleIterator(self.data)]

    def __iter__(self):
        return self
    

    def _next_impl(self):
        if len(self.ptr) == 0:
            raise StopIteration

        try:
            elem = next(self.ptr[-1])
        except StopIteration:
            self.ptr.pop()
            return self._next_impl()
        
        if isinstance(elem, List):
            self.ptr.append(SimpleIterator(elem))
            return self._next_impl()
        else:
            return elem

    def __next__(self):
        return self._next_impl()

    def get_state(self):
        return self.ptr
    
    def set_state(self, ptr):
        self.ptr = ptr

In [4]:
x = [[], [], [[1, [2]]], 3]
def test_sync(x, out):
    ret = [elem for elem in MyIterator(x)]
    msg = f""" ======
test input: {x}
output: {ret}
expected output: {out}
    """
    print(msg)

test_sync(x, [1, 2, 3])
test_sync([], [])
test_sync([1, 2, 3], [1, 2, 3])

test input: [[], [], [[1, [2]]], 3]
output: [1, 2, 3]
expected output: [1, 2, 3]
    
test input: []
output: []
expected output: []
    
test input: [1, 2, 3]
output: [1, 2, 3]
expected output: [1, 2, 3]
    


In [29]:
import asyncio
import threading
import random

fetch_data_delay = 3.0

class MyAsyncIterator:
    def __init__(self, data):
        self.data = data
        self.ptr = [SimpleIterator(self.data)]

    def __aiter__(self):
        return self

    async def _next_impl(self):
        if len(self.ptr) == 0:
            raise StopAsyncIteration

        try:            
            elem = next(self.ptr[-1])
        except StopIteration:
            self.ptr.pop()
            return await self._next_impl()
        
        if isinstance(elem, List):
            self.ptr.append(SimpleIterator(elem))
            return await self._next_impl()
        else:
            await asyncio.sleep(fetch_data_delay)
            return elem

    async def __anext__(self):
        return await self._next_impl()


    def get_state(self):
        return self.ptr
    
    def set_state(self, ptr):
        self.ptr = ptr

In [37]:
x = [[], [], [[1, [2]]], 3]
sleep_time = iter([2, 7, 2])


async def process_element(elem, delay):
    await asyncio.sleep(delay)
    print(f'element {elem} processed!')

async def run(x):
    tasks = []
    async for elem in MyAsyncIterator(x):
        tasks.append(asyncio.create_task(process_element(elem, next(sleep_time, 0))))
        
    await asyncio.gather(*tasks)

await run(x)

element 1 processed!
element 3 processed!
element 2 processed!


In [41]:
import threading
import random
import time

fetch_data_delay = 3.0

class MyTsIterator:
    def __init__(self, data):
        self.data = data
        self.ptr = [SimpleIterator(self.data)]
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def _next_impl(self):
        if len(self.ptr) == 0:
            raise StopIteration

        try:            
            elem = next(self.ptr[-1])
        except StopIteration:
            self.ptr.pop()
            return self._next_impl()
        
        if isinstance(elem, List):
            self.ptr.append(SimpleIterator(elem))
            return self._next_impl()
        else:
            time.sleep(fetch_data_delay)
            return elem

    def __next__(self):
        with self.lock:
            return self._next_impl()

    def get_state(self):
        return self.ptr
    
    def set_state(self, ptr):
        self.ptr = ptr

In [43]:
x = [[], [], [[1, [2]]], 3]
sleep_time = iter([2, 7, 2])
data = MyTsIterator(x)

def get_and_process_data():
    elem = next(data, None) 
    while elem is not None:
        time.sleep(next(sleep_time, 0))
        print(f'element {elem} processed!')
        elem = next(data, None) 

get_and_process_data()

element 1 processed!
element 2 processed!
element 3 processed!


In [45]:
num_threads = 3

x = [[], [], [[1, [2]]], 3]
sleep_time = iter([2, 7, 2])
data = MyTsIterator(x)

tasks = []
for _ in range(num_threads):
    t = threading.Thread(target=get_and_process_data)
    t.start()
    tasks.append(t)

for t in tasks:
    t.join()


element 1 processed!
element 3 processed!
element 2 processed!


In [49]:
num_threads = 2

x = [[], [], [[1, [2]]], 3]
sleep_time = iter([2, 7, 2])
data = MyTsIterator(x)

tasks = []
for _ in range(num_threads):
    t = threading.Thread(target=get_and_process_data)
    t.start()
    tasks.append(t)

for t in tasks:
    t.join()

element 1 processed!
element 3 processed!
element 2 processed!
