Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Redis keys to avoid duplicate string literals #309

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tasktiger/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
from .task import Task
from .tasktiger import TaskTiger

# Constants pertaining to Redis keys
TASK = "task"
EXECUTIONS = "executions"
EXECUTIONS_COUNT = "executions_count"

# Task states (represented by different queues)
# Note some client code may rely on the string values (e.g. get_queue_stats).
QUEUED = "queued"
Expand Down
4 changes: 3 additions & 1 deletion tasktiger/migrations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING

from ._internal import EXECUTIONS, TASK
from .utils import redis_glob_escape

if TYPE_CHECKING:
Expand All @@ -22,7 +23,8 @@ def migrate_executions_count(tiger: "TaskTiger") -> None:
)

match = (
redis_glob_escape(tiger.config["REDIS_PREFIX"]) + ":task:*:executions"
redis_glob_escape(tiger.config["REDIS_PREFIX"])
+ f":{TASK}:*:{EXECUTIONS}"
)

for key in tiger.connection.scan_iter(count=100, match=match):
Expand Down
16 changes: 12 additions & 4 deletions tasktiger/redis_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@

from redis import Redis

from ._internal import ACTIVE, ERROR, QUEUED, SCHEDULED
from ._internal import (
ACTIVE,
ERROR,
EXECUTIONS,
EXECUTIONS_COUNT,
QUEUED,
SCHEDULED,
TASK,
)

try:
from redis.commands.core import Script
Expand Down Expand Up @@ -595,9 +603,9 @@ def _bool_to_str(v: bool) -> str:
def _none_to_empty_str(v: Optional[str]) -> str:
return v or ""

key_task_id = key_func("task", id)
key_task_id_executions = key_func("task", id, "executions")
key_task_id_executions_count = key_func("task", id, "executions_count")
key_task_id = key_func(TASK, id)
key_task_id_executions = key_func(TASK, id, EXECUTIONS)
key_task_id_executions_count = key_func(TASK, id, EXECUTIONS_COUNT)
key_from_state = key_func(from_state)
key_to_state = key_func(to_state) if to_state else ""
key_active_queue = key_func(ACTIVE, queue)
Expand Down
17 changes: 10 additions & 7 deletions tasktiger/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

from ._internal import (
ERROR,
EXECUTIONS,
EXECUTIONS_COUNT,
QUEUED,
SCHEDULED,
TASK,
g,
gen_id,
gen_unique_id,
Expand Down Expand Up @@ -393,7 +396,7 @@ def delay(

pipeline = tiger.connection.pipeline()
pipeline.sadd(tiger._key(state), self.queue)
pipeline.set(tiger._key("task", self.id), serialized_task)
pipeline.set(tiger._key(TASK, self.id), serialized_task)
# In case of unique tasks, don't update the score.
tiger.scripts.zadd(
tiger._key(state, self.queue),
Expand Down Expand Up @@ -454,11 +457,11 @@ def from_id(
latest). If the task doesn't exist, None is returned.
"""
pipeline = tiger.connection.pipeline()
pipeline.get(tiger._key("task", task_id))
pipeline.get(tiger._key(TASK, task_id))
pipeline.zscore(tiger._key(state, queue), task_id)
if load_executions:
pipeline.lrange(
tiger._key("task", task_id, "executions"), -load_executions, -1
tiger._key(TASK, task_id, EXECUTIONS), -load_executions, -1
)
(
serialized_data,
Expand Down Expand Up @@ -526,10 +529,10 @@ def tasks_from_queue(
]
if load_executions:
pipeline = tiger.connection.pipeline()
pipeline.mget([tiger._key("task", item[0]) for item in items])
pipeline.mget([tiger._key(TASK, item[0]) for item in items])
for item in items:
pipeline.lrange(
tiger._key("task", item[0], "executions"),
tiger._key(TASK, item[0], EXECUTIONS),
-load_executions,
-1,
)
Expand Down Expand Up @@ -586,8 +589,8 @@ def n_executions(self) -> int:
Queries and returns the number of past task executions.
"""
pipeline = self.tiger.connection.pipeline()
pipeline.exists(self.tiger._key("task", self.id))
pipeline.get(self.tiger._key("task", self.id, "executions_count"))
pipeline.exists(self.tiger._key(TASK, self.id))
pipeline.get(self.tiger._key(TASK, self.id, EXECUTIONS_COUNT))

exists, executions_count = pipeline.execute()
if not exists:
Expand Down
17 changes: 9 additions & 8 deletions tasktiger/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@
from ._internal import (
ACTIVE,
ERROR,
EXECUTIONS,
EXECUTIONS_COUNT,
QUEUED,
SCHEDULED,
TASK,
dotted_parts,
g,
g_fork_lock,
Expand Down Expand Up @@ -345,7 +348,7 @@ def _worker_queue_expired_tasks(self) -> None:
self.config["REQUEUE_EXPIRED_TASKS_BATCH_SIZE"],
)

for (queue, task_id) in task_data:
for queue, task_id in task_data:
self.log.debug("expiring task", queue=queue, task_id=task_id)
self._did_work = True
try:
Expand Down Expand Up @@ -374,7 +377,7 @@ def _worker_queue_expired_tasks(self) -> None:
# have a task without a task object.

# XXX: Ideally, the following block should be atomic.
if not self.connection.get(self._key("task", task_id)):
if not self.connection.get(self._key(TASK, task_id)):
self.log.error("not found", queue=queue, task_id=task_id)
task = Task(
self.tiger,
Expand Down Expand Up @@ -812,7 +815,7 @@ def _process_queue_tasks(

# Get all tasks
serialized_tasks = self.connection.mget(
[self._key("task", task_id) for task_id in task_ids]
[self._key(TASK, task_id) for task_id in task_ids]
)

# Parse tasks
Expand Down Expand Up @@ -1053,7 +1056,7 @@ def _mark_done() -> None:
should_log_error = True
# Get execution info (for logging and retry purposes)
execution = self.connection.lindex(
self._key("task", task.id, "executions"), -1
self._key(TASK, task.id, EXECUTIONS), -1
)

if execution:
Expand Down Expand Up @@ -1242,10 +1245,8 @@ def _store_task_execution(
serialized_execution = json.dumps(execution)

for task in tasks:
executions_key = self._key("task", task.id, "executions")
executions_count_key = self._key(
"task", task.id, "executions_count"
)
executions_key = self._key(TASK, task.id, EXECUTIONS)
executions_count_key = self._key(TASK, task.id, EXECUTIONS_COUNT)

pipeline = self.connection.pipeline()
pipeline.incr(executions_count_key)
Expand Down
Loading