In [1]:
#Importing libraries
import matplotlib.pyplot as plt
import mne
from pathlib import Path
import seaborn as sns

from deepjr.simulation import jr_typical_param
from deepjr.simulation import JRSimulator, EventRelatedExp, SimResults

# Genral Information

These are the values for the Jansen Rit Model which are available in the literature, particularly from [this paper](https://mathematical-neuroscience.springeropen.com/articles/10.1186/s13408-017-0046-4/tables/1). The value of $v_{max}$ was corrected from 5 Hz to 50 Hz, a more reasonable value and a value that compatible with other publications (e.g., [this one](https://link.springer.com/article/10.1007/s10827-013-0493-1#Tab1)). Minimum and maximum values are defined as per [the code of The Virtual Brain](https://docs.thevirtualbrain.org/_modules/tvb/simulator/models/jansen_rit.html).




| Parameter  | Description                                                                  | Typical value | min value | max value |
| ---------- | ---------------------------------------------------------------------------- | ------------- | --------- | --------- |
| $A_e$      | Average excitatory synaptic gain                                             | 3.25 mV       | 2.6 mV    | 9.75 mV   | 
| $A_i$      | Average inhibitory synaptic gain 	                                        | 22 mV         | 17.6 mV   | 110.0 mV  |
| $b_e$	     | Inverse of the time constant of excitatory postsynaptic potential            | 100 Hz        | 50 Hz     | 150 Hz    |
| $b_i$	     | Inverse of the time constant of inhibitory postsynaptic potential            | 50 Hz         | 25 Hz     | 75 Hz     |
| $C$	     | Average number of synapses between the populations 	                        | 135           | 65        | 1350      |
| $a_1$      | Average probability of synaptic contacts in the feedback excitatory loop     | 1.0           | 0.5       | 1.5       |
| $a_2$	     | Average probability of synaptic contacts in the slow feedback excitatory loop| 0.8           | 0.4       | 1.2       |
| $a_3$	     | Average probability of synaptic contacts in the feedback inhibitory loop     | 0.25          | 0.125     | 0.375     |
| $a_4$	     | Average probability of synaptic contacts in the slow feedback inhibitory loop| 0.25          | 0.125     | 0.375     |
| $v_{max}$  | Maximum firing rate of the neural populations (max. of sigmoid fct.)         | 50 Hz         |     -     |     -     | 
| $v_0$      | Value for which 50% of the maximum firing rate is attained 	                | 6 mV          | 3.12 mV   | 6.0 mV    |




In [None]:
# Default Parameters from Jansen RIT model
parameters = dict(jr_typical_param)


In [None]:
parameters

### Getting and info structure, a montage, and a noise covariance matrix for simulation

In [None]:
# Set up the simulator. This also create a head model for EEG simulation.
jr_sim = JRSimulator()

# Set the experiment.
er_exp = EventRelatedExp(jr_sim.info)

In [None]:
# Plot the montage with adjusted font size
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
mne.viz.plot_montage(jr_sim.montage, scale_factor=20, axes=ax, show_names=True)

fig.savefig('montage_plot.png', dpi=300, bbox_inches='tight')

### Generate stimulus

In [None]:
er_exp.plot_stimulus()

### Run an example of simulation

In [None]:
jr_sim.run_simulation(er_exp, parameters, jr_noise_sd=0.0)
jr_sim.plot_jr_results();

In [None]:
noise_fact = 1 #1e3

jr_sim.generate_raw(seed=0, noise_fact=noise_fact)
jr_sim.raw.plot();

In [None]:
jr_sim.generate_evoked(er_exp)
jr_sim.evoked.plot();

In [None]:
jr_sim.evoked.plot_topomap();

## EEG simulations

In [None]:
recompute = False

base_path = Path('deepjr_training_data')
base_path.mkdir(exist_ok=True)
method = 'normal'  # normal distribution of parameters
nb_sims = 1000
sim_results = SimResults(nb_sims, noise_fact, base_path)

if not sim_results.full_path.exists() or recompute:
    # Simulation per parameter calling the function
    mne.set_log_level(verbose=False)
    jr_sim.simulate_for_parameter(er_exp, method=method, 
                                  nb_sims=nb_sims, noise_fact=noise_fact,
                                  base_path=base_path, use_tqdm=True)

In [None]:
sim_results = SimResults(nb_sims, noise_fact, base_path)
sim_results.load()
sim_results.clean()
sim_results.plot_evoked_heatmap()

In [None]:
sns.displot(sim_results.snr)

# Training

In [None]:
from deepjr.inference import JRInvModel
from deepjr.utils import reset_random_seeds

recompute = True
epochs = 150
batch_size=32

inv_model = JRInvModel(nb_sims=nb_sims, noise_fact=noise_fact, path=base_path)

if not inv_model.full_path_model.exists() or recompute:
    reset_random_seeds()  # Reset the seeds
    inv_model.train_model(epochs, batch_size=batch_size)
    inv_model.save()
else:
    inv_model.load()

# Assessment

In [None]:
# Create a small testing set
mne.set_log_level(verbose=False)
jr_sim.simulate_for_parameter(er_exp, method="normal", 
                              nb_sims=50, noise_fact=noise_fact,
                              save=False, use_tqdm=True)
jr_sim.sim_results.clean()
sim_results.plot_evoked_heatmap()

dataset = jr_sim.sim_results.dataset

In [None]:
X = dataset.evoked.transpose("sim_no", "time", "ch_names").values
y = dataset.parameters.sel(param=inv_model.estim_params).values

In [None]:
inv_model.assess_model(parameter='all', X=X, y=y)

In [None]:
inv_model.plot_test_regressions(X=X, y=y)