# Run dimensionality reduction
This is a part of `run_SPUD_multiple_tests.py` script. Running this notebook will:
- load projection data from pickle files
- ...

## Setup

### Imports and loading general params

In [5]:
# General imports
import numpy as np
import numpy.linalg as la
import sys, os 
import time, datetime
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pickle as pkl

# For responsive plot
%matplotlib widget

# Import shared modules
gen_fn_dir = os.path.abspath('.') + '/shared_scripts'
sys.path.append(gen_fn_dir)
import general_file_fns as gff
from binned_spikes_class import spike_counts
from dim_red_fns import run_dim_red

# Get current date
curr_date = datetime.datetime.now().strftime('%Y_%m_%d')+'_'

# Load and print general params and create directory to load dimensionality reduction results if needed
gen_params = gff.load_pickle_file('./general_params/general_params.pkl')
print(f"General params used for this session:\n{gen_params}")

Reading data from ./general_params/general_params.pkl...
General params used for this session:
{'raw_data_dir': './data/raw_data/', 'processed_data_dir': './data/processed/', 'kernel_rates_dir': './data/analyses/kernel_rates/', 'results_dir': './data/analyses/', 'cols': {'REM': (0.392, 0.549, 0.0784), 'SWS': (0.824, 0.627, 0.0392), 'Wake': (0.0118, 0.235, 0.392), 'measured': (0.3, 0.3, 0.3), 'fit': (0.49, 0.961, 0.961)}}


### Load projections

In [6]:
# Load projections from pickle files created by dim_red
embeddings_path = "./data/analyses/dim_red/Mouse28-140313_Wake_iso_3_embeddings_2023_06_12_.pkl"
with open(embeddings_path, "rb") as f:
    embeddings = pkl.load(f)

print(f"Loaded embedding dict with keys:")
print(embeddings.keys())

Loaded embedding dict with keys:
dict_keys(['Wake', 'meas_angles'])


### Load used params

In [8]:
params_path = "./data/analyses/dim_red/Mouse28-140313_Wake_iso_3_used_params_2023_06_12_.pkl"
with open(params_path, "rb") as f:
    params = pkl.load(f)

print(f"Loaded params used in dim_red part:")
print(params)

Loaded params used in dim_red part:
{'session': 'Mouse28-140313', 'fit_dim': 3, 'nKnots': 15, 'knot_order': 'wt_per_len', 'penalty_type': 'mult_len', 'nTests': 10, 'train_frac': 0.8, 'area': 'ADn', 'state': 'Wake', 'dt_kernel': 0.1, 'sigma': 0.1, 'method': 'iso', 'n_neighbors': 5, 'dt': 0.1, 'target_dim': 3, 'desired_nSamples': 15000}


## Fitting the spline

In [None]:
current_manifold = embeddings[params['state']]
nPoints = len(current_manifold)
nTrain = np.round(params['train_frac'] * params['nPoints']).astype(int)

# Use measured angles to set origin and direction of coordinate increase
ref_angles = embeddings['meas_angles']
fit_params = {'dalpha' : 0.005, 'knot_order' : params['knot_order'], 'penalty_type' : params['penalty_type'], 'nKnots' : params['nKnots']}

results = {}
tic = time.time()
k = (session, params['fit_dim'], nKnots, knot_order, penalty_type, train_frac)
print('Fitting manifold')
for curr_sample in range(nTests):
    curr_fit_params = dict(fit_params)
        
    train_idx = np.random.choice(nPoints, size=nTrain, replace=False)
    test_idx = np.array([idx for idx in range(nPoints) if idx not in train_idx])
    data_to_fit = current_manifold[train_idx].copy()
    data_to_decode = current_manifold[test_idx].copy()

    curr_fit_result = mff.fit_manifold(data_to_fit, curr_fit_params)
    dec_angle, mse = mff.decode_from_passed_fit(data_to_decode, curr_fit_result['tt'][:-1], 
        curr_fit_result['curve'][:-1], ref_angles[test_idx])
    if k in results:
        results[k].append([mse, curr_fit_result['fit_err'], 
            np.array(curr_fit_result['final_knots'])])
    else:
        results[k] = [[mse, curr_fit_result['fit_err'], 
            np.array(curr_fit_result['final_knots'])]]
print('Time ', time.time()-tic)

to_save = {'fit_results' : results, 'session' : session, 'area' : area, 'state' : state, 
    'embeddings_file' : embeddings_fname} 
gff.save_pickle_file(to_save, dir_to_save + '%s_%s_dim%d_trainfrac%.2f_decode_errors_sd%d.pkl'%(
    session, state, fit_dim, train_frac, sd))

rmse_to_plot = np.sqrt([x[0] for x in results[k]])

fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
if nTests<20:
    ax.scatter(np.ones_like(rmse_to_plot), rmse_to_plot)
    ax.set_xticks([1])
    # ax.set_xticklabels(['Samples'])
else:
    vp = ax.violinplot([rmse_to_plot], positions=[1], points=100,
        widths=0.75, showmeans=True, showextrema=False, showmedians=True)
ax.set_xlim([0,2])
ax.set_ylabel('Root mean squared error (rad)')
ax.set_ylim([0,1.8])
plt.show()
