Skip to content

Commit

Permalink
Merge 257e910 into 0c52378
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Apr 7, 2022
2 parents 0c52378 + 257e910 commit f1bc21c
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 44 deletions.
2 changes: 2 additions & 0 deletions alyx/alyx/settings_secret_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = '%SECRET_KEY%'

S3_ACCESS = {} # should include the keys (access_key, secret_key, region)

# Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases

Expand Down
202 changes: 164 additions & 38 deletions alyx/misc/management/commands/one_cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from time import time
import io
import socket
import json
import logging
from pathlib import Path
import urllib.parse
from datetime import datetime
from functools import wraps
from sys import getsizeof
import zipfile
import tempfile
import re

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -33,13 +38,56 @@ def wrapper(*arg, **kwargs):
return wrapper


def _save(filename: Path, df: pd.DataFrame, metadata: dict = None) -> None:
def _s3_filesystem(**kwargs) -> pa.fs.S3FileSystem:
"""
Save pandas dataframe to parquet
:param filename: Parquet save location
Get S3 FileSystem object. Order of credential precedence:
1. kwargs
2. S3_ACCESS dict in settings_secret.py
3. Default aws cli credentials
:param kwargs: see pyarrow.fs.S3FileSystem
:return: A FileSystem object with the given credentials
"""
try:
from alyx.settings_secret import S3_ACCESS
except ImportError:
S3_ACCESS = {}
S3_ACCESS.update(kwargs)
return pa.fs.S3FileSystem(**S3_ACCESS)


def _get_s3_virtual_host(uri, region) -> str:
"""
Convert a given bucket URI to a URL by
S3 documentation:
https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html#virtual-host-style-url-ex
:param uri: The bucket name or full path URI
:param region: The region, e.g. eu-west-1
:return: The Web URL (virtual host name and https scheme)
"""
assert region and re.match(r'\w{2}-\w+-[1-3]', region)
parsed = urllib.parse.urlparse(uri) # remove scheme if necessary
key = parsed.path.strip('/').split('/')
bucket = parsed.netloc or key.pop(0)
hostname = f"{bucket}.{parsed.scheme or 's3'}.{region}.amazonaws.com"
return 'https://' + '/'.join((hostname, *key))


def _save(filename: str, df: pd.DataFrame, metadata: dict = None, dry=False) -> pa.Table:
"""
Save pandas dataframe to parquet.
If using S3, by default the aws default credentials are used. These may be overridden by the
S3_ACCESS dict in settings_secret.py.
:param filename: Parquet save location, may be local file path or S3 location
:param df: A DataFrame to save as parquet table
:param metadata: A dict of optional metadata
:return:
:param dry: if True, return pyarrow table without saving to disk
:return: the saved pyarrow table
"""
# cf https://towardsdatascience.com/saving-metadata-with-dataframes-71f51f558d8e

Expand All @@ -52,8 +100,16 @@ def _save(filename: Path, df: pd.DataFrame, metadata: dict = None) -> None:
**table.schema.metadata
})

# Save to parquet.
pq.write_table(table, filename)
if not dry:
parsed = urllib.parse.urlparse(filename)
if parsed.scheme == 's3':
# Filename mustn't include scheme
pq.write_table(table, parsed.path, filesystem=_s3_filesystem())
elif parsed.scheme == '':
pq.write_table(table, filename)
else:
raise ValueError(f'Unsupported URI scheme "{parsed.scheme}"')
return table


def _uuid2np(eids_uuid):
Expand All @@ -63,12 +119,13 @@ def _uuid2np(eids_uuid):

class Command(BaseCommand):
"""
NB: When compress flag is passed, all tables are expected to fit into memory together.
"""
help = "Generate ONE cache tables"
dst_dir = None
tables = None
metadata = None
compress = None

def add_arguments(self, parser):
parser.add_argument('-D', '--destination', default=TABLES_ROOT,
Expand All @@ -77,54 +134,123 @@ def add_arguments(self, parser):
help="List of tables to generate")
parser.add_argument('--int-id', action='store_true',
help="Save uuids as ints")
parser.add_argument('--compress', action='store_true',
help="Save files into compressed folder")

def handle(self, *_, **options):
if options['verbosity'] < 1:
logger.setLevel(logging.WARNING)
if options['verbosity'] > 1:
logger.setLevel(logging.DEBUG)
self.dst_dir = Path(options.get('destination'))
self.dst_dir = options.get('destination')
self.compress = options.get('compress')
tables, int_id = options.get('tables'), options.get('int_id')
self.generate_tables(tables, int_id=int_id)

def generate_tables(self, tables, **kwargs) -> None:
def generate_tables(self, tables, **kwargs) -> list:
"""
Generate and save a list of tables. Supported tables include 'sessions' and 'datasets'.
:param tables: A tuple of table names.
:param kwargs:
:return: A list of paths to the saved files
"""
self.metadata = create_metadata()
to_compress = {}
dry = self.compress
for table in tables:
if table.lower() == 'sessions':
logger.debug('Generating sessions DataFrame')
self._save_table(generate_sessions_frame(**kwargs), table)
tbl, filename = self._save_table(generate_sessions_frame(**kwargs), table, dry=dry)
to_compress[filename] = tbl
elif table.lower() == 'datasets':
logger.debug('Generating datasets DataFrame')
self._save_table(generate_datasets_frame(**kwargs), table)
tbl, filename = self._save_table(generate_datasets_frame(**kwargs), table, dry=dry)
to_compress[filename] = tbl
else:
raise ValueError(f'Unknown table "{table}"')
self._compress_tables()

def _save_table(self, table, name):
self.dst_dir.mkdir(exist_ok=True)
logger.info(f'Saving table "{name}" to {self.dst_dir}...')
filename = self.dst_dir / f'{name}.pqt' # Save to parquet
_save(filename, table, self.metadata)

def _compress_tables(self) -> None:
"""Write cache_info JSON and create zip file comprising parquet tables + JSON"""
from zipfile import ZipFile
zip = ZipFile(self.dst_dir / 'cache.zip', 'w')
jsonmeta = {}
logger.info(f'Compressing tables to {zip.filename}...')
for filename in self.dst_dir.glob('*.pqt'):
zip.write(filename, filename.name)
pqtinfo = pq.read_metadata(filename)
jsonmeta[filename.stem] = {'nrecs': pqtinfo.num_rows, 'size': pqtinfo.serialized_size}
# creates a json file containing metadata and add it to the zip file
tag_file = self.dst_dir / 'cache_info.json'
with open(tag_file, 'w') as fid:
json.dump({**self.metadata, 'tables': jsonmeta}, fid, indent=1)
zip.write(tag_file, tag_file.name)
write_fail = zip.testzip()
zip.close()
if write_fail:
logger.error(f'Failed to compress {write_fail}')

if self.compress:
return list(self._compress_tables(to_compress))
else:
return list(to_compress.keys())

def _save_table(self, table, name, **kwargs):
"""Save a given table to <dst_dir>/<name>.pqt.
Given a table name and a pandas DataFrame, save as parquet table to disk. If dst_dir
attribute is an s3 URI, the table is saved directly there
:param table: the pandas DataFrame to save
:param name: table name
:param dry: If True, does not actually write to disk
:return: A PyArrow table and the full path to the saved file
"""
if not kwargs.get('dry'):
logger.info(f'Saving table "{name}" to {self.dst_dir}...')
scheme = urllib.parse.urlparse(self.dst_dir).scheme or 'file'
if scheme == 'file':
Path(self.dst_dir).mkdir(exist_ok=True)
filename = Path(self.dst_dir) / f'{name}.pqt' # Save to parquet
else:
filename = self.dst_dir.strip('/') + f'/{name}.pqt' # Save to parquet
pa_table = _save(str(filename), table, self.metadata, **kwargs)
return pa_table, str(filename)

def _compress_tables(self, table_map) -> tuple:
"""
Write cache_info JSON and create zip file comprising parquet tables + JSON
:param table_map: a dict of filenames and corresponding
:return:
"""
ZIP_NAME = 'cache.zip'
META_NAME = 'cache_info.json'

logger.info('Compressing tables...') # Write zip in memory
zip_buffer = io.BytesIO() # Mem buffer to store compressed table data
with tempfile.TemporaryDirectory() as tmp, \
zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip:
jsonmeta = {}
for filename, table in table_map.items():
tmp_filename = Path(tmp) / Path(filename).name # Table filename in temp dir
pq.write_table(table, tmp_filename) # Write table to tempdir
zip.write(tmp_filename, Path(filename).name) # Load and compress
pqtinfo = pq.read_metadata(tmp_filename) # Load metadata for cache_info file
jsonmeta[Path(filename).stem] = {
'nrecs': pqtinfo.num_rows,
'size': pqtinfo.serialized_size
}
metadata = {**self.metadata, 'tables': jsonmeta}
zip.writestr(META_NAME, json.dumps(metadata, indent=1)) # Compress cache info

logger.info('Writing to file...')
parsed = urllib.parse.urlparse(self.dst_dir)
scheme = parsed.scheme or 'file'
try:
if scheme == 's3':
zip_file = f'{parsed.path.strip("/")}/{ZIP_NAME}'
tag_file = f'{parsed.path.strip("/")}/{META_NAME}'
s3 = _s3_filesystem()
metadata['location'] = _get_s3_virtual_host(zip_file, s3.region) # Add URL
# Write cache info json to s3
with s3.open_output_stream(tag_file) as stream:
stream.writelines(json.dumps(metadata, indent=1))
# Write zip file to s3
with s3.open_output_stream(zip_file) as stream:
stream.write(zip_buffer.getbuffer())
elif scheme == 'file':
# creates a json file containing metadata and add it to the zip file
tag_file = Path(self.dst_dir) / META_NAME
zip_file = Path(self.dst_dir) / ZIP_NAME
with open(tag_file, 'w') as fid:
json.dump(metadata, fid, indent=1)
with open(zip_file, 'wb') as fid:
fid.write(zip_buffer.getbuffer())
else:
raise ValueError(f'Unsupported URI scheme "{scheme}"')
finally:
zip_buffer.close()
return zip_file, tag_file


@measure_time
Expand Down
47 changes: 41 additions & 6 deletions alyx/misc/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pathlib import Path
import os.path as op
import json
import urllib.parse

import magic
import requests
from django.contrib.auth import get_user_model
from django.http import HttpResponse, FileResponse, JsonResponse
from django.http import HttpResponse, FileResponse, JsonResponse, HttpResponseRedirect

from rest_framework import viewsets, views
from rest_framework.response import Response
Expand Down Expand Up @@ -134,9 +136,39 @@ def get(self, request=None, format=None, img_url=''):


def _get_cache_info():
file_json_cache = Path(TABLES_ROOT).joinpath('cache_info.json')
with open(file_json_cache) as fid:
cache_info = json.load(fid)
"""
Load and return the cache info JSON file. Contains information such as cache table timestamp,
size and API version.
:return: dict of cache table information
"""
META_NAME = 'cache_info.json'
parsed = urllib.parse.urlparse(TABLES_ROOT)
scheme = parsed.scheme or 'file'
if scheme == 'file':
# Cache table is local
file_json_cache = Path(TABLES_ROOT).joinpath(META_NAME)
with open(file_json_cache) as fid:
cache_info = json.load(fid)
elif scheme.startswith('http'):
file_json_cache = TABLES_ROOT.strip('/') + f'/{META_NAME}'
resp = requests.get(file_json_cache)
resp.raise_for_status()
cache_info = resp.json()
if 'location' not in cache_info:
cache_info['location'] = TABLES_ROOT.strip('/') + '/cache.zip'
elif scheme == 's3':
# Use PyArrow to read file from s3
from misc.management.commands.one_cache import _s3_filesystem
s3 = _s3_filesystem()
file_json_cache = parsed.netloc + '/' + parsed.path.strip('/') + '/' + META_NAME
with s3.open_input_stream(file_json_cache) as stream:
cache_info = json.load(stream)
if 'location' not in cache_info:
cache_info['location'] = TABLES_ROOT.strip('/') + '/' + META_NAME
else:
raise ValueError(f'Unsupported URI scheme "{scheme}"')

return cache_info


Expand All @@ -151,6 +183,9 @@ class CacheDownloadView(views.APIView):
permission_classes = rest_permission_classes()

def get(self, request=None, **kwargs):
cache_file = Path(TABLES_ROOT).joinpath('cache.zip')
response = FileResponse(open(cache_file, 'br'))
if TABLES_ROOT.startswith('http'):
response = HttpResponseRedirect(TABLES_ROOT.strip('/') + '/cache.zip')
else:
cache_file = Path(TABLES_ROOT).joinpath('cache.zip')
response = FileResponse(open(cache_file, 'br'))
return response

0 comments on commit f1bc21c

Please sign in to comment.