Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 34 additions & 51 deletions codeforlife/tasks/data_warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,16 @@ class DataWarehouseTask(Task):
timestamp_key = "_timestamp"

GetQuerySet: t.TypeAlias = t.Callable[..., QuerySet[t.Any]]
BqTableWriteMode: t.TypeAlias = t.Literal["overwrite", "append"]

# pylint: disable-next=too-many-instance-attributes
class Settings:
"""The settings for a data warehouse task."""

BqTableWriteMode: t.TypeAlias = t.Literal["overwrite", "append"]

# pylint: disable-next=too-many-arguments,too-many-branches
def __init__(
self,
bq_table_write_mode: BqTableWriteMode,
bq_table_write_mode: "DataWarehouseTask.BqTableWriteMode",
chunk_size: int,
fields: t.List[str],
id_field: str = "id",
Expand Down Expand Up @@ -210,46 +209,47 @@ class ChunkMetadata:
"""All of the metadata used to track a chunk."""

bq_table_name: str # the name of the BigQuery table
bq_table_write_mode: "DataWarehouseTask.BqTableWriteMode"
timestamp: str # when the task was first run
obj_i_start: int # object index span start
obj_i_end: int # object index span end
obj_count_digits: int # number of digits in the object count

def to_blob_name(self):
"""Convert this chunk metadata into a blob name."""

# Left-pad the object indexes with zeros.
obj_i_start_fstr = str(self.obj_i_start).zfill(
self.obj_count_digits
)
obj_i_end_fstr = str(self.obj_i_end).zfill(self.obj_count_digits)

# E.g. "user/2025-01-01_00:00:00__0001_1000.csv"
# E.g. "user__append/2025-01-01_00:00:00__1_1000.csv"
return (
f"{self.bq_table_name}/{self.timestamp}__"
f"{obj_i_start_fstr}_{obj_i_end_fstr}.csv"
f"{self.bq_table_name}__{self.bq_table_write_mode}/"
f"{self.timestamp}__{self.obj_i_start}_{self.obj_i_end}.csv"
)

@classmethod
def from_blob_name(cls, blob_name: str):
"""Extract the chunk metadata from a blob name."""

# E.g. "user/2025-01-01_00:00:00__0001_1000.csv"
# "2025-01-01_00:00:00__0001_1000.csv"
bq_table_name, blob_name = blob_name.split("/", maxsplit=1)
# "2025-01-01_00:00:00__0001_1000"
blob_name = blob_name.removesuffix(".csv")
# "2025-01-01_00:00:00", "0001_1000"
timestamp, obj_i_span_fstr = blob_name.split("__")
# "0001", "1000"
obj_i_start_fstr, obj_i_end_fstr = obj_i_span_fstr.split("_")
# E.g. "user__append/2025-01-01_00:00:00__1_1000.csv"
# "user__append", "2025-01-01_00:00:00__1_1000.csv"
dir_name, file_name = blob_name.split("/")
# "user", "append"
bq_table_name, bq_table_write_mode = dir_name.rsplit(
"__", maxsplit=1
)
assert bq_table_write_mode in ("overwrite", "append")
# "2025-01-01_00:00:00__1_1000"
file_name = file_name.removesuffix(".csv")
# "2025-01-01_00:00:00", "1_1000"
timestamp, obj_i_span = file_name.split("__")
# "1", "1000"
obj_i_start, obj_i_end = obj_i_span.split("_")

return cls(
bq_table_name=bq_table_name,
bq_table_write_mode=t.cast(
DataWarehouseTask.BqTableWriteMode, bq_table_write_mode
),
timestamp=timestamp,
obj_i_start=int(obj_i_start_fstr),
obj_i_end=int(obj_i_end_fstr),
obj_count_digits=len(obj_i_start_fstr),
obj_i_start=int(obj_i_start),
obj_i_end=int(obj_i_end),
)

def _get_gcs_bucket(self):
Expand Down Expand Up @@ -356,9 +356,6 @@ def _save_query_set_as_csvs_in_gcs_bucket(
if obj_count == 0:
return

# Get the number of digits in the object count.
obj_count_digits = len(str(obj_count))

# If the queryset is not ordered, order it by ID by default.
if not queryset.ordered:
queryset = queryset.order_by(self.settings.id_field)
Expand All @@ -373,11 +370,17 @@ def _save_query_set_as_csvs_in_gcs_bucket(
# The name of the last blob from the current timestamp.
last_blob_name_from_current_timestamp: t.Optional[str] = None

# The name of the directory where the blobs are expected to be located.
blob_dir_name = (
f"{self.settings.bq_table_name}__"
f"{self.settings.bq_table_write_mode}/"
)

# List all the existing blobs.
for blob in t.cast(
t.Iterator[gcs.Blob],
bucket.list_blobs(
prefix=f"{self.settings.bq_table_name}/"
prefix=blob_dir_name
+ (
timestamp
if self.settings.only_list_blobs_from_current_timestamp
Expand All @@ -390,28 +393,8 @@ def _save_query_set_as_csvs_in_gcs_bucket(
# Check if found first blob from current timestamp.
if (
self.settings.only_list_blobs_from_current_timestamp
or blob_name.startswith(
f"{self.settings.bq_table_name}/{timestamp}"
)
or blob_name.startswith(blob_dir_name + timestamp)
):
chunk_metadata = self.ChunkMetadata.from_blob_name(blob_name)

# If the number of digits in the object count has changed...
if obj_count_digits != chunk_metadata.obj_count_digits:
# ...update the number of digits in the object count...
chunk_metadata.obj_count_digits = obj_count_digits
# ...and update the blob name...
blob_name = chunk_metadata.to_blob_name()
# ...and copy the blob with the updated name...
bucket.copy_blob(
blob=blob,
destination_bucket=bucket,
new_name=blob_name,
)
# ...and delete the old blob.
logging.info('Deleting blob "%s".', blob.name)
blob.delete()

last_blob_name_from_current_timestamp = blob_name
# Check if blobs not from the current timestamp should be deleted.
elif self.settings.delete_blobs_not_from_current_timestamp:
Expand Down Expand Up @@ -454,10 +437,10 @@ def upload_csv(obj_i_end: int):
# Generate the path to the CSV in the bucket.
blob_name = self.ChunkMetadata(
bq_table_name=self.settings.bq_table_name,
bq_table_write_mode=self.settings.bq_table_write_mode,
timestamp=timestamp,
obj_i_start=obj_i_start,
obj_i_end=obj_i_end,
obj_count_digits=obj_count_digits,
).to_blob_name()

# Create a blob object for the CSV file's path and upload it.
Expand Down
53 changes: 12 additions & 41 deletions codeforlife/tasks/data_warehouse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ def __repr__(self):
@classmethod
# pylint: disable-next=too-many-arguments
def generate_list(
cls,
task: DWT,
timestamp: str,
obj_i_start: int,
obj_i_end: int,
obj_count_digits: int,
cls, task: DWT, timestamp: str, obj_i_start: int, obj_i_end: int
):
"""Generate a list of mock GCS blobs.

Expand All @@ -83,12 +78,12 @@ def generate_list(
cls(
chunk_metadata=DWT.ChunkMetadata(
bq_table_name=task.settings.bq_table_name,
bq_table_write_mode=task.settings.bq_table_write_mode,
timestamp=timestamp,
obj_i_start=obj_i_start,
obj_i_end=min(
obj_i_start + task.settings.chunk_size - 1, obj_i_end
),
obj_count_digits=obj_count_digits,
)
)
for obj_i_start in range(
Expand Down Expand Up @@ -124,17 +119,14 @@ def setUp(self):
self.datetime = datetime.combine(self.date, self.time)

self.bq_table_name = "example"
self.bq_table_write_mode: DWT.BqTableWriteMode = "append"
self.timestamp = DWT.to_timestamp(self.datetime)
self.obj_i_start = 1
self.obj_i_end = 100
self.obj_count_digits = 4

obj_i_start_fstr = str(self.obj_i_start).zfill(self.obj_count_digits)
obj_i_end_fstr = str(self.obj_i_end).zfill(self.obj_count_digits)

self.blob_name = (
f"{self.bq_table_name}/{self.timestamp}__"
f"{obj_i_start_fstr}_{obj_i_end_fstr}.csv"
f"{self.bq_table_name}__{self.bq_table_write_mode}/"
f"{self.timestamp}__{self.obj_i_start}_{self.obj_i_end}.csv"
)

return super().setUp()
Expand All @@ -144,7 +136,7 @@ def setUp(self):
def _test_settings(
self,
code: str,
bq_table_write_mode: DWT.Settings.BqTableWriteMode = ("append"),
bq_table_write_mode: DWT.BqTableWriteMode = "append",
chunk_size: int = 10,
fields: t.Optional[t.List[str]] = None,
**kwargs,
Expand Down Expand Up @@ -210,21 +202,21 @@ def test_chunk_metadata__to_blob_name(self):
"""Can successfully convert a chunk's metadata into a blob name."""
blob_name = DWT.ChunkMetadata(
bq_table_name=self.bq_table_name,
bq_table_write_mode=self.bq_table_write_mode,
timestamp=self.timestamp,
obj_i_start=self.obj_i_start,
obj_i_end=self.obj_i_end,
obj_count_digits=self.obj_count_digits,
).to_blob_name()
assert blob_name == self.blob_name

def test_chunk_metadata__from_blob_name(self):
"""Can successfully convert a chunk's metadata into a blob name."""
chunk_metadata = DWT.ChunkMetadata.from_blob_name(self.blob_name)
assert chunk_metadata.bq_table_name == self.bq_table_name
assert chunk_metadata.bq_table_write_mode == self.bq_table_write_mode
assert chunk_metadata.timestamp == self.timestamp
assert chunk_metadata.obj_i_start == self.obj_i_start
assert chunk_metadata.obj_i_end == self.obj_i_end
assert chunk_metadata.obj_count_digits == self.obj_count_digits

# Init CSV writer

Expand Down Expand Up @@ -319,11 +311,6 @@ def _test_task(
assert uploaded_obj_count <= obj_count
assert (obj_count - uploaded_obj_count) > 0

# Get the object count's current magnitude (number of digits) and
# simulate a higher order of magnitude during the previous run.
obj_count_digits = len(str(obj_count))
uploaded_obj_count_digits = obj_count_digits + 1

# Get the current datetime.
now = datetime.now(timezone.utc)

Expand All @@ -335,7 +322,6 @@ def _test_task(
timestamp=DWT.to_timestamp(now - since_previous_run),
obj_i_start=1,
obj_i_end=obj_count,
obj_count_digits=obj_count_digits,
)
if since_previous_run is not None
else []
Expand All @@ -348,14 +334,12 @@ def _test_task(
timestamp=timestamp,
obj_i_start=1,
obj_i_end=uploaded_obj_count,
obj_count_digits=uploaded_obj_count_digits,
)
non_uploaded_blobs_from_current_timestamp = MockGcsBlob.generate_list(
task=task,
timestamp=timestamp,
obj_i_start=uploaded_obj_count + 1,
obj_i_end=obj_count,
obj_count_digits=obj_count_digits,
)

# Generate a mock GCS bucket.
Expand Down Expand Up @@ -398,7 +382,10 @@ def _test_task(
# table's write-mode is append, assert only the blobs in the current
# timestamp were listed.
bucket.list_blobs.assert_called_once_with(
prefix=f"{task.settings.bq_table_name}/"
prefix=(
f"{task.settings.bq_table_name}__"
f"{task.settings.bq_table_write_mode}/"
)
+ (
timestamp
if task.settings.only_list_blobs_from_current_timestamp
Expand All @@ -421,22 +408,6 @@ def _test_task(
]
)

# Assert that the uploaded blobs in the current timestamp were copied
# with the magnitude corrected in their name and the old blobs deleted.
for blob in uploaded_blobs_from_current_timestamp:
blob.chunk_metadata.obj_count_digits = obj_count_digits
blob.delete.assert_called_once()
bucket.copy_blob.assert_has_calls(
[
call(
blob=blob,
destination_bucket=bucket,
new_name=blob.chunk_metadata.to_blob_name(),
)
for blob in uploaded_blobs_from_current_timestamp
]
)

# Assert that each blob was uploaded from a CSV string.
for blob in non_uploaded_blobs_from_current_timestamp:
csv_content, csv_writer = task.init_csv_writer()
Expand Down
Loading