Skip to content

Commit

Permalink
Merge pull request #311 from brian-team/update_for_2.5.4
Browse files Browse the repository at this point in the history
Update for Brian 2.5.4
  • Loading branch information
mstimberg committed Jul 10, 2023
2 parents 87a32ce + 2c6cee8 commit f240d34
Show file tree
Hide file tree
Showing 15 changed files with 497 additions and 210 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ If you use this software in a published article, please cite
## License
Brian2CUDA is free software licensed under the [GNU General Public License v3 (GPLv3)](https://www.gnu.org/licenses/gpl-3.0.en.html).

## Testing
To run the test suite on Google Collab (no installation or GPU required), click on the badge below:

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brian-team/brian2cuda/blob/master/brian2cuda/tools/test_suite/run_tests.ipynb)
14 changes: 9 additions & 5 deletions brian2cuda/cuda_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,11 @@ def determine_keywords(self):
# set up the restricted pointers, these are used so that the compiler
# knows there is no aliasing in the pointers, for optimisation
pointers = []
# Add additional lines inside the kernel functions
# Add additional lines inside the kernel functions, also contains the clock_pointers
# lines from below for backwards compatibility
kernel_lines = []
# Translate clock variables back into pointers
clock_pointers = []
# It is possible that several different variable names refer to the
# same array. E.g. in gapjunction code, v_pre and v_post refer to the
# same array if a group is connected to itself
Expand Down Expand Up @@ -783,16 +786,17 @@ def determine_keywords(self):
# c_data_type(variable.dtype) is float, but we need double
dtype = "double"
else:
dtype = dtype=c_data_type(variable.dtype)
dtype = c_data_type(variable.dtype)
line = f"const {dtype}* _ptr{arrayname} = &_value{arrayname};"
if line not in kernel_lines:
kernel_lines.append(line)
if line not in clock_pointers:
clock_pointers.append(line)

keywords = {'pointers_lines': stripped_deindented_lines('\n'.join(pointers)),
'support_code_lines': stripped_deindented_lines('\n'.join(support_code)),
'hashdefine_lines': stripped_deindented_lines('\n'.join(hash_defines)),
'denormals_code_lines': stripped_deindented_lines('\n'.join(self.denormals_to_zero_code())),
'kernel_lines': stripped_deindented_lines('\n'.join(kernel_lines)),
'kernel_lines': stripped_deindented_lines('\n'.join(kernel_lines + clock_pointers)),
'clock_pointers': stripped_deindented_lines('\n'.join(clock_pointers)),
'uses_atomics': self.uses_atomics
}
keywords.update(template_kwds)
Expand Down
2 changes: 1 addition & 1 deletion brian2cuda/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self):
# only true during first run call (relevant for synaptic pre/post ID deletion)
self.first_run = True
# the minimal supported GPU compute capability
self.minimal_compute_capability = 3.5
self.minimal_compute_capability = 5.0
# store the ID of the used GPU and it's compute capability
self.gpu_id = None
self.compute_capability = None
Expand Down
3 changes: 3 additions & 0 deletions brian2cuda/templates/common_group.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
{% block before_run_defines %}
{% endblock %}

///// Support code /////
{{support_code_lines|autoindent}}

void _before_run_{{codeobj_name}}()
{
using namespace brian;
Expand Down
16 changes: 8 additions & 8 deletions brian2cuda/templates/spatialstateupdate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ __global__ void _tridiagsolve_kernel_{{codeobj_name}}(
///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}
///// translate clock variables into pointers /////
{{clock_pointers|autoindent}}

// we need to run the kernel with 1 thread per block (to be changed by optimization)
assert(tid == 0 && bid == _idx);
Expand Down Expand Up @@ -214,8 +214,8 @@ __global__ void _coupling_kernel_{{codeobj_name}}(
///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}
///// translate clock variables into pointers /////
{{clock_pointers|autoindent}}

// we need to run the kernel with 1 thread, 1 block
assert(_idx == 0);
Expand Down Expand Up @@ -328,8 +328,8 @@ __global__ void _combine_kernel_{{codeobj_name}}(
///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}
///// translate clock variables into pointers /////
{{clock_pointers|autoindent}}

// we need to run the kernel with 1 thread per block (to be changed by optimization)
assert(tid == 0 && bid == _idx);
Expand Down Expand Up @@ -371,8 +371,8 @@ __global__ void _currents_kernel_{{codeobj_name}}(
///// KERNEL_CONSTANTS /////
%KERNEL_CONSTANTS%

///// kernel_lines /////
{{kernel_lines|autoindent}}
///// translate clock variables into pointers /////
{{clock_pointers|autoindent}}

if(_idx >= _N)
{
Expand Down
6 changes: 3 additions & 3 deletions brian2cuda/templates/spikemonitor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


{% block extra_headers %}
// TODO: uncomment when thrust calls below are fixed
//#include <thrust/count.h>
//#include <thrust/execution_policy.h>
#include <thrust/copy.h>
#include <thrust/count.h>
#include <thrust/execution_policy.h>
{% endblock %}


Expand Down
4 changes: 2 additions & 2 deletions brian2cuda/tests/test_cpp_cuda_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_stdp_example():
dApre *= 0.1*gmax

connectivity = numpy.random.randn(n_cells, n_cells)
sources = numpy.random.random_integers(0, n_cells-1, 10*n_cells)
sources = numpy.random.randint(0, n_cells, 10*n_cells)
# Only use one spike per time step (to rule out that a single source neuron
# has more than one spike in a time step)
times = numpy.random.choice(numpy.arange(10*n_cells), 10*n_cells, replace=False)*ms
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_stdp_heterog_delays_example():
dApre *= 0.1*gmax

connectivity = numpy.random.randn(n_cells, n_cells)
sources = numpy.random.random_integers(0, n_cells-1, 10*n_cells)
sources = numpy.random.randint(0, n_cells, 10*n_cells)
# Only use one spike per time step (to rule out that a single source neuron
# has more than one spike in a time step)
times = numpy.random.choice(numpy.arange(10*n_cells), 10*n_cells, replace=False)*ms
Expand Down
6 changes: 3 additions & 3 deletions brian2cuda/tests/test_gpu_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_wrong_cuda_path_warning(reset_cuda_detection, use_default_prefs, monkey
@pytest.mark.standalone_only
def test_manual_setting_compute_capability(reset_gpu_detection):
set_device("cuda_standalone", directory=None)
compute_capability_pref = '3.5'
compute_capability_pref = '6.0'
prefs.devices.cuda_standalone.cuda_backend.compute_capability = float(compute_capability_pref)
with catch_logs(log_level=logging.INFO) as logs:
run(0*ms)
Expand All @@ -141,8 +141,8 @@ def test_unsupported_compute_capability_error(reset_gpu_detection):
@pytest.mark.standalone_only
def test_warning_compute_capability_set_twice(reset_gpu_detection, use_default_prefs):
set_device("cuda_standalone", directory=None)
prefs.devices.cuda_standalone.cuda_backend.compute_capability = 3.5
prefs.devices.cuda_standalone.cuda_backend.extra_compile_args_nvcc.append('-arch=sm_37')
prefs.devices.cuda_standalone.cuda_backend.compute_capability = 5.3
prefs.devices.cuda_standalone.cuda_backend.extra_compile_args_nvcc.append('-arch=sm_52')
with catch_logs() as logs:
run(0*ms)

Expand Down
4 changes: 2 additions & 2 deletions brian2cuda/tests/test_neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_semantics_floor_division():
fvalue : 1
ivalue : integer''',
dtype={'a': np.int32, 'b': np.int64,
'x': np.float, 'y': np.double})
'x': np.float32, 'y': np.float64})
int_values = np.arange(-5, 6)
float_values = np.arange(-5.0, 6.0, dtype=np.double)
float_values = np.arange(-5.0, 6.0, dtype=np.float64)
G.ivalue = int_values
G.fvalue = float_values
with catch_logs() as l:
Expand Down
56 changes: 56 additions & 0 deletions brian2cuda/tools/test_suite/run_tests.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "173851c6",
"metadata": {},
"source": [
"# Run Brian2CUDA test suite\n",
"\n",
"Minimal notebook to run tests against latest version of Brian2CUDA from github. Meant to be run on Google Collab with a GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3bca2f17",
"metadata": {},
"outputs": [],
"source": [
"%pip install git+https://github.com/brian-team/brian2cuda.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a80015a2",
"metadata": {},
"outputs": [],
"source": [
"import brian2cuda\n",
"brian2cuda.test()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (brian2cuda)",
"language": "python",
"name": "brian2cuda"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 0 additions & 1 deletion brian2cuda/utils/gputools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import shutil
import shlex
import re
import distutils

from brian2.core.preferences import prefs, PreferenceError
from brian2.codegen.cpp_prefs import get_compiler_and_args
Expand Down
2 changes: 1 addition & 1 deletion frozen_repos/brian2
Submodule brian2 updated 194 files

0 comments on commit f240d34

Please sign in to comment.