Skip to content

Commit

Permalink
Do not unpack sqlite3.Row to list
Browse files Browse the repository at this point in the history
  • Loading branch information
Florian Blanchet authored and fblanchetNaN committed Aug 2, 2022
1 parent b0b6528 commit 20c58d7
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 80 deletions.
157 changes: 85 additions & 72 deletions qcodes/dataset/sqlite/queries.py
Expand Up @@ -136,14 +136,17 @@ def _build_data_query(table_name: str,
return query


@deprecate('This method does not accurately represent the dataset.',
'Use `get_parameter_data` instead.')
def get_data(conn: ConnectionPlus,
table_name: str,
columns: List[str],
start: Optional[int] = None,
end: Optional[int] = None,
) -> List[List[Any]]:
@deprecate(
"This method does not accurately represent the dataset.",
"Use `get_parameter_data` instead.",
)
def get_data(
conn: ConnectionPlus,
table_name: str,
columns: List[str],
start: Optional[int] = None,
end: Optional[int] = None,
) -> List[Tuple[Any, ...]]:
"""
Get data from the columns of a table.
Allows to specify a range of rows (1-based indexing, both ends are
Expand All @@ -164,7 +167,7 @@ def get_data(conn: ConnectionPlus,
'get_data: requested data without specifying parameters/columns.'
'Returning empty list.'
)
return [[]]
return [tuple()]
query = _build_data_query(table_name, columns, start, end)
c = atomic_transaction(conn, query)
res = many_many(c, *columns)
Expand Down Expand Up @@ -326,7 +329,9 @@ def get_parameter_data_for_one_paramtree(
return param_data, n_rows


def _expand_data_to_arrays(data: List[List[Any]], paramspecs: Sequence[ParamSpecBase]) -> None:
def _expand_data_to_arrays(
data: List[Tuple[Any, ...]], paramspecs: Sequence[ParamSpecBase]
) -> None:
types = [param.type for param in paramspecs]
# if we have array type parameters expand all other parameters
# to arrays
Expand All @@ -335,55 +340,60 @@ def _expand_data_to_arrays(data: List[List[Any]], paramspecs: Sequence[ParamSpec
if ('numeric' in types or 'text' in types
or 'complex' in types):
first_array_element = types.index('array')
numeric_elms = [i for i, x in enumerate(types)
if x == "numeric"]
complex_elms = [i for i, x in enumerate(types)
if x == 'complex']
text_elms = [i for i, x in enumerate(types)
if x == "text"]
for row in data:
for element in numeric_elms:
row[element] = np.full_like(row[first_array_element],
row[element],
dtype=np.dtype(np.float64))
# todo should we handle int/float types here
# we would in practice have to perform another
# loop to check that all elements of a given can be cast to
# int without loosing precision before choosing an integer
# representation of the array
for element in complex_elms:
row[element] = np.full_like(row[first_array_element],
row[element],
dtype=np.dtype(np.complex128))
for element in text_elms:
strlen = len(row[element])
row[element] = np.full_like(row[first_array_element],
row[element],
dtype=np.dtype(f'U{strlen}'))

for row in data:
types_mapping: Dict[int, Callable[[str], np.dtype[Any]]] = {}
for i, x in enumerate(types):
if x == "numeric":
types_mapping[i] = lambda _: np.dtype(np.float64)
elif x == "complex":
types_mapping[i] = lambda _: np.dtype(np.complex128)
elif x == "text":
types_mapping[i] = lambda array: np.dtype(f"U{len(array)}")

for i_row, row in enumerate(data):
# todo should we handle int/float types here
# we would in practice have to perform another
# loop to check that all elements of a given can be cast to
# int without loosing precision before choosing an integer
# representation of the array
data[i_row] = tuple(
np.full_like(
row[first_array_element], array, dtype=types_mapping[i](array)
)
if i in types_mapping
else array
for i, array in enumerate(row)
)

for i_row, row in enumerate(data):
# now expand all one element arrays to match the expected size
# one element arrays are introduced if scalar values are stored
# with an explicit array storage type
sizes = tuple(array.size for array in row)
max_size = max(sizes)
max_index = sizes.index(max_size)

max_size = 0
for i, array in enumerate(row):
if array.size != max_size:
if array.size == 1:
row[i] = np.full_like(row[max_index],
row[i],
dtype=row[i].dtype)
else:
log.warning(f"Cannot expand array of size {array.size} "
f"to size {row[max_index].size}")


def _get_data_for_one_param_tree(conn: ConnectionPlus, table_name: str,
interdeps: InterDependencies_, output_param: str,
start: Optional[int], end: Optional[int]) \
-> Tuple[List[List[Any]], List[ParamSpecBase], int]:
if array.size > max_size:
if max_size > 1:
log.warning(
f"Cannot expand array of size {max_size} "
f"to size {array.size}"
)
max_size, max_index = array.size, i

data[i_row] = tuple(
np.full_like(row[max_index], array, dtype=array.dtype)
if array.size != max_size
else array
for array in row
)


def _get_data_for_one_param_tree(
conn: ConnectionPlus,
table_name: str,
interdeps: InterDependencies_,
output_param: str,
start: Optional[int],
end: Optional[int],
) -> Tuple[List[Tuple[Any, ...]], List[ParamSpecBase], int]:
output_param_spec = interdeps._id_to_paramspec[output_param]
# find all the dependencies of this param

Expand All @@ -400,11 +410,13 @@ def _get_data_for_one_param_tree(conn: ConnectionPlus, table_name: str,
return res, paramspecs, n_rows


@deprecate('This method does not accurately represent the dataset.',
'Use `get_parameter_data` instead.')
def get_values(conn: ConnectionPlus,
table_name: str,
param_name: str) -> List[List[Any]]:
@deprecate(
"This method does not accurately represent the dataset.",
"Use `get_parameter_data` instead.",
)
def get_values(
conn: ConnectionPlus, table_name: str, param_name: str
) -> List[Tuple[Any, ...]]:
"""
Get the not-null values of a parameter
Expand All @@ -426,12 +438,14 @@ def get_values(conn: ConnectionPlus,
return res


def get_parameter_tree_values(conn: ConnectionPlus,
result_table_name: str,
toplevel_param_name: str,
*other_param_names: str,
start: Optional[int] = None,
end: Optional[int] = None) -> List[List[Any]]:
def get_parameter_tree_values(
conn: ConnectionPlus,
result_table_name: str,
toplevel_param_name: str,
*other_param_names: str,
start: Optional[int] = None,
end: Optional[int] = None,
) -> List[Tuple[Any, ...]]:
"""
Get the values of one or more columns from a data table. The rows
retrieved are the rows where the 'toplevel_param_name' column has
Expand Down Expand Up @@ -491,9 +505,9 @@ def get_parameter_tree_values(conn: ConnectionPlus,


@deprecate(alternative="get_parameter_data")
def get_setpoints(conn: ConnectionPlus,
table_name: str,
param_name: str) -> Dict[str, List[List[Any]]]:
def get_setpoints(
conn: ConnectionPlus, table_name: str, param_name: str
) -> Dict[str, List[Tuple[Any, ...]]]:
"""
Get the setpoints for a given dependent parameter
Expand Down Expand Up @@ -546,7 +560,7 @@ def get_setpoints(conn: ConnectionPlus,
setpoint_names = cast(List[str], setpoint_names)

# get the actual setpoint data
output: Dict[str, List[List[Any]]] = {}
output: Dict[str, List[Tuple[Any, ...]]] = {}
for sp_name in setpoint_names:
sql = f"""
SELECT {sp_name}
Expand Down Expand Up @@ -783,8 +797,7 @@ def _get_dependents(conn: ConnectionPlus,
return res


def _get_dependencies(conn: ConnectionPlus,
layout_id: int) -> List[List[int]]:
def _get_dependencies(conn: ConnectionPlus, layout_id: int) -> List[Tuple[int, ...]]:
"""
Get the dependencies of a certain dependent variable (indexed by its
layout_id)
Expand Down
33 changes: 26 additions & 7 deletions qcodes/dataset/sqlite/query_helpers.py
Expand Up @@ -42,7 +42,21 @@ def one(curr: sqlite3.Cursor, column: Union[int, str]) -> Any:
return res[0][column]


def many(curr: sqlite3.Cursor, *columns: str) -> List[Any]:
def _need_to_select(curr: sqlite3.Cursor, *columns: str) -> bool:
"""
Return True if the columns' description of the last query doesn't exactly match
"""
return tuple(c[0] for c in curr.description) != columns


def _select_columns(row: sqlite3.Row, *columns: str) -> Tuple[Any, ...]:
"""
sqlite3.Row({key:value, key2:value2}), (key2,) -> [value2]
"""
return tuple(row[c] for c in columns)


def many(curr: sqlite3.Cursor, *columns: str) -> Tuple[Any, ...]:
"""Get the values of many columns from one row
Args:
curr: cursor to operate on
Expand All @@ -55,10 +69,14 @@ def many(curr: sqlite3.Cursor, *columns: str) -> List[Any]:
if len(res) > 1:
raise RuntimeError("Expected only one row")
else:
return [res[0][c] for c in columns]
return (
_select_columns(res[0], *columns)
if _need_to_select(curr, *columns)
else res[0]
)


def many_many(curr: sqlite3.Cursor, *columns: str) -> List[List[Any]]:
def many_many(curr: sqlite3.Cursor, *columns: str) -> List[Tuple[Any, ...]]:
"""Get all values of many columns
Args:
curr: cursor to operate on
Expand All @@ -68,10 +86,11 @@ def many_many(curr: sqlite3.Cursor, *columns: str) -> List[List[Any]]:
list of lists of values
"""
res = curr.fetchall()
results = []
for r in res:
results.append([r[c] for c in columns])
return results
return (
[_select_columns(r, *columns) for r in res]
if _need_to_select(curr, *columns)
else res
)


def select_one_where(
Expand Down
2 changes: 1 addition & 1 deletion qcodes/tests/dataset/test_sqlite_base.py
Expand Up @@ -239,7 +239,7 @@ def test_get_data_no_columns(scalar_dataset):
with pytest.warns(QCoDeSDeprecationWarning) as record:
ref = mut_queries.get_data(ds.conn, ds.table_name, [])

assert ref == [[]]
assert ref == [tuple()]
assert len(record) == 2
assert str(record[0].message).startswith("The function <get_data>")
assert str(record[1].message).startswith("get_data")
Expand Down

0 comments on commit 20c58d7

Please sign in to comment.