Skip to content

Commit

Permalink
Add versioning to Azure Blob CSV Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
mzjp2 committed Sep 27, 2019
1 parent f8c85cd commit 8db124b
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions kedro/contrib/io/azure/csv_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@
""" ``AbstractDataSet`` implementation to access CSV files directly from
Microsoft's Azure blob storage.
"""
import copy
import io
from pathlib import PurePath
from typing import Any, Dict, Optional

import pandas as pd
from azure.storage.blob import BlockBlobService

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet
from kedro.io import AbstractVersionedDataSet, DataSetError, Version


class CSVBlobDataSet(DefaultArgumentsMixIn, AbstractDataSet):
# pylint: disable=too-many-instance-attributes
class CSVBlobDataSet(AbstractVersionedDataSet):
"""``CSVBlobDataSet`` loads and saves csv files in Microsoft's Azure
blob storage. It uses azure storage SDK to read and write in azure and
pandas to handle the csv file locally.
Expand All @@ -62,6 +64,7 @@ class CSVBlobDataSet(DefaultArgumentsMixIn, AbstractDataSet):
>>> assert data.equals(reloaded)
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
Expand All @@ -72,18 +75,20 @@ def _describe(self) -> Dict[str, Any]:
blob_from_text_args=self._blob_from_text_args,
load_args=self._load_args,
save_args=self._save_args,
version=self._version,
)

# pylint: disable=too-many-arguments
def __init__(
self,
filepath: str,
filepath: PurePath,
container_name: str,
credentials: Dict[str, Any],
blob_to_text_args: Optional[Dict[str, Any]] = None,
blob_from_text_args: Optional[Dict[str, Any]] = None,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
version: Version = None,
) -> None:
"""Creates a new instance of ``CSVBlobDataSet`` pointing to a
concrete csv file on Azure blob storage.
Expand All @@ -107,30 +112,51 @@ def __init__(
Here you can find all available arguments:
https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_csv.html
All defaults are preserved, but "index", which is set to False.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
"""
self._filepath = filepath
self._container_name = container_name
self._credentials = credentials if credentials else {}
self._blob_to_text_args = blob_to_text_args if blob_to_text_args else {}
self._blob_from_text_args = blob_from_text_args if blob_from_text_args else {}
super().__init__(load_args, save_args)
self._blob_service = BlockBlobService(**self._credentials)
super().__init__(filepath, version)

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> pd.DataFrame:
blob_service = BlockBlobService(**self._credentials)
blob = blob_service.get_blob_to_text(
load_path = str(self._get_load_path())
blob = self._blob_service.get_blob_to_text(
container_name=self._container_name,
blob_name=self._filepath,
blob_name=load_path,
**self._blob_to_text_args
)
csv_content = io.StringIO(blob.content)
return pd.read_csv(csv_content, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
blob_service = BlockBlobService(**self._credentials)
blob_service.create_blob_from_text(
save_path = str(self._get_save_path())

self._blob_service.create_blob_from_text(
container_name=self._container_name,
blob_name=self._filepath,
blob_name=save_path,
text=data.to_csv(**self._save_args),
**self._blob_from_text_args
)

def _exists(self) -> bool:
try:
load_path = self._get_load_path()
except DataSetError:
return False
return self._blob_service.exists(str(load_path))

0 comments on commit 8db124b

Please sign in to comment.