In [None]:
# Import modules

%matplotlib inline

import os
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(10)

import matplotlib.animation as animation

default_dir = os.path.dirname(os.getcwd())
os.chdir(default_dir)

import RLConn as rc
from RLConn import problem_definitions as problems

## Import target in-vivo modes

In [None]:
m1_invivo = rc.neural_params.m1_target
m2_invivo = rc.neural_params.m2_target

In [None]:
centroid_coord = rc.utils.centeroidnp(np.vstack([m1_invivo, m2_invivo]))

plt.figure(figsize=(5.5,5))

plt.scatter(m1_invivo, m2_invivo, s = 10, color = 'black')
plt.scatter(centroid_coord[0], centroid_coord[1], s = 10, color = 'black')
plt.ylim(-45, 45)
plt.xlim(-45, 45)

## Define problem statement params and DQN params

In [None]:
# define initial network connectivity and external parameters

network_dict_init = rc.connectome_utils.generate_random_network(10, 3, 8)

input_vec = np.zeros(10)
input_vec[5] = 0.3
    
external_params_dict = {

"input_vec" : input_vec,
"ablation_mask" : np.ones(10),
"tf" : 15,
"t_delta" : 0.01,
"cutoff_1" : 400,
"cutoff_2" : 900
    
}

batchsize = 1
num_epochs = 500
err_threshold = 10
weight_min = 0
weight_max = 8
plotting_period = 100 * (3*(3-1)/2)

In [None]:
network_dict_init['directionality']

In [None]:
# Train

training_result = rc.network_sim.train_network(network_dict_init, external_params_dict, m1_invivo, m2_invivo,
                    batchsize = batchsize, num_epochs = num_epochs, err_threshold = err_threshold, 
                    weight_min = weight_min, weight_max = weight_max, plotting_period = plotting_period)

In [None]:
best_ind = training_result['err_list'].index(np.min(training_result['err_list']))

In [None]:
training_result['err_list'][best_ind], best_ind

In [None]:
# Display results
Gg_trained = training_result['Gg_list'][best_ind]
Gs_trained = training_result['Gs_list'][best_ind]
E = training_result['E']

error_dist_flattened, error_frobenius, m1_test, m2_test = rc.utils.compute_score(Gg_trained, Gs_trained, E, 
                    input_vec, external_params_dict['ablation_mask'], 
                    external_params_dict['tf'], external_params_dict['t_delta'], 
                       external_params_dict['cutoff_1'], external_params_dict['cutoff_2'],
                    m1_target = rc.neural_params.m1_target,
                    m2_target = rc.neural_params.m2_target,
                    plot_result = True,
                    verbose = True)