In [None]:
# Import modules

%matplotlib inline

import os
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(10)

import matplotlib.animation as animation

default_dir = os.path.dirname(os.getcwd())
os.chdir(default_dir)

import RLConn as rc
from RLConn import problem_definitions as problems

## Import ground truth modes

In [None]:
m1_gt = rc.neural_params.m1_gt
m2_gt = rc.neural_params.m2_gt

In [None]:
centroid_coord = rc.utils.centeroidnp(np.vstack([m1_gt,m2_gt]))

plt.figure(figsize=(5.5,5))

plt.scatter(m1_gt, m2_gt, s = 10, color = 'black')
plt.scatter(centroid_coord[0], centroid_coord[1], s = 10, color = 'black')
plt.ylim(-45, 45)
plt.xlim(-45, 45)

## Define problem statement params and DQN params

In [None]:
Gg_noise = np.random.randint(-2, 2, (3,3))
Gg_noise = (Gg_noise + Gg_noise.T)/2
Gg_noise = Gg_noise.astype('int')
np.fill_diagonal(Gg_noise, 0)

Gs_noise = np.random.randint(-2, 2, (3,3))
np.fill_diagonal(Gs_noise, 0)

Gg_groundtruth = np.array([[0, 8, 5],
                           [8, 0, 2],
                           [5, 2, 0]]).astype('float')

Gg_noised = Gg_groundtruth + Gg_noise
Gg_noised[Gg_noised < 0] = 0

assert np.sum(Gg_noised == Gg_noised.T) == 9
assert np.diag(Gg_noised).sum() == 0

Gs_groundtruth = np.array([[0, 2, 8],
                           [7, 0, 3],
                           [7, 7, 0]]).astype('float')

Gs_noised = Gs_groundtruth + Gs_noise
Gs_noised[Gs_noised < 0] = 0

assert np.diag(Gs_noised).sum() == 0

E = np.array([1, 0, 0])

In [None]:
network_dict_init = {
    
    "gap": Gg_noised,
    "syn": Gs_noised,
    "directionality": E
}


external_params_dict = {

"input_vec" : [0, 0.03, 0],
"ablation_mask" : np.ones(3),
"tf" : 10,
"t_delta" : 0.01,
"cutoff_1" : 400,
"cutoff_2" : 900
    
}

# Feel free to change the params

batchsize = 1
num_epochs = 10000
err_threshold = 10
weight_min = 0
weight_max = np.max(np.max([network_dict_init['gap'], network_dict_init['syn']]))
plotting_period = 300

## Training

In [None]:
# Train

training_result = rc.network_sim.train_network(network_dict_init, external_params_dict, m1_gt, m2_gt,
                    batchsize = batchsize, num_epochs = num_epochs, err_threshold = err_threshold, 
                    weight_min = weight_min, weight_max = weight_max, plotting_period = plotting_period)

In [None]:
best_ind = training_result['err_list'].index(np.min(training_result['err_list']))
#best_ind = np.where(validation['err_list'] == np.min(validation['err_list']))

In [None]:
training_result['err_list'][best_ind], best_ind

In [None]:
np.savez('validation_train.npz', **training_result)

In [None]:
validation = np.load('validation_train.npz')

In [None]:
plt.hist(-np.tanh(0.005 * np.diff(validation['err_list'])), bins = 5)

In [None]:
plt.plot(np.diff(validation['err_list']))
plt.ylim(-10, 10)

In [None]:
# Display results
Gg_init = training_result['Gg_list'][2]
Gs_init = training_result['Gs_list'][2]
E = training_result['E']

Gg_trained = training_result['Gg_list'][best_ind]
Gs_trained = training_result['Gs_list'][best_ind]

#Gg_init = validation['Gg_list'][2]
#Gs_init = validation['Gs_list'][2]
#E = validation['E']

#Gg_trained = validation['Gg_list'][3]
#Gs_trained = validation['Gs_list'][3]

m1_init, m2_init = rc.utils.compute_score(Gg_init, Gs_init, E, 
                    external_params_dict['input_vec'], external_params_dict['ablation_mask'], 
                    external_params_dict['tf'], external_params_dict['t_delta'], 
                       external_params_dict['cutoff_1'], external_params_dict['cutoff_2'],
                    m1_target = m1_gt,
                    m2_target = m2_gt,
                    plot_result = True,
                    verbose = True)[-2:]

m1_trained, m2_trained = rc.utils.compute_score(Gg_trained, Gs_trained, E, 
                    external_params_dict['input_vec'], external_params_dict['ablation_mask'], 
                    external_params_dict['tf'], external_params_dict['t_delta'], 
                       external_params_dict['cutoff_1'], external_params_dict['cutoff_2'],
                    m1_target = m1_gt,
                    m2_target = m2_gt,
                    plot_result = True,
                    verbose = True)[-2:]

In [None]:
centroid_coord = rc.utils.centeroidnp(np.vstack([m1_trained,m2_trained]))

plt.figure(figsize=(5.5,5))

plt.scatter(m1_trained, m2_trained, s = 10, color = 'black')
plt.scatter(centroid_coord[0], [1], s = 10, color = 'black')
plt.ylim(-45, 45)
plt.xlim(-45, 45)

In [None]:
def l2_err(m1_target, m2_target, m1_test, m2_test):
    
    m1_diff_dist = np.subtract(m1_target, m1_test)
    m2_diff_dist = np.subtract(m2_target, m2_test)

    m_joined_dist = np.vstack([m1_diff_dist, m2_diff_dist])
    errors_dist = np.sqrt(np.power(m_joined_dist, 2).sum(axis = 0))

    l2_err = np.mean(errors_dist)
    
    return l2_err

In [None]:
init_err = l2_err(m1_gt, m2_gt, m1_init, m2_init)
trained_err = l2_err(m1_gt, m2_gt, m1_trained, m2_trained)

In [None]:
#(63.87632835483021, 13.947343895884398)
init_err, trained_err

In [None]:
plt.figure(figsize=(5,5))

plt.pcolor(Gg_init, cmap = 'Reds', vmin = 0, vmax = 8)
plt.ylim(3, 0)

In [None]:
plt.figure(figsize=(5,5))

plt.pcolor(Gg_trained, cmap = 'Reds', vmin = 0, vmax = 8)
plt.ylim(3, 0)

In [None]:
plt.figure(figsize=(5,5))

plt.pcolor(Gs_init, cmap = 'Blues', vmin = 0, vmax = 8)
plt.ylim(3, 0)

In [None]:
plt.figure(figsize=(5,5))

plt.pcolor(Gs_trained, cmap = 'Blues', vmin = 0, vmax = 8)
plt.ylim(3, 0)

In [None]:
np.save('Gg_init.npy', Gg_init) 
np.save('Gs_init.npy', Gs_init) 

In [None]:
Gg_init, Gg_groundtruth

In [None]:
plt.figure(figsize=(5.5,5))

#plt.scatter(m1_init, m2_init, s = 0.75, color = 'red')
plt.scatter(m1_trained, m2_trained, s = 0.75, color = 'red')
plt.scatter(m1_gt, m2_gt, s = 0.75, color = 'black')
plt.ylim(-60, 60)
plt.xlim(-60, 60)