Skip to content

Commit

Permalink
RF/OPT boosting: use Cython mp (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Aug 5, 2022
1 parent e81c439 commit cfa2957
Show file tree
Hide file tree
Showing 12 changed files with 552 additions and 437 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ doc/auto_examples
*.ipynb
.ipynb_checkpoints
.eggs

*.html
6 changes: 6 additions & 0 deletions doc/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ Changes

.. currentmodule:: eelbrain

New in 0.38
-----------

* :func:`boosting` optimized (as a consequence, the progress bar has been disabled).


New in 0.37
-----------

Expand Down
6 changes: 5 additions & 1 deletion eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,11 @@ def names(self):

class Named:

def __init__(self, name, info):
def __init__(
self,
name: str = None,
info: Dict = None,
):
self.info = {} if info is None else dict(info)
self._name = name

Expand Down
12 changes: 6 additions & 6 deletions eelbrain/_ndvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Callable, Literal, Sequence, Union

import mne
from numba import njit, prange
import numba
import numpy as np
import numpy
from scipy import linalg, ndimage, signal, stats
Expand All @@ -28,7 +28,7 @@
from ._mne import complete_source_space
from ._stats.connectivity import Connectivity
from ._stats.connectivity import find_peaks as _find_peaks
from ._trf._boosting_opt import l1
from ._trf._fit_metrics import error_for_indexes
from ._utils.numpy_utils import aslice, newaxis


Expand Down Expand Up @@ -272,7 +272,7 @@ def convolve(h, x, ds=None, name=None):
return NDVar(out, dims, *op_name(x, name=name))


@njit(parallel=True)
@numba.njit(nogil=True, cache=True, parallel=True)
def parallel_convolve(
h_flat: np.ndarray, # n_h_only, n_shared, n_h_times
x_flat: np.ndarray, # n_x_only, n_shared, n_x_times
Expand All @@ -282,12 +282,12 @@ def parallel_convolve(
):
# loop through x and h dimensions
out_indexes = [(ix, ih) for ix in range(len(x_flat)) for ih in range(len(h_flat))]
for i_out in prange(len(out_indexes)):
for i_out in numba.prange(len(out_indexes)):
i_x, i_h = out_indexes[i_out]
convolve_jit(h_flat[i_h], x_flat[i_x], out_flat[i_x, i_h], i_start, i_stop)


@njit
@numba.njit(nogil=True, cache=True)
def convolve_jit(
h: np.ndarray, # n_h, n_h_times
x: np.ndarray, # n_h, n_x_times
Expand Down Expand Up @@ -805,7 +805,7 @@ def label_operator(labels, operation='mean', exclude=None, weights=None,
if weights is not None:
xs *= weights
if operation == 'mean':
xs /= l1(xs, np.array(((0, len(xs)),), np.int64))
xs /= error_for_indexes(xs, np.array(((0, len(xs)),), np.int64), 1)
return NDVar(x, (label_dim, dim), labels.name)


Expand Down

0 comments on commit cfa2957

Please sign in to comment.