From 10f667f0b4f8c4092e265944d3d486f083762cb1 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Tue, 30 Jun 2020 09:48:31 +0300 Subject: [PATCH] tests: remotes: use TmpDir-like fixtures --- setup.cfg | 2 +- tests/func/test_add.py | 106 +++++- tests/func/test_api.py | 34 +- tests/func/test_gc.py | 42 +++ tests/func/test_import_url.py | 53 +++ tests/func/test_repro.py | 497 -------------------------- tests/func/test_repro_multistage.py | 36 -- tests/func/test_run_multistage.py | 86 +++++ tests/func/test_update.py | 26 +- tests/remotes/__init__.py | 45 ++- tests/remotes/azure.py | 7 +- tests/remotes/base.py | 60 +++- tests/remotes/gdrive.py | 14 +- tests/remotes/gs.py | 68 +++- tests/remotes/hdfs.py | 63 +++- tests/remotes/http.py | 41 ++- tests/remotes/local.py | 9 +- tests/remotes/oss.py | 6 +- tests/remotes/s3.py | 99 +++-- tests/remotes/ssh.py | 74 +++- tests/unit/remote/test_http.py | 14 +- tests/unit/remote/test_remote_tree.py | 16 +- tests/unit/utils/test_http.py | 19 +- tests/utils/httpd.py | 69 ++-- 24 files changed, 792 insertions(+), 694 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6e9dd89a42..f9dc5b216d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ count=true [isort] include_trailing_comma=true known_first_party=dvc,tests -known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,google,grandalf,mock,moto,nanotime,networkx,packaging,paramiko,pathspec,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,tqdm,voluptuous,yaml,zc +known_third_party=PyInstaller,RangeHTTPServer,boto3,colorama,configobj,distro,dpath,flaky,flufl,funcy,git,grandalf,mock,moto,nanotime,networkx,packaging,pathspec,pylint,pytest,requests,ruamel,setuptools,shortuuid,shtab,tqdm,voluptuous,yaml,zc line_length=79 force_grid_wrap=0 use_parentheses=True diff --git a/tests/func/test_add.py b/tests/func/test_add.py index c681b61297..8babe59764 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -185,27 +185,97 @@ def test_add_file_in_dir(tmp_dir, dvc): assert stage.outs[0].def_path == "subdata" -class TestAddExternalLocalFile(TestDvc): - def test(self): - from dvc.stage.exceptions import StageExternalOutputsError +@pytest.mark.parametrize( + "workspace, hash_name, hash_value", + [ + ( + pytest.lazy_fixture("local_cloud"), + "md5", + "8c7dd922ad47494fc02c388e12c00eac", + ), + pytest.param( + pytest.lazy_fixture("ssh"), + "md5", + "8c7dd922ad47494fc02c388e12c00eac", + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ( + pytest.lazy_fixture("s3"), + "etag", + "8c7dd922ad47494fc02c388e12c00eac", + ), + (pytest.lazy_fixture("gs"), "md5", "8c7dd922ad47494fc02c388e12c00eac"), + ( + pytest.lazy_fixture("hdfs"), + "checksum", + "000002000000000000000000a86fe4d846edc1bf4c355cb6112f141e", + ), + ], + indirect=["workspace"], +) +def test_add_external_file(tmp_dir, dvc, workspace, hash_name, hash_value): + from dvc.stage.exceptions import StageExternalOutputsError - dname = TestDvc.mkdtemp() - fname = os.path.join(dname, "foo") - shutil.copyfile(self.FOO, fname) + workspace.gen("file", "file") - with self.assertRaises(StageExternalOutputsError): - self.dvc.add(fname) + with pytest.raises(StageExternalOutputsError): + dvc.add(workspace.url) - stages = self.dvc.add(fname, external=True) - self.assertEqual(len(stages), 1) - stage = stages[0] - self.assertNotEqual(stage, None) - self.assertEqual(len(stage.deps), 0) - self.assertEqual(len(stage.outs), 1) - self.assertEqual(stage.relpath, "foo.dvc") - self.assertEqual(len(os.listdir(dname)), 1) - self.assertTrue(os.path.isfile(fname)) - self.assertTrue(filecmp.cmp(fname, "foo", shallow=False)) + dvc.add("remote://workspace/file") + assert (tmp_dir / "file.dvc").read_text() == ( + "outs:\n" + f"- {hash_name}: {hash_value}\n" + " path: remote://workspace/file\n" + ) + assert (workspace / "file").read_text() == "file" + assert ( + workspace / "cache" / hash_value[:2] / hash_value[2:] + ).read_text() == "file" + + assert dvc.status() == {} + + +@pytest.mark.parametrize( + "workspace, hash_name, hash_value", + [ + ( + pytest.lazy_fixture("local_cloud"), + "md5", + "b6dcab6ccd17ca0a8bf4a215a37d14cc.dir", + ), + pytest.param( + pytest.lazy_fixture("ssh"), + "md5", + "b6dcab6ccd17ca0a8bf4a215a37d14cc.dir", + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ( + pytest.lazy_fixture("s3"), + "etag", + "ec602a6ba97b2dd07bd6d2cd89674a60.dir", + ), + ( + pytest.lazy_fixture("gs"), + "md5", + "b6dcab6ccd17ca0a8bf4a215a37d14cc.dir", + ), + ], + indirect=["workspace"], +) +def test_add_external_dir(tmp_dir, dvc, workspace, hash_name, hash_value): + workspace.gen({"dir": {"file": "file", "subdir": {"subfile": "subfile"}}}) + + dvc.add("remote://workspace/dir") + assert (tmp_dir / "dir.dvc").read_text() == ( + "outs:\n" + f"- {hash_name}: {hash_value}\n" + " path: remote://workspace/dir\n" + ) + assert (workspace / "cache" / hash_value[:2] / hash_value[2:]).is_file() class TestAddLocalRemoteFile(TestDvc): diff --git a/tests/func/test_api.py b/tests/func/test_api.py index d1b67f58e2..d7eb49255b 100644 --- a/tests/func/test_api.py +++ b/tests/func/test_api.py @@ -68,7 +68,22 @@ def test_open(tmp_dir, dvc, remote): assert fd.read() == "foo-text" -@pytest.mark.parametrize("cloud", clouds) +@pytest.mark.parametrize( + "cloud", + [ + pytest.lazy_fixture(cloud) + for cloud in [ + "real_s3", # NOTE: moto's s3 fails in some tests + "gs", + "azure", + "gdrive", + "oss", + "ssh", + "hdfs", + "http", + ] + ], +) def test_open_external(erepo_dir, cloud): erepo_dir.add_remote(config=cloud.config) @@ -104,7 +119,22 @@ def test_open_granular(tmp_dir, dvc, remote): assert fd.read() == "foo-text" -@pytest.mark.parametrize("remote", all_remotes) +@pytest.mark.parametrize( + "remote", + [ + pytest.lazy_fixture(f"{cloud}_remote") + for cloud in [ + "real_s3", # NOTE: moto's s3 fails in some tests + "gs", + "azure", + "gdrive", + "oss", + "ssh", + "hdfs", + "http", + ] + ], +) def test_missing(tmp_dir, dvc, remote): tmp_dir.dvc_gen("foo", "foo") diff --git a/tests/func/test_gc.py b/tests/func/test_gc.py index ce4b0d519a..1723beaadc 100644 --- a/tests/func/test_gc.py +++ b/tests/func/test_gc.py @@ -348,3 +348,45 @@ def test_gc_not_collect_pipeline_tracked_files(tmp_dir, dvc, run_copy): Dvcfile(dvc, PIPELINE_FILE).remove(force=True) dvc.gc(workspace=True, force=True) assert _count_files(dvc.cache.local.cache_dir) == 0 + + +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.lazy_fixture("hdfs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ], + indirect=True, +) +def test_gc_external_output(tmp_dir, dvc, workspace): + workspace.gen({"foo": "foo", "bar": "bar"}) + + (foo_stage,) = dvc.add("remote://workspace/foo") + (bar_stage,) = dvc.add("remote://workspace/bar") + + foo_hash = foo_stage.outs[0].checksum + bar_hash = bar_stage.outs[0].checksum + + assert ( + workspace / "cache" / foo_hash[:2] / foo_hash[2:] + ).read_text() == "foo" + assert ( + workspace / "cache" / bar_hash[:2] / bar_hash[2:] + ).read_text() == "bar" + + (tmp_dir / "foo.dvc").unlink() + + dvc.gc(workspace=True) + + assert not (workspace / "cache" / foo_hash[:2] / foo_hash[2:]).exists() + assert ( + workspace / "cache" / bar_hash[:2] / bar_hash[2:] + ).read_text() == "bar" diff --git a/tests/func/test_import_url.py b/tests/func/test_import_url.py index c5b8db8f9f..27ad5ae8a5 100644 --- a/tests/func/test_import_url.py +++ b/tests/func/test_import_url.py @@ -112,3 +112,56 @@ def test_import_url_with_no_exec(tmp_dir, dvc, erepo_dir): dvc.imp_url(src, ".", no_exec=True) dst = tmp_dir / "file" assert not dst.exists() + + +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.lazy_fixture("hdfs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + pytest.lazy_fixture("http"), + ], + indirect=True, +) +def test_import_url(tmp_dir, dvc, workspace): + workspace.gen("file", "file") + assert not (tmp_dir / "file").exists() # sanity check + dvc.imp_url("remote://workspace/file") + assert (tmp_dir / "file").read_text() == "file" + + assert dvc.status() == {} + + +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ], + indirect=True, +) +def test_import_url_dir(tmp_dir, dvc, workspace): + workspace.gen({"dir": {"file": "file", "subdir": {"subfile": "subfile"}}}) + assert not (tmp_dir / "dir").exists() # sanity check + dvc.imp_url("remote://workspace/dir") + assert set(os.listdir(tmp_dir / "dir")) == {"file", "subdir"} + assert (tmp_dir / "dir" / "file").read_text() == "file" + assert list(os.listdir(tmp_dir / "dir" / "subdir")) == ["subfile"] + assert (tmp_dir / "dir" / "subdir" / "subfile").read_text() == "subfile" + + assert dvc.status() == {} diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 9b77a0f727..dd052528b6 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -1,20 +1,10 @@ import filecmp -import getpass import os -import posixpath import re import shutil -import uuid from pathlib import Path -from subprocess import PIPE, Popen -from unittest import SkipTest -from urllib.parse import urljoin -import boto3 -import paramiko import pytest -from flaky.flaky_decorator import flaky -from google.cloud import storage as gc from mock import patch from dvc.dvcfile import DVC_FILE, Dvcfile @@ -25,9 +15,7 @@ ) from dvc.main import main from dvc.output.base import BaseOutput -from dvc.path_info import URLInfo from dvc.remote.local import LocalRemoteTree -from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.stage.exceptions import StageFileDoesNotExistError from dvc.system import System @@ -35,17 +23,6 @@ from dvc.utils.fs import remove from dvc.utils.yaml import dump_yaml, load_yaml from tests.basic_env import TestDvc -from tests.remotes import ( - GCP, - HDFS, - S3, - SSH, - TEST_AWS_REPO_BUCKET, - TEST_GCP_REPO_BUCKET, - Local, - SSHMocked, -) -from tests.utils.httpd import ContentMD5Handler, StaticFileServer class SingleStageRun: @@ -841,430 +818,6 @@ def test(self): self.assertTrue(filecmp.cmp(foo, bar, shallow=False)) -class ReproExternalTestMixin(SingleStageRun, TestDvc): - cache_type = None - - @staticmethod - def should_test(): - return False - - @property - def cache_scheme(self): - return self.scheme - - @property - def scheme(self): - return None - - @property - def scheme_sep(self): - return "://" - - @property - def sep(self): - return "/" - - def check_already_cached(self, stage): - stage.outs[0].remove() - - patch_download = patch.object( - stage.deps[0], "download", wraps=stage.deps[0].download - ) - - patch_checkout = patch.object( - stage.outs[0], "checkout", wraps=stage.outs[0].checkout - ) - - from dvc.stage.run import cmd_run - - patch_run = patch("dvc.stage.run.cmd_run", wraps=cmd_run) - - with self.dvc.lock, self.dvc.state: - with patch_download as mock_download: - with patch_checkout as mock_checkout: - with patch_run as mock_run: - stage.frozen = False - stage.run() - stage.frozen = True - - mock_run.assert_not_called() - mock_download.assert_not_called() - mock_checkout.assert_called_once() - - @patch("dvc.prompt.confirm", return_value=True) - def test(self, _mock_prompt): - if not self.should_test(): - raise SkipTest(f"Test {self.__class__.__name__} is disabled") - - cache = ( - self.scheme - + self.scheme_sep - + self.bucket - + self.sep - + str(uuid.uuid4()) - ) - - ret = main(["config", "cache." + self.cache_scheme, "myrepo"]) - self.assertEqual(ret, 0) - ret = main(["remote", "add", "myrepo", cache]) - self.assertEqual(ret, 0) - if self.cache_type: - ret = main(["remote", "modify", "myrepo", "type", self.cache_type]) - self.assertEqual(ret, 0) - - remote_name = "myremote" - remote_key = str(uuid.uuid4()) - remote = ( - self.scheme + self.scheme_sep + self.bucket + self.sep + remote_key - ) - - ret = main(["remote", "add", remote_name, remote]) - self.assertEqual(ret, 0) - if self.cache_type: - ret = main( - ["remote", "modify", remote_name, "type", self.cache_type] - ) - self.assertEqual(ret, 0) - - self.dvc = DvcRepo(".") - - foo_key = remote_key + self.sep + self.FOO - bar_key = remote_key + self.sep + self.BAR - - foo_path = ( - self.scheme + self.scheme_sep + self.bucket + self.sep + foo_key - ) - bar_path = ( - self.scheme + self.scheme_sep + self.bucket + self.sep + bar_key - ) - - # Using both plain and remote notation - out_foo_path = "remote://" + remote_name + "/" + self.FOO - out_bar_path = bar_path - - self.write(self.bucket, foo_key, self.FOO_CONTENTS) - - import_stage = self.dvc.imp_url(out_foo_path, "import") - - self.assertTrue(os.path.exists("import")) - self.assertTrue(filecmp.cmp("import", self.FOO, shallow=False)) - self.assertEqual(self.dvc.status([import_stage.path]), {}) - self.check_already_cached(import_stage) - - import_remote_stage = self.dvc.imp_url( - out_foo_path, out_foo_path + "_imported" - ) - self.assertEqual(self.dvc.status([import_remote_stage.path]), {}) - - cmd_stage = self._run( - outs=[out_bar_path], - deps=[out_foo_path], - cmd=self.cmd(foo_path, bar_path), - name="external-base", - external=True, - ) - - self.assertEqual(self.dvc.status([cmd_stage.addressing]), {}) - self.assertEqual(self.dvc.status(), {}) - self.check_already_cached(cmd_stage) - - self.write(self.bucket, foo_key, self.BAR_CONTENTS) - - self.assertNotEqual(self.dvc.status(), {}) - - self.dvc.update([import_stage.path]) - self.assertTrue(os.path.exists("import")) - self.assertTrue(filecmp.cmp("import", self.BAR, shallow=False)) - self.assertEqual(self.dvc.status([import_stage.path]), {}) - - self.dvc.update([import_remote_stage.path]) - self.assertEqual(self.dvc.status([import_remote_stage.path]), {}) - - stages = self.dvc.reproduce(cmd_stage.addressing) - self.assertEqual(len(stages), 1) - self.assertEqual(self.dvc.status([cmd_stage.addressing]), {}) - - self.assertEqual(self.dvc.status(), {}) - self.dvc.gc(workspace=True) - self.assertEqual(self.dvc.status(), {}) - - with self.dvc.lock: - cmd_stage.remove_outs(force=True) - self.assertNotEqual(self.dvc.status([cmd_stage.addressing]), {}) - - self.dvc.checkout([cmd_stage.path], force=True) - self.assertEqual(self.dvc.status([cmd_stage.addressing]), {}) - - -@pytest.mark.skipif(os.name == "nt", reason="temporarily disabled on windows") -@flaky(max_runs=3, min_passes=1) -class TestReproExternalS3(S3, ReproExternalTestMixin): - @property - def scheme(self): - return "s3" - - @property - def bucket(self): - return TEST_AWS_REPO_BUCKET - - def cmd(self, i, o): - return f"aws s3 cp {i} {o}" - - def write(self, bucket, key, body): - s3 = boto3.client("s3") - s3.put_object(Bucket=bucket, Key=key, Body=body) - - -class TestReproExternalGS(GCP, ReproExternalTestMixin): - @property - def scheme(self): - return "gs" - - @property - def bucket(self): - return TEST_GCP_REPO_BUCKET - - def cmd(self, i, o): - return f"gsutil cp {i} {o}" - - def write(self, bucket, key, body): - client = gc.Client() - bucket = client.bucket(bucket) - bucket.blob(key).upload_from_string(body) - - -class TestReproExternalHDFS(HDFS, ReproExternalTestMixin): - @property - def scheme(self): - return "hdfs" - - @property - def bucket(self): - return f"{getpass.getuser()}@127.0.0.1" - - def cmd(self, i, o): - return f"hadoop fs -cp {i} {o}" - - def write(self, bucket, key, body): - url = self.scheme + "://" + bucket + "/" + key - p = Popen( - f"hadoop fs -rm -f {url}", - shell=True, - executable=os.getenv("SHELL"), - stdin=PIPE, - stdout=PIPE, - stderr=PIPE, - ) - p.communicate() - - p = Popen( - "hadoop fs -mkdir -p {}".format(posixpath.dirname(url)), - shell=True, - executable=os.getenv("SHELL"), - stdin=PIPE, - stdout=PIPE, - stderr=PIPE, - ) - out, err = p.communicate() - if p.returncode != 0: - print(out) - print(err) - self.assertEqual(p.returncode, 0) - - with open("tmp", "w+") as fd: - fd.write(body) - - p = Popen( - "hadoop fs -copyFromLocal {} {}".format("tmp", url), - shell=True, - executable=os.getenv("SHELL"), - stdin=PIPE, - stdout=PIPE, - stderr=PIPE, - ) - out, err = p.communicate() - if p.returncode != 0: - print(out) - print(err) - self.assertEqual(p.returncode, 0) - - -@flaky(max_runs=5, min_passes=1) -class TestReproExternalSSH(SSH, ReproExternalTestMixin): - _dir = None - cache_type = "copy" - - @property - def scheme(self): - return "ssh" - - @property - def bucket(self): - if not self._dir: - self._dir = self.mkdtemp() - return f"{getpass.getuser()}@127.0.0.1:{self._dir}" - - def cmd(self, i, o): - prefix = "ssh://" - assert i.startswith(prefix) and o.startswith(prefix) - i = i[len(prefix) :] - o = o[len(prefix) :] - return f"scp {i} {o}" - - def write(self, _, key, body): - path = posixpath.join(self._dir, key) - - ssh = None - sftp = None - try: - ssh = paramiko.SSHClient() - ssh.load_system_host_keys() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect("127.0.0.1") - - sftp = ssh.open_sftp() - try: - sftp.stat(path) - sftp.remove(path) - except OSError: - pass - - _, stdout, _ = ssh.exec_command(f"mkdir -p $(dirname {path})") - self.assertEqual(stdout.channel.recv_exit_status(), 0) - - with sftp.open(path, "w+") as fobj: - fobj.write(body) - finally: - if sftp: - sftp.close() - if ssh: - ssh.close() - - -class TestReproExternalLOCAL(Local, ReproExternalTestMixin): - cache_type = "hardlink" - - def setUp(self): - super().setUp() - self.tmpdir = self.mkdtemp() - ret = main(["config", "cache.type", "hardlink"]) - self.assertEqual(ret, 0) - self.dvc = DvcRepo(".") - - @property - def cache_scheme(self): - return "local" - - @property - def scheme(self): - return "" - - @property - def scheme_sep(self): - return "" - - @property - def sep(self): - return os.sep - - @property - def bucket(self): - return self.tmpdir - - def cmd(self, i, o): - if os.name == "nt": - return f"copy {i} {o}" - return f"cp {i} {o}" - - def write(self, bucket, key, body): - path = os.path.join(bucket, key) - dname = os.path.dirname(path) - - if not os.path.exists(dname): - os.makedirs(dname) - - with open(path, "w+") as fd: - fd.write(body) - - -class TestReproExternalHTTP(ReproExternalTestMixin): - _external_cache_id = None - - @staticmethod - def get_remote(port): - return f"http://localhost:{port}/" - - @property - def local_cache(self): - return os.path.join(self.dvc.dvc_dir, "cache") - - def test(self): # pylint: disable=arguments-differ - # Import - with StaticFileServer() as httpd: - import_url = urljoin(self.get_remote(httpd.server_port), self.FOO) - import_output = "imported_file" - import_stage = self.dvc.imp_url(import_url, import_output) - - self.assertTrue(os.path.exists(import_output)) - self.assertTrue(filecmp.cmp(import_output, self.FOO, shallow=False)) - - self.dvc.remove("imported_file.dvc", outs=True) - - with StaticFileServer(handler_class=ContentMD5Handler) as httpd: - import_url = urljoin(self.get_remote(httpd.server_port), self.FOO) - import_output = "imported_file" - import_stage = self.dvc.imp_url(import_url, import_output) - assert import_stage.repo == self.dvc - - self.assertTrue(os.path.exists(import_output)) - self.assertTrue(filecmp.cmp(import_output, self.FOO, shallow=False)) - - # Run --deps - with StaticFileServer() as httpd: - remote = self.get_remote(httpd.server_port) - - cache_id = str(uuid.uuid4()) - cache = urljoin(remote, cache_id) - - ret1 = main(["remote", "add", "mycache", cache]) - ret2 = main(["remote", "add", "myremote", remote]) - self.assertEqual(ret1, 0) - self.assertEqual(ret2, 0) - - self.dvc = import_stage.repo = DvcRepo(".") - - run_dependency = urljoin(remote, self.BAR) - run_output = "remote_file" - cmd = f'open("{run_output}", "w+")' - - with open("create-output.py", "w") as fd: - fd.write(cmd) - - run_stage = self._run( - deps=[run_dependency], - outs=[run_output], - cmd="python create-output.py", - name="http_run", - ) - self.assertTrue(run_stage is not None) - - self.assertTrue(os.path.exists(run_output)) - - # Pull - with self.dvc.lock: - self.assertEqual(import_stage.repo.lock.is_locked, True) - self.assertEqual(self.dvc.lock.is_locked, True) - import_stage.remove_outs(force=True) - self.assertFalse(os.path.exists(import_output)) - - shutil.move(self.local_cache, cache_id) - self.assertFalse(os.path.exists(self.local_cache)) - - self.dvc.pull([import_stage.path], remote="mycache") - - self.assertTrue(os.path.exists(import_output)) - - class TestReproShell(TestDvc): def test(self): if os.name == "nt": @@ -1728,56 +1281,6 @@ def test_downstream(dvc): assert evaluation[4].relpath == "E.dvc" -@pytest.mark.skipif( - os.name == "nt", - reason="external output scenario is not supported on Windows", -) -def test_ssh_dir_out(tmp_dir, dvc, ssh_server): - from tests.remotes.ssh import TEST_SSH_USER, TEST_SSH_KEY_PATH - - tmp_dir.gen({"foo": "foo content"}) - - # Set up remote and cache - user = TEST_SSH_USER - port = ssh_server.port - keyfile = TEST_SSH_KEY_PATH - - remote_url = SSHMocked.get_url(user, port) - assert main(["remote", "add", "upstream", remote_url]) == 0 - assert main(["remote", "modify", "upstream", "keyfile", keyfile]) == 0 - - cache_url = SSHMocked.get_url(user, port) - assert main(["remote", "add", "sshcache", cache_url]) == 0 - assert main(["config", "cache.ssh", "sshcache"]) == 0 - assert main(["remote", "modify", "sshcache", "keyfile", keyfile]) == 0 - - # Recreating to reread configs - repo = DvcRepo(dvc.root_dir) - - # To avoid "WARNING: UNPROTECTED PRIVATE KEY FILE" from ssh - os.chmod(keyfile, 0o600) - - (tmp_dir / "script.py").write_text( - "import sys, pathlib\n" - "path = pathlib.Path(sys.argv[1])\n" - "dir_out = path / 'dir-out'\n" - "dir_out.mkdir()\n" - "(dir_out / '1.txt').write_text('1')\n" - "(dir_out / '2.txt').write_text('2')\n" - ) - - url_info = URLInfo(remote_url) - repo.run( - cmd="python {} {}".format(tmp_dir / "script.py", url_info.path), - single_stage=True, - outs=["remote://upstream/dir-out"], - deps=["foo"], # add a fake dep to not consider this a callback - ) - - repo.reproduce("dir-out.dvc") - repo.reproduce("dir-out.dvc", force=True) - - def test_repro_when_cmd_changes(tmp_dir, dvc, run_copy, mocker): from dvc.dvcfile import SingleStageFile diff --git a/tests/func/test_repro_multistage.py b/tests/func/test_repro_multistage.py index 995cc409df..1dd363a7dc 100644 --- a/tests/func/test_repro_multistage.py +++ b/tests/func/test_repro_multistage.py @@ -159,42 +159,6 @@ class TestReproChangedDirDataMultiStage( pass -class TestReproExternalS3MultiStage( - MultiStageRun, test_repro.TestReproExternalS3 -): - pass - - -class TestReproExternalGSMultiStage( - MultiStageRun, test_repro.TestReproExternalGS -): - pass - - -class TestReproExternalHDFSMultiStage( - MultiStageRun, test_repro.TestReproExternalHDFS -): - pass - - -class TestReproExternalHTTPMultiStage( - MultiStageRun, test_repro.TestReproExternalHTTP -): - pass - - -class TestReproExternalLOCALMultiStage( - MultiStageRun, test_repro.TestReproExternalLOCAL -): - pass - - -class TestReproExternalSSHMultiStage( - MultiStageRun, test_repro.TestReproExternalSSH -): - pass - - def test_non_existing_stage_name(tmp_dir, dvc, run_copy): from dvc.exceptions import DvcException diff --git a/tests/func/test_run_multistage.py b/tests/func/test_run_multistage.py index c300f1ceef..daf9c0c2ee 100644 --- a/tests/func/test_run_multistage.py +++ b/tests/func/test_run_multistage.py @@ -361,3 +361,89 @@ def test_run_overwrite_preserves_meta_and_comment(tmp_dir, dvc, run_copy): assert (tmp_dir / PIPELINE_FILE).read_text() == text.format( src="foo1", dest="bar1" ) + + +@pytest.mark.parametrize( + "workspace, hash_name, foo_hash, bar_hash", + [ + ( + pytest.lazy_fixture("local_cloud"), + "md5", + "acbd18db4cc2f85cedef654fccc4a4d8", + "37b51d194a7513e45b56f6524f2d51f2", + ), + pytest.param( + pytest.lazy_fixture("ssh"), + "md5", + "acbd18db4cc2f85cedef654fccc4a4d8", + "37b51d194a7513e45b56f6524f2d51f2", + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ( + pytest.lazy_fixture("s3"), + "etag", + "acbd18db4cc2f85cedef654fccc4a4d8", + "37b51d194a7513e45b56f6524f2d51f2", + ), + ( + pytest.lazy_fixture("gs"), + "md5", + "acbd18db4cc2f85cedef654fccc4a4d8", + "37b51d194a7513e45b56f6524f2d51f2", + ), + ( + pytest.lazy_fixture("hdfs"), + "checksum", + "0000020000000000000000003dba826b9be9c6a8e2f8310a770555c4", + "00000200000000000000000075433c81259d3c38e364b348af52e84d", + ), + ], + indirect=["workspace"], +) +def test_run_external_outputs( + tmp_dir, dvc, workspace, hash_name, foo_hash, bar_hash +): + workspace.gen("foo", "foo") + dvc.run( + name="mystage", + cmd="mycmd", + deps=["remote://workspace/foo"], + outs=["remote://workspace/bar"], + no_exec=True, + ) + + dvc_yaml = ( + "stages:\n" + " mystage:\n" + " cmd: mycmd\n" + " deps:\n" + " - remote://workspace/foo\n" + " outs:\n" + " - remote://workspace/bar\n" + ) + + assert (tmp_dir / "dvc.yaml").read_text() == dvc_yaml + assert not (tmp_dir / "dvc.lock").exists() + + workspace.gen("bar", "bar") + dvc.commit("dvc.yaml", force=True) + + assert (tmp_dir / "dvc.yaml").read_text() == dvc_yaml + assert (tmp_dir / "dvc.lock").read_text() == ( + "mystage:\n" + " cmd: mycmd\n" + " deps:\n" + " - path: remote://workspace/foo\n" + f" {hash_name}: {foo_hash}\n" + " outs:\n" + " - path: remote://workspace/bar\n" + f" {hash_name}: {bar_hash}\n" + ) + + assert (workspace / "foo").read_text() == "foo" + assert (workspace / "bar").read_text() == "bar" + assert ( + workspace / "cache" / bar_hash[:2] / bar_hash[2:] + ).read_text() == "bar" diff --git a/tests/func/test_update.py b/tests/func/test_update.py index b5bc9d14ed..b1898c4582 100644 --- a/tests/func/test_update.py +++ b/tests/func/test_update.py @@ -111,19 +111,33 @@ def test_update_before_and_after_dvc_init(tmp_dir, dvc, git_dir): assert dvc.status([stage.path]) == {} -def test_update_import_url(tmp_dir, dvc, tmp_path_factory): - import_src = tmp_path_factory.mktemp("import_url_source") - src = import_src / "file" - src.write_text("file content") +@pytest.mark.parametrize( + "workspace", + [ + pytest.lazy_fixture("local_cloud"), + pytest.lazy_fixture("s3"), + pytest.lazy_fixture("gs"), + pytest.lazy_fixture("hdfs"), + pytest.param( + pytest.lazy_fixture("ssh"), + marks=pytest.mark.skipif( + os.name == "nt", reason="disabled on windows" + ), + ), + ], + indirect=True, +) +def test_update_import_url(tmp_dir, dvc, workspace): + workspace.gen("file", "file content") dst = tmp_dir / "imported_file" - stage = dvc.imp_url(os.fspath(src), os.fspath(dst)) + stage = dvc.imp_url("remote://workspace/file", os.fspath(dst)) assert dst.is_file() assert dst.read_text() == "file content" # update data - src.write_text("updated file content") + workspace.gen("file", "updated file content") assert dvc.status([stage.path]) == {} dvc.update([stage.path]) diff --git a/tests/remotes/__init__.py b/tests/remotes/__init__.py index 6b23448e1e..36ea3e119f 100644 --- a/tests/remotes/__init__.py +++ b/tests/remotes/__init__.py @@ -1,3 +1,19 @@ +import pytest + +from .azure import Azure, azure, azure_remote # noqa: F401 +from .hdfs import HDFS, hdfs, hdfs_remote # noqa: F401 +from .http import HTTP, http, http_remote, http_server # noqa: F401 +from .local import Local, local_cloud, local_remote # noqa: F401 +from .oss import OSS, TEST_OSS_REPO_BUCKET, oss, oss_remote # noqa: F401 +from .s3 import ( # noqa: F401 + S3, + TEST_AWS_REPO_BUCKET, + real_s3, + real_s3_remote, + s3, + s3_remote, +) + TEST_REMOTE = "upstream" TEST_CONFIG = { "cache": {}, @@ -5,12 +21,6 @@ "remote": {TEST_REMOTE: {"url": ""}}, } -from .azure import Azure, azure, azure_remote # noqa: F401 -from .hdfs import HDFS, hdfs, hdfs_remote # noqa: F401 -from .http import HTTP, http, http_remote, http_server # noqa: F401 -from .local import Local, local_cloud, local_remote # noqa: F401 -from .oss import OSS, TEST_OSS_REPO_BUCKET, oss, oss_remote # noqa: F401 -from .s3 import S3, TEST_AWS_REPO_BUCKET, S3Mocked, s3, s3_remote # noqa: F401 from .gdrive import ( # noqa: F401; noqa: F401 TEST_GDRIVE_REPO_BUCKET, @@ -33,3 +43,26 @@ ssh_remote, ssh_server, ) + + +@pytest.fixture +def workspace(tmp_dir, dvc, request): + from dvc.cache import Cache + + cloud = request.param + + assert cloud + + tmp_dir.add_remote(name="workspace", config=cloud.config, default=False) + tmp_dir.add_remote( + name="cache", url="remote://workspace/cache", default=False + ) + + scheme = getattr(cloud, "scheme", "local") + if scheme != "http": + with dvc.config.edit() as conf: + conf["cache"][scheme] = "cache" + + dvc.cache = Cache(dvc) + + return cloud diff --git a/tests/remotes/azure.py b/tests/remotes/azure.py index 2f15ca79a6..2549232460 100644 --- a/tests/remotes/azure.py +++ b/tests/remotes/azure.py @@ -1,14 +1,17 @@ +# pylint:disable=abstract-method + import os import uuid import pytest +from dvc.path_info import CloudURLInfo from dvc.utils import env2bool from .base import Base -class Azure(Base): +class Azure(Base, CloudURLInfo): @staticmethod def should_test(): do_test = env2bool("DVC_TEST_AZURE", undefined=None) @@ -30,7 +33,7 @@ def get_url(): def azure(): if not Azure.should_test(): pytest.skip("no azure running") - yield Azure() + yield Azure(Azure.get_url()) @pytest.fixture diff --git a/tests/remotes/base.py b/tests/remotes/base.py index 50904c7e5b..71b0fad67c 100644 --- a/tests/remotes/base.py +++ b/tests/remotes/base.py @@ -1,7 +1,61 @@ +import pathlib + from funcy import cached_property +from dvc.path_info import URLInfo + + +class Base(URLInfo): + def is_file(self): + raise NotImplementedError + + def is_dir(self): + raise NotImplementedError + + def exists(self): + raise NotImplementedError + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + raise NotImplementedError + + def write_text(self, contents, encoding=None, errors=None): + raise NotImplementedError + + def write_bytes(self, contents): + raise NotImplementedError + + def read_text(self, encoding=None, errors=None): + raise NotImplementedError + + def read_bytes(self): + raise NotImplementedError + + def _gen(self, struct, prefix=None): + for name, contents in struct.items(): + path = (prefix or self) / name + + if isinstance(contents, dict): + if not contents: + path.mkdir(parents=True) + else: + self._gen(contents, prefix=path) + else: + path.parent.mkdir(parents=True) + if isinstance(contents, bytes): + path.write_bytes(contents) + else: + path.write_text(contents, encoding="utf-8") + + def gen(self, struct, text=""): + if isinstance(struct, (str, bytes, pathlib.PurePath)): + struct = {struct: text} + + self._gen(struct) + return struct.keys() + + def close(self): + pass -class Base: @staticmethod def should_test(): return True @@ -10,10 +64,6 @@ def should_test(): def get_url(): raise NotImplementedError - @cached_property - def url(self): - return self.get_url() - @cached_property def config(self): return {"url": self.url} diff --git a/tests/remotes/gdrive.py b/tests/remotes/gdrive.py index c8dc7fe24e..a8c6c77172 100644 --- a/tests/remotes/gdrive.py +++ b/tests/remotes/gdrive.py @@ -1,9 +1,11 @@ +# pylint:disable=abstract-method import os import uuid import pytest from funcy import cached_property +from dvc.path_info import CloudURLInfo from dvc.remote.gdrive import GDriveRemoteTree from .base import Base @@ -11,7 +13,7 @@ TEST_GDRIVE_REPO_BUCKET = "root" -class GDrive(Base): +class GDrive(Base, CloudURLInfo): @staticmethod def should_test(): return os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) is not None @@ -25,10 +27,6 @@ def config(self): "gdrive_use_service_account": True, } - def __init__(self, dvc): - tree = GDriveRemoteTree(dvc, self.config) - tree._gdrive_create_dir("root", tree.path_info.path) - @staticmethod def _get_storagepath(): return TEST_GDRIVE_REPO_BUCKET + "/" + str(uuid.uuid4()) @@ -46,7 +44,11 @@ def gdrive(make_tmp_dir): # NOTE: temporary workaround tmp_dir = make_tmp_dir("gdrive", dvc=True) - return GDrive(tmp_dir.dvc) + + ret = GDrive(GDrive.get_url()) + tree = GDriveRemoteTree(tmp_dir.dvc, ret.config) + tree._gdrive_create_dir("root", tree.path_info.path) + return ret @pytest.fixture diff --git a/tests/remotes/gs.py b/tests/remotes/gs.py index 090baa4d6b..b5a7b68a48 100644 --- a/tests/remotes/gs.py +++ b/tests/remotes/gs.py @@ -1,12 +1,11 @@ +import locale import os import uuid -from contextlib import contextmanager import pytest from funcy import cached_property -from dvc.remote.base import Remote -from dvc.remote.gs import GSRemoteTree +from dvc.path_info import CloudURLInfo from dvc.utils import env2bool from .base import Base @@ -23,7 +22,7 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE -class GCP(Base): +class GCP(Base, CloudURLInfo): @staticmethod def should_test(): from subprocess import CalledProcessError, check_output @@ -64,24 +63,63 @@ def _get_storagepath(): def get_url(): return "gs://" + GCP._get_storagepath() - @classmethod - @contextmanager - def remote(cls, repo): - yield Remote(GSRemoteTree(repo, {"url": cls.get_url()})) + @property + def _gc(self): + from google.cloud.storage import Client - @staticmethod - def put_objects(remote, objects): - client = remote.tree.gs - bucket = client.get_bucket(remote.path_info.bucket) - for key, body in objects.items(): - bucket.blob((remote.path_info / key).path).upload_from_string(body) + return Client.from_service_account_json(TEST_GCP_CREDS_FILE) + + @property + def _bucket(self): + return self._gc.bucket(self.bucket) + + @property + def _blob(self): + return self._bucket.blob(self.path) + + def is_file(self): + if self.path.endswith("/"): + return False + + return self._blob.exists() + + def is_dir(self): + dir_path = self / "" + return bool(list(self._bucket.list_blobs(prefix=dir_path.path))) + + def exists(self): + return self.is_file() or self.is_dir() + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + assert mode == 0o777 + assert parents + assert not exist_ok + + def write_bytes(self, contents): + assert isinstance(contents, bytes) + self._blob.upload_from_string(contents) + + def write_text(self, contents, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + self.write_bytes(contents.encode(encoding)) + + def read_bytes(self): + return self._blob.download_as_string() + + def read_text(self, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + return self.read_bytes().decode(encoding) @pytest.fixture def gs(): if not GCP.should_test(): pytest.skip("no gs") - yield GCP() + yield GCP(GCP.get_url()) @pytest.fixture diff --git a/tests/remotes/hdfs.py b/tests/remotes/hdfs.py index bade5333cf..3ddfaed6ad 100644 --- a/tests/remotes/hdfs.py +++ b/tests/remotes/hdfs.py @@ -1,15 +1,19 @@ import getpass +import locale import os import platform +from contextlib import contextmanager from subprocess import CalledProcessError, Popen, check_output import pytest +from dvc.path_info import URLInfo + from .base import Base from .local import Local -class HDFS(Base): +class HDFS(Base, URLInfo): @staticmethod def should_test(): if platform.system() != "Linux": @@ -43,12 +47,67 @@ def get_url(): getpass.getuser(), Local.get_storagepath() ) + @contextmanager + def _hdfs(self): + import pyarrow + + conn = pyarrow.hdfs.connect() + try: + yield conn + finally: + conn.close() + + def is_file(self): + with self._hdfs() as _hdfs: + return _hdfs.isfile(self.path) + + def is_dir(self): + with self._hdfs() as _hdfs: + return _hdfs.isfile(self.path) + + def exists(self): + with self._hdfs() as _hdfs: + return _hdfs.exists(self.path) + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + assert mode == 0o777 + assert parents + assert not exist_ok + + with self._hdfs() as _hdfs: + # NOTE: hdfs.mkdir always creates parents + _hdfs.mkdir(self.path) + + def write_bytes(self, contents): + with self._hdfs() as _hdfs: + # NOTE: hdfs.open only supports 'rb', 'wb' or 'ab' + with _hdfs.open(self.path, "wb") as fobj: + fobj.write(contents) + + def write_text(self, contents, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + self.write_bytes(contents.encode(encoding)) + + def read_bytes(self): + with self._hdfs() as _hdfs: + # NOTE: hdfs.open only supports 'rb', 'wb' or 'ab' + with _hdfs.open(self.path, "rb") as fobj: + return fobj.read() + + def read_text(self, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + return self.read_bytes().decode(encoding) + @pytest.fixture def hdfs(): if not HDFS.should_test(): pytest.skip("no hadoop running") - yield HDFS() + yield HDFS(HDFS.get_url()) @pytest.fixture diff --git a/tests/remotes/http.py b/tests/remotes/http.py index 49cf75ffb1..d12db70247 100644 --- a/tests/remotes/http.py +++ b/tests/remotes/http.py @@ -1,28 +1,51 @@ +# pylint:disable=abstract-method +import locale +import os +import uuid + import pytest +import requests + +from dvc.path_info import HTTPURLInfo from .base import Base -class HTTP(Base): +class HTTP(Base, HTTPURLInfo): @staticmethod def get_url(port): # pylint: disable=arguments-differ - return f"http://127.0.0.1:{port}" + dname = str(uuid.uuid4()) + return f"http://127.0.0.1:{port}/{dname}" - def __init__(self, server): - self.url = self.get_url(server.server_port) + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + assert mode == 0o777 + assert parents + assert not exist_ok + def write_bytes(self, contents): + assert isinstance(contents, bytes) + response = requests.post(self.url, data=contents) + assert response.status_code == 200 -@pytest.fixture -def http_server(tmp_dir): - from tests.utils.httpd import PushRequestHandler, StaticFileServer + def write_text(self, contents, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + self.write_bytes(contents.encode(encoding)) + + +@pytest.fixture(scope="session") +def http_server(tmp_path_factory): + from tests.utils.httpd import StaticFileServer - with StaticFileServer(handler_class=PushRequestHandler) as httpd: + directory = os.fspath(tmp_path_factory.mktemp("http")) + with StaticFileServer(directory=directory) as httpd: yield httpd @pytest.fixture def http(http_server): - yield HTTP(http_server) + yield HTTP(HTTP.get_url(http_server.server_port)) @pytest.fixture diff --git a/tests/remotes/local.py b/tests/remotes/local.py index a99cc26c27..8ba3580d5f 100644 --- a/tests/remotes/local.py +++ b/tests/remotes/local.py @@ -1,4 +1,4 @@ -# pylint: disable=cyclic-import +# pylint: disable=cyclic-import,disable=abstract-method import pytest from tests.basic_env import TestDvc @@ -17,8 +17,11 @@ def get_url(): @pytest.fixture -def local_cloud(): - yield Local() +def local_cloud(make_tmp_dir): + ret = make_tmp_dir("local-cloud") + ret.url = str(ret) + ret.config = {"url": ret.url} + return ret @pytest.fixture diff --git a/tests/remotes/oss.py b/tests/remotes/oss.py index e8a59c7906..fd69d1fbf9 100644 --- a/tests/remotes/oss.py +++ b/tests/remotes/oss.py @@ -1,8 +1,10 @@ +# pylint:disable=abstract-method import os import uuid import pytest +from dvc.path_info import CloudURLInfo from dvc.utils import env2bool from .base import Base @@ -10,7 +12,7 @@ TEST_OSS_REPO_BUCKET = "dvc-test" -class OSS(Base): +class OSS(Base, CloudURLInfo): @staticmethod def should_test(): do_test = env2bool("DVC_TEST_OSS", undefined=None) @@ -36,7 +38,7 @@ def get_url(): def oss(): if not OSS.should_test(): pytest.skip("no oss running") - yield OSS() + yield OSS(OSS.get_url()) @pytest.fixture diff --git a/tests/remotes/s3.py b/tests/remotes/s3.py index f6e2c6e15f..bcd1704952 100644 --- a/tests/remotes/s3.py +++ b/tests/remotes/s3.py @@ -1,12 +1,12 @@ +import locale import os import uuid -from contextlib import contextmanager import pytest -from moto.s3 import mock_s3 +from funcy import cached_property +from moto import mock_s3 -from dvc.remote.base import Remote -from dvc.remote.s3 import S3RemoteTree +from dvc.path_info import CloudURLInfo from dvc.utils import env2bool from .base import Base @@ -14,7 +14,7 @@ TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-temp") -class S3(Base): +class S3(Base, CloudURLInfo): @staticmethod def should_test(): do_test = env2bool("DVC_TEST_AWS", undefined=None) @@ -36,12 +36,70 @@ def _get_storagepath(): def get_url(): return "s3://" + S3._get_storagepath() + @cached_property + def _s3(self): + import boto3 + + return boto3.client("s3") + + def is_file(self): + from botocore.exceptions import ClientError + + if self.path.endswith("/"): + return False + + try: + self._s3.head_object(Bucket=self.bucket, Key=self.path) + except ClientError as exc: + if exc.response["Error"]["Code"] != "404": + raise + return False + + return True + + def is_dir(self): + path = (self / "").path + resp = self._s3.list_objects(Bucket=self.bucket, Prefix=path) + return bool(resp.get("Contents")) + + def exists(self): + return self.is_file() or self.is_dir() + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + assert mode == 0o777 + assert parents + assert not exist_ok + + def write_bytes(self, contents): + self._s3.put_object( + Bucket=self.bucket, Key=self.path, Body=contents, + ) + + def write_text(self, contents, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + self.write_bytes(contents.encode(encoding)) + + def read_bytes(self): + data = self._s3.get_object(Bucket=self.bucket, Key=self.path) + return data["Body"].read() + + def read_text(self, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + return self.read_bytes().decode(encoding) + @pytest.fixture def s3(): - if not S3.should_test(): - pytest.skip("no s3") - yield S3() + with mock_s3(): + import boto3 + + boto3.client("s3").create_bucket(Bucket=TEST_AWS_REPO_BUCKET) + + yield S3(S3.get_url()) @pytest.fixture @@ -50,19 +108,14 @@ def s3_remote(tmp_dir, dvc, s3): yield s3 -class S3Mocked(S3): - @classmethod - @contextmanager - def remote(cls, repo): - with mock_s3(): - yield Remote(S3RemoteTree(repo, {"url": cls.get_url()})) +@pytest.fixture +def real_s3(): + if not S3.should_test(): + pytest.skip("no real s3") + yield S3(S3.get_url()) - @staticmethod - def put_objects(remote, objects): - client = remote.tree.s3 - bucket = remote.path_info.bucket - client.create_bucket(Bucket=bucket) - for key, body in objects.items(): - client.put_object( - Bucket=bucket, Key=(remote.path_info / key).path, Body=body - ) + +@pytest.fixture +def real_s3_remote(tmp_dir, dvc, real_s3): + tmp_dir.add_remote(config=real_s3.config) + yield real_s3 diff --git a/tests/remotes/ssh.py b/tests/remotes/ssh.py index 85387b59ae..336d52ec0b 100644 --- a/tests/remotes/ssh.py +++ b/tests/remotes/ssh.py @@ -1,10 +1,13 @@ import getpass +import locale import os +from contextlib import contextmanager from subprocess import CalledProcessError, check_output import pytest from funcy import cached_property +from dvc.path_info import URLInfo from dvc.utils import env2bool from .base import Base @@ -41,7 +44,7 @@ def get_url(): ) -class SSHMocked(Base): +class SSHMocked(Base, URLInfo): @staticmethod def get_url(user, port): # pylint: disable=arguments-differ path = Local.get_storagepath() @@ -63,13 +66,6 @@ def get_url(user, port): # pylint: disable=arguments-differ url = f"ssh://{user}@127.0.0.1:{port}{path}" return url - def __init__(self, server): - self.server = server - - @cached_property - def url(self): - return self.get_url(TEST_SSH_USER, self.server.port) - @cached_property def config(self): return { @@ -77,6 +73,66 @@ def config(self): "keyfile": TEST_SSH_KEY_PATH, } + @contextmanager + def _ssh(self): + from dvc.remote.ssh.connection import SSHConnection + + conn = SSHConnection( + host=self.host, + port=self.port, + username=TEST_SSH_USER, + key_filename=TEST_SSH_KEY_PATH, + ) + try: + yield conn + finally: + conn.close() + + def is_file(self): + with self._ssh() as _ssh: + return _ssh.isfile(self.path) + + def is_dir(self): + with self._ssh() as _ssh: + return _ssh.isdir(self.path) + + def exists(self): + with self._ssh() as _ssh: + return _ssh.exists(self.path) + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + assert mode == 0o777 + assert parents + assert not exist_ok + + with self._ssh() as _ssh: + _ssh.makedirs(self.path) + + def write_bytes(self, contents): + assert isinstance(contents, bytes) + with self._ssh() as _ssh: + with _ssh.open(self.path, "w+") as fobj: + # NOTE: accepts both str and bytes + fobj.write(contents) + + def write_text(self, contents, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + self.write_bytes(contents.encode(encoding)) + + def read_bytes(self): + with self._ssh() as _ssh: + # NOTE: sftp always reads in binary format + with _ssh.open(self.path, "r") as fobj: + return fobj.read() + + def read_text(self, encoding=None, errors=None): + if not encoding: + encoding = locale.getpreferredencoding(False) + assert errors is None + return self.read_bytes().decode(encoding) + @pytest.fixture def ssh_server(): @@ -106,7 +162,7 @@ def ssh(ssh_server, monkeypatch): # NOTE: see http://github.com/iterative/dvc/pull/3501 monkeypatch.setattr(SSHRemoteTree, "CAN_TRAVERSE", False) - return SSHMocked(ssh_server) + return SSHMocked(SSHMocked.get_url(TEST_SSH_USER, ssh_server.port)) @pytest.fixture diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index 16949c4849..31e850c91c 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -1,20 +1,14 @@ import pytest from dvc.exceptions import HTTPError -from dvc.path_info import URLInfo from dvc.remote.http import HTTPRemoteTree -from tests.utils.httpd import StaticFileServer -def test_download_fails_on_error_code(dvc): - with StaticFileServer() as httpd: - url = f"http://localhost:{httpd.server_port}/" - config = {"url": url} +def test_download_fails_on_error_code(dvc, http): + tree = HTTPRemoteTree(dvc, http.config) - tree = HTTPRemoteTree(dvc, config) - - with pytest.raises(HTTPError): - tree._download(URLInfo(url) / "missing.txt", "missing.txt") + with pytest.raises(HTTPError): + tree._download(http / "missing.txt", "missing.txt") def test_public_auth_method(dvc): diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index 7bba2253cd..0c01ab89b6 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -3,11 +3,11 @@ import pytest from dvc.path_info import PathInfo +from dvc.remote import get_remote from dvc.remote.s3 import S3RemoteTree from dvc.utils.fs import walk_files -from tests.remotes import GCP, S3Mocked -remotes = [GCP, S3Mocked] +remotes = [pytest.lazy_fixture(fix) for fix in ["gs", "s3"]] FILE_WITH_CONTENTS = { "data1.txt": "", @@ -27,11 +27,9 @@ @pytest.fixture def remote(request, dvc): - if not request.param.should_test(): - raise pytest.skip() - with request.param.remote(dvc) as _remote: - request.param.put_objects(_remote, FILE_WITH_CONTENTS) - yield _remote + cloud = request.param + cloud.gen(FILE_WITH_CONTENTS) + return get_remote(dvc, **cloud.config) @pytest.mark.parametrize("remote", remotes, indirect=True) @@ -86,7 +84,7 @@ def test_walk_files(remote): assert list(remote.tree.walk_files(remote.path_info / "data")) == files -@pytest.mark.parametrize("remote", [S3Mocked], indirect=True) +@pytest.mark.parametrize("remote", [pytest.lazy_fixture("s3")], indirect=True) def test_copy_preserve_etag_across_buckets(remote, dvc): s3 = remote.tree.s3 s3.create_bucket(Bucket="another") @@ -115,7 +113,7 @@ def test_makedirs(remote): assert tree.isdir(empty_dir) -@pytest.mark.parametrize("remote", [GCP, S3Mocked], indirect=True) +@pytest.mark.parametrize("remote", remotes, indirect=True) def test_isfile(remote): test_cases = [ (False, "empty_dir/"), diff --git a/tests/unit/utils/test_http.py b/tests/unit/utils/test_http.py index cb45e298a5..a2f97701ff 100644 --- a/tests/unit/utils/test_http.py +++ b/tests/unit/utils/test_http.py @@ -3,12 +3,9 @@ import requests from dvc.utils.http import open_url -from tests.utils.httpd import StaticFileServer -def test_open_url(tmp_path, monkeypatch): - monkeypatch.chdir(tmp_path) - +def test_open_url(tmp_path, monkeypatch, http): # Simulate bad connection original_iter_content = requests.Response.iter_content @@ -26,11 +23,9 @@ def bad_iter_content(self, *args, **kwargs): # using twice of that plus something tests second resume, # this is important because second response is different text = "0123456789" * (io.DEFAULT_BUFFER_SIZE // 10 + 1) - (tmp_path / "sample.txt").write_text(text * 2) - - with StaticFileServer() as httpd: - url = f"http://localhost:{httpd.server_port}/sample.txt" - with open_url(url) as fd: - # Test various .read() variants - assert fd.read(len(text)) == text - assert fd.read() == text + http.gen("sample.txt", text * 2) + + with open_url((http / "sample.txt").url) as fd: + # Test various .read() variants + assert fd.read(len(text)) == text + assert fd.read() == text diff --git a/tests/utils/httpd.py b/tests/utils/httpd.py index d706752802..a28d370574 100644 --- a/tests/utils/httpd.py +++ b/tests/utils/httpd.py @@ -1,14 +1,48 @@ import hashlib import os +import sys import threading from http import HTTPStatus -from http.server import HTTPServer, SimpleHTTPRequestHandler +from http.server import HTTPServer from RangeHTTPServer import RangeRequestHandler class TestRequestHandler(RangeRequestHandler): - checksum_header = None + def __init__(self, *args, **kwargs): + # NOTE: `directory` was introduced in 3.7 + if sys.version_info < (3, 7): + self.directory = kwargs.pop("directory", None) or os.getcwd() + super().__init__(*args, **kwargs) + + def translate_path(self, path): + import urllib + import posixpath + + # NOTE: `directory` was introduced in 3.7 + if sys.version_info < (3, 7): + return super().translate_path(path) + + path = path.split("?", 1)[0] + path = path.split("#", 1)[0] + # Don't forget explicit trailing slash when normalizing. Issue17324 + trailing_slash = path.rstrip().endswith("/") + try: + path = urllib.parse.unquote(path, errors="surrogatepass") + except UnicodeDecodeError: + path = urllib.parse.unquote(path) + path = posixpath.normpath(path) + words = path.split("/") + words = filter(None, words) + path = self.directory + for word in words: + if os.path.dirname(word) or word in (os.curdir, os.pardir): + # Ignore components that are not a simple file/directory name + continue + path = os.path.join(path, word) + if trailing_slash: + path += "/" + return path def end_headers(self): # RangeRequestHandler only sends Accept-Ranges header if Range header @@ -17,27 +51,16 @@ def end_headers(self): self.send_header("Accept-Ranges", "bytes") # Add a checksum header - if self.checksum_header: - file = self.translate_path(self.path) + file = self.translate_path(self.path) - if not os.path.isdir(file) and os.path.exists(file): - with open(file) as fd: - encoded_text = fd.read().encode("utf8") - checksum = hashlib.md5(encoded_text).hexdigest() - self.send_header(self.checksum_header, checksum) + if not os.path.isdir(file) and os.path.exists(file): + with open(file) as fd: + encoded_text = fd.read().encode("utf8") + checksum = hashlib.md5(encoded_text).hexdigest() + self.send_header("Content-MD5", checksum) RangeRequestHandler.end_headers(self) - -class ETagHandler(TestRequestHandler): - checksum_header = "ETag" - - -class ContentMD5Handler(TestRequestHandler): - checksum_header = "Content-MD5" - - -class PushRequestHandler(SimpleHTTPRequestHandler): def _chunks(self): while True: data = self.rfile.readline(65537) @@ -69,9 +92,13 @@ def do_POST(self): class StaticFileServer: _lock = threading.Lock() - def __init__(self, handler_class=ETagHandler): + def __init__(self, directory): + from functools import partial + self._lock.acquire() - self._httpd = HTTPServer(("localhost", 0), handler_class) + self._httpd = HTTPServer( + ("localhost", 0), partial(TestRequestHandler, directory=directory), + ) self._thread = None def __enter__(self):