Skip to content

Commit

Permalink
S3 archiving fix - post rebase (#373)
Browse files Browse the repository at this point in the history
* Fixed S3 archiving base.py

* Update test_project_client.py

* Update publish.py to work with black formatting 23.3.0
  • Loading branch information
Tejass9922 committed Jul 10, 2023
1 parent 4634856 commit a6ba049
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 23 deletions.
2 changes: 0 additions & 2 deletions rubicon_ml/intake_rubicon/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def publish(


def _update_catalog(base_catalog_filepath, new_experiments, output_filepath=None):

"""Helper function to update exisiting intake catalog.
Parameters
Expand Down Expand Up @@ -92,7 +91,6 @@ def _update_catalog(base_catalog_filepath, new_experiments, output_filepath=None


def _build_catalog(experiments):

"""Helper function to build catalog dictionary from given experiments.
Parameters
Expand Down
45 changes: 24 additions & 21 deletions rubicon_ml/repository/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import tempfile
import warnings
from datetime import datetime
from typing import List, Optional
Expand Down Expand Up @@ -296,33 +297,35 @@ def _archive(
-------
filepath of newly created archive
"""
if remote_rubicon_root is not None:
archive_dir = os.path.join(remote_rubicon_root, slugify(project_name), "archives")
else:
archive_dir = os.path.join(self.root_dir, slugify(project_name), "archives")

remote_s3 = True if remote_rubicon_root and remote_rubicon_root.startswith("s3") else False
root_dir = remote_rubicon_root if remote_rubicon_root is not None else self.root_dir
archive_dir = os.path.join(root_dir, slugify(project_name), "archives")
ts = datetime.timestamp(datetime.now())
archive_path = os.path.join(archive_dir, "archive-" + str(ts))
zip_archive_filename = str(archive_path + ".zip")
experiments_path = self._get_experiment_metadata_root(project_name)

if not self._exists(archive_dir):
self._mkdir(archive_dir)
if not remote_s3:
if not self._exists(archive_dir):
self._mkdir(archive_dir)

file_name = None
with tempfile.NamedTemporaryFile() as tf:
if experiments is not None:
with ZipFile(tf, "x") as archive:
experiment_paths = []
for experiment in experiments:
experiment_paths.append(os.path.join(experiments_path, experiment.id))
for file_path in experiment_paths:
archive.write(file_path, os.path.basename(file_path))
file_name = archive.filename

if experiments is None:
shutil.make_archive(archive_path, "zip", experiments_path)
else:
experiment_paths = []
for experiment in experiments:
experiment_paths.append(os.path.join(experiments_path, experiment.id))
with ZipFile(zip_archive_filename, "x") as archive:
for file_path in experiment_paths:
archive.write(file_path, os.path.basename(file_path))

if self._exists(zip_archive_filename):
print("zip archive created")
else:
print("zip archive not created")
else:
file_name = shutil.make_archive(tf.name, "zip", experiments_path)

with fsspec.open(zip_archive_filename, "wb") as fp:
with open(file_name, "rb") as tf:
fp.write(tf.read())

return zip_archive_filename

Expand Down
22 changes: 22 additions & 0 deletions tests/unit/client/test_project_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import warnings
from unittest import mock
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -489,3 +490,24 @@ def test_experiments_from_archive_latest_only():
assert new_num_expsB == 4
rubiconA.repository.filesystem.rm(rubiconA.config.root_dir, recursive=True)
rubiconB.repository.filesystem.rm(rubiconB.config.root_dir, recursive=True)


@patch("fsspec.open")
def test_archive_remote_rubicon_s3(mock_open):
print("buffer")
rubicon_a = Rubicon(
persistence="filesystem",
root_dir=os.path.join(os.path.dirname(os.path.realpath(__file__)), "rubiconA"),
)
s3_repo = "s3://bucket/root/path/to/data"

rubicon_b = Rubicon(persistence="filesystem", root_dir=s3_repo)

projectA = rubicon_a.get_or_create_project("ArchiveTesting")
projectA.log_experiment(name="experiment1")
projectA.log_experiment(name="experiment2")

zip_archive_filename = projectA.archive(remote_rubicon=rubicon_b)

mock_open.assert_called_once_with(zip_archive_filename, "wb")
rubicon_a.repository.filesystem.rm(rubicon_a.config.root_dir, recursive=True)

0 comments on commit a6ba049

Please sign in to comment.