In [1]:
import random
import threading
import time
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import NoResultFound


In [2]:
#---- this seems to work!
import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager

from sqlalchemy import create_engine, Column, Integer, String, Boolean, DateTime
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import NoResultFound
from sqlalchemy.sql import text

from sqlalchemy import Index


# --- Setup the SQLite database & SQLAlchemy ORM ---

# Using a file-based SQLite database and allow multithreaded access:
#engine = create_engine("sqlite:///test0.db", connect_args={"check_same_thread": False})
engine = create_engine('postgresql://postgres@localhost:5333/ajtest1')
SessionLocal = sessionmaker(bind=engine)

Base = declarative_base()

class Tile(Base):
    __tablename__ = 'tiles'
    id = Column(Integer, primary_key=True)
    annotation_class_id = Column(Integer, default=1)  # for demo, all tiles use 1
    hasgt = Column(Boolean, default=True)
    datetime = Column(DateTime, default=datetime.datetime.utcnow)
    status = Column(String, default="pending")  # can be "pending", "in_progress", etc.
    worker_id = Column(Integer, nullable=True)    # which worker claimed it

    # Correct way to define indexes
    __table_args__ = (
        Index("idx_datetime", datetime),
        Index("idx_status", status),
    )
# Drop and recreate the table (for demo purposes)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)

In [3]:
Tile.__table__.indexes

{Index('idx_datetime', Column('datetime', DateTime(), table=<tiles>, default=CallableColumnDefault(<function datetime.utcnow at 0x7f8eb6acfeb0>))),
 Index('idx_status', Column('status', String(), table=<tiles>, default=ScalarElementColumnDefault('pending')))}

In [4]:
# --- Populate the database with some tiles ---
with SessionLocal() as session:
    tiles = []
    for i in range(100_000):  # create tiles
        tile = Tile(
            annotation_class_id=1,
            hasgt=True,
            # Newer tiles have a more recent datetime:
            datetime=datetime.datetime.utcnow() - datetime.timedelta(seconds=i)
        )
        tiles.append(tile)
    session.add_all(tiles)
    session.commit()

In [5]:
# --- Utility to get a session ---
@contextmanager
def get_session():
    session = SessionLocal()
    try:
        yield session
        session.commit()
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.close()

In [7]:
from sqlalchemy.orm import Session
from sqlalchemy import update, select, text
from sqlalchemy.engine import Engine

def getWorkersTile(worker_id: int):
    """
    Atomically retrieves and marks a tile as 'in_progress' so no two workers claim the same tile.
    """
    with get_session() as db_session:  # Ensure this provides a session context
        dialect = db_session.bind.dialect.name  # Get database type
        with db_session.begin():  # Explicit transaction
            subquery = (
                select(Tile.id)
                .where(Tile.annotation_class_id == 1,
                       Tile.hasgt == True,
                       Tile.status == 'pending')
                .order_by(Tile.datetime.desc())
                .limit(1).with_for_update(skip_locked=True)
            )
            
            tile_id = db_session.execute(
                update(Tile)
                .where(Tile.id == subquery.scalar_subquery())
                .where(Tile.status == 'pending')  # Ensures another worker hasn't claimed it
                .values(status='in_progress', worker_id=worker_id)
                .returning(Tile.id)
            ).scalar()
            
            if tile:
                return f"Worker {worker_id} claimed Tile {tile_id}"
            else:
                return f"Worker {worker_id} found no tile"



In [8]:
# --- Worker function ---
def worker_function(worker_id):
    result = getWorkersTile(worker_id)
    print(result)
    return result

In [9]:
# --- Main function: spawn many workers concurrently ---
def main():
    num_workers = 200  # simulate an aggressive scenario with 200 concurrent workers
    results = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(worker_function, worker_id) for worker_id in range(num_workers)]
        for future in as_completed(futures):
            results.append(future.result())
    
    print("\nSummary:")
    claimed = [r for r in results if "claimed Tile" in r]
    for res in results:
        print(res)
    print(f"\nTotal claimed tiles: {len(claimed)}")

if __name__ == "__main__":
    main()

Worker 0 claimed Tile 1
Worker 1 claimed Tile 2
Worker 8 claimed Tile 3
Worker 5 claimed Tile 6
Worker 3 claimed Tile 5
Worker 4 claimed Tile 4
Worker 6 claimed Tile 7
Worker 9 claimed Tile 8
Worker 16 claimed Tile 14
Worker 12 claimed Tile 15
Worker 7 claimed Tile 11
Worker 11 claimed Tile 10
Worker 15 claimed Tile 13
Worker 10 claimed Tile 12
Worker 22 claimed Tile 22
Worker 2 claimed Tile 9
Worker 13 claimed Tile 16
Worker 18 claimed Tile 19
Worker 20 claimed Tile 20
Worker 14 claimed Tile 24
Worker 21 claimed Tile 21
Worker 24 claimed Tile 23
Worker 30 claimed Tile 29
Worker 25 claimed Tile 25
Worker 23 claimed Tile 28
Worker 27 claimed Tile 27
Worker 17 claimed Tile 18
Worker 26 claimed Tile 26
Worker 28 claimed Tile 34
Worker 32 claimed Tile 30
Worker 35 claimed Tile 36
Worker 19 claimed Tile 17
Worker 33 claimed Tile 31
Worker 29 claimed Tile 32
Worker 31 claimed Tile 33
Worker 36 claimed Tile 43
Worker 44 claimed Tile 40
Worker 34 claimed Tile 35
Worker 47 claimed Tile 41
Worke

In [10]:
%%timeit
result = getWorkersTile(-1)

4.11 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
