Skip to content

Commit

Permalink
Merge 880cd2e into 761c9c0
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Dec 2, 2022
2 parents 761c9c0 + 880cd2e commit aee19c3
Show file tree
Hide file tree
Showing 85 changed files with 353 additions and 386 deletions.
1 change: 1 addition & 0 deletions .devcontainer/dev-requirements.txt
Expand Up @@ -6,3 +6,4 @@ ipympl
pre-commit == 2.20.*
black == 22.10.0
isort == 5.10.1
pyupgrade == 3.2.3
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Expand Up @@ -10,3 +10,5 @@ d3ae59251c753ae0737d6ae6242b7e85b60908c4
1e9ea598491444fe7c4ee9ece2ec94ad7c5020ec
# Reformatting with isort
67bf6d3760fa3fb8b3aa121b1b972d6cf36ec048
# Update syntax to Python 3.8 with pyupgrade
28b02c51545298cb9a76d8295e64a5df391b9207
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -3,6 +3,13 @@ repos:
hooks:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/asottile/pyupgrade
rev: v3.2.3
hooks:
- id: pyupgrade
args: [--py38-plus]
exclude: '^brian2/_version.py$'
files: '^brian2/.*\.pyi?$'
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
Expand Down
4 changes: 2 additions & 2 deletions brian2/__init__.py
Expand Up @@ -128,7 +128,7 @@ def _get_size_recursively(dirname):
try:
size = os.path.getsize(os.path.join(dirpath, fname))
total_size += size
except (OSError, IOError):
except OSError:
pass # ignore the file
return total_size

Expand Down Expand Up @@ -185,7 +185,7 @@ def clear_cache(target):
if f.endswith(ext):
break
else:
raise IOError(
raise OSError(
f"The cache directory for target '{target}' contains "
f"the file '{os.path.join(folder, f)}' of an unexpected type and "
"will therefore not be removed. Delete files in "
Expand Down
2 changes: 1 addition & 1 deletion brian2/codegen/codeobject.py
Expand Up @@ -342,7 +342,7 @@ def create_runner_codeobj(
device = get_device()

if override_conditional_write is None:
override_conditional_write = set([])
override_conditional_write = set()
else:
override_conditional_write = set(override_conditional_write)

Expand Down
18 changes: 9 additions & 9 deletions brian2/codegen/cpp_prefs.py
Expand Up @@ -46,13 +46,13 @@
hostname = socket.gethostname()
if os.path.isfile(flag_file):
try:
with open(flag_file, "r", encoding="utf-8") as f:
with open(flag_file, encoding="utf-8") as f:
previously_stored_flags = json.load(f)
if hostname not in previously_stored_flags:
logger.debug("Ignoring stored CPU flags for a different host")
else:
flags = previously_stored_flags[hostname]
except (IOError, OSError) as ex:
except OSError as ex:
logger.debug(
f'Opening file "{flag_file}" to get CPU flags failed with error'
f' "{str(ex)}".'
Expand All @@ -66,7 +66,7 @@
try:
output = subprocess.check_output(
[sys.executable, get_cpu_flags_script],
universal_newlines=True,
text=True,
encoding="utf-8",
)
flags = json.loads(output)
Expand All @@ -79,7 +79,7 @@
to_store = {hostname: flags}
with open(flag_file, "w", encoding="utf-8") as f:
json.dump(to_store, f)
except (IOError, OSError) as ex:
except OSError as ex:
logger.debug(
f'Writing file "{flag_file}" to store CPU flags failed with error'
f' "{str(ex)}".'
Expand Down Expand Up @@ -269,7 +269,7 @@ def _determine_flag_compatibility(compiler, flagname):
prefix="brian_flag_test_"
) as temp_dir, std_silent():
fname = os.path.join(temp_dir, "flag_test.cpp")
with open(fname, "wt") as f:
with open(fname, "w") as f:
f.write("int main (int argc, char **argv) { return 0; }")
try:
compiler.compile([fname], output_dir=temp_dir, extra_postargs=[flagname])
Expand Down Expand Up @@ -351,7 +351,7 @@ def get_msvc_env():
try:
_msvc_env = msvc.msvc14_get_vc_env(arch_name)
except distutils.errors.DistutilsPlatformError:
raise IOError(
raise OSError(
"Cannot find Microsoft Visual Studio, You "
"can try to set the path to vcvarsall.bat "
"via the codegen.cpp.msvc_vars_location "
Expand All @@ -370,11 +370,11 @@ def compiler_supports_c99():
fd, tmp_file = tempfile.mkstemp(suffix=".cpp")
os.write(
fd,
"""
b"""
#if _MSC_VER < 1800
#error
#endif
""".encode(),
""",
)
os.close(fd)
msvc_env, vcvars_cmd = get_msvc_env()
Expand All @@ -396,7 +396,7 @@ def compiler_supports_c99():
return _compiler_supports_c99


class C99Check(object):
class C99Check:
"""
Helper class to create objects that can be passed as an ``availability_check`` to
a `FunctionImplementation`.
Expand Down
30 changes: 14 additions & 16 deletions brian2/codegen/generators/GSL_generator.py
Expand Up @@ -64,7 +64,7 @@ def valid_gsl_dir(val):
)


class GSLCodeGenerator(object):
class GSLCodeGenerator:
"""
GSL code generator.
Expand Down Expand Up @@ -427,9 +427,9 @@ def write_dataholder_single(self, var_obj):
restrict = ""
if var_obj.scalar or var_obj.size == 1:
restrict = ""
return "%s* %s %s{end_statement}" % (dtype, restrict, pointer_name)
return f"{dtype}* {restrict} {pointer_name}{{end_statement}}"
else:
return "%s %s{end_statement}" % (dtype, var_obj.name)
return f"{dtype} {var_obj.name}{{end_statement}}"

def write_dataholder(self, variables_in_vector):
"""
Expand Down Expand Up @@ -530,7 +530,7 @@ def scale_array_code(self, diff_vars, method_options):
)

def find_undefined_variables(self, statements):
"""
r"""
Find identifiers that are not in ``self.variables`` dictionary.
Brian does not save the ``_lio_`` variables it uses anywhere. This is
Expand Down Expand Up @@ -737,9 +737,9 @@ def translate_vector_code(self, code_lines, to_replace):
# special substitute because of limitations of regex word boundaries with
# variable[_idx]
for from_sub, to_sub in list(to_replace.items()):
m = re.search("\[(\w+)\];?$", from_sub)
m = re.search(r"\[(\w+)\];?$", from_sub)
if m:
code = re.sub(re.sub("\[", "\[", from_sub), to_sub, code)
code = re.sub(re.sub(r"\[", r"\[", from_sub), to_sub, code)

if "_gsl" in code:
raise AssertionError(
Expand Down Expand Up @@ -775,7 +775,7 @@ def translate_scalar_code(
"""
code = []
for line in code_lines:
m = re.search("(\w+ = .*)", line)
m = re.search(r"(\w+ = .*)", line)
try:
new_line = m.group(1)
var, op, expr, comment = parse_statement(new_line)
Expand Down Expand Up @@ -935,14 +935,12 @@ def translate(
f"{len(vs)} lines of abstract code, first line is: '{vs[0]}'\n"
)
logger.warn(
(
"Came across an abstract code block that may not be "
"well-defined: the outcome may depend on the "
"order of execution. You can ignore this warning if "
"you are sure that the order of operations does not "
"matter. "
+ error_msg
)
"Came across an abstract code block that may not be "
"well-defined: the outcome may depend on the "
"order of execution. You can ignore this warning if "
"you are sure that the order of operations does not "
"matter. "
+ error_msg
)

# save function names because self.generator.translate_statement_sequence
Expand All @@ -960,7 +958,7 @@ def translate(
# first check if any indexing other than '_idx' is used (currently not supported)
for code_list in list(scalar_code.values()) + list(vector_code.values()):
for code in code_list:
m = re.search("\[(\w+)\]", code)
m = re.search(r"\[(\w+)\]", code)
if m is not None:
if m.group(1) != "0" and m.group(1) != "_idx":
from brian2.stateupdaters.base import (
Expand Down
40 changes: 19 additions & 21 deletions brian2/codegen/generators/base.py
Expand Up @@ -19,7 +19,7 @@
logger = get_logger(__name__)


class CodeGenerator(object):
class CodeGenerator:
"""
Base class for all languages.
Expand Down Expand Up @@ -187,30 +187,30 @@ def array_read_write(self, statements):
f"referring to vector variable '{name}'"
)
write.add(stmt.var)
read = set(
read = {
varname
for varname, var in list(variables.items())
if isinstance(var, ArrayVariable) and varname in read
)
write = set(
}
write = {
varname
for varname, var in list(variables.items())
if isinstance(var, ArrayVariable) and varname in write
)
}
# Gather the indices stored as arrays (ignore _idx which is special)
indices = set()
indices |= set(
indices |= {
variable_indices[varname]
for varname in read
if not variable_indices[varname] in ("_idx", "0")
and isinstance(variables[variable_indices[varname]], ArrayVariable)
)
indices |= set(
}
indices |= {
variable_indices[varname]
for varname in write
if not variable_indices[varname] in ("_idx", "0")
and isinstance(variables[variable_indices[varname]], ArrayVariable)
)
}
# don't list arrays that are read explicitly and used as indices twice
read -= indices
return read, write, indices
Expand All @@ -236,12 +236,12 @@ def arrays_helper(self, statements):
"""
read, write, indices = self.array_read_write(statements)
conditional_write_vars = self.get_conditional_write_vars()
read |= set(var for var in write if var in conditional_write_vars)
read |= set(
read |= {var for var in write if var in conditional_write_vars}
read |= {
conditional_write_vars[var]
for var in write
if var in conditional_write_vars
)
}
return read, write, indices, conditional_write_vars

def has_repeated_indices(self, statements):
Expand All @@ -255,7 +255,7 @@ def has_repeated_indices(self, statements):
# Check whether we potentially deal with repeated indices (which will
# be the case most importantly when we write to pre- or post-synaptic
# variables in synaptic code)
used_indices = set(variable_indices[var] for var in write)
used_indices = {variable_indices[var] for var in write}
all_unique = all(
variables[index].unique
for index in used_indices
Expand Down Expand Up @@ -293,14 +293,12 @@ def translate(self, code, dtype):
f"{len(vs)} lines of abstract code, first line is: '{vs[0]}'\n"
)
logger.warn(
(
"Came across an abstract code block that may not be "
"well-defined: the outcome may depend on the "
"order of execution. You can ignore this warning if "
"you are sure that the order of operations does not "
"matter. "
+ error_msg
)
"Came across an abstract code block that may not be "
"well-defined: the outcome may depend on the "
"order of execution. You can ignore this warning if "
"you are sure that the order of operations does not "
"matter. "
+ error_msg
)

translated = self.translate_statement_sequence(
Expand Down
2 changes: 1 addition & 1 deletion brian2/codegen/generators/cpp_generator.py
Expand Up @@ -160,7 +160,7 @@ class CPPCodeGenerator(CodeGenerator):
universal_support_code = _universal_support_code

def __init__(self, *args, **kwds):
super(CPPCodeGenerator, self).__init__(*args, **kwds)
super().__init__(*args, **kwds)
self.c_data_type = c_data_type

@property
Expand Down
8 changes: 4 additions & 4 deletions brian2/codegen/generators/cython_generator.py
Expand Up @@ -32,8 +32,8 @@
]
# fmt: on

cpp_dtype = dict((canonical, cpp) for canonical, cpp, np in data_type_conversion_table)
numpy_dtype = dict((canonical, np) for canonical, cpp, np in data_type_conversion_table)
cpp_dtype = {canonical: cpp for canonical, cpp, np in data_type_conversion_table}
numpy_dtype = {canonical: np for canonical, cpp, np in data_type_conversion_table}


def get_cpp_dtype(obj):
Expand All @@ -57,7 +57,7 @@ def render_BinOp(self, node):
right = self.render_node(node.right)
return f"((({left})%({right}))+({right}))%({right})"
else:
return super(CythonNodeRenderer, self).render_BinOp(node)
return super().render_BinOp(node)


class CythonCodeGenerator(CodeGenerator):
Expand All @@ -69,7 +69,7 @@ class CythonCodeGenerator(CodeGenerator):

def __init__(self, *args, **kwds):
self.temporary_vars = set()
super(CythonCodeGenerator, self).__init__(*args, **kwds)
super().__init__(*args, **kwds)

def translate_expression(self, expr):
expr = word_substitute(expr, self.func_name_replacements)
Expand Down

0 comments on commit aee19c3

Please sign in to comment.