In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import h5py
import json
import os

from cell_inference.config import paths, params
from cell_inference.utils.plotting.plot_results import plot_lfp_traces, plot_lfp_heatmap
from cell_inference.utils.feature_extractors.SummaryStats2D import calculate_stats, build_lfp_grid
from cell_inference.utils.feature_extractors.parameterprediction import ClassifierTypes, ClassifierBuilder

%matplotlib inline

## Load simulation data

In [2]:
DATA_PATH = 'cell_inference/resources/simulation_data'
TRIAL_PATH = os.path.join(DATA_PATH, 'neuronal_model_491766131_Loc5')

CONFIG_PATH = os.path.join(TRIAL_PATH, 'config.json')  # trial configuration
LFP_PATH = os.path.join(TRIAL_PATH, 'lfp.npz')  # LFP and labels
STATS_PATH = os.path.join(TRIAL_PATH, 'summ_stats.npz')  # summary statistics

with open(CONFIG_PATH, 'r') as f:
    config_dict = json.load(f)

inference_list = config_dict['Trial_Parameters']['inference_list']
print(inference_list)

STATS = np.load(STATS_PATH)
summ_stats = STATS['x']
labels = STATS['y']
df_la = pd.DataFrame(labels, columns=inference_list)
yshift = 'y' in inference_list and STATS['ys'].size != 0
if yshift:
    labels[:,inference_list.index('y')] = STATS['ys']
    df_la.sort_values(by='y',inplace=True)

with pd.option_context('display.max_rows',10):
    display(df_la)

['y', 'd', 'theta', 'h', 'phi']


Unnamed: 0,y,d,theta,h,phi
3,-132.540830,141.725841,-0.032450,0.369646,-1.861283
6,-111.224290,127.695576,0.047458,-0.142611,0.045752
2,-102.280518,163.525782,0.443308,-0.503671,1.759341
12,-79.567411,140.102742,0.185093,0.740001,2.408614
9,-65.990438,189.524516,0.024754,-0.642074,-1.370801
...,...,...,...,...,...
5,5.826070,79.906507,-0.386840,0.750147,-0.835700
10,16.722668,64.684229,0.151374,-0.980576,-2.605126
14,20.456311,99.531140,0.215845,-0.116415,-2.969536
8,53.338696,141.096088,0.021343,-0.373789,-1.365200


### Normalizing labels

#### Set bounds for y shift

In [3]:
ranges = config_dict['Simulation_Parameters']['loc_param_range']
ranges.update(config_dict['Simulation_Parameters']['geo_param_range'])
print(json.dumps(ranges))

if yshift:
    ranges['y'] = [-150, 150]
    df_la_idx = df_la[df_la['y'].between(*ranges['y'])].index.values
    labels = labels[df_la_idx,:]
    summ_stats = summ_stats[df_la_idx,:]
    with pd.option_context('display.max_rows',10):
        display(df_la.loc[df_la_idx])

{"x": [-50, 50], "y": [-1400, 1400], "z": [20.0, 200.0], "alpha": [0, 3.141592653589793], "h": [-1.0, 1.0], "phi": [-3.141592653589793, 3.141592653589793], "d": [20.0, 200.0], "theta": [-1.0471975511965976, 1.0471975511965976], "r_s": [5.0, 12.0], "l_t": [20.0, 800.0], "r_t": [0.25, 0.8], "r_d": [0.15, 0.45], "r_tu": [0.15, 0.45], "l_d": [100.0, 300.0], "r_a": [0.15, 0.45]}


Unnamed: 0,y,d,theta,h,phi
3,-132.540830,141.725841,-0.032450,0.369646,-1.861283
6,-111.224290,127.695576,0.047458,-0.142611,0.045752
2,-102.280518,163.525782,0.443308,-0.503671,1.759341
12,-79.567411,140.102742,0.185093,0.740001,2.408614
9,-65.990438,189.524516,0.024754,-0.642074,-1.370801
...,...,...,...,...,...
5,5.826070,79.906507,-0.386840,0.750147,-0.835700
10,16.722668,64.684229,0.151374,-0.980576,-2.605126
14,20.456311,99.531140,0.215845,-0.116415,-2.969536
8,53.338696,141.096088,0.021343,-0.373789,-1.365200


#### Normalization

In [4]:
feature_range = (-1, 1)
for i, lb in enumerate(inference_list):
    x_std = (labels[:,i] - ranges[lb][0]) / (ranges[lb][1] - ranges[lb][0])
    labels[:,i] = x_std * (feature_range[1] - feature_range[0]) + feature_range[0]
df_la = pd.DataFrame(labels, columns=inference_list)
with pd.option_context('display.max_rows',10):
    display(df_la)

Unnamed: 0,y,d,theta,h,phi
0,-0.883606,0.352509,-0.030987,0.369646,-0.592465
1,-0.741495,0.196618,0.045319,-0.142611,0.014563
2,-0.681870,0.594731,0.423328,-0.503671,0.560016
3,-0.530449,0.334475,0.176751,0.740001,0.766686
4,-0.439936,0.883606,0.023638,-0.642074,-0.436339
...,...,...,...,...,...
15,0.038840,-0.334372,-0.369405,0.750147,-0.266012
16,0.111484,-0.503509,0.144552,-0.980576,-0.829237
17,0.136375,-0.116321,0.206117,-0.116415,-0.945233
18,0.355591,0.345512,0.020381,-0.373789,-0.434557


## Load model

In [8]:
import torch
from cell_inference.utils.feature_extractors.fullyconnectednetwork import FullyConnectedNetwork, ActivationTypes
# from cell_inference.utils.feature_extractors.convolutionalnetwork import ConvolutionalNetwork

model = FullyConnectedNetwork(in_features=summ_stats.shape[1], out_features=labels.shape[1])

In [None]:
model_name = 'FCN_batch128.pth'
MODEL_PATH = os.path.join(TRIAL_PATH,model_name)
model.load_state_dict(torch.load(MODEL_PATH))

In [None]:
from sklearn.metrics import r2_score
from cell_inference.utils.transform.geometry_transformation import hphi2unitsphere, unitsphere2hphi, trivarnorm2unitsphere
from cell_inference.utils.feature_extractors.helperfunctions import build_dataloader_from_numpy
from cell_inference.utils.metrics.corrcoef import corrcoef

device = torch.device("cpu")

train_loader, test_loader = build_dataloader_from_numpy(input_arr=summ_stats, labels_arr=labels, batch_size=512, shuffle=True)

x, y = next(iter(test_loader))
model.eval()
x = x.to(device)
output = model(x)
output = output.to("cpu").detach().numpy()
y = y.to("cpu").detach().numpy()

# print("R2: {}".format(r2_score(y, output)))
print('R2 Score Y-Shift: {}'.format(r2_score(y[:,0], output[:,0])))
print('R2 Score D: {}'.format(r2_score(y[:,1], output[:,1])))
print('R2 Score Theta: {}'.format(r2_score(y[:,2], output[:,2])))
print('R2 Score h: {}'.format(r2_score(y[:,3], output[:,3])))
print('R2 Score Phi: {}'.format(r2_score(y[:,4], output[:,4])))
print('R2 Score Soma Radius: {}'.format(r2_score(y[:,5], output[:,5])))
print('R2 Score Trunk Length: {}'.format(r2_score(y[:,6], output[:,6])))
print('R2 Score Trunk Width: {}'.format(r2_score(y[:,7], output[:,7])))


# print(output.shape)
# print(y.shape)

df_la = pd.DataFrame(y, columns=inference_list).sort_values(by='y')
display(df_la)
# print(y[:,0])
# print(output[:,0])

In [None]:
if True:
    for i in range(y.shape[1]):
        old_y = y[:,i]
        old_out = output[:,i]
        label_name = inference_list[i]
        min_max_range = ranges[label_name]
        org_y = (((old_y - feature_range[0]) / (feature_range[1] - feature_range[0])) 
                    * (min_max_range[1] - min_max_range[0]) + min_max_range[0])
        
        org_out = (((old_out - feature_range[0]) / (feature_range[1] - feature_range[0])) 
                    * (min_max_range[1] - min_max_range[0]) + min_max_range[0])
        y[:,i] = org_y
        output[:,i] = org_out

# print(y[:,0])
# print(output[:,0])

In [None]:
%matplotlib inline

idx = 0

plt.figure(figsize=(20,20))

plt.suptitle("Stylized Cell Actual VS Predicted Parameters", fontsize=30)
fontsize = 25
labelsize = 25

lab_ax = 0
ax = plt.subplot(331)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel('y-shift actual', fontsize=fontsize)
ax.set_ylabel('y-shift predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 1
ax = plt.subplot(332)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel('d actual', fontsize=fontsize)
ax.set_ylabel('d predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 2
ax = plt.subplot(333)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel(r'$\theta$ actual', fontsize=fontsize)
ax.set_ylabel(r'$\theta$ predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 3
ax = plt.subplot(334)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel('h actual', fontsize=fontsize)
ax.set_ylabel('h predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 4
ax = plt.subplot(335)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel(r'$\varphi$ actual', fontsize=fontsize)
ax.set_ylabel(r'$\varphi$ predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 5
ax = plt.subplot(336)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel('soma radius actual', fontsize=fontsize)
ax.set_ylabel('soma radius predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 6
ax = plt.subplot(337)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel('trunk length actual', fontsize=fontsize)
ax.set_ylabel('trunk length predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

lab_ax = 7
ax = plt.subplot(338)
ax.scatter(y[:,lab_ax], output[:,lab_ax], c='red', marker='.')
m, b = np.polyfit(y[:,lab_ax], output[:,lab_ax], 1)
ax.plot(y[:,lab_ax], m*y[:,lab_ax]+b, label='y = '+ str(round(m,2)) + 'x +' + str(round(b,2)))
ax.set_xlabel('trunk radius actual', fontsize=fontsize)
ax.set_ylabel('trunk radius predicted', fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
ax.legend(fontsize=labelsize)

plt.tight_layout(pad=3., rect=[0, 0.03, 1, 1])
plt.show()

In [None]:
from cell_inference.utils.feature_extractors.SummaryStats2D import get_y_window
from tqdm.notebook import tqdm
from cell_inference.utils.spike_window import first_pk_tr, get_spike_window

DATA_PATH = 'cell_inference/resources/invivo'

INVIVO_PATH = os.path.join(DATA_PATH, 'all_cell_LFP_2D.h5')

with h5py.File(INVIVO_PATH, "r") as f:
    print(f.keys())
    c = f['coord'][:]
    d = f['data'][:]  # time x channels x samples
    ids = f['ID'][:]

t = np.arange(d.shape[0])

scaler = 7720.0

filtered_lfp = np.divide(d, scaler)

pk_tr_idx_in_window = 16  # 16*0.025=0.4 ms
lfp_list = []
for i in range(d.shape[2]):
    #     filtered_lfp[i] /= np.max(np.abs(filtered_lfp[i]))
    fst_idx = first_pk_tr(filtered_lfp[:,:,i])
    start, end = get_spike_window(filtered_lfp[:,:,i], win_size=params.WINDOW_SIZE, align_at=pk_tr_idx_in_window)
    lfp_list.append(filtered_lfp[start:end, :, i])

windowed_lfp = np.stack(lfp_list, axis=0)  # (samples x time window x channels)

test_data = []
summ_stats = []
bad_indices = []
y_pos = []
for i in tqdm(range(windowed_lfp.shape[0])):
    try:
        g_lfp, g_coords, y_i = build_lfp_grid(windowed_lfp[i, :, :], params.ELECTRODE_POSITION[:, :2], y_window_size=960.0)
    except ValueError:
        # windowed_lfp = np.delete(windowed_lfp, i, axis=0)
        # self.labels = np.delete(self.labels, i, axis=0)
        bad_indices.append(i)
        continue
    test_data.append(g_lfp)
    summ_stats.append(calculate_stats(g_lfp))
    y_pos.append(y_i)
    
ids = np.delete(ids, bad_indices, axis=0)
test_data = np.stack(test_data, axis=0)
summ_stats = np.array(summ_stats)
y_pos = np.stack(y_pos, axis=0)
print(test_data.shape, summ_stats.shape, y_pos.shape)

In [None]:
y_pos = y_pos.reshape((-1,))
# print(y_pos)

np.set_printoptions(suppress=True)

model.eval()
summ_stats = torch.Tensor(summ_stats)
summ_stats_tensor = summ_stats.to(device)
pred = model(summ_stats_tensor)
pred = pred.to("cpu").detach().numpy()

# print(pred[:, 0].shape)

if True:
    for i in range(pred.shape[1]):
        old_pred = pred[:,i]
        label_name = inference_list[i]
        min_max_range = ranges[label_name]
        org_pred = (((old_pred - feature_range[0]) / (feature_range[1] - feature_range[0])) 
                    * (min_max_range[1] - min_max_range[0]) + min_max_range[0])
        pred[:,i] = org_pred

pred[:,0] = y_pos - pred[:, 0]


# idx_map = np.stack((np.arange(ids.shape[0]), ids), axis=-1)
# idx_map = dict(enumerate(ids.flatten()))
# print(ids.flatten())
# print(idx_map)
        
y_pred = pred
df = pd.DataFrame(y_pred, columns=inference_list)#.sort_values(by='y')
# df = df.rename(index=idx_map)
df['cell_id'] = ids.flatten()
# df = df[df['y'].between(-700,700)]
display(df)

In [None]:
df['r_s'] = df['r_s'].clip(ranges['r_s'][0], ranges['r_s'][1])
df['l_t'] = df['l_t'].clip(ranges['l_t'][0], ranges['l_t'][1])
df['r_t'] = df['r_t'].clip(ranges['r_t'][0], ranges['r_t'][1])

In [None]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    display(df)

In [None]:
t = np.arange(test_data.shape[1])

ix = 1
ylim = [-1900,1900]
x_dist = np.unique(g_coords[:,0])
e_idx = ((g_coords[:,0]==x_dist[ix]) & 
         (g_coords[:,1]>=ylim[0]) & 
         (g_coords[:,1]<=ylim[1]))

In [None]:
from cell_inference.utils.metrics.prediction_verification import InVivoParamSimulator

simu = InVivoParamSimulator(df)
lfp, t = simu.verify_and_save(save=False)

print(lfp.shape)
data_set = []
bad_indices = []
coordinates = []
for i in tqdm(range(lfp.shape[0])):
    try:
        g_lfp, g_coords, y_i = build_lfp_grid(lfp[i, :, :], params.ELECTRODE_POSITION[:, :2], y_window_size=960.0)
    except ValueError:
        bad_indices.append(i)
        continue
    data_set.append(g_lfp)
    coordinates.append(g_coords)
data_set = np.stack(data_set, axis=0)
coordinates = np.stack(coordinates, axis=0)

test_data = np.delete(test_data, bad_indices, axis=0)
ids = np.delete(ids, bad_indices, axis=0)
df = df.drop(index=bad_indices)

In [None]:
%matplotlib inline
from importlib import reload
import cell_inference.utils.plotting.plot_results
reload(cell_inference.utils.plotting.plot_results)
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from cell_inference.utils.plotting.plot_results import plot_multiple_lfp_heatmaps


for j, (i, row) in enumerate(df.iterrows()):
#     if i < 20:
#         continue
    fig=plt.figure(figsize=(15,4))
    outer=GridSpec(1,2)
    
#     In Vivo Plot
    vlim = plot_multiple_lfp_heatmaps(t,
                                   coordinates[j, e_idx, 1],
                                   np.transpose(test_data[j,:,e_idx]), 
#                                    savefig='lfp_heatmaps/realinvivo' + str(i) + '.jpg',
                                   vlim='auto',
                                   fontsize=18,labelpad=0,ticksize=15,nbins=5,
                                   fig=fig, outer=outer, col=0, cell_num=0, title='In Vivo Cell {}'.format(row['cell_id']))

    # Predicted In Vivo From Params
    plot_multiple_lfp_heatmaps(t,
                                   coordinates[j, e_idx, 1],
                                   np.transpose(data_set[j,:,e_idx]), 
#                                    savefig='lfp_heatmaps/predictedinvivo' + str(i) + '.jpg',
                                   vlim=vlim,
                                   fontsize=18,labelpad=0,ticksize=15,nbins=5,
                                   fig=fig, outer=outer, col=1, cell_num=0, title='Predicted Cell {}'.format(row['cell_id']))

    plt.tight_layout()
    
    plt.savefig('lfp_heatmaps/' + str(row['cell_id']) + '.jpg')
#     if i == 22:
#         break
# plt.show()

In [None]:
df.to_pickle('invivo_df.pkl')

In [None]:
from cell_inference.utils.metrics.corrcoef import corrcoef

corrcoef_list = np.array([corrcoef(test_data[i,:,:], data_set[i,:,:]) for i in range(data_set.shape[0])])

maidx = np.argmax(corrcoef_list)
miidx = np.argmin(corrcoef_list)
print('Max Index: {}'.format(maidx))
print('Min Index: {}'.format(miidx))

plt.hist(corrcoef_list, bins=20)
plt.xlabel('Correlation Coefficient')
plt.ylabel('Count')
plt.title('In Vivo vs Predicted Correlation Coefficients')
plt.show()

In [None]:
# fig=plt.figure(figsize=(15,10))

for i in [miidx, maidx]:
    fig=plt.figure(figsize=(15,4))
    outer=GridSpec(1,2)
#     print(df.loc[i, 'cell_id'])
#     In Vivo Plot
    vlim = plot_multiple_lfp_heatmaps(t,
                                   coordinates[i, e_idx, 1],
                                   np.transpose(test_data[i,:,e_idx]), 
#                                    savefig='lfp_heatmaps/realinvivo' + str(i) + '.jpg',
                                   vlim='auto',
                                   fontsize=18,labelpad=0,ticksize=15,nbins=5,
                                   fig=fig, outer=outer, col=0, cell_num=0, title='In Vivo Cell {}'.format(df.loc[i, 'cell_id']))

    # Predicted In Vivo From Params
    plot_multiple_lfp_heatmaps(t,
                                   coordinates[i, e_idx, 1],
                                   np.transpose(data_set[i,:,e_idx]), 
#                                    savefig='lfp_heatmaps/predictedinvivo' + str(i) + '.jpg',
                                   vlim=vlim,
                                   fontsize=18,labelpad=0,ticksize=15,nbins=5,
                                   fig=fig, outer=outer, col=1, cell_num=0, title='Predicted Cell {}'.format(df.loc[i, 'cell_id']))

    plt.tight_layout()

plt.show()

In [None]:
%matplotlib notebook

from importlib import reload
# import cell_inference.utils.plotting.plot_all_cells
# reload(cell_inference.utils.plotting.plot_all_cells)
from cell_inference.utils.plotting.plot_all_cells import plot_all_cells

fig, ax = plot_all_cells(df, figsize=(15., 15.))
ax.autoscale()

In [None]:
%matplotlib inline
import pandas as pd
from matplotlib.gridspec import GridSpec

LOSS_PATH = 'cell_inference/resources/results/pytorch_losses/13_43_39__02_14_2022.csv'
loss_df = pd.read_csv(LOSS_PATH)

t_loss = loss_df['Training_Loss'].to_numpy() # / 86729
epochs = np.arange(t_loss.shape[0])

v_loss = loss_df['Validation_Loss'].to_numpy() # / 28910


fig=plt.figure(figsize=(10,10))

fig.suptitle("Loss Graphs", fontsize=40)
gs=GridSpec(2,2)

ax1=fig.add_subplot(gs[0,:]) 
ax2=fig.add_subplot(gs[1,0]) 
ax3=fig.add_subplot(gs[1,1]) 

ax1.plot(epochs, t_loss, label='Training Loss')
ax1.plot(epochs, v_loss, label='Validation Loss')
ax1.tick_params(labelsize=20)
ax1.set_ylim(bottom=15, top=120)
# ax1.set_xlabel('Epoch', fontsize=40)
ax1.set_ylabel('Loss', fontsize=30)
ax1.legend(fontsize=15)

ax2.plot(epochs, t_loss, label='Training Loss')
ax2.tick_params(labelsize=20)
ax2.set_ylim(bottom=40, top=120)
ax2.set_xlabel('Epoch', fontsize=30)
ax2.set_ylabel('Loss', fontsize=30)
ax2.set_title('Training', fontsize=20)

ax3.plot(epochs, v_loss, label='Validation Loss')
ax3.tick_params(labelsize=20)
ax3.set_xlabel('Epoch', fontsize=30)
ax3.set_title('Validation', fontsize=20)
# ax3.set_ylabel('Loss', fontsize=40)

plt.tight_layout()
plt.show()
# print(t_loss)
# display(loss_df)