Skip to content

Commit

Permalink
Merge pull request #824 from jenshnielsen/mypy_exp
Browse files Browse the repository at this point in the history
Run mypy type checking on the codebase with Travis
  • Loading branch information
WilliamHPNielsen committed Apr 4, 2018
2 parents c8d8588 + e1003d9 commit a5645fb
Show file tree
Hide file tree
Showing 48 changed files with 283 additions and 214 deletions.
5 changes: 4 additions & 1 deletion .travis.yml
Expand Up @@ -35,7 +35,10 @@ script:
- py.test --cov=qcodes --cov-report xml --cov-config=.coveragerc
# build docs with warnings as errors
- |
cd ../docs
cd ..
mypy qcodes --ignore-missing-imports
- |
cd docs
make SPHINXOPTS="-W" html-api
- cd ..

Expand Down
4 changes: 2 additions & 2 deletions qcodes/__init__.py
Expand Up @@ -10,7 +10,7 @@
# we dont want spyder to reload qcodes as this will overwrite the default station
# instrument list and running monitor
add_to_spyder_UMR_excludelist('qcodes')
config = Config()
config = Config() # type: Config

from qcodes.version import __version__

Expand Down Expand Up @@ -88,7 +88,7 @@
del _c

try:
get_ipython() # Check if we are in iPython
get_ipython() # type: ignore # Check if we are in iPython
from qcodes.utils.magic import register_magic_class
_register_magic = config.core.get('register_magic', False)
if _register_magic is not False:
Expand Down
5 changes: 3 additions & 2 deletions qcodes/config/config.py
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

import jsonschema
from typing import Dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,8 +88,8 @@ class Config():
defaults = None
defaults_schema = None

_diff_config = {}
_diff_schema = {}
_diff_config: Dict[str, dict] = {}
_diff_schema: Dict[str, dict] = {}

def __init__(self):
self.defaults, self.defaults_schema = self.load_default()
Expand Down
3 changes: 2 additions & 1 deletion qcodes/data/data_set.py
Expand Up @@ -5,6 +5,7 @@
from traceback import format_exc
from copy import deepcopy
from collections import OrderedDict
from typing import Dict, Callable

from .gnuplot_format import GNUPlotFormat
from .io import DiskIO
Expand Down Expand Up @@ -168,7 +169,7 @@ class DataSet(DelegateAttributes):
default_formatter = GNUPlotFormat()
location_provider = FormatLocation()

background_functions = OrderedDict()
background_functions: Dict[str, Callable] = OrderedDict()

def __init__(self, location=None, arrays=None, formatter=None, io=None,
write_period=5):
Expand Down
4 changes: 2 additions & 2 deletions qcodes/data/location.py
Expand Up @@ -3,7 +3,7 @@
import re
import string

from qcodes import config
import qcodes.config

class SafeFormatter(string.Formatter):

Expand Down Expand Up @@ -83,7 +83,7 @@ class FormatLocation:
as '{date:%Y-%m-%d}' or '{counter:03}'
"""

default_fmt = config['core']['default_fmt']
default_fmt = qcodes.config['core']['default_fmt']

def __init__(self, fmt=None, fmt_date=None, fmt_time=None,
fmt_counter=None, record=None):
Expand Down
6 changes: 3 additions & 3 deletions qcodes/dataset/data_set.py
Expand Up @@ -179,8 +179,8 @@ def _new(self, name, exp_id, specs: SPECS = None, values=None,
Actually perform all the side effects needed for
the creation of a new dataset.
"""
_, run_id, _ = create_run(self.conn, exp_id, name,
specs, values, metadata)
_, run_id, __ = create_run(self.conn, exp_id, name,
specs, values, metadata)

# this is really the UUID (an ever increasing count in the db)
self.run_id = run_id
Expand Down Expand Up @@ -440,7 +440,7 @@ def modify_results(self, start_index: int,
flattened_keys,
flattened_values)

def add_parameter_values(self, spec: ParamSpec, values: List[VALUES]):
def add_parameter_values(self, spec: ParamSpec, values: VALUES):
"""
Add a parameter to the DataSet and associates result values with the
new parameter.
Expand Down
4 changes: 2 additions & 2 deletions qcodes/dataset/experiment_container.py
Expand Up @@ -236,8 +236,8 @@ def load_experiment_by_name(name: str,
for row in rows:
s = f"exp_id:{row['exp_id']} ({row['name']}-{row['sample_name']}) started at({row['start_time']})"
_repr.append(s)
_repr = "\n".join(_repr)
raise ValueError(f"Many experiments matching your request found {_repr}")
_repr_str = "\n".join(_repr)
raise ValueError(f"Many experiments matching your request found {_repr_str}")
else:
e.exp_id = rows[0]['exp_id']
return e
19 changes: 12 additions & 7 deletions qcodes/dataset/measurements.py
Expand Up @@ -11,7 +11,7 @@

import qcodes as qc
from qcodes import Station
from qcodes.instrument.parameter import ArrayParameter, _BaseParameter
from qcodes.instrument.parameter import ArrayParameter, _BaseParameter, Parameter
from qcodes.dataset.experiment_container import Experiment
from qcodes.dataset.param_spec import ParamSpec
from qcodes.dataset.data_set import DataSet
Expand All @@ -29,7 +29,7 @@ class DataSaver:
datasaving to the database
"""

default_callback = None
default_callback: Optional[dict] = None

def __init__(self, dataset: DataSet, write_period: float,
parameters: Dict[str, ParamSpec]) -> None:
Expand All @@ -45,7 +45,7 @@ def __init__(self, dataset: DataSet, write_period: float,
self._known_parameters = list(parameters.keys())
self._results: List[dict] = [] # will be filled by addResult
self._last_save_time = monotonic()
self._known_dependencies: Dict[str, str] = {}
self._known_dependencies: Dict[str, List[str]] = {}
for param, parspec in parameters.items():
if parspec.depends_on != '':
self._known_dependencies.update({str(param):
Expand Down Expand Up @@ -152,6 +152,7 @@ def add_result(self,
# For compatibility with the old Loop, setpoints are
# tuples of numbers (usually tuple(np.linspace(...))
if hasattr(value, '__len__') and not(isinstance(value, str)):
value = cast(Union[Sequence,np.ndarray], value)
res_dict.update({param: value[index]})
else:
res_dict.update({param: value})
Expand Down Expand Up @@ -398,6 +399,7 @@ def register_parameter(
name = str(parameter)

if isinstance(parameter, ArrayParameter):
parameter = cast(ArrayParameter, parameter)
if parameter.setpoint_names:
spname = (f'{parameter._instrument.name}_'
f'{parameter.setpoint_names[0]}')
Expand All @@ -416,8 +418,10 @@ def register_parameter(
label=splabel, unit=spunit)

self.parameters[spname] = sp
setpoints = setpoints if setpoints else ()
setpoints += (spname,)
my_setpoints: Tuple[Union[_BaseParameter, str], ...] = setpoints if setpoints else ()
my_setpoints += (spname,)
else:
my_setpoints = setpoints

# We currently treat ALL parameters as 'numeric' and fail to add them
# to the dataset if they can not be unraveled to fit that description
Expand All @@ -426,12 +430,13 @@ def register_parameter(
# requirement later and start saving binary blobs with the datasaver,
# but for now binary blob saving is referred to using the DataSet
# API directly
parameter = cast(Union[Parameter, ArrayParameter], parameter)
paramtype = 'numeric'
label = parameter.label
unit = parameter.unit

if setpoints:
sp_strings = [str(sp) for sp in setpoints]
if my_setpoints:
sp_strings = [str(sp) for sp in my_setpoints]
else:
sp_strings = []
if basis:
Expand Down
9 changes: 4 additions & 5 deletions qcodes/dataset/plotting.py
Expand Up @@ -13,11 +13,10 @@
log = logging.getLogger(__name__)
DB = qc.config["core"]["db_location"]

mplaxes = matplotlib.axes.Axes

def plot_by_id(run_id: int,
axes: Optional[Union[mplaxes,
Sequence[mplaxes]]]=None) -> List[mplaxes]:
axes: Optional[Union[matplotlib.axes.Axes,
Sequence[matplotlib.axes.Axes]]]=None) -> List[matplotlib.axes.Axes]:
def set_axis_labels(ax, data):
if data[0]['label'] == '':
lbl = data[0]['name']
Expand Down Expand Up @@ -50,7 +49,7 @@ def set_axis_labels(ax, data):
"""
alldata = get_data_by_id(run_id)
nplots = len(alldata)
if isinstance(axes, mplaxes):
if isinstance(axes, matplotlib.axes.Axes):
axes = [axes]

if axes is None:
Expand Down Expand Up @@ -115,7 +114,7 @@ def set_axis_labels(ax, data):

def plot_on_a_plain_grid(x: np.ndarray, y: np.ndarray,
z: np.ndarray,
ax: mplaxes) -> mplaxes:
ax: matplotlib.axes.Axes) -> matplotlib.axes.Axes:
"""
Plot a heatmap of z using x and y as axes. Assumes that the data
are rectangular, i.e. that x and y together describe a rectangular
Expand Down
6 changes: 3 additions & 3 deletions qcodes/dataset/sqlite_base.py
Expand Up @@ -107,7 +107,7 @@ def _convert_array(text: bytes) -> ndarray:
return np.load(out)


def one(curr: sqlite3.Cursor, column: str) -> Any:
def one(curr: sqlite3.Cursor, column: Union[int, str]) -> Any:
"""Get the value of one column from one row
Args:
curr: cursor to operate on
Expand Down Expand Up @@ -408,11 +408,11 @@ def insert_many_values(conn: sqlite3.Connection,
# According to the SQLite changelog, the version number
# to check against below
# ought to be 3.7.11, but that fails on Travis
if LooseVersion(version) <= LooseVersion('3.8.2'):
if LooseVersion(str(version)) <= LooseVersion('3.8.2'):
max_var = qc.SQLiteSettings.limits['MAX_COMPOUND_SELECT']
else:
max_var = qc.SQLiteSettings.limits['MAX_VARIABLE_NUMBER']
rows_per_transaction = int(max_var/no_of_columns)
rows_per_transaction = int(int(max_var)/no_of_columns)

_columns = ",".join(columns)
_values = "(" + ",".join(["?"] * len(values[0])) + ")"
Expand Down
5 changes: 4 additions & 1 deletion qcodes/dataset/sqlite_settings.py
Expand Up @@ -2,7 +2,7 @@
from typing import Tuple, Dict, Union


def _read_settings() -> Tuple[Dict[str, str],
def _read_settings() -> Tuple[Dict[str, Union[str,int]],
Dict[str, Union[bool, int, str]]]:
"""
Function to read the local SQLite settings at import time.
Expand All @@ -19,6 +19,7 @@ def _read_settings() -> Tuple[Dict[str, str],
"""
# For the limits, there are known default values
# (known from https://www.sqlite.org/limits.html)
DEFAULT_LIMITS: Dict[str, Union[str, int]]
DEFAULT_LIMITS = {'MAX_ATTACHED': 10,
'MAX_COLUMN': 2000,
'MAX_COMPOUND_SELECT': 500,
Expand All @@ -35,6 +36,7 @@ def _read_settings() -> Tuple[Dict[str, str],
opt_num = 0
resp = ''

limits: Dict[str, Union[str,int]]
limits = DEFAULT_LIMITS.copy()
settings = {}

Expand All @@ -47,6 +49,7 @@ def _read_settings() -> Tuple[Dict[str, str],
opt_num += 1
lst = resp.split('=')
if len(lst) == 2:
val: Union[str,int]
(param, val) = lst
if val.isnumeric():
val = int(val)
Expand Down
32 changes: 18 additions & 14 deletions qcodes/instrument/base.py
Expand Up @@ -3,14 +3,15 @@
import time
import warnings
import weakref
from typing import Sequence, Optional, Dict, Union, Callable, Any, List
from typing import Sequence, Optional, Dict, Union, Callable, Any, List, TYPE_CHECKING, cast

import numpy as np

if TYPE_CHECKING:
from qcodes.instrumet.channel import ChannelList
from qcodes.utils.helpers import DelegateAttributes, strip_attrs, full_class
from qcodes.utils.metadata import Metadatable
from qcodes.utils.validators import Anything
from .parameter import Parameter
from .parameter import Parameter, _BaseParameter
from .function import Function

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,9 +46,9 @@ def __init__(self, name: str,
metadata: Optional[Dict]=None, **kwargs) -> None:
self.name = str(name)

self.parameters = {}
self.functions = {}
self.submodules = {}
self.parameters = {} # type: Dict[str, _BaseParameter]
self.functions = {} # type: Dict[str, Function]
self.submodules = {} # type: Dict[str, Union['InstrumentBase', 'ChannelList']]
super().__init__(**kwargs)

def add_parameter(self, name: str,
Expand Down Expand Up @@ -109,7 +110,7 @@ def add_function(self, name: str, **kwargs) -> None:
func = Function(name=name, instrument=self, **kwargs)
self.functions[name] = func

def add_submodule(self, name: str, submodule: Metadatable) -> None:
def add_submodule(self, name: str, submodule: Union['InstrumentBase', 'ChannelList']) -> None:
"""
Bind one submodule to this instrument.
Expand Down Expand Up @@ -360,7 +361,9 @@ class Instrument(InstrumentBase):

shared_kwargs = ()

_all_instruments = {}
_all_instruments = {} # type: Dict[str, weakref.ref[Instrument]]
_type = None
_instances = [] # type: List[weakref.ref]

def __init__(self, name: str,
metadata: Optional[Dict]=None, **kwargs) -> None:
Expand All @@ -377,7 +380,7 @@ def __init__(self, name: str,

self.record_instance(self)

def get_idn(self) -> Dict:
def get_idn(self) -> Dict[str, Optional[str]]:
"""
Parse a standard VISA '\*IDN?' response into an ID dict.
Expand All @@ -399,6 +402,7 @@ def get_idn(self) -> Dict:
idstr = self.ask('*IDN?')
# form is supposed to be comma-separated, but we've seen
# other separators occasionally
idparts = [] # type: List[Optional[str]]
for separator in ',;:':
# split into no more than 4 parts, so we don't lose info
idparts = [p.strip() for p in idstr.split(separator, 3)]
Expand Down Expand Up @@ -580,14 +584,14 @@ def find_instrument(cls, name: str,
if ins is None:
del cls._all_instruments[name]
raise KeyError('Instrument {} has been removed'.format(name))

inst = cast('Instrument', ins)
if instrument_class is not None:
if not isinstance(ins, instrument_class):
if not isinstance(inst, instrument_class):
raise TypeError(
'Instrument {} is {} but {} was requested'.format(
name, type(ins), instrument_class))
name, type(inst), instrument_class))

return ins
return inst

# `write_raw` and `ask_raw` are the interface to hardware #
# `write` and `ask` are standard wrappers to help with error reporting #
Expand Down Expand Up @@ -658,7 +662,7 @@ def ask(self, cmd: str) -> str:
e.args = e.args + ('asking ' + repr(cmd) + ' to ' + inst,)
raise e

def ask_raw(self, cmd: str) -> None:
def ask_raw(self, cmd: str) -> str:
"""
Low level method to write to the hardware and return a response.
Expand Down

0 comments on commit a5645fb

Please sign in to comment.