Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase the robustness of the sympy infrastructure. Add bounded min, max functions. #332

Merged
merged 3 commits into from
Dec 14, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/aeif_cond_alpha.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ neuron aeif_cond_alpha_neuron:
end

equations:
V_bounded mV = min(V_m, V_peak) # prevent exponential divergence
V_bounded mV = bounded_min(V_m, V_peak) # prevent exponential divergence
shape g_in = (e/tau_syn_in) * t * exp(-1/tau_syn_in*t)
shape g_ex = (e/tau_syn_ex) * t * exp(-1/tau_syn_ex*t)

Expand Down
8 changes: 4 additions & 4 deletions models/aeif_cond_alpha_implicit.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ neuron aeif_cond_alpha_implicit:
equations:
V_bounded mV = min(V_m, V_peak) # prevent exponential divergence
# alpha function for the g_in
g_in'' = -g_in'/tau_syn_in
g_in' = g_in' - g_in/tau_syn_in
g_in'' = (-2/tau_syn_in) * g_in'-(1/tau_syn_in**2) * g_in
g_in' = g_in'

# alpha function for the g_ex
g_ex'' = -g_ex'/tau_syn_ex
g_ex' = g_ex' - g_ex/tau_syn_ex
g_ex'' = (-2/tau_syn_ex) * g_ex'-(1/tau_syn_ex**2) * g_ex
g_ex' = g_ex'

# Add aliases to simplify the equation definition of V_m
exp_arg real = (V_bounded-V_th)/Delta_T
Expand Down
2 changes: 1 addition & 1 deletion models/aeif_cond_exp.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ neuron aeif_cond_exp_neuron:
end

equations:
V_bounded mV = min(V_m, V_peak) # prevent exponential divergence
V_bounded mV = bounded_min(V_m, V_peak) # prevent exponential divergence
shape g_in = exp(-1/tau_syn_in*t)
shape g_ex = exp(-1/tau_syn_ex*t)

Expand Down
13 changes: 8 additions & 5 deletions models/hh_cond_exp_traub.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ neuron hh_cond_exp_traub_neuron:
state:
V_m mV = E_L # Membrane potential

# TODO: it should be possible, to define these variables in the internal block
# equilibrium values for (in)activation variables
alias alpha_n_init real = 0.032 * ( 15. - V_m ) / ( exp( ( 15. - V_m ) / 5. ) - 1. )
alias beta_n_init real = 0.5 * exp( ( 10. - V_m ) / 40. )
alias alpha_m_init real = 0.32 * ( 13. - V_m ) / ( exp( ( 13. - V_m ) / 4. ) - 1. )
Expand All @@ -50,19 +50,22 @@ neuron hh_cond_exp_traub_neuron:
end

equations:
shape g_in = exp(-1/tau_syn_in*t)
shape g_ex = exp(-1/tau_syn_ex*t)

# Add aliases to simplify the equation definition of V_m
# ionic currents
I_Na pA = g_Na * Act_m * Act_m * Act_m * Act_h * ( V_m - E_Na )
I_K pA = g_K * Inact_n * Inact_n * Inact_n * Inact_n * ( V_m - E_K )
I_L pA = g_L * ( V_m - E_L )

I_syn_exc pA = cond_sum(g_ex, spikeExc) * ( V_m - E_ex )
I_syn_inh pA = cond_sum(g_in, spikeInh) * ( V_m - E_in )

shape g_in = exp(-1/tau_syn_in*t)
shape g_ex = exp(-1/tau_syn_ex*t)

# membrane potential
V_m' =( -I_Na - I_K - I_L - I_syn_exc - I_syn_inh + I_stim + I_e ) / C_m

# equilibrium values for (in)activation variables
# channel dynamics
V_rel mV = V_m - V_T
alpha_n real = 0.032 * ( 15. - V_rel ) / ( exp( ( 15. - V_rel ) / 5. ) - 1. )
beta_n real = 0.5 * exp( ( 10. - V_rel ) / 40. )
Expand Down
16 changes: 5 additions & 11 deletions models/hh_psc_alpha.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ neuron hh_psc_alpha_neuron:

equations:
# synapses: alpha functions
shape I_in = 1 pA * (e/tau_syn_in) * t * exp(-1/tau_syn_in*t)
shape I_ex = 1 pA * (e/tau_syn_ex) * t * exp(-1/tau_syn_ex*t)
shape I_in = (e/tau_syn_in) * t * exp(-1/tau_syn_in*t)
shape I_ex = (e/tau_syn_ex) * t * exp(-1/tau_syn_ex*t)

I_syn_exc pA = curr_sum(I_ex, spikeExc)
I_syn_inh pA = curr_sum(I_in, spikeInh)
I_Na pA = g_Na * Act_m * Act_m * Act_m * Act_h * ( V_m - E_Na )
I_K pA = g_K * Inact_n * Inact_n * Inact_n * Inact_n * ( V_m - E_K )
I_L pA = g_L * ( V_m - E_L )
V_m' =( -( I_Na + I_K + I_L ) + I_stim + I_e + I_syn_inh + I_syn_exc ) / C_m
V_m' =( -( I_Na + I_K + I_L ) + currents + I_e + I_syn_inh + I_syn_exc ) / C_m

# Inact_n
alpha_n real = ( 0.01 * ( V_m + 55. ) ) / ( 1. - exp( -( V_m + 55. ) / 10. ) )
Expand All @@ -83,7 +83,7 @@ neuron hh_psc_alpha_neuron:
beta_m real = 4. * exp( -( V_m + 65. ) / 18. )
Act_m' = alpha_m * ( 1 - Act_m ) - beta_m * Act_m # m-variable

# Act_h'
# Act_h
alpha_h real = 0.07 * exp( -( V_m + 65. ) / 20. )
beta_h real = 1. / ( 1. + exp( -( V_m + 35. ) / 10. ) )
Act_h' = alpha_h * ( 1 - Act_h ) - beta_h * Act_h # h-variable
Expand Down Expand Up @@ -115,11 +115,6 @@ neuron hh_psc_alpha_neuron:

RefractoryCounts integer = steps(t_ref) # refractory time in steps
r integer # number of steps in the current refractory phase

# Input current injected by CurrentEvent.
# This variable is used to transport the current applied into the
# _dynamics function computing the derivative of the state vector.
I_stim pA = 0pA
end

input:
Expand All @@ -136,12 +131,11 @@ neuron hh_psc_alpha_neuron:
# sending spikes: crossing 0 mV, pseudo-refractoriness and local maximum...
if r > 0: # is refractory?
r -= 1
elif V_m > 0 and U_old > V_m: # threshold && maximum
elif V_m > 0mV and U_old > V_m: # threshold && maximum
r = RefractoryCounts
emit_spike()
end

I_stim = currents.get_sum()
end

end
8 changes: 4 additions & 4 deletions models/iaf_cond_alpha_implicit.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ neuron iaf_cond_alpha_implicit:
g_ex'' = -g_ex'/tau_syn_ex
g_ex' = g_ex' -g_ex/tau_syn_ex

I_syn_exc pA = cond_sum(g_ex, spikeExc) * ( V_m - E_ex )
I_syn_inh pA = cond_sum(g_in, spikeInh) * ( V_m - E_in )
I_syn_exc pA = cond_sum(g_ex, spikeExc) * ( V_m - E_ex )
I_syn_inh pA = cond_sum(g_in, spikeInh) * ( V_m - E_in )
I_leak pA = g_L * ( V_m - E_L )

V_m' = ( -I_leak - I_syn_exc - I_syn_inh + I_stim + I_e ) / C_m
Expand Down Expand Up @@ -106,8 +106,8 @@ neuron iaf_cond_alpha_implicit:
end

# add incoming spikes
g_ex' += spikeExc.get_sum() * PSConInit_E
g_in' += spikeInh.get_sum() * PSConInit_I
g_ex' += spikeExc * PSConInit_E
g_in' += spikeInh * PSConInit_I
# set new input current
I_stim = currents.get_sum()
end
Expand Down
19 changes: 19 additions & 0 deletions src/main/java/org/nest/codegeneration/NestCodeGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
*/
package org.nest.codegeneration;

import com.google.common.collect.Lists;
import com.google.common.io.Files;
import de.monticore.generating.GeneratorEngine;
import de.monticore.generating.GeneratorSetup;
import de.monticore.generating.templateengine.GlobalExtensionManagement;
import de.se_rwth.commons.logging.Log;
import org.nest.codegeneration.converters.*;
import org.nest.codegeneration.helpers.*;
import org.nest.codegeneration.sympy.ODETransformer;
import org.nest.codegeneration.sympy.OdeProcessor;
import org.nest.codegeneration.sympy.TransformerBase;
import org.nest.nestml._ast.ASTBody;
import org.nest.nestml._ast.ASTNESTMLCompilationUnit;
import org.nest.nestml._ast.ASTNeuron;
Expand All @@ -23,8 +27,11 @@
import org.nest.utils.AstUtils;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -84,6 +91,7 @@ private ASTNeuron solveODESInNeuron(
if (odesBlock.isPresent()) {
if (odesBlock.get().getShapes().size() == 0) {
info("The model will be solved numerically with GSL solver.", LOG_NAME);
markNumericSolver(astNeuron.getName(), outputBase);
return astNeuron;
}
else {
Expand All @@ -98,6 +106,17 @@ private ASTNeuron solveODESInNeuron(

}

private void markNumericSolver(final String neuronName, final Path outputBase) {
try {
Files.write("numeric",
Paths.get(outputBase.toString(), neuronName + "." + TransformerBase.SOLVER_TYPE).toFile(),
Charset.defaultCharset());
}
catch (IOException e) {
Log.error("Cannot write status file. Check you permissions.", e);
}
}

private void generateNestCode(
final ASTNeuron astNeuron,
final Path outputBase) {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/nest/codegeneration/SolverType.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.NoSuchFileException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ public String convertFunctionCall(final ASTFunctionCall astFunctionCall) {
if (PredefinedFunctions.POW.equals(functionName)) {
return "pow(%s)";
}
if (PredefinedFunctions.MAX.equals(functionName)) {
if (PredefinedFunctions.MAX.equals(functionName) || PredefinedFunctions.BOUNDED_MAX.equals(functionName)) {
return "std::max(%s)";
}
if (PredefinedFunctions.MIN.equals(functionName)) {
if (PredefinedFunctions.MIN.equals(functionName)|| PredefinedFunctions.BOUNDED_MIN.equals(functionName)) {
return "std::min(%s)";
}
if (functionName.contains(PredefinedFunctions.EMIT_SPIKE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ public String convertFunctionCall(final ASTFunctionCall astFunctionCall) {
if (PredefinedFunctions.POW.equals(functionName)) {
return "std::pow(%s)";
}
if (PredefinedFunctions.MAX.equals(functionName)) {
if (PredefinedFunctions.MAX.equals(functionName) || PredefinedFunctions.BOUNDED_MAX.equals(functionName)) {
return "std::max(%s)";
}
if (PredefinedFunctions.MIN.equals(functionName)) {
if (PredefinedFunctions.MIN.equals(functionName) || PredefinedFunctions.BOUNDED_MIN.equals(functionName) ) {
return "std::min(%s)";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
*/
class DeltaSolutionTransformer extends TransformerBase {
final static String PROPAGATOR_STEP = "propagator.step.tmp";
final static String ODE_TYPE = "solverType.tmp";
final static String P30_FILE = "P30.tmp";

ASTNeuron addExactSolution(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class NESTMLASTCreator {
}

static List<ASTAliasDecl> createAliases(final Path declarationFile) {
checkArgument(Files.exists(declarationFile));
checkArgument(Files.exists(declarationFile), declarationFile.toString());

try {
return Files.lines(declarationFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ private static Path generateSympyScript(

glex.setGlobalValue("variables", variables);
glex.setGlobalValue("aliases", aliases);
glex.setGlobalValue("neuronName", neuron.getName());

final ExpressionsPrettyPrinter expressionsPrinter = new ExpressionsPrettyPrinter();
glex.setGlobalValue("printer", expressionsPrinter);
Expand Down
37 changes: 32 additions & 5 deletions src/main/java/org/nest/codegeneration/sympy/ODETransformer.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.nest.codegeneration.sympy;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import de.monticore.ast.ASTNode;
import org.nest.commons._ast.ASTExpr;
import org.nest.commons._ast.ASTFunctionCall;
Expand All @@ -19,23 +20,49 @@
* @author plotnikov
*/
public class ODETransformer {
private static final List<String> functions = Lists.newArrayList(
PredefinedFunctions.CURR_SUM,
PredefinedFunctions.COND_SUM,
PredefinedFunctions.BOUNDED_MIN,
PredefinedFunctions.BOUNDED_MAX);

private static final List<String> sumFunctions = Lists.newArrayList(
PredefinedFunctions.CURR_SUM,
PredefinedFunctions.COND_SUM);


// this function is used in freemarker templates und must be public
public static <T extends ASTNode> T replaceFunctions(final T astOde) {
// since the transformation replaces the call inplace, make a copy to preserve the information for further steps
final List<ASTFunctionCall> functionsCalls = getFunctionCalls(astOde, functions);

final T workingCopy = (T) astOde.deepClone(); // IT is OK, since the deepClone returns T
functionsCalls.forEach(functionCall -> replaceFunctionCallThroughFirstArgument(astOde, functionCall)); // TODO deepClone
return astOde;
}

// this function is used in freemarker templates und must be public
public static <T extends ASTNode> T replaceSumCalls(final T astOde) {
// since the transformation replaces the call inplace, make a copy to preserve the information for further steps
final List<ASTFunctionCall> functions = get_sumFunctionCalls(astOde);
final List<ASTFunctionCall> functionsCalls = get_sumFunctionCalls(astOde);

final T workingCopy = (T) astOde.deepClone(); // IT is OK, since the deepClone returns T
functions.forEach(node -> replaceFunctionCallThroughFirstArgument(astOde, node)); // TODO deepClone
functionsCalls.forEach(functionCall -> replaceFunctionCallThroughFirstArgument(astOde, functionCall)); // TODO deepClone
return astOde;
}



// this function is used in freemarker templates und must be public
static List<ASTFunctionCall> get_sumFunctionCalls(final ASTNode workingCopy) {
return getFunctionCalls(workingCopy, sumFunctions);
}

// this function is used in freemarker templates und must be public
private static List<ASTFunctionCall> getFunctionCalls(final ASTNode workingCopy, final List<String> functionNames) {
return AstUtils.getAll(workingCopy, ASTFunctionCall.class)
.stream()
.filter(astFunctionCall ->
astFunctionCall.getCalleeName().equals(PredefinedFunctions.CURR_SUM) ||
astFunctionCall.getCalleeName().equals(PredefinedFunctions.COND_SUM))
.filter(astFunctionCall -> functionNames.contains(astFunctionCall.getCalleeName()))
.collect(Collectors.toList());
}

Expand Down
28 changes: 14 additions & 14 deletions src/main/java/org/nest/codegeneration/sympy/OdeProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private ASTNeuron handleDeltaShape(

checkState(successfulExecution, "Error during solver script evaluation.");

final Path odeTypePath = Paths.get(outputBase.toString(), DeltaSolutionTransformer.ODE_TYPE);
final Path odeTypePath = Paths.get(outputBase.toString(), astNeuron.getName() + "." + DeltaSolutionTransformer.SOLVER_TYPE);
final SolverType solutionType = SolverType.fromFile(odeTypePath);

if (solutionType.equals(SolverType.EXACT)) {
Expand All @@ -109,8 +109,8 @@ private ASTNeuron handleDeltaShape(
LOG_NAME);
deltaSolutionTransformer.addExactSolution(
astNeuron,
Paths.get(outputBase.toString(), DeltaSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), DeltaSolutionTransformer.PROPAGATOR_STEP));
Paths.get(outputBase.toString(), astNeuron.getName() + "." + DeltaSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + DeltaSolutionTransformer.PROPAGATOR_STEP));
}
else {
Log.warn(astNeuron.getName() + " has a delta shape function with a non-linear ODE.");
Expand All @@ -137,29 +137,29 @@ protected ASTNeuron handleNeuronWithODE(

checkState(successfulExecution, "Error during solver script evaluation.");

final Path odeTypePath = Paths.get(outputBase.toString(), TransformerBase.SOLVER_TYPE);
final Path odeTypePath = Paths.get(outputBase.toString(), astNeuron.getName() + "." + TransformerBase.SOLVER_TYPE);
final SolverType solutionType = SolverType.fromFile(odeTypePath);

if (solutionType.equals(SolverType.EXACT)) {
info("ODE is solved exactly.", LOG_NAME);

return linearSolutionTransformer.addExactSolution(
astNeuron,
Paths.get(outputBase.toString(), LinearSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VARIABLES_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.PROPAGATOR_MATRIX_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.PROPAGATOR_STEP_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VECTOR_TMP_DECLARATIONS_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VECTOR_UPDATE_STEPS_FILE),
Paths.get(outputBase.toString(), LinearSolutionTransformer.STATE_VECTOR_TMP_BACK_ASSIGNMENTS_FILE));
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.P30_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VARIABLES_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.PROPAGATOR_MATRIX_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.PROPAGATOR_STEP_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VECTOR_TMP_DECLARATIONS_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VECTOR_UPDATE_STEPS_FILE),
Paths.get(outputBase.toString(), astNeuron.getName() + "." + LinearSolutionTransformer.STATE_VECTOR_TMP_BACK_ASSIGNMENTS_FILE));
}
else if (solutionType.equals(SolverType.NUMERIC)) {
info("ODE is solved numerically.", LOG_NAME);
return implicitFormTransformer.transformToImplicitForm(
astNeuron,
Paths.get(outputBase.toString(),ImplicitFormTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(),ImplicitFormTransformer.EQUATIONS_FILE));
Paths.get(outputBase.toString(), astNeuron.getName() + "." + ImplicitFormTransformer.PSC_INITIAL_VALUE_FILE),
Paths.get(outputBase.toString(),astNeuron.getName() + "." + ImplicitFormTransformer.EQUATIONS_FILE));
}
else {
warn(astNeuron.getName() + ": ODEs could not be solved. The model remains unchanged.");
Expand Down