Skip to content

Commit

Permalink
fix: no early file closing (#1024)
Browse files Browse the repository at this point in the history
A fix for a subtle bug introduced here: https://github.com/jina-ai/jina/pull/1009/files#diff-e47ef1c2a309b9935e22d5ff1e58d7dcR30

Beforehand, only the iterator was returned and thus the file closed. Since the iterator has not iterated the file yet, it would fail in the moment it starts iterating. With the `yield` solution, the file remains open.
  • Loading branch information
maximilianwerk committed Oct 5, 2020
1 parent 61d6b0f commit b73ab6b
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 98 deletions.
34 changes: 23 additions & 11 deletions jina/clients/python/io.py
Expand Up @@ -9,9 +9,14 @@
import numpy as np


def input_lines(lines: Iterator[str] = None, filepath: str = None, size: int = None, sampling_rate: float = None,
read_mode='r') -> Iterator[Union[str, bytes]]:
""" Input function that iterates over list of strings, it can be used in the Flow API
def input_lines(
lines: Iterator[str] = None,
filepath: str = None,
size: int = None,
sampling_rate: float = None,
read_mode='r',
) -> Iterator[Union[str, bytes]]:
"""Input function that iterates over list of strings, it can be used in the Flow API
:param filepath: a text file that each line contains a document
:param lines: a list of strings, each is considered as a document
Expand All @@ -28,16 +33,22 @@ def sample(iterable):

if filepath:
with open(filepath, read_mode) as f:
return it.islice(sample(f), size)
for line in it.islice(sample(f), size):
yield line
elif lines:
return it.islice(sample(lines), size)
else:
raise ValueError('"filepath" and "lines" can not be both empty')

def input_files(patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None,
read_mode: str = None) -> Iterator[Union[str, bytes]]:
""" Input function that iterates over files, it can be used in the Flow API

def input_files(
patterns: Union[str, List[str]],
recursive: bool = True,
size: int = None,
sampling_rate: float = None,
read_mode: str = None,
) -> Iterator[Union[str, bytes]]:
"""Input function that iterates over files, it can be used in the Flow API
:param patterns: The pattern may contain simple shell-style wildcards, e.g. '\*.py', '[\*.zip, \*.gz]'
:param recursive: If recursive is true, the pattern '**' will match any files and
Expand Down Expand Up @@ -69,9 +80,10 @@ def iter_file_exts(ps):
break


def input_numpy(array: 'np.ndarray', axis: int = 0, size: int = None,
shuffle: bool = False) -> Iterator[Any]:
""" Input function that iterates over a numpy array, it can be used in the Flow API
def input_numpy(
array: 'np.ndarray', axis: int = 0, size: int = None, shuffle: bool = False
) -> Iterator[Any]:
"""Input function that iterates over a numpy array, it can be used in the Flow API
:param array: the numpy ndarray data source
:param axis: iterate over that axis
Expand Down
180 changes: 93 additions & 87 deletions tests/unit/clients/python/test_client.py
@@ -1,96 +1,102 @@
import time

import numpy as np
import pytest
import requests

from jina.clients import py_client
from jina.clients.python import PyClient
from jina.clients.python.io import input_files, input_numpy
from jina.clients.python.io import input_files
from jina.enums import ClientMode
from jina.flow import Flow
from jina.parser import set_gateway_parser
from jina.peapods.gateway import RESTGatewayPea
from jina.proto.jina_pb2 import Document
from tests import JinaTestCase


class ClientTestCase(JinaTestCase):

def test_client(self):
f = Flow().add(uses='_pass')
with f:
print(py_client(port_expose=f.port_expose).call_unary(b'a1234', mode=ClientMode.INDEX))

def tearDown(self) -> None:
super().tearDown()
time.sleep(3)

def test_check_input(self):
input_fn = iter([b'1234', b'45467'])
PyClient.check_input(input_fn)
input_fn = iter([Document(), Document()])
PyClient.check_input(input_fn)
bad_input_fn = iter([b'1234', '45467', [12, 2, 3]])
self.assertRaises(TypeError, PyClient.check_input, bad_input_fn)
bad_input_fn = iter([Document(), None])
self.assertRaises(TypeError, PyClient.check_input, bad_input_fn)

def test_gateway_ready(self):
p = set_gateway_parser().parse_args([])
with RESTGatewayPea(p):
a = requests.get(f'http://0.0.0.0:{p.port_expose}/ready')
assert a.status_code == 200

with RESTGatewayPea(p):
a = requests.post(f'http://0.0.0.0:{p.port_expose}/api/ass')
assert a.status_code == 405

def test_gateway_index(self):
f = Flow(rest_api=True).add(uses='_pass')
with f:
a = requests.post(f'http://0.0.0.0:{f.port_expose}/api/index',
json={'data': [
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AvdGjTZeOlQq07xSYPgJjlWRwfWEBx2+CgAVrPrP+O5ghhOa+a0cocoWnaMJFAsBuCQCgiJOKDBcIQTiLieOrPD/cp/6iZ/Iu4HqAh5dGzggIQVJI3WqTxwVTDjs5XJOy38AlgHoaKgY+xJEXeFTyR7FOfF7JNWjs3b8evQE6B2dTDvQZx3n3Rz6rgOtVlaZRLvR9geCAxuY3G+0mepEAhrTISES3bwPWYYi48OUrQOc//IaJeij9xZGGmDIG9kc73fNI7eA8VMBAAD//0SxXMMT90UdAAAAAElFTkSuQmCC']})

j = a.json()
self.assertTrue('index' in j)
assert len(j['index']['docs']) == 2
assert j['index']['docs'][0][
'uri'] == 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC'
assert a.status_code == 200

def test_gateway_index_with_args(self):
f = Flow(rest_api=True).add(uses='_pass')
with f:
a = requests.post(f'http://0.0.0.0:{f.port_expose}/api/index',
json={'data': [
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AvdGjTZeOlQq07xSYPgJjlWRwfWEBx2+CgAVrPrP+O5ghhOa+a0cocoWnaMJFAsBuCQCgiJOKDBcIQTiLieOrPD/cp/6iZ/Iu4HqAh5dGzggIQVJI3WqTxwVTDjs5XJOy38AlgHoaKgY+xJEXeFTyR7FOfF7JNWjs3b8evQE6B2dTDvQZx3n3Rz6rgOtVlaZRLvR9geCAxuY3G+0mepEAhrTISES3bwPWYYi48OUrQOc//IaJeij9xZGGmDIG9kc73fNI7eA8VMBAAD//0SxXMMT90UdAAAAAElFTkSuQmCC'],
})
j = a.json()
self.assertTrue('index' in j)
assert len(j['index']['docs']) == 2
assert j['index']['docs'][0][
'uri'] == 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC'
assert a.status_code == 200

def test_io_files(self):
PyClient.check_input(input_files('*.*'))
PyClient.check_input(input_files('*.*', recursive=True))
PyClient.check_input(input_files('*.*', size=2))
PyClient.check_input(input_files('*.*', size=2, read_mode='rb'))
PyClient.check_input(input_files('*.*', sampling_rate=.5))

f = Flow().add(uses='- !URI2Buffer {}')

def validate_mime_type(req):
for d in req.index.docs:
assert d.mime_type == 'text/x-python'

with f:
f.index(input_files('*.py'), validate_mime_type)

def test_io_np(self):
PyClient.check_input(input_numpy(np.random.random([100, 4, 2])))
PyClient.check_input(['asda', 'dsadas asdasd'])


def test_client():
f = Flow().add(uses='_pass')
with f:
print(
py_client(port_expose=f.port_expose).call_unary(
b'a1234', mode=ClientMode.INDEX
)
)


def test_check_input():
input_fn = iter([b'1234', b'45467'])
PyClient.check_input(input_fn)
input_fn = iter([Document(), Document()])
PyClient.check_input(input_fn)
bad_input_fn_1 = iter([b'1234', '45467', [12, 2, 3]])
with pytest.raises(TypeError):
PyClient.check_input(bad_input_fn_1)
bad_input_fn_2 = iter([Document(), None])
with pytest.raises(TypeError):
PyClient.check_input(bad_input_fn_2)


def test_gateway_ready():
p = set_gateway_parser().parse_args([])
with RESTGatewayPea(p):
a = requests.get(f'http://0.0.0.0:{p.port_expose}/ready')
assert a.status_code == 200

with RESTGatewayPea(p):
a = requests.post(f'http://0.0.0.0:{p.port_expose}/api/ass')
assert a.status_code == 405


def test_gateway_index():
f = Flow(rest_api=True).add(uses='_pass')
with f:
a = requests.post(
f'http://0.0.0.0:{f.port_expose}/api/index',
json={
'data': [
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AvdGjTZeOlQq07xSYPgJjlWRwfWEBx2+CgAVrPrP+O5ghhOa+a0cocoWnaMJFAsBuCQCgiJOKDBcIQTiLieOrPD/cp/6iZ/Iu4HqAh5dGzggIQVJI3WqTxwVTDjs5XJOy38AlgHoaKgY+xJEXeFTyR7FOfF7JNWjs3b8evQE6B2dTDvQZx3n3Rz6rgOtVlaZRLvR9geCAxuY3G+0mepEAhrTISES3bwPWYYi48OUrQOc//IaJeij9xZGGmDIG9kc73fNI7eA8VMBAAD//0SxXMMT90UdAAAAAElFTkSuQmCC',
]
},
)

j = a.json()
assert 'index' in j
assert len(j['index']['docs']) == 2
assert (
j['index']['docs'][0]['uri']
== 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC'
)
assert a.status_code == 200


def test_gateway_index_with_args():
f = Flow(rest_api=True).add(uses='_pass')
with f:
a = requests.post(
f'http://0.0.0.0:{f.port_expose}/api/index',
json={
'data': [
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AvdGjTZeOlQq07xSYPgJjlWRwfWEBx2+CgAVrPrP+O5ghhOa+a0cocoWnaMJFAsBuCQCgiJOKDBcIQTiLieOrPD/cp/6iZ/Iu4HqAh5dGzggIQVJI3WqTxwVTDjs5XJOy38AlgHoaKgY+xJEXeFTyR7FOfF7JNWjs3b8evQE6B2dTDvQZx3n3Rz6rgOtVlaZRLvR9geCAxuY3G+0mepEAhrTISES3bwPWYYi48OUrQOc//IaJeij9xZGGmDIG9kc73fNI7eA8VMBAAD//0SxXMMT90UdAAAAAElFTkSuQmCC',
],
},
)
j = a.json()
assert 'index' in j
assert len(j['index']['docs']) == 2
assert (
j['index']['docs'][0]['uri']
== 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAA2ElEQVR4nADIADf/AxWcWRUeCEeBO68T3u1qLWarHqMaxDnxhAEaLh0Ssu6ZGfnKcjP4CeDLoJok3o4aOPYAJocsjktZfo4Z7Q/WR1UTgppAAdguAhR+AUm9AnqRH2jgdBZ0R+kKxAFoAME32BL7fwQbcLzhw+dXMmY9BS9K8EarXyWLH8VYK1MACkxlLTY4Eh69XfjpROqjE7P0AeBx6DGmA8/lRRlTCmPkL196pC0aWBkVs2wyjqb/LABVYL8Xgeomjl3VtEMxAeaUrGvnIawVh/oBAAD///GwU6v3yCoVAAAAAElFTkSuQmCC'
)
assert a.status_code == 200


def test_mime_type():

f = Flow().add(uses='- !URI2Buffer {}')

def validate_mime_type(req):
for d in req.index.docs:
assert d.mime_type == 'text/x-python'

with f:
f.index(input_files('*.py'), validate_mime_type)
28 changes: 28 additions & 0 deletions tests/unit/clients/python/test_io.py
@@ -0,0 +1,28 @@
import numpy as np
import os

from jina.clients.python import PyClient
from jina.clients.python.io import input_files, input_lines, input_numpy


def test_read_file(tmpdir):
input_filepath = os.path.join(tmpdir, 'input_file.csv')
with open(input_filepath, 'w') as input_file:
input_file.writelines(["1\n", "2\n", "3\n"])
result = list(input_lines(filepath=input_filepath, size=2))
assert len(result) == 2
assert result[0] == "1\n"
assert result[1] == "2\n"


def test_io_files():
PyClient.check_input(input_files('*.*'))
PyClient.check_input(input_files('*.*', recursive=True))
PyClient.check_input(input_files('*.*', size=2))
PyClient.check_input(input_files('*.*', size=2, read_mode='rb'))
PyClient.check_input(input_files('*.*', sampling_rate=0.5))


def test_io_np():
PyClient.check_input(input_numpy(np.random.random([100, 4, 2])))
PyClient.check_input(['asda', 'dsadas asdasd'])

0 comments on commit b73ab6b

Please sign in to comment.