Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
sudo: required
language: python
env:
- DJ_TEST_HOST="127.0.0.1" DJ_TEST_USER="datajoint" DJ_TEST_PASSWORD="datajoint" DJ_HOST="127.0.0.1" DJ_USER="datajoint" DJ_PASS="datajoint"
- DJ_TEST_HOST="127.0.0.1" DJ_TEST_USER="datajoint" DJ_TEST_PASSWORD="datajoint" DJ_HOST="127.0.0.1" DJ_USER="datajoint" DJ_PASS="datajoint" BOTO_CONFIG="/tmp/bogusvalue"
python:
- "3.4"
- "3.5"
Expand Down
7 changes: 5 additions & 2 deletions datajoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ class DataJointError(Exception):

# override login credentials with environment variables
mapping = {k: v for k, v in zip(
('database.host', 'database.user', 'database.password'),
map(os.getenv, ('DJ_HOST', 'DJ_USER', 'DJ_PASS')))
('database.host', 'database.user', 'database.password',
'external.aws_access_key_id', 'external.aws_secret_access_key',),
map(os.getenv, ('DJ_HOST', 'DJ_USER', 'DJ_PASS',
'DJ_AWS_ACCESS_KEY_ID', 'DJ_AWS_SECRET_ACCESS_KEY',)))
if v is not None}
for k in mapping:
config.add_history('Updated login credentials from %s' % k)
Expand All @@ -64,6 +66,7 @@ class DataJointError(Exception):

# ------------- flatten import hierarchy -------------------------
from .connection import conn, Connection
from .s3 import bucket, Bucket
from .base_relation import FreeRelation, BaseRelation
from .user_relations import Manual, Lookup, Imported, Computed, Part
from .relational_operand import Not, AndList, OrList, U
Expand Down
152 changes: 152 additions & 0 deletions datajoint/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
This module contains logic related to external file storage
"""

import logging
from getpass import getpass

import boto3
from botocore.exceptions import ClientError

from . import config
from . import DataJointError

logger = logging.getLogger(__name__)


def bucket(aws_access_key_id=None, aws_secret_access_key=None, reset=False):
"""
Returns a boto3 AWS session object to be shared by multiple modules.
If the connection is not yet established or reset=True, a new
connection is set up. If connection information is not provided,
it is taken from config which takes the information from
dj_local_conf.json. If the password is not specified in that file
datajoint prompts for the password.

:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Key
:param reset: whether the connection should be reset or not
"""
if not hasattr(bucket, 'bucket') or reset:
aws_access_key_id = aws_access_key_id \
if aws_access_key_id is not None \
else config['external.aws_access_key_id']

aws_secret_access_key = aws_secret_access_key \
if aws_secret_access_key is not None \
else config['external.aws_secret_access_key']

if aws_access_key_id is None: # pragma: no cover
aws_access_key_id = input("Please enter AWS Access Key ID: ")

if aws_secret_access_key is None: # pragma: no cover
aws_secret_access_key = getpass(
"Please enter AWS Secret Access Key: "
)

bucket.bucket = Bucket(aws_access_key_id, aws_secret_access_key)
return bucket.bucket


class Bucket:
"""
A dj.Bucket object manages a connection to an AWS S3 Bucket.

Currently, basic CRUD operations are supported; of note permissions and
object versioning are not currently supported.

Most of the parameters below should be set in the local configuration file.

:param aws_access_key_id: AWS Access Key ID
:param aws_secret_access_key: AWS Secret Key
"""

def __init__(self, aws_access_key_id, aws_secret_access_key):
self._session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)
self._s3 = None
try:
self._bucket = config['external.location'].split("s3://")[1]
except (AttributeError, IndexError, KeyError) as e:
raise DataJointError(
'external.location not properly configured: {l}'.format(
l=config['external.location'])
) from None

def connect(self):
if self._s3 is None:
self._s3 = self._session.resource('s3')

def stat(self, rpath=None):
"""
Check if a file exists in the bucket.

:param rpath: remote path within bucket
"""
try:
self.connect()
self._s3.Object(self._bucket, rpath).load()
except ClientError as e:
if e.response['Error']['Code'] != "404":
raise DataJointError(
'Error checking remote file {r} ({e})'.format(r=rpath, e=e)
)
return False

return True

def put(self, lpath=None, rpath=None):
"""
Upload a file to the bucket.

:param rpath: remote path within bucket
:param lpath: local path
"""
try:
self.connect()
self._s3.Object(self._bucket, rpath).upload_file(lpath)
except Exception as e:
raise DataJointError(
'Error uploading file {l} to {r} ({e})'.format(
l=lpath, r=rpath, e=e)
)

return True

def get(self, rpath=None, lpath=None):
"""
Retrieve a file from the bucket.

:param rpath: remote path within bucket
:param lpath: local path
"""
try:
self.connect()
self._s3.Object(self._bucket, rpath).download_file(lpath)
except Exception as e:
raise DataJointError(
'Error downloading file {r} to {l} ({e})'.format(
r=rpath, l=lpath, e=e)
)

return True

def delete(self, rpath):
'''
Delete a single remote object.
Note: will return True even if object doesn't exist;
for explicit verification combine with a .stat() call.

:param rpath: remote path within bucket
'''
try:
self.connect()
r = self._s3.Object(self._bucket, rpath).delete()
# XXX: if/when does 'False' occur? - s3 returns ok if no file...
return r['ResponseMetadata']['HTTPStatusCode'] == 204
except Exception as e:
raise DataJointError(
'error deleting file {r} ({e})'.format(r=rpath, e=e)
)
7 changes: 5 additions & 2 deletions datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@
'safemode': True,
'display.limit': 7,
'display.width': 14,
'display.show_tuple_count': True
'display.show_tuple_count': True,
'external.aws_access_key_id': None,
'external.aws_secret_access_key': None,
'external.location' : None
})

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,4 +180,4 @@ def __setitem__(self, key, value):
if validators[key](value):
self._conf[key] = value
else:
raise DataJointError(u'Validator for {0:s} did not pass'.format(key, ))
raise DataJointError(u'Validator for {0:s} did not pass'.format(key, ))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pyparsing
ipython
networkx~=1.11
pydotplus
boto3
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
matplotlib
pygraphviz
moto
125 changes: 125 additions & 0 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

"""
Test of dj.Bucket() using moto *MOCKED* S3 library
Using real s3 could incur cost, requires appropriate credentials managment;
but probably should be done at some point once best methodology is determined.
"""

import os
from unittest import TestCase

import boto3
from moto import mock_s3

import datajoint as dj

# Verify moto is itself functional
# BEGIN via Moto Docs


class MotoTest:
'''
Simple example to verify moto is itself working
'''

def __init__(self, name, value):
self.name = name
self.value = value

def save(self):
s3 = boto3.client('s3', region_name='us-east-1')
s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value)


@mock_s3
def test_moto_test():
# Create Bucket so that test can run
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='mybucket')

model_instance = MotoTest('steve', 'is awesome')
model_instance.save()

body = conn.Object('mybucket', 'steve').get()['Body'].read().decode()

assert body == 'is awesome'

# END via Moto Docs


@mock_s3
def test_dj_bucket_factory():
'''
Test *part of* the dj.bucket() singleton/factory function.
The user-interactive portion is not tested.
'''
try:
b = dj.Bucket(None, None)
except dj.DataJointError: # no dj.config['external.location']
pass

# monkey patch dj.bucket.bucket to use mocked implementation
dj.config['external.location'] = 's3://djtest.datajoint.io'
b = dj.Bucket(None, None)
dj.bucket.bucket = b

assert dj.bucket() == b


@mock_s3
class DjBucketTest(TestCase):

def setUp(self):
dj.config['external.location'] = 's3://djtest.datajoint.io'
b = dj.Bucket(None, None)
dj.bucket.bucket = b

# create moto's virtual bucket
b.connect() # note: implicit test of b.connect(), which is trivial
b._s3.create_bucket(Bucket='djtest.datajoint.io')
self._bucket = b

# todo:
# - appropriate remote filename (e.g. mkstemp())
# - appropriate local temp filename (e.g. mkstemp())
self._lfile = __file__
self._rfile = 'DjBucketTest-TEMP_NO_EDIT_WILL_ZAP.py'
self._lfile_cpy = self._rfile

self._zaptmpfile()

def tearDown(self):
self._zaptmpfile()

def _zaptmpfile(self):
try:
os.remove(self._lfile_cpy)
except FileNotFoundError:
pass

def test_bucket_methods(self):
'''
Test dj.Bucket.(put,state,get,delete,)()
Currently done in one test to simplify interdependencies.
'''

# ensure no initial files
assert self._bucket.delete(self._rfile) is True
assert self._bucket.stat(self._rfile) is False
assert os.path.exists(self._lfile_cpy) is False

# test put
assert self._bucket.put(self._lfile, self._rfile) is True

# test stat
assert self._bucket.stat(self._rfile) is True

# test get
assert self._bucket.get(self._rfile, self._lfile_cpy) is True
assert os.path.exists(self._lfile_cpy) is True

# test delete
assert self._bucket.delete(self._rfile) is True

# verify delete
assert self._bucket.stat(self._rfile) is False