Skip to content

Commit

Permalink
Use a shortcut for globals to reduce time (#2909)
Browse files Browse the repository at this point in the history
  • Loading branch information
alkino committed Jun 19, 2024
1 parent b622a8f commit f436008
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/nmodl/deriv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ void solv_diffeq(Item* qsol,
// derivimplicit_thread
Sprintf(buf,
"%s %s%s_thread(%d, _slist%d, _dlist%d, neuron::scopmath::row_view{_ml, _iml}, %s, "
"_ml, _iml, _ppvar, _thread, _nt);\n%s",
"_ml, _iml, _ppvar, _thread, _globals, _nt);\n%s",
deriv1_advance,
ssprefix,
method->name,
Expand Down
8 changes: 8 additions & 0 deletions src/nmodl/noccout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ void c_out_vectorize() {
P("_ni = _ml_arg->_nodeindices;\n");
P("_cntml = _ml_arg->_nodecount;\n");
P("_thread = _ml_arg->_thread;\n");
P("double* _globals = nullptr;\n");
P("if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n");
/*check_tables();*/
P("for (_iml = 0; _iml < _cntml; ++_iml) {\n");
P(" _ppvar = _ml_arg->_pdata[_iml];\n");
Expand Down Expand Up @@ -598,6 +600,8 @@ void c_out_vectorize() {
P("_ni = _ml_arg->_nodeindices;\n");
P("_cntml = _ml_arg->_nodecount;\n");
P("_thread = _ml_arg->_thread;\n");
P("double* _globals = nullptr;\n");
P("if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n");
P("for (_iml = 0; _iml < _cntml; ++_iml) {\n");
P(" _ppvar = _ml_arg->_pdata[_iml];\n");
ext_vdef();
Expand Down Expand Up @@ -664,6 +668,8 @@ void c_out_vectorize() {
P("_ni = _ml_arg->_nodeindices;\n");
P("_cntml = _ml_arg->_nodecount;\n");
P("_thread = _ml_arg->_thread;\n");
P("double* _globals = nullptr;\n");
P("if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n");
P("for (_iml = 0; _iml < _cntml; ++_iml) {\n");
if (electrode_current) {
P(" _nd = _ml_arg->_nodelist[_iml];\n");
Expand Down Expand Up @@ -700,6 +706,8 @@ void c_out_vectorize() {
P("_ni = _ml_arg->_nodeindices;\n");
P("size_t _cntml = _ml_arg->_nodecount;\n");
P("_thread = _ml_arg->_thread;\n");
P("double* _globals = nullptr;\n");
P("if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n");
P("for (size_t _iml = 0; _iml < _cntml; ++_iml) {\n");
P(" _ppvar = _ml_arg->_pdata[_iml];\n");
P(" _nd = _ml_arg->_nodelist[_iml];\n");
Expand Down
82 changes: 50 additions & 32 deletions src/nmodl/nocpout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,12 @@ void parout() {
if (vectorize) {
Lappendstr(defs_list,
"\n\
#define _threadargscomma_ _ml, _iml, _ppvar, _thread, _nt,\n\
#define _threadargsprotocomma_ Memb_list* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, NrnThread* _nt,\n\
#define _internalthreadargsprotocomma_ _nrn_mechanism_cache_range* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, NrnThread* _nt,\n\
#define _threadargs_ _ml, _iml, _ppvar, _thread, _nt\n\
#define _threadargsproto_ Memb_list* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, NrnThread* _nt\n\
#define _internalthreadargsproto_ _nrn_mechanism_cache_range* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, NrnThread* _nt\n\
#define _threadargscomma_ _ml, _iml, _ppvar, _thread, _globals, _nt,\n\
#define _threadargsprotocomma_ Memb_list* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, double* _globals, NrnThread* _nt,\n\
#define _internalthreadargsprotocomma_ _nrn_mechanism_cache_range* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, double* _globals, NrnThread* _nt,\n\
#define _threadargs_ _ml, _iml, _ppvar, _thread, _globals, _nt\n\
#define _threadargsproto_ Memb_list* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, double* _globals, NrnThread* _nt\n\
#define _internalthreadargsproto_ _nrn_mechanism_cache_range* _ml, size_t _iml, Datum* _ppvar, Datum* _thread, double* _globals, NrnThread* _nt\n\
");
} else {
Lappendstr(defs_list,
Expand Down Expand Up @@ -544,11 +544,6 @@ extern void nrn_promote(Prop*, int, int);\n\
}
}

emit_check_table_thread = 0;
if (vectorize && check_tables_threads(defs_list)) {
emit_check_table_thread = 1;
}

/* per thread top LOCAL */
/* except those that are marked assigned_to_ == 2 stay static double */
if (vectorize && toplocal_) {
Expand Down Expand Up @@ -601,7 +596,7 @@ extern void nrn_promote(Prop*, int, int);\n\
}
/* per thread global data */
gind = 0;
if (vectorize)
if (vectorize) {
SYMLISTITER {
s = SYM(q);
if (s->nrntype & (NRNGLOBAL) && s->assigned_to_ == 1) {
Expand All @@ -612,8 +607,15 @@ extern void nrn_promote(Prop*, int, int);\n\
}
}
}
}
/* double scalars declared internally */
Lappendstr(defs_list, "/* declare global and static user variables */\n");
Sprintf(buf, "#define gind %d\n", gind);
Lappendstr(defs_list, buf);
if (!gind) {
Sprintf(buf, "#define _gth 0\n");
Lappendstr(defs_list, buf);
}
if (gind) {
Sprintf(buf,
"static int _thread1data_inuse = 0;\nstatic double _thread1data[%d];\n#define _gth "
Expand Down Expand Up @@ -644,7 +646,7 @@ extern void nrn_promote(Prop*, int, int);\n\
if (s->subtype & ARRAY) {
Sprintf(buf,
"#define %s%s (_thread1data + %d)\n\
#define %s (_thread[_gth].get<double*>() + %d)\n",
#define %s (_globals + %d)\n",
s->name,
suffix,
gind,
Expand All @@ -653,7 +655,7 @@ extern void nrn_promote(Prop*, int, int);\n\
} else {
Sprintf(buf,
"#define %s%s _thread1data[%d]\n\
#define %s _thread[_gth].get<double*>()[%d]\n",
#define %s _globals[%d]\n",
s->name,
suffix,
gind,
Expand Down Expand Up @@ -684,6 +686,11 @@ extern void nrn_promote(Prop*, int, int);\n\
}
}

emit_check_table_thread = 0;
if (vectorize && check_tables_threads(defs_list)) {
emit_check_table_thread = 1;
}

Lappendstr(defs_list, "/* some parameters have upper and lower limits */\n");
Lappendstr(defs_list, "static HocParmLimits _hoc_parm_limits[] = {\n");
SYMLISTITER {
Expand Down Expand Up @@ -1575,6 +1582,8 @@ void ldifusreg() {
"_nrn_model_sorted_token const& _sorted_token) {\n"
" _nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, _ml_arg->_type()};\n"
" auto* const _ml = &_lmr;\n"
" double* _globals = nullptr;\n"
" if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n"
" *_pdvol = ",
n,
n);
Expand Down Expand Up @@ -1929,10 +1938,13 @@ void bablk(int ba, int type, Item* q1, Item* q2) {
insertstr(q1, buf);
q = q1->next;
vectorize_substitute(insertstr(q, ""), "Datum* _ppvar;");
qv = insertstr(q,
"_nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, "
"_ml_arg->_type()}; auto* const "
"_ml = &_lmr;\n");
qv = insertstr(
q,
"_nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, "
"_ml_arg->_type()}; auto* const "
"_ml = &_lmr;\n"
"double* _globals = nullptr;\n"
"if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n");
qv = insertstr(q, "_ppvar = _ppd;\n");
movelist(qb, q2, procfunc);

Expand Down Expand Up @@ -2769,18 +2781,21 @@ void out_nt_ml_frag(List* p) {
vectorize_substitute(lappendstr(p, ""), " Datum* _ppvar;\n");
vectorize_substitute(lappendstr(p, ""), " size_t _iml;");
vectorize_substitute(lappendstr(p, ""), " _nrn_mechanism_cache_range* _ml;");
Lappendstr(p,
" Node* _nd{};\n"
" double _v{};\n"
" int _cntml;\n"
" _nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, _type};\n"
" _ml = &_lmr;\n"
" _cntml = _ml_arg->_nodecount;\n"
" Datum *_thread{_ml_arg->_thread};\n"
" for (_iml = 0; _iml < _cntml; ++_iml) {\n"
" _ppvar = _ml_arg->_pdata[_iml];\n"
" _nd = _ml_arg->_nodelist[_iml];\n"
" v = NODEV(_nd);\n");
Lappendstr(
p,
" Node* _nd{};\n"
" double _v{};\n"
" int _cntml;\n"
" _nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml_arg, _type};\n"
" _ml = &_lmr;\n"
" _cntml = _ml_arg->_nodecount;\n"
" Datum *_thread{_ml_arg->_thread};\n"
" double* _globals = nullptr;\n"
" if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n"
" for (_iml = 0; _iml < _cntml; ++_iml) {\n"
" _ppvar = _ml_arg->_pdata[_iml];\n"
" _nd = _ml_arg->_nodelist[_iml];\n"
" v = NODEV(_nd);\n");
}

void cvode_emit_interface() {
Expand Down Expand Up @@ -3042,7 +3057,9 @@ void net_receive(Item* qarg, Item* qp1, Item* qp2, Item* qstmt, Item* qend) {
" auto* const _ml = &_ml_real;\n"
" size_t const _iml{};\n");
q = insertstr(qstmt, " _ppvar = _nrn_mechanism_access_dparam(_pnt->_prop);\n");
vectorize_substitute(insertstr(q, ""), " _thread = nullptr; _nt = (NrnThread*)_pnt->_vnt;");
vectorize_substitute(
insertstr(q, ""),
" _thread = nullptr; double* _globals = nullptr; _nt = (NrnThread*)_pnt->_vnt;");
if (debugging_) {
if (0) {
insertstr(qstmt, " assert(_tsav <= t); _tsav = t;");
Expand Down Expand Up @@ -3133,7 +3150,8 @@ void net_init(Item* qinit, Item* qp2) {
" auto* const _ml = &_ml_real;\n"
" size_t const _iml{};\n"
" Datum* _ppvar = _nrn_mechanism_access_dparam(_pnt->_prop);\n"
" Datum* _thread = (Datum*)0;\n"
" Datum* _thread = nullptr;\n"
" double* _globals = nullptr;\n"
" NrnThread* _nt = (NrnThread*)_pnt->_vnt;\n");
if (net_init_q1_) {
diag("NET_RECEIVE block can contain only one INITIAL block", (char*) 0);
Expand Down
52 changes: 31 additions & 21 deletions src/nmodl/parsact.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,15 @@ int check_tables_threads(List* p) {
Sprintf(buf, "\nstatic void %s(_internalthreadargsproto_);", STR(q));
lappendstr(p, buf);
}
lappendstr(p,
"\n"
"static void _check_table_thread(_threadargsprotocomma_ int _type, "
"_nrn_model_sorted_token const& _sorted_token) {\n"
" _nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml, _type};\n"
" {\n"
" auto* const _ml = &_lmr;\n");
lappendstr(
p,
"\n"
"static void _check_table_thread(_threadargsprotocomma_ int _type, "
"_nrn_model_sorted_token const& _sorted_token) {\n"
" if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); } \n"
" _nrn_mechanism_cache_range _lmr{_sorted_token, *_nt, *_ml, _type};\n"
" {\n"
" auto* const _ml = &_lmr;\n");
ITERATE(q, check_table_thread_list) {
Sprintf(buf, " %s(_threadargs_);\n", STR(q));
lappendstr(p, buf);
Expand Down Expand Up @@ -760,13 +762,16 @@ static void funchack(Symbol* n, bool ishoc, int hack) {
" hoc_execerror(\"POINT_PROCESS data instance not valid\", NULL);\n"
" }\n");
q = lappendstr(procfunc, " _setdata(_p);\n");
vectorize_substitute(q,
" _nrn_mechanism_cache_instance _ml_real{_p};\n"
" auto* const _ml = &_ml_real;\n"
" size_t const _iml{};\n"
" _ppvar = _nrn_mechanism_access_dparam(_p);\n"
" _thread = _extcall_thread.data();\n"
" _nt = static_cast<NrnThread*>(_pnt->_vnt);\n");
vectorize_substitute(
q,
" _nrn_mechanism_cache_instance _ml_real{_p};\n"
" auto* const _ml = &_ml_real;\n"
" size_t const _iml{};\n"
" _ppvar = _nrn_mechanism_access_dparam(_p);\n"
" _thread = _extcall_thread.data();\n"
" double* _globals = nullptr;\n"
" if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n"
" _nt = static_cast<NrnThread*>(_pnt->_vnt);\n");
} else if (ishoc) {
hocfunc_setdata_item(n, lappendstr(procfunc, ""));
vectorize_substitute(
Expand All @@ -776,18 +781,23 @@ static void funchack(Symbol* n, bool ishoc, int hack) {
"size_t const _iml{};\n"
"_ppvar = _local_prop ? _nrn_mechanism_access_dparam(_local_prop) : nullptr;\n"
"_thread = _extcall_thread.data();\n"
"double* _globals = nullptr;\n"
"if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n"
"_nt = nrn_threads;\n");
} else { // _npy_...
q = lappendstr(procfunc,
" neuron::legacy::set_globals_from_prop(_prop, _ml_real, _ml, _iml);\n"
" _ppvar = _nrn_mechanism_access_dparam(_prop);\n");
vectorize_substitute(q,
"_nrn_mechanism_cache_instance _ml_real{_prop};\n"
"auto* const _ml = &_ml_real;\n"
"size_t const _iml{};\n"
"_ppvar = _nrn_mechanism_access_dparam(_prop);\n"
"_thread = _extcall_thread.data();\n"
"_nt = nrn_threads;\n");
vectorize_substitute(
q,
"_nrn_mechanism_cache_instance _ml_real{_prop};\n"
"auto* const _ml = &_ml_real;\n"
"size_t const _iml{};\n"
"_ppvar = _nrn_mechanism_access_dparam(_prop);\n"
"_thread = _extcall_thread.data();\n"
"double* _globals = nullptr;\n"
"if (gind != 0 && _thread != nullptr) { _globals = _thread[_gth].get<double*>(); }\n"
"_nt = nrn_threads;\n");
}
if (n == last_func_using_table) {
qp = lappendstr(procfunc, "");
Expand Down
2 changes: 1 addition & 1 deletion src/nmodl/simultan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ Item* mixed_eqns(Item* q2, Item* q3, Item* q4) /* name, '{', '}' */
Sprintf(buf,
"error = nrn_newton_thread(_newtonspace%d, %d, _slist%d, "
"neuron::scopmath::row_view{_ml, _iml}, %s, _dlist%d, _ml,"
" _iml, _ppvar, _thread, _nt);\n",
" _iml, _ppvar, _thread, _globals, _nt);\n",
numlist - 1,
counts,
numlist,
Expand Down
1 change: 1 addition & 0 deletions src/nrniv/kschan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ static void check_table_thread_(Memb_list*,
std::size_t,
Datum*,
Datum*,
double*,
NrnThread* vnt,
int type,
neuron::model_sorted_token const&) {
Expand Down
12 changes: 7 additions & 5 deletions src/nrniv/nrncore_write/callbacks/nrncore_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,11 @@ int nrnthread_dat2_corepointer_mech(int tid,
icnt = 0;
// data size and allocate
for (int i = 0; i < ml->nodecount; ++i) {
(*nrn_bbcore_write_[type])(NULL, NULL, &dcnt, &icnt, ml, i, ml->pdata[i], ml->_thread, &nt);
(*nrn_bbcore_write_[type])(
nullptr, nullptr, &dcnt, &icnt, ml, i, ml->pdata[i], ml->_thread, nullptr, &nt);
}
dArray = NULL;
iArray = NULL;
dArray = nullptr;
iArray = nullptr;
if (icnt) {
iArray = new int[icnt];
}
Expand All @@ -567,7 +568,7 @@ int nrnthread_dat2_corepointer_mech(int tid,
// data values
for (int i = 0; i < ml->nodecount; ++i) {
(*nrn_bbcore_write_[type])(
dArray, iArray, &dcnt, &icnt, ml, i, ml->pdata[i], ml->_thread, &nt);
dArray, iArray, &dcnt, &icnt, ml, i, ml->pdata[i], ml->_thread, nullptr, &nt);
}

return 1;
Expand All @@ -593,7 +594,8 @@ int core2nrn_corepointer_mech(int tid, int type, int icnt, int dcnt, int* iArray
int dk = 0;
// data values
for (int i = 0; i < ml->nodecount; ++i) {
(*nrn_bbcore_read_[type])(dArray, iArray, &dk, &ik, ml, i, ml->pdata[i], ml->_thread, &nt);
(*nrn_bbcore_read_[type])(
dArray, iArray, &dk, &ik, ml, i, ml->pdata[i], ml->_thread, nullptr, &nt);
}
assert(dk == dcnt);
assert(ik == icnt);
Expand Down
12 changes: 10 additions & 2 deletions src/nrniv/nrncore_write/io/nrncore_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ void write_uint32vec(std::vector<uint32_t>& vec, FILE* f);
#define writedbl(p, size) writedbl_(p, size, f)
// also for read
struct Memb_list;
using bbcore_write_t =
void (*)(double*, int*, int*, int*, Memb_list*, std::size_t, Datum*, Datum*, NrnThread*);
using bbcore_write_t = void (*)(double*,
int*,
int*,
int*,
Memb_list*,
std::size_t,
Datum*,
Datum*,
double*,
NrnThread*);

void write_nrnthread_task(const char*, CellGroup* cgs, bool append);
void nrnbbcore_vecplay_write(FILE* f, NrnThread& nt);
Expand Down
1 change: 1 addition & 0 deletions src/nrnoc/membfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using nrn_thread_table_check_t = void (*)(Memb_list*,
std::size_t,
Datum*,
Datum*,
double*,
NrnThread*,
int,
neuron::model_sorted_token const&);
Expand Down
4 changes: 3 additions & 1 deletion src/nrnoc/multicore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,10 @@ void nrn_mk_table_check() {
void nrn_thread_table_check(neuron::model_sorted_token const& sorted_token) {
for (auto [id, tml]: table_check_) {
Memb_list* ml = tml->ml;
// here _globals cannot be guessed (missing _gth) so we give nullptr, and set the variable
// locally in _check_table_thread
memb_func[tml->index].thread_table_check_(
ml, 0, ml->pdata[0], ml->_thread, nrn_threads + id, tml->index, sorted_token);
ml, 0, ml->pdata[0], ml->_thread, nullptr, nrn_threads + id, tml->index, sorted_token);
}
}

Expand Down

0 comments on commit f436008

Please sign in to comment.