Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions for ensuring SQLite #46

Merged
merged 4 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pystow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ensure_open,
ensure_open_gz,
ensure_open_lzma,
ensure_open_sqlite,
ensure_open_tarfile,
ensure_open_zip,
ensure_pickle,
Expand Down
44 changes: 44 additions & 0 deletions src/pystow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"ensure_open_lzma",
"ensure_open_tarfile",
"ensure_open_zip",
"ensure_open_sqlite",
# Processors
"ensure_csv",
"ensure_custom",
Expand Down Expand Up @@ -1422,3 +1423,46 @@ def joinpath_sqlite(key: str, *subkeys: str, name: str) -> str:
"""
_module = Module.from_key(key, ensure_exists=True)
return _module.joinpath_sqlite(*subkeys, name=name)


@contextmanager
def ensure_open_sqlite(
key: str,
*subkeys: str,
url: str,
name: Optional[str] = None,
force: bool = False,
download_kwargs: Optional[Mapping[str, Any]] = None,
):
"""Ensure and connect to a SQLite database.

:param key:
The name of the module. No funny characters. The envvar
`<key>_HOME` where key is uppercased is checked first before using
the default home directory.
:param subkeys:
A sequence of additional strings to join. If none are given,
returns the directory for this module.
:param url:
The URL to download.
:param name:
Overrides the name of the file at the end of the URL, if given. Also
useful for URLs that don't have proper filenames with extensions.
:param force:
Should the download be done again, even if the path already exists?
Defaults to false.
:param download_kwargs: Keyword arguments to pass through to :func:`pystow.utils.download`.
:yields: A connection from :func:`sqlite3.connect`

Example usage:
>>> import pystow
>>> import pandas as pd
>>> url = "https://s3.amazonaws.com/bbop-sqlite/hp.db"
>>> with pystow.ensure_open_sqlite("test", url=url) as conn:
>>> df = pd.read_sql(" <query> ", conn)
"""
_module = Module.from_key(key, ensure_exists=True)
with _module.ensure_open_sqlite(
*subkeys, url=url, name=name, force=force, download_kwargs=download_kwargs
) as yv:
yield yv
46 changes: 44 additions & 2 deletions src/pystow/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tarfile
import warnings
import zipfile
from contextlib import contextmanager
from contextlib import closing, contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union

Expand All @@ -22,6 +22,7 @@
mkdir,
name_from_s3_key,
name_from_url,
path_to_sqlite,
read_rdf,
read_tarfile_csv,
read_tarfile_xml,
Expand Down Expand Up @@ -140,7 +141,7 @@ def joinpath_sqlite(self, *subkeys: str, name: str) -> str:
:return: A SQLite path string.
"""
path = self.join(*subkeys, name=name, ensure_exists=True)
return f"sqlite:///{path.as_posix()}"
return path_to_sqlite(path)

def ensure(
self,
Expand Down Expand Up @@ -1273,6 +1274,47 @@ def ensure_from_google(
download_from_google(file_id, path, force=force, **(download_kwargs or {}))
return path

@contextmanager
def ensure_open_sqlite(
self,
*subkeys: str,
url: str,
name: Optional[str] = None,
force: bool = False,
download_kwargs: Optional[Mapping[str, Any]] = None,
):
"""Ensure and connect to a SQLite database.

:param subkeys:
A sequence of additional strings to join. If none are given,
returns the directory for this module.
:param url:
The URL to download.
:param name:
Overrides the name of the file at the end of the URL, if given. Also
useful for URLs that don't have proper filenames with extensions.
:param force:
Should the download be done again, even if the path already exists?
Defaults to false.
:param download_kwargs: Keyword arguments to pass through to :func:`pystow.utils.download`.
:yields: A connection from :func:`sqlite3.connect`

Example usage:
>>> import pystow
>>> import pandas as pd
>>> url = "https://s3.amazonaws.com/bbop-sqlite/hp.db"
>>> module = pystow.module("test")
>>> with module.ensure_open_sqlite(url=url) as conn:
>>> df = pd.read_sql(" <query> ", conn)
"""
import sqlite3

path = self.ensure(
*subkeys, url=url, name=name, force=force, download_kwargs=download_kwargs
)
with closing(sqlite3.connect(path.as_posix())) as conn:
yield conn


def _clean_csv_kwargs(read_csv_kwargs):
read_csv_kwargs = {} if read_csv_kwargs is None else dict(read_csv_kwargs)
Expand Down
26 changes: 26 additions & 0 deletions src/pystow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"get_home",
"get_name",
"get_base",
"path_to_sqlite",
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -671,6 +672,21 @@ def read_rdf(path: Union[str, Path], **kwargs):
return graph


def write_sql(df, name: str, path: Union[str, Path], **kwargs) -> None:
"""Write a dataframe as a SQL table.

:param df: A dataframe
:type df: pandas.DataFrame
:param name: The table the database to write to
:param path: The path to the resulting tar archive
:param kwargs: Additional keyword arguments to pass to :meth:`pandas.DataFrame.to_sql`
"""
import sqlite3

with contextlib.closing(sqlite3.connect(path)) as conn:
df.to_sql(name, conn, **kwargs)


def get_commit(org: str, repo: str, provider: str = "git") -> str:
"""Get last commit hash for the given repo.

Expand Down Expand Up @@ -895,3 +911,13 @@ def ensure_readme() -> None:
return
with readme_path.open("w", encoding="utf8") as file:
print(README_TEXT, file=file) # noqa:T001,T201


def path_to_sqlite(path: Union[str, Path]) -> str:
"""Convert a path to a SQLite connection string.

:param path: A path to a SQLite database file
:returns: A standard connection string to the database
"""
path = Path(path).expanduser().resolve()
return f"sqlite:///{path.as_posix()}"
17 changes: 17 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
mock_envvar,
n,
write_pickle_gz,
write_sql,
write_tarfile_csv,
write_zipfile_csv,
)
Expand All @@ -36,6 +37,11 @@
TSV_NAME = "test_1.tsv"
TSV_URL = f"{n()}/{TSV_NAME}"

SQLITE_NAME = "test_1.db"
SQLITE_URL = f"{n()}/{SQLITE_NAME}"
SQLITE_PATH = RESOURCES / SQLITE_NAME
SQLITE_TABLE = "testtable"

JSON_NAME = "test_1.json"
JSON_URL = f"{n()}/{JSON_NAME}"

Expand All @@ -52,6 +58,7 @@
JSON_URL: RESOURCES / JSON_NAME,
PICKLE_URL: PICKLE_PATH,
PICKLE_GZ_URL: PICKLE_GZ_PATH,
SQLITE_URL: SQLITE_PATH,
}

TEST_TSV_ROWS = [
Expand All @@ -65,6 +72,9 @@
if not PICKLE_PATH.is_file():
PICKLE_PATH.write_bytes(pickle.dumps(TEST_TSV_ROWS))

if not SQLITE_PATH.is_file():
write_sql(TEST_DF, name=SQLITE_TABLE, path=SQLITE_PATH, index=False)


class TestMocks(unittest.TestCase):
"""Tests for :mod:`pystow` mocks and context managers."""
Expand Down Expand Up @@ -296,3 +306,10 @@ def touch_file(path: Path, **_kwargs):
path = pystow.ensure_custom("test", name=name, provider=provider, **kwargs)
# ensure that the provider was only called once with the given parameters
provider.assert_called_once_with(path, **kwargs)

def test_ensure_open_sqlite(self):
"""Test caching SQLite."""
with self.mock_directory(), self.mock_download():
with pystow.ensure_open_sqlite("test", url=SQLITE_URL) as conn:
df = pd.read_sql(f"SELECT * from {SQLITE_TABLE}", conn) # noqa:S608
self.assertEqual(3, len(df.columns))