Skip to content
This repository has been archived by the owner on Feb 23, 2022. It is now read-only.

Commit

Permalink
Merge pull request #351 from multinet-app/specify_key_csv
Browse files Browse the repository at this point in the history
Specify key csv
  • Loading branch information
jjnesbitt committed Mar 26, 2020
2 parents 277f634 + e57ec10 commit e3c4e71
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 28 deletions.
63 changes: 56 additions & 7 deletions multinet/uploaders/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from flask import Blueprint, request
from flask import current_app as app
from webargs import fields as webarg_fields
from webargs.flaskparser import use_kwargs

# Import types
from typing import Set, MutableMapping, Sequence, Any, List
from typing import Set, MutableMapping, Sequence, Any, List, Dict


bp = Blueprint("csv", __name__)
Expand All @@ -30,21 +32,45 @@ class InvalidRow(ValidationFailure):


@dataclass
class KeyFieldAlreadyExists(ValidationFailure):
"""CSV file has both existing _key field and specified key field."""

key: str


@dataclass
class KeyFieldDoesNotExist(ValidationFailure):
"""The specified key field does not exist."""

key: str


class MissingBody(ValidationFailure):
"""Missing body in a CSV file."""


def validate_csv(rows: Sequence[MutableMapping]) -> None:
def validate_csv(
rows: Sequence[MutableMapping], key_field: str = "_key", overwrite: bool = False
) -> None:
"""Perform any necessary CSV validation, and return appropriate errors."""
data_errors: List[ValidationFailure] = []

if not rows:
raise ValidationFailed([MissingBody()])

fieldnames = rows[0].keys()
if "_key" in fieldnames:

if key_field != "_key" and key_field not in fieldnames:
data_errors.append(KeyFieldDoesNotExist(key=key_field))
raise ValidationFailed(data_errors)

if "_key" in fieldnames and key_field != "_key" and not overwrite:
data_errors.append(KeyFieldAlreadyExists(key=key_field))
raise ValidationFailed(data_errors)

if key_field in fieldnames:
# Node Table, check for key uniqueness
keys = [row["_key"] for row in rows]
keys = [row[key_field] for row in rows]
unique_keys: Set[str] = set()
for key in keys:
if key in unique_keys:
Expand Down Expand Up @@ -75,9 +101,28 @@ def validate_csv(rows: Sequence[MutableMapping]) -> None:
raise ValidationFailed(data_errors)


def set_table_key(rows: List[Dict[str, str]], key: str) -> List[Dict[str, str]]:
"""Update the _key field in each row."""
new_rows: List[Dict[str, str]] = []
for row in rows:
new_row = dict(row)
new_row["_key"] = new_row[key]
new_rows.append(new_row)

return new_rows


@bp.route("/<workspace>/<table>", methods=["POST"])
@use_kwargs(
{
"key": webarg_fields.Str(location="query"),
"overwrite": webarg_fields.Bool(location="query"),
}
)
@swag_from("swagger/csv.yaml")
def upload(workspace: str, table: str) -> Any:
def upload(
workspace: str, table: str, key: str = "_key", overwrite: bool = False
) -> Any:
"""
Store a CSV file into the database as a node or edge table.
Expand All @@ -91,10 +136,14 @@ def upload(workspace: str, table: str) -> Any:
# Read the request body into CSV format
body = decode_data(request.data)

rows = list(csv.DictReader(StringIO(body)))
# Type to a Dict rather than an OrderedDict
rows: List[Dict[str, str]] = list(csv.DictReader(StringIO(body)))

# Perform validation.
validate_csv(rows)
validate_csv(rows, key, overwrite)

if key != "_key" and overwrite:
rows = set_table_key(rows, key)

# Set the collection, paying attention to whether the data contains
# _from/_to fields.
Expand Down
20 changes: 19 additions & 1 deletion multinet/uploaders/swagger/csv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ consumes:
parameters:
- $ref: "#/parameters/workspace"
- $ref: "#/parameters/table"
- name: data
-
name: data
in: body
description: Raw CSV text
schema:
Expand All @@ -16,6 +17,23 @@ parameters:
0,picard,captain
1,riker,commander
2,data,lieutenant commander
-
name: key
in: query
description: Key Field
schema:
type: string
example: _key
-
name: overwrite
in: query
description: Overwrites the default key field if it exists
enum:
- true
- false
schema:
type: boolean
default: false

responses:
200:
Expand Down
4 changes: 3 additions & 1 deletion mypy_stubs/webargs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ class fields:
@staticmethod
def Int() -> Any: ...
@staticmethod
def Str() -> Any: ...
def Str(required: bool = False, location: str = "json") -> Any: ...
@staticmethod
def List(t: Any) -> Any: ...
@staticmethod
def Bool(required: bool = False, location: str = "json") -> Any: ...
4 changes: 4 additions & 0 deletions test/data/startrek.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_key,name,rank
0,picard,captain
1,riker,commander
2,data,lieutenant commander
4 changes: 4 additions & 0 deletions test/data/startrek_no_key_field.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name,rank
picard,captain
riker,commander
data,lieutenant commander
94 changes: 75 additions & 19 deletions test/test_csv_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,94 @@
import pytest

from multinet.errors import ValidationFailed, DecodeFailed
from multinet.uploaders.csv import validate_csv, decode_data, InvalidRow
from multinet.validation import DuplicateKey
from multinet.uploaders.csv import (
validate_csv,
decode_data,
InvalidRow,
KeyFieldAlreadyExists,
KeyFieldDoesNotExist,
)
from multinet.validation import DuplicateKey, UnsupportedTable

TEST_DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))


def test_validate_csv():
"""Tests the validate_csv function."""
duplicate_keys_file_path = os.path.join(
TEST_DATA_DIR, "clubs_invalid_duplicate_keys.csv"
)
def read_csv(filename: str):
"""Read in CSV files."""
file_path = os.path.join(TEST_DATA_DIR, filename)
with open(file_path) as path_file:
return list(csv.DictReader(StringIO(path_file.read())))

invalid_headers_file_path = os.path.join(
TEST_DATA_DIR, "membership_invalid_syntax.csv"
)

# Test duplicate keys
with open(duplicate_keys_file_path) as test_file:
test_file = test_file.read()
def test_missing_key_field():
"""Test that missing key fields are handled properly."""
rows = read_csv("startrek_no_key_field.csv")

correct = UnsupportedTable().asdict()
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows)

validation_resp = v_error.value.errors
assert len(validation_resp) == 1
assert validation_resp[0] == correct


def test_invalid_key_field():
"""Test that specifying a missing key field results in an error."""
rows = read_csv("startrek.csv")
invalid_key = "invalid"

correct = KeyFieldDoesNotExist(key=invalid_key).asdict()
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows, key_field=invalid_key)

validation_resp = v_error.value.errors
assert len(validation_resp) == 1
assert validation_resp[0] == correct

rows = list(csv.DictReader(StringIO(test_file)))

def test_key_field_already_exists_a():
"""
Test that specifying a key when one already exists results in an error.
(overwrite = False)
"""
rows = read_csv("startrek.csv")
key_field = "name"

correct = KeyFieldAlreadyExists(key=key_field).asdict()
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows, key_field=key_field, overwrite=False)

validation_resp = v_error.value.errors
assert len(validation_resp) == 1
assert validation_resp[0] == correct


def test_key_field_already_exists_b():
"""
Test that specifying a key when one already exists doesn't result in an error.
(overwrite = True).
"""
rows = read_csv("startrek.csv")
validate_csv(rows, key_field="name", overwrite=True)


def test_duplicate_keys():
"""Test that duplicate keys are handled properly."""
rows = read_csv("clubs_invalid_duplicate_keys.csv")
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows)

validation_resp = v_error.value.errors
correct = [err.asdict() for err in [DuplicateKey(key="2"), DuplicateKey(key="5")]]
assert all(err in validation_resp for err in correct)

# Test invalid syntax
with open(invalid_headers_file_path) as test_file:
test_file = test_file.read()

rows = list(csv.DictReader(StringIO(test_file)))
def test_invalid_headers():
"""Test that invalid headers are handled properly."""
rows = read_csv("membership_invalid_syntax.csv")
with pytest.raises(ValidationFailed) as v_error:
validate_csv(rows)

Expand All @@ -53,6 +107,8 @@ def test_validate_csv():
]
assert all(err in validation_resp for err in correct)

# Test unicode decode errors

def test_decode_failed():
"""Test that the DecodeFailed validation error is raised."""
test_data = b"\xff\xfe_\x00k\x00e\x00y\x00,\x00n\x00a\x00m\x00e\x00\n"
pytest.raises(DecodeFailed, decode_data, test_data)

0 comments on commit e3c4e71

Please sign in to comment.