Skip to content

Commit

Permalink
Merge pull request #332 from DimitriPlotnikov/master
Browse files Browse the repository at this point in the history
Increase the robustness of the sympy infrastructure. Add bounded `min`, `max` functions.
  • Loading branch information
Plotnikov committed Dec 14, 2016
2 parents c3d9c00 + d97cb22 commit dba38d7
Show file tree
Hide file tree
Showing 42 changed files with 195 additions and 117 deletions.
2 changes: 1 addition & 1 deletion models/aeif_cond_alpha.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ neuron aeif_cond_alpha_neuron:
end

equations:
V_bounded mV = min(V_m, V_peak) # prevent exponential divergence
V_bounded mV = bounded_min(V_m, V_peak) # prevent exponential divergence
shape g_in = (e/tau_syn_in) * t * exp(-1/tau_syn_in*t)
shape g_ex = (e/tau_syn_ex) * t * exp(-1/tau_syn_ex*t)

Expand Down
8 changes: 4 additions & 4 deletions models/aeif_cond_alpha_implicit.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ neuron aeif_cond_alpha_implicit:
equations:
V_bounded mV = min(V_m, V_peak) # prevent exponential divergence
# alpha function for the g_in
g_in'' = -g_in'/tau_syn_in
g_in' = g_in' - g_in/tau_syn_in
g_in'' = (-2/tau_syn_in) * g_in'-(1/tau_syn_in**2) * g_in
g_in' = g_in'

# alpha function for the g_ex
g_ex'' = -g_ex'/tau_syn_ex
g_ex' = g_ex' - g_ex/tau_syn_ex
g_ex'' = (-2/tau_syn_ex) * g_ex'-(1/tau_syn_ex**2) * g_ex
g_ex' = g_ex'

# Add aliases to simplify the equation definition of V_m
exp_arg real = (V_bounded-V_th)/Delta_T
Expand Down
2 changes: 1 addition & 1 deletion models/aeif_cond_exp.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ neuron aeif_cond_exp_neuron:
end

equations:
V_bounded mV = min(V_m, V_peak) # prevent exponential divergence
V_bounded mV = bounded_min(V_m, V_peak) # prevent exponential divergence
shape g_in = exp(-1/tau_syn_in*t)
shape g_ex = exp(-1/tau_syn_ex*t)

Expand Down
13 changes: 8 additions & 5 deletions models/hh_cond_exp_traub.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ neuron hh_cond_exp_traub_neuron:
state:
V_m mV = E_L # Membrane potential

# TODO: it should be possible, to define these variables in the internal block
# equilibrium values for (in)activation variables
alias alpha_n_init real = 0.032 * ( 15. - V_m ) / ( exp( ( 15. - V_m ) / 5. ) - 1. )
alias beta_n_init real = 0.5 * exp( ( 10. - V_m ) / 40. )
alias alpha_m_init real = 0.32 * ( 13. - V_m ) / ( exp( ( 13. - V_m ) / 4. ) - 1. )
Expand All @@ -50,19 +50,22 @@ neuron hh_cond_exp_traub_neuron:
end

equations:
shape g_in = exp(-1/tau_syn_in*t)
shape g_ex = exp(-1/tau_syn_ex*t)

# Add aliases to simplify the equation definition of V_m
# ionic currents
I_Na pA = g_Na * Act_m * Act_m * Act_m * Act_h * ( V_m - E_Na )
I_K pA = g_K * Inact_n * Inact_n * Inact_n * Inact_n * ( V_m - E_K )
I_L pA = g_L * ( V_m - E_L )

I_syn_exc pA = cond_sum(g_ex, spikeExc) * ( V_m - E_ex )
I_syn_inh pA = cond_sum(g_in, spikeInh) * ( V_m - E_in )

shape g_in = exp(-1/tau_syn_in*t)
shape g_ex = exp(-1/tau_syn_ex*t)

# membrane potential
V_m' =( -I_Na - I_K - I_L - I_syn_exc - I_syn_inh + I_stim + I_e ) / C_m

# equilibrium values for (in)activation variables
# channel dynamics
V_rel mV = V_m - V_T
alpha_n real = 0.032 * ( 15. - V_rel ) / ( exp( ( 15. - V_rel ) / 5. ) - 1. )
beta_n real = 0.5 * exp( ( 10. - V_rel ) / 40. )
Expand Down
16 changes: 5 additions & 11 deletions models/hh_psc_alpha.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ neuron hh_psc_alpha_neuron:

equations:
# synapses: alpha functions
shape I_in = 1 pA * (e/tau_syn_in) * t * exp(-1/tau_syn_in*t)
shape I_ex = 1 pA * (e/tau_syn_ex) * t * exp(-1/tau_syn_ex*t)
shape I_in = (e/tau_syn_in) * t * exp(-1/tau_syn_in*t)
shape I_ex = (e/tau_syn_ex) * t * exp(-1/tau_syn_ex*t)

I_syn_exc pA = curr_sum(I_ex, spikeExc)
I_syn_inh pA = curr_sum(I_in, spikeInh)
I_Na pA = g_Na * Act_m * Act_m * Act_m * Act_h * ( V_m - E_Na )
I_K pA = g_K * Inact_n * Inact_n * Inact_n * Inact_n * ( V_m - E_K )
I_L pA = g_L * ( V_m - E_L )
V_m' =( -( I_Na + I_K + I_L ) + I_stim + I_e + I_syn_inh + I_syn_exc ) / C_m
V_m' =( -( I_Na + I_K + I_L ) + currents + I_e + I_syn_inh + I_syn_exc ) / C_m

# Inact_n
alpha_n real = ( 0.01 * ( V_m + 55. ) ) / ( 1. - exp( -( V_m + 55. ) / 10. ) )
Expand All @@ -83,7 +83,7 @@ neuron hh_psc_alpha_neuron:
beta_m real = 4. * exp( -( V_m + 65. ) / 18. )
Act_m' = alpha_m * ( 1 - Act_m ) - beta_m * Act_m # m-variable

# Act_h'
# Act_h
alpha_h real = 0.07 * exp( -( V_m + 65. ) / 20. )
beta_h real = 1. / ( 1. + exp( -( V_m + 35. ) / 10. ) )
Act_h' = alpha_h * ( 1 - Act_h ) - beta_h * Act_h # h-variable
Expand Down Expand Up @@ -115,11 +115,6 @@ neuron hh_psc_alpha_neuron:

RefractoryCounts integer = steps(t_ref) # refractory time in steps
r integer # number of steps in the current refractory phase

# Input current injected by CurrentEvent.
# This variable is used to transport the current applied into the
# _dynamics function computing the derivative of the state vector.
I_stim pA = 0pA
end

input:
Expand All @@ -136,12 +131,11 @@ neuron hh_psc_alpha_neuron:
# sending spikes: crossing 0 mV, pseudo-refractoriness and local maximum...
if r > 0: # is refractory?
r -= 1
elif V_m > 0 and U_old > V_m: # threshold && maximum
elif V_m > 0mV and U_old > V_m: # threshold && maximum
r = RefractoryCounts
emit_spike()
end

I_stim = currents.get_sum()
end

end
8 changes: 4 additions & 4 deletions models/iaf_cond_alpha_implicit.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ neuron iaf_cond_alpha_implicit:
g_ex'' = -g_ex'/tau_syn_ex
g_ex' = g_ex' -g_ex/tau_syn_ex

I_syn_exc pA = cond_sum(g_ex, spikeExc) * ( V_m - E_ex )
I_syn_inh pA = cond_sum(g_in, spikeInh) * ( V_m - E_in )
I_syn_exc pA = cond_sum(g_ex, spikeExc) * ( V_m - E_ex )
I_syn_inh pA = cond_sum(g_in, spikeInh) * ( V_m - E_in )
I_leak pA = g_L * ( V_m - E_L )

V_m' = ( -I_leak - I_syn_exc - I_syn_inh + I_stim + I_e ) / C_m
Expand Down Expand Up @@ -106,8 +106,8 @@ neuron iaf_cond_alpha_implicit:
end

# add incoming spikes
g_ex' += spikeExc.get_sum() * PSConInit_E
g_in' += spikeInh.get_sum() * PSConInit_I
g_ex' += spikeExc * PSConInit_E
g_in' += spikeInh * PSConInit_I
# set new input current
I_stim = currents.get_sum()
end
Expand Down
19 changes: 19 additions & 0 deletions src/main/java/org/nest/codegeneration/NestCodeGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
*/
package org.nest.codegeneration;

import com.google.common.collect.Lists;
import com.google.common.io.Files;
import de.monticore.generating.GeneratorEngine;
import de.monticore.generating.GeneratorSetup;
import de.monticore.generating.templateengine.GlobalExtensionManagement;
import de.se_rwth.commons.logging.Log;
import org.nest.codegeneration.converters.*;
import org.nest.codegeneration.helpers.*;
import org.nest.codegeneration.sympy.ODETransformer;
import org.nest.codegeneration.sympy.OdeProcessor;
import org.nest.codegeneration.sympy.TransformerBase;
import org.nest.nestml._ast.ASTBody;
import org.nest.nestml._ast.ASTNESTMLCompilationUnit;
import org.nest.nestml._ast.ASTNeuron;
Expand All @@ -23,8 +27,11 @@
import org.nest.utils.AstUtils;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -84,6 +91,7 @@ private ASTNeuron solveODESInNeuron(
if (odesBlock.isPresent()) {
if (odesBlock.get().getShapes().size() == 0) {
info("The model will be solved numerically with GSL solver.", LOG_NAME);
markNumericSolver(astNeuron.getName(), outputBase);
return astNeuron;
}
else {
Expand All @@ -98,6 +106,17 @@ private ASTNeuron solveODESInNeuron(

}

private void markNumericSolver(final String neuronName, final Path outputBase) {
try {
Files.write("numeric",
Paths.get(outputBase.toString(), neuronName + "." + TransformerBase.SOLVER_TYPE).toFile(),
Charset.defaultCharset());
}
catch (IOException e) {
Log.error("Cannot write status file. Check you permissions.", e);
}
}

private void generateNestCode(
final ASTNeuron astNeuron,
final Path outputBase) {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/nest/codegeneration/SolverType.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.NoSuchFileException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ public String convertFunctionCall(final ASTFunctionCall astFunctionCall) {
if (PredefinedFunctions.POW.equals(functionName)) {
return "pow(%s)";
}
if (PredefinedFunctions.MAX.equals(functionName)) {
if (PredefinedFunctions.MAX.equals(functionName) || PredefinedFunctions.BOUNDED_MAX.equals(functionName)) {
return "std::max(%s)";
}
if (PredefinedFunctions.MIN.equals(functionName)) {
if (PredefinedFunctions.MIN.equals(functionName)|| PredefinedFunctions.BOUNDED_MIN.equals(functionName)) {
return "std::min(%s)";
}
if (functionName.contains(PredefinedFunctions.EMIT_SPIKE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ public String convertFunctionCall(final ASTFunctionCall astFunctionCall) {
if (PredefinedFunctions.POW.equals(functionName)) {
return "std::pow(%s)";
}
if (PredefinedFunctions.MAX.equals(functionName)) {
if (PredefinedFunctions.MAX.equals(functionName) || PredefinedFunctions.BOUNDED_MAX.equals(functionName)) {
return "std::max(%s)";
}
if (PredefinedFunctions.MIN.equals(functionName)) {
if (PredefinedFunctions.MIN.equals(functionName) || PredefinedFunctions.BOUNDED_MIN.equals(functionName) ) {
return "std::min(%s)";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
*/
class DeltaSolutionTransformer extends TransformerBase {
final static String PROPAGATOR_STEP = "propagator.step.tmp";
final static String ODE_TYPE = "solverType.tmp";
final static String P30_FILE = "P30.tmp";

ASTNeuron addExactSolution(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class NESTMLASTCreator {
}

static List<ASTAliasDecl> createAliases(final Path declarationFile) {
checkArgument(Files.exists(declarationFile));
checkArgument(Files.exists(declarationFile), declarationFile.toString());

try {
return Files.lines(declarationFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ private static Path generateSympyScript(

glex.setGlobalValue("variables", variables);
glex.setGlobalValue("aliases", aliases);
glex.setGlobalValue("neuronName", neuron.getName());

final ExpressionsPrettyPrinter expressionsPrinter = new ExpressionsPrettyPrinter();
glex.setGlobalValue("printer", expressionsPrinter);
Expand Down
37 changes: 32 additions & 5 deletions src/main/java/org/nest/codegeneration/sympy/ODETransformer.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.nest.codegeneration.sympy;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import de.monticore.ast.ASTNode;
import org.nest.commons._ast.ASTExpr;
import org.nest.commons._ast.ASTFunctionCall;
Expand All @@ -19,23 +20,49 @@
* @author plotnikov
*/
public class ODETransformer {
private static final List<String> functions = Lists.newArrayList(
PredefinedFunctions.CURR_SUM,
PredefinedFunctions.COND_SUM,
PredefinedFunctions.BOUNDED_MIN,
PredefinedFunctions.BOUNDED_MAX);

private static final List<String> sumFunctions = Lists.newArrayList(
PredefinedFunctions.CURR_SUM,
PredefinedFunctions.COND_SUM);


// this function is used in freemarker templates und must be public
public static <T extends ASTNode> T replaceFunctions(final T astOde) {
// since the transformation replaces the call inplace, make a copy to preserve the information for further steps
final List<ASTFunctionCall> functionsCalls = getFunctionCalls(astOde, functions);

final T workingCopy = (T) astOde.deepClone(); // IT is OK, since the deepClone returns T
functionsCalls.forEach(functionCall -> replaceFunctionCallThroughFirstArgument(astOde, functionCall)); // TODO deepClone
return astOde;
}

// this function is used in freemarker templates und must be public
public static <T extends ASTNode> T replaceSumCalls(final T astOde) {
// since the transformation replaces the call inplace, make a copy to preserve the information for further steps
final List<ASTFunctionCall> functions = get_sumFunctionCalls(astOde);
final List<ASTFunctionCall> functionsCalls = get_sumFunctionCalls(astOde);

final T workingCopy = (T) astOde.deepClone(); // IT is OK, since the deepClone returns T
functions.forEach(node -> replaceFunctionCallThroughFirstArgument(astOde, node)); // TODO deepClone
functionsCalls.forEach(functionCall -> replaceFunctionCallThroughFirstArgument(astOde, functionCall)); // TODO deepClone
return astOde;
}



// this function is used in freemarker templates und must be public
static List<ASTFunctionCall> get_sumFunctionCalls(final ASTNode workingCopy) {
return getFunctionCalls(workingCopy, sumFunctions);
}

// this function is used in freemarker templates und must be public
private static List<ASTFunctionCall> getFunctionCalls(final ASTNode workingCopy, final List<String> functionNames) {
return AstUtils.getAll(workingCopy, ASTFunctionCall.class)
.stream()
.filter(astFunctionCall ->
astFunctionCall.getCalleeName().equals(PredefinedFunctions.CURR_SUM) ||
astFunctionCall.getCalleeName().equals(PredefinedFunctions.COND_SUM))
.filter(astFunctionCall -> functionNames.contains(astFunctionCall.getCalleeName()))
.collect(Collectors.toList());
}

Expand Down
28 changes: 14 additions & 14 deletions src/main/java/org/nest/codegeneration/sympy/OdeProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private ASTNeuron handleDeltaShape(

checkState(successfulExecution, "Error during solver script evaluation.");

final Path odeTypePath = Paths.get(outputBase.toString(), DeltaSolutionTransformer.ODE_TYPE);
final Path odeTypePath = Paths.get(outputBase.toString(), astNeuron.getName() + "." + DeltaSolutionTransformer.SOLVER_TYPE);
final SolverType solutionType = SolverType.fromFile(odeTypePath);

if (solutionType.equals(SolverType.EXACT)) {
Expand All @@ -109,8 +109,8 @@ private ASTNeuron handleDeltaShape(
LOG_NAME);
deltaSolutionTransformer.addExactSolution(
astNeuron,
Paths.get(outputBase.toString(), DeltaSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), DeltaSolutionTransformer.PROPAGATOR_STEP));
Paths.get(outputBase.toString(), astNeuron.getName() + "." + DeltaSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + DeltaSolutionTransformer.PROPAGATOR_STEP));
}
else {
Log.warn(astNeuron.getName() + " has a delta shape function with a non-linear ODE.");
Expand All @@ -137,29 +137,29 @@ protected ASTNeuron handleNeuronWithODE(

checkState(successfulExecution, "Error during solver script evaluation.");

final Path odeTypePath = Paths.get(outputBase.toString(), TransformerBase.SOLVER_TYPE);
final Path odeTypePath = Paths.get(outputBase.toString(), astNeuron.getName() + "." + TransformerBase.SOLVER_TYPE);
final SolverType solutionType = SolverType.fromFile(odeTypePath);

if (solutionType.equals(SolverType.EXACT)) {
info("ODE is solved exactly.", LOG_NAME);

return linearSolutionTransformer.addExactSolution(
astNeuron,
Paths.get(outputBase.toString(), LinearSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VARIABLES_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.PROPAGATOR_MATRIX_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.PROPAGATOR_STEP_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VECTOR_TMP_DECLARATIONS_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VECTOR_UPDATE_STEPS_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VECTOR_TMP_BACK_ASSIGNMENTS_FILE));
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VARIABLES_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.PROPAGATOR_MATRIX_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.PROPAGATOR_STEP_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VECTOR_TMP_DECLARATIONS_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VECTOR_UPDATE_STEPS_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VECTOR_TMP_BACK_ASSIGNMENTS_FILE));
}
else if (solutionType.equals(SolverType.NUMERIC)) {
info("ODE is solved numerically.", LOG_NAME);
return implicitFormTransformer.transformToImplicitForm(
astNeuron,
Paths.get(outputBase.toString(),ImplicitFormTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(),ImplicitFormTransformer.EQUATIONS_FILE));
Paths.get(outputBase.toString(), astNeuron.getName() + "." + ImplicitFormTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(),astNeuron.getName() + "." + ImplicitFormTransformer.EQUATIONS_FILE));
}
else {
warn(astNeuron.getName() + ": ODEs could not be solved. The model remains unchanged.");
Expand Down
Loading

0 comments on commit dba38d7

Please sign in to comment.