In [None]:
import pandas as pd
import numpy as np
import os
import sys
import scipy.io as sio
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
from sklearn.decomposition import PCA
import scipy
import torch
import torch.nn as nn
from torch import autograd

In [None]:
# functions for individual unit analysis
def tuning_calculate(firingdata,angle_list):
    divisions = 50
    dtheta = 1/divisions*360
    theta = np.arange(0,divisions)/divisions*360
    tuning_fr = np.zeros_like(theta)
    for i in np.arange(0,divisions):
        idx = (angle_list>i*dtheta)&(angle_list<(i+1)*dtheta)
        tuning_fr[i] = np.mean(firingdata[idx]) 
    return tuning_fr
def fwhm_calculate(tuning_curve):
    data = tuning_curve-np.min(tuning_curve)
    max_indx = np.argmax(data)
    max_value = np.max(data)
    if max_indx >0:
        for i in np.arange(max_indx):
            valuehere = data[max_indx-i-1]
            if valuehere < max_value/2:
                break
        lefthalf_indx = max_indx-i-1
        # print(lefthalf_indx,valuehere,max_value/2)
        # print(lefthalf_indx-1,data[lefthalf_indx-1])
        # print(lefthalf_indx+1,data[lefthalf_indx+1])
    else:
        lefthalf_indx = 0
    if len(data)-max_indx-1>0:
        for i in np.arange(len(data)-max_indx):
            valuehere = data[max_indx+i]
            if valuehere < max_value/2:
                break
        righthalf_indx = max_indx+i
    else:
        righthalf_indx = len(data)-1

    if lefthalf_indx == 0 and righthalf_indx!=len(data)-1:
        lefthalf_indx = righthalf_indx
        for i in np.arange(len(data)-1):
            valuehere = data[len(data)-i-1]
            if valuehere < max_value/2:
                break
        if i == 0: 
            righthalf_indx = len(data)-1
        else:
            righthalf_indx = len(data)-i-1
        half_peak_width =len(data)-(righthalf_indx-lefthalf_indx)
        left_half_real = righthalf_indx
        right_half_real = lefthalf_indx
    elif lefthalf_indx != 0 and righthalf_indx==len(data)-1:
        righthalf_indx = lefthalf_indx
        for i in np.arange(len(data)-1):
            valuehere = data[i]
            if valuehere < max_value/2:
                break
        lefthalf_indx = i
        half_peak_width =len(data)-(righthalf_indx-lefthalf_indx)
        left_half_real = righthalf_indx
        right_half_real = lefthalf_indx
    else:
        # lefthalf_indx = lefthalf_indx
        # righthalf_indx = righthalf_indx
        half_peak_width = righthalf_indx-lefthalf_indx
        left_half_real = lefthalf_indx
        right_half_real = righthalf_indx
    return half_peak_width,left_half_real,right_half_real

In [None]:
# define color for MP and SP group
color1 = [254/255,129/255,126/255]
color2 = [129/255,184/255,223/255]

In [None]:
# calculate effective dimension for 7 subgroups across different simulated sessions (Fig. 5C)
non_singlelist = [54, 67, 84, 13, 4, 6, 88, 20, 97, 35, 83]
non_selective_list = [12,42,44,45,51,68,69,71,79,82,94,95]
singlelist = np.array(list(set(np.arange(100))-set(non_singlelist)-set(non_selective_list)))
non_singlelist = np.array(non_singlelist)
singlelist = singlelist.astype(int)
non_singlelist = non_singlelist.astype(int)

trials_all = np.arange(20)
ed_unimodal_all = np.zeros((7,20))
ed_multimodal_all = np.zeros(20)
for trialnum in trials_all:
    print(trialnum)
    varname = '../simulated_data/trials_test_1000/output_simulated_200trials_1000t_epoch_500_testtrial_'+str(trialnum+1)+'_0609.mat'
    data = sio.loadmat(varname)
    data1 = data['data']
    data1 = np.array(data1)
    neuraldata = data1[0:100,:]
    angle_list_target = data1[-3,:]
    angle_list_rnnout = data1[-2,:]
    angular_velocity = data1[-1,:]
    data_single = neuraldata[singlelist,:]
    non_singlelist = non_singlelist.astype(int)
    data_nonsinglelist = neuraldata[non_singlelist,:]

    half_peak_width_all = np.zeros(100)
    for i in range(0,100):
        firingdata = neuraldata[i,:]
        tuning_i,theta = tuning_calculate(firingdata,angle_list_rnnout,divisions=50)
        half_peak_width_i,left_half_index,right_half_index = fwhm_calculate(tuning_i)
        half_peak_width_all[i] = half_peak_width_i

    half_peak_width_all_new = half_peak_width_all[singlelist]
    sorted_index = np.argsort(half_peak_width_all_new)
    list_index = singlelist[sorted_index]
    increase_width_group = list_index.reshape(7,11)
    ed_all = np.zeros(7)
    peak_width_tick = np.zeros(7)

    for i in np.arange(7):
        data = neuraldata[increase_width_group[i],:]
        peak_width_tick[i] = np.mean(half_peak_width_all[increase_width_group[i]])/50*2*np.pi
        # data = scipy.stats.zscore(data,axis=1)
        data = data.T
        pca = PCA()
        _ = pca.fit(data)
        explained = pca.explained_variance_ratio_
        ed_all[i] = 1/np.sum(explained**2)
    ed_unimodal_all[:,trialnum] = ed_all

    data = neuraldata[non_singlelist,:]
    data = data.T
    pca = PCA()
    _ = pca.fit(data)
    explained = pca.explained_variance_ratio_
    ed_multimodal = 1/np.sum(explained**2)
    ed_multimodal_all[trialnum] = ed_multimodal    

ed_unimodal_mean = np.mean(ed_unimodal_all,axis=1)
ed_unimodal_std = np.std(ed_unimodal_all,axis=1)
ed_multimodal_mean = np.mean(ed_multimodal_all)
ed_multimodal_std = np.std(ed_multimodal_all)

# plot effective dimension
x_index = np.arange(2,9)
fig = plt.figure(figsize=(4,3))
ax = plt.subplot(111)
_ = ax.plot(x_index,ed_unimodal_mean,'o-',color2,label='unimodal')
_ = ax.fill_between(x_index,ed_unimodal_mean-ed_unimodal_std,ed_unimodal_mean+ed_unimodal_std,alpha=0.5,color='black')
_ = ax.set_xticks(x_index)
_ = ax.set_xticklabels(['{:.1f}'.format(xi) for xi in peak_width_tick])
_ = ax.plot(1,ed_multimodal_mean,'o',color= color1,label='multimodal')
_ = ax.errorbar(1,ed_multimodal_mean,yerr=ed_multimodal_std,color='red')
_ = ax.set_xlabel('half peak width (rad)')
_ = ax.set_ylabel('effective dimension')
fig.savefig('../../figures/effective_dimension_unimodal_multimodal.pdf', bbox_inches='tight', transparent=True, format='pdf')


In [None]:
# neural geometry-trajectory length analysis for 7 subgroups across different simulated sessions (Fig. 5C)
divisions = 360
dtheta = 1/divisions*360
trials_all = np.arange(20)
trajecotry_len_unimodal_all = np.zeros((7,20))
trajecotry_len_multimodal_all = np.zeros(20)
for trialnum in trials_all:
    print(trialnum)
    varname = '../simulated_data/trials_test_1000/output_simulated_200trials_1000t_epoch_500_testtrial_'+str(trialnum+1)+'_0609.mat'
    data = sio.loadmat(varname)
    data1 = data['data']
    data1 = np.array(data1)
    neuraldata = data1[0:100,:]
    angle_list_target = data1[-3,:]
    angle_list_rnnout = data1[-2,:]
    angular_velocity = data1[-1,:]
    down_ratio = 1
    angle_list = angle_list_rnnout[::down_ratio]
    #singlelist_rnn,singlelist_target,tuningfr_all_rnn,tuningfr_all_target = singlelistget(neuraldata,angle_list_rnnout,angle_list_target)
    # singlelist = singlelist_rnn
    data_single = neuraldata[singlelist,:]
    non_singlelist = non_singlelist.astype(int)
    data_nonsinglelist = neuraldata[non_singlelist,:]
    # calculate length for unimodal groups
    for groupnum in np.arange(7):
        data = neuraldata[increase_width_group[groupnum],:]
        data = scipy.stats.zscore(data,axis=1)
        data = data[:,::down_ratio].T
        average_manifold = np.zeros((divisions,np.shape(data)[1]))
        for i in np.arange(0,divisions):
            idx = (angle_list>i*dtheta)&(angle_list<(i+1)*dtheta)
            average_manifold[i,:] = np.mean(data[idx,:],0)
        trajecotry_len_unimodal_all[groupnum,trialnum] = np.sum(np.linalg.norm(np.diff(average_manifold,axis=0),axis=1))

    data = neuraldata[non_singlelist,:]
    data = scipy.stats.zscore(data,axis=1)
    data = data[:,::down_ratio].T
    average_manifold = np.zeros((divisions,np.shape(data)[1]))
    for i in np.arange(0,divisions):
        idx = (angle_list>i*dtheta)&(angle_list<(i+1)*dtheta)
        average_manifold[i,:] = np.mean(data[idx,:],0)
    trajecotry_len_multimodal_all[trialnum] = np.sum(np.linalg.norm(np.diff(average_manifold,axis=0),axis=1))

trajecotry_len_unimodal_mean = np.mean(trajecotry_len_unimodal_all,axis=1)
trajecotry_len_unimodal_std = np.std(trajecotry_len_unimodal_all,axis=1)
trajecotry_len_multimodal_mean = np.mean(trajecotry_len_multimodal_all)
trajecotry_len_multimodal_std = np.std(trajecotry_len_multimodal_all)

x_index = np.arange(2,9)
fig = plt.figure(figsize=(4,3))
ax = plt.subplot(111)
_ = ax.plot(x_index,trajecotry_len_unimodal_mean ,'o-',color=color2, label='unimodal')
_ = ax.fill_between(x_index,trajecotry_len_unimodal_mean-trajecotry_len_unimodal_std,trajecotry_len_unimodal_mean+trajecotry_len_unimodal_std,alpha=0.5,color= color2)
_ = ax.set_xticks(x_index)
_ = ax.set_xticklabels(['{:.1f}'.format(xi) for xi in peak_width_tick])
_ = ax.plot(1,trajecotry_len_multimodal_mean,'o',color=color1,label='multimodal')
_ = ax.errorbar(1,trajecotry_len_multimodal_mean,yerr=trajecotry_len_multimodal_std,color=color1)


plt.hlines(trajecotry_len_multimodal_mean + trajecotry_len_multimodal_std, 1 - 0.2, 1 + 0.2, colors=color1)
plt.hlines(trajecotry_len_multimodal_mean - trajecotry_len_multimodal_std, 1 - 0.2, 1 + 0.2, colors=color1)
_ = ax.set_xlabel('half peak width (rad)')
_ = ax.set_ylabel('Sum of trajectory length')
plt.legend()
fig.savefig('../../figures/trajectory_length_unimodal_multimodal.pdf', bbox_inches='tight', transparent=True, format='pdf')

In [None]:
# load example data
non_singlelist = [54,67,84,13,4,6,88,20,97,35,83]
non_selective_list = [12,42,44,45,51,68,69,71,79,82,94,95]
singlelist = np.array(list(set(np.arange(100))-set(non_singlelist)-set(non_selective_list)))
non_singlelist = np.array(non_singlelist)
singlelist = singlelist.astype(int)
non_singlelist = non_singlelist.astype(int)
varname = '../simulated_data/epoch_test/output_simulated_200trials_1000t_epoch_'+str(500)+'_0228.mat'
# varname = '../simulated_data/epoch_test/output_simulated_200trials_1000t_epoch_'+str(500)+'_0228.mat'
data = sio.loadmat(varname)
data1 = data['data']
data1 = np.array(data1)
neuraldata = data1[0:100,:]
angle_list_target = data1[-3,:]
angle_list_rnnout = data1[-2,:]
angular_velocity = data1[-1,:]

In [None]:
# divide SP units into 7 subgroups and PCA neural space visualization (Fig. 5D)
with open('../../../results/All_metrics_individual_units_epoch500.pickle', 'rb') as f:
    data = pickle.load(f)
half_peak_width_all = np.array(data['half_peak_width_all'])

half_peak_width_all_new = half_peak_width_all[singlelist]
sorted_index = np.argsort(half_peak_width_all_new)
list_index = singlelist[sorted_index]
increase_width_group = list_index.reshape(7,11)
for groupi in range(7):
    data = neuraldata[increase_width_group[groupi],:] # Notice: increase_width_group should be calculated also from the same file:str(500)+'_0228.mat
    data = scipy.stats.zscore(data,axis=1)
    down_ratio = 10
    data = data[:,::down_ratio].T
    pca = PCA(n_components=3)
    x_embd = pca.fit_transform(data)
    x_embd = 2*x_embd/np.max(np.abs(x_embd)) # normalize
    axis_vectors = pca.components_
    divisions = 360
    dtheta = 1/divisions*360
    angle_list = angle_list_rnnout[::down_ratio]
    average_manifold = np.zeros((divisions,np.shape(data)[1]))
    for i in np.arange(0,divisions):
        idx = (angle_list>i*dtheta)&(angle_list<(i+1)*dtheta)
        average_manifold[i,:] = np.mean(data[idx,:],0)
    projections = np.matmul(average_manifold,axis_vectors.T)
    projections_new = projections*1
    %matplotlib widget
    fig = plt.figure(figsize=(8,8))
    plt.set_cmap('hsv') # circular cmap
    ax = fig.add_subplot(111, projection='3d')
    ax.azim = -12.66 # finally use this angle to visualize
    ax.elev = 19.18
    ds_plt = 3
    cmap = angle_list[::ds_plt]
    scat = ax.scatter(x_embd[:,0][::ds_plt], x_embd[:,1][::ds_plt], x_embd[:,2][::ds_plt], c=cmap, alpha=.7)
    cbar = plt.colorbar(scat)
    cbar.set_label('HD')
    ax.plot(projections_new[:,0],projections_new[:,1],projections_new[:,2],c='k')
    cmap = np.arange(0,divisions)
    ax.scatter(projections_new[:,0],projections_new[:,1],projections_new[:,2],c=cmap,alpha =.5)
    fig.savefig('../../figures/low_d_manifold_uni_inc_group_IV.pdf', bbox_inches='tight', transparent=True, format='pdf')


In [None]:
# functions for population ANN decoding
class VelocityNet(nn.Module):
    def __init__(self, in_features) -> None:
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(in_features, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, 128),
                                    nn.ReLU(),
                                    nn.Linear(128, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, 2))

        # self.layers = nn.Sequential(nn.Linear(in_features, 64),
        #                             nn.ReLU(),
        #                             nn.Linear(64, 64),
        #                             nn.ReLU(),
        #                             nn.Linear(64, 2))

    def forward(self, x):
        y = self.layers(x)
        v_hat = y[:, 0]
        v_var = softplus(y[:, 1])
        # v_var = torch.ones_like(v_var)
        return v_hat, v_var


class AngleNet(nn.Module):
    def __init__(self, in_features) -> None:
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(in_features, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, 128),
                                    nn.ReLU(),
                                    nn.Linear(128, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, 3))
        # self.layers = nn.Sequential(nn.Linear(in_features, 64),
        #                             nn.Tanh(),
        #                             nn.Linear(64, 64),
        #                             nn.Tanh(),
        #                             nn.Linear(64, 3))

    def forward(self, x):
        y = self.layers(x)
        a = y[:, [0, 1]]
        a = a / (torch.sqrt(a.pow(2).sum(-1, keepdim=True)) + 1e-8)
        a_cos = a[:, 0]
        a_sin = a[:, 1]
        # a_kap = softplus(y[:, 2])
        # with torch.no_grad():
        #     a_kap.clamp_max_(10)
        # a_kap = a_kap.clamp_max_(10)
        a_kap = 10*torch.sigmoid(y[:, 2])
        return a_cos, a_sin, a_kap
    

def make_dataloader(x, y, device=None):
    x = torch.Tensor(x).to(device)
    y = torch.Tensor(y).to(device)
    data_set = Data.TensorDataset(x, y)
    data_loader = Data.DataLoader(dataset=data_set,
                                  batch_size=512,
                                  shuffle=True,
                                  drop_last=True)
    return data_loader


def make_optim(net, weight_dacay=1e-2):
    return Adam(net.parameters(), weight_decay=weight_dacay)


def make_vloss():
    return nn.GaussianNLLLoss()


def make_aloss():
    from torch.special import i0

    def von_Mises_loss(a_cos_hat, a_sin_hat, angle, kappa):
        a_cos = torch.cos(angle)
        a_sin = torch.sin(angle)
        loss = -(kappa * (a_cos * a_cos_hat + a_sin *
                 a_sin_hat) - torch.log(i0(kappa))).mean()
        return loss
    return von_Mises_loss


def cal_angle_tderiv(rates, rates_tderiv, angle_net, optim):
    rates = rates.clone().detach().requires_grad_(True)
    a_cos_hat, a_sin_hat, _ = angle_net(rates)
    a_hat = torch.arctan2(a_sin_hat, a_cos_hat)
    optim.zero_grad()
    a_hat.backward(torch.ones_like(a_hat))
    optim.zero_grad()
    angle_tderiv = (rates.grad * rates_tderiv).sum(-1).numpy() / 50
    return angle_tderiv

def cum(a, window=9):
    b = a.cumsum()
    b[window:] = b[window:] - b[:-window]
    return b


def train_and_test(session,data,nmin):
    # rates, rates_tderiv, velocity, angle = data['rates'], data['rates t-deriv'], data['velocity'], data['angle']
    sinlgerates, nonsinglerates,velocity, angle = data['single_rates'],data['nonsingle_rates'],data['velocity'], data['angle']
    r1 = r2 = r3 = r4 = 0
    infeature = np.shape(sinlgerates)
    # print('infeature single', infeature)
    rates = sinlgerates
    if infeature[0]==0:
        infeature = 0
    else:
        infeature = infeature[1]
    if infeature >=nmin:
        temp  = int(np.floor(len(rates)/4))
        # print(np.shape(rates[temp:]))
        dl_a = make_dataloader(rates[temp:], angle[temp:])
        a_net = AngleNet(in_features=infeature)
        optim_a = make_optim(a_net, weight_dacay=5e-2)
        loss_a = make_aloss()
        epochs = 10
        epoch_loss = []
        for epoch in range(epochs):
            for batch, (r, a) in enumerate(dl_a):
                with autograd.set_detect_anomaly(False):
                    a_cos_hat, a_sin_hat, a_kap = a_net(r)
                    batch_loss = loss_a(a_cos_hat, a_sin_hat, a, a_kap)
                    optim_a.zero_grad()
                    batch_loss.backward()
                    optim_a.step()
                    epoch_loss.append(batch_loss)

        dl_v = make_dataloader(rates[temp:], velocity[temp:])
        v_net = VelocityNet(in_features=infeature)
        optim_v = make_optim(v_net, 5e-3)
        loss_v = make_vloss()
        epochs = 20
        epoch_loss = []
        for epoch in range(epochs):
            for batch, (r, v) in enumerate(dl_v):
                v_hat, v_var = v_net(r)
                batch_loss = loss_v(v_hat, v, v_var)
                optim_v.zero_grad()
                batch_loss.backward()
                optim_v.step()
                epoch_loss.append(batch_loss)           
        angle_sample = angle[:temp]
        rates_sample = torch.Tensor(rates[:temp])
        a_cos_hat, a_sin_hat, _ = a_net(rates_sample)
        a_cos_hat, a_sin_hat = a_cos_hat.detach().numpy(), a_sin_hat.detach().numpy()
        r1 = np.corrcoef(np.cos(angle_sample), a_cos_hat)
        
        r1 = r1[0,1]
        velocity_sample = velocity[:temp]
        rates_sample = torch.Tensor(rates[:temp])
        velocity_hat, _ = v_net(rates_sample)
        velocity_hat = velocity_hat.detach().numpy()
        r2 = np.corrcoef(velocity_sample, velocity_hat)
        r2 = r2[0,1]

    rates = nonsinglerates
    infeature = np.shape(rates)
    if infeature[0]==0:
        infeature = 0
    else:
        infeature = infeature[1]
    if infeature >=nmin:
        temp  = int(np.floor(len(rates)/4))
        dl_a = make_dataloader(rates[temp:], angle[temp:])
        a_net = AngleNet(in_features=infeature)
        optim_a = make_optim(a_net, weight_dacay=5e-2)
        loss_a = make_aloss()
        epochs = 10
        epoch_loss = []
        for epoch in range(epochs):
            for batch, (r, a) in enumerate(dl_a):
                with autograd.set_detect_anomaly(False):
                    a_cos_hat, a_sin_hat, a_kap = a_net(r)
                    batch_loss = loss_a(a_cos_hat, a_sin_hat, a, a_kap)
                    optim_a.zero_grad()
                    batch_loss.backward()
                    optim_a.step()
                    epoch_loss.append(batch_loss)

        dl_v = make_dataloader(rates[temp:], velocity[temp:])
        v_net = VelocityNet(in_features=infeature)
        optim_v = make_optim(v_net, 5e-3)
        loss_v = make_vloss()
        epochs = 20
        epoch_loss = []
        for epoch in range(epochs):
            for batch, (r, v) in enumerate(dl_v):
                v_hat, v_var = v_net(r)
                batch_loss = loss_v(v_hat, v, v_var)
                optim_v.zero_grad()
                batch_loss.backward()
                optim_v.step()
                epoch_loss.append(batch_loss)           
        angle_sample = angle[:temp]
        rates_sample = torch.Tensor(rates[:temp])
        a_cos_hat, a_sin_hat, _ = a_net(rates_sample)
        a_cos_hat, a_sin_hat = a_cos_hat.detach().numpy(), a_sin_hat.detach().numpy()
        r3 = np.corrcoef(np.cos(angle_sample), a_cos_hat)
        r3 = r3[0,1]
        velocity_sample = velocity[:temp]
        rates_sample = torch.Tensor(rates[:temp])
        velocity_hat, _ = v_net(rates_sample)
        velocity_hat = velocity_hat.detach().numpy()
        r4 = np.corrcoef(velocity_sample, velocity_hat)
        r4 = r4[0,1]        
    return r1,r2,r3,r4


def get_cor(rate,angle,velocity):
    numT = rate.shape[0]
    train_size = numT * 0.8
    test_size = numT - train_size
    r1_all = []
    r2_all = []
    for i in np.arange(5):        
        infeature = rate.shape[1]
        test_index = list(range(int(i*test_size),int(i*test_size+test_size)))
        train_index = list(range(0,int(i*test_size)))+list(range(int(i*test_size+test_size),numT))
        dl_a = make_dataloader(rate[train_index], angle[train_index])
        a_net = AngleNet(in_features=infeature)
        optim_a = make_optim(a_net, weight_dacay=5e-2)
        loss_a = make_aloss()
        epochs = 10
        epoch_loss = []
        for epoch in range(epochs):
            for batch, (r, a) in enumerate(dl_a):
                with autograd.set_detect_anomaly(False):
                    a_cos_hat, a_sin_hat, a_kap = a_net(r)
                    batch_loss = loss_a(a_cos_hat, a_sin_hat, a, a_kap)
                    optim_a.zero_grad()
                    batch_loss.backward()
                    optim_a.step()
                    epoch_loss.append(batch_loss)

        dl_v = make_dataloader(rate[train_index], velocity[train_index])
        v_net = VelocityNet(in_features=infeature)
        optim_v = make_optim(v_net, 5e-3)
        loss_v = make_vloss()
        epochs = 20
        epoch_loss = []
        for epoch in range(epochs):
            for batch, (r, v) in enumerate(dl_v):
                v_hat, v_var = v_net(r)
                batch_loss = loss_v(v_hat, v, v_var)
                optim_v.zero_grad()
                batch_loss.backward()
                optim_v.step()
                epoch_loss.append(batch_loss)    

        angle_sample = angle[test_index]
        rates_sample = torch.Tensor(rate[test_index])
        a_cos_hat, a_sin_hat, _ = a_net(rates_sample)
        a_cos_hat, a_sin_hat = a_cos_hat.detach().numpy(), a_sin_hat.detach().numpy()
        r1 = np.corrcoef(np.cos(angle_sample), a_cos_hat)
        r1 = r1[0,1]
        r1_all.append(r1)
        
        velocity_sample = velocity[test_index]
        rates_sample = torch.Tensor(rate[test_index])
        velocity_hat, _ = v_net(rates_sample)
        velocity_hat = velocity_hat.detach().numpy()
        r2 = np.corrcoef(velocity_sample, velocity_hat)
        r2 = r2[0,1]
        r2_all.append(r2)
    # r_angle = np.mean(r1_all)
    # r_velocity = np.mean(r2_all)
    return r1_all, r2_all

In [None]:
# Example decoding results for HD/AHV for SP/MP populations （Fig. 3A）
# load example data
non_singlelist = [54,67,84,13,4,6,88,20,97,35,83]
non_selective_list = [12,42,44,45,51,68,69,71,79,82,94,95]
singlelist = np.array(list(set(np.arange(100))-set(non_singlelist)-set(non_selective_list)))
non_singlelist = np.array(non_singlelist)
singlelist = singlelist.astype(int)
non_singlelist = non_singlelist.astype(int)
varname = '../simulated_data/epoch_test/output_simulated_200trials_1000t_epoch_'+str(500)+'_0228.mat'
# varname = '../simulated_data/epoch_test/output_simulated_200trials_1000t_epoch_'+str(500)+'_0228.mat'
data = sio.loadmat(varname)
data1 = data['data']
data1 = np.array(data1)
neuraldata = data1[0:100,:]
angle_list_target = data1[-3,:]
angle_list_rnnout = data1[-2,:]
angular_velocity = data1[-1,:]
data_single = neuraldata[singlelist,:]
data_nonsinglelist = neuraldata[non_singlelist,:]
rangeuse = np.arange(100000)

# for SP units, decode HD and AHV
data = data_single[:,rangeuse]
rates = data.T
rates = rates-rates.mean()
rates = rates/rates.std()
angle = angle_list_rnnout[rangeuse]/180*np.pi
print('angle_min_max',np.min(angle), np.max(angle))
infeature = np.shape(rates)
if infeature[0]==0:
    infeature = 0
else:
    infeature = infeature[1]
temp  = int(np.floor(len(rates)/5))
dl_a = make_dataloader(rates[temp:], angle[temp:])
a_net = AngleNet(in_features=infeature)
optim_a = make_optim(a_net, weight_dacay=5e-2)
loss_a = make_aloss()
epochs = 10
epoch_loss = []
for epoch in range(epochs):
    for batch, (r, a) in enumerate(dl_a):
        with autograd.set_detect_anomaly(False):
            a_cos_hat, a_sin_hat, a_kap = a_net(r)
            batch_loss = loss_a(a_cos_hat, a_sin_hat, a, a_kap)
            optim_a.zero_grad()
            batch_loss.backward()
            optim_a.step()
            epoch_loss.append(batch_loss)


angular_velocity_use = angular_velocity[rangeuse]
dl_v = make_dataloader(rates[temp:], angular_velocity_use[temp:])
v_net = VelocityNet(in_features=infeature)
optim_v = make_optim(v_net, 5e-3)
loss_v = make_vloss()
epochs = 20
epoch_loss = []
for epoch in range(epochs):
    for batch, (r, v) in enumerate(dl_v):
        with autograd.set_detect_anomaly(False):
            v_hat, v_var = v_net(r)
            batch_loss = loss_v(v_hat, v, v_var)
            optim_v.zero_grad()
            batch_loss.backward()
            optim_v.step()
            epoch_loss.append(batch_loss)

angle_sample = angle[:temp]
rates_sample = torch.Tensor(rates[:temp])
a_cos_hat, a_sin_hat, _ = a_net(rates_sample)
a_cos_hat, a_sin_hat = a_cos_hat.detach().numpy(), a_sin_hat.detach().numpy()

fig = plt.figure()
x = angle_sample
y = np.arctan2(a_sin_hat, a_cos_hat) % (2*np.pi)
data = {'angle_sample':x, 'angle_predicted':y}
g = sns.jointplot(x='angle_sample', y='angle_predicted', data=data, marginal_kws=dict(bins=20, color=color2, alpha=0.5))
g.ax_joint.clear()
g.plot_joint(sns.scatterplot, color= color2, s=20, alpha=0.5)   # 绘制散点图
_ = sns.regplot(x='angle_sample', y='angle_predicted', data=data, ax=g.ax_joint, scatter=False, color=color2, line_kws={'linestyle': '--'})  # 绘制回归曲线
ax = plt.gca()
_ = ax.set_title('unimodal')
_ = ax.set_ylabel('HD_predicted')
_ = ax.set_xlabel('HD')
_ = ax.title.set_position([1.15, 1.15])
g.savefig('../../figures/representative_scaatter_plot_angle_r_unimodal.pdf', bbox_inches='tight', transparent=True, format='pdf')

velocity_sample = angular_velocity_use[:temp]
rates_sample = torch.Tensor(rates[:temp])
velocity_hat, _ = v_net(rates_sample)
velocity_hat = velocity_hat.detach().numpy()
r2 = np.corrcoef(velocity_sample, velocity_hat)
fig = plt.figure()
data = {'velocity_sample':velocity_sample, 'velocity_predicted':velocity_hat}
g = sns.jointplot(x='velocity_sample', y='velocity_predicted', data=data, marginal_kws=dict(bins=20, color= color2, alpha=0.5))
g.ax_joint.clear()
g.plot_joint(sns.scatterplot, color=color2, s=20, alpha=0.5)   # 绘制散点图
_ = sns.regplot(x='velocity_sample', y='velocity_predicted', data=data, ax=g.ax_joint, scatter=False, color=color2, line_kws={'linestyle': '--'})  # 绘制回归曲线
ax = plt.gca()
_ = ax.set_title('unimodal')
_ = ax.title.set_position([1.15, 1.15])
_ = ax.set_ylabel('AHV_predicted')
_ = ax.set_xlabel('AHV')
g.savefig('../../figures/representative_scaatter_plot_AHV_r_unimodal.pdf', bbox_inches='tight', transparent=True, format='pdf')

# for MP units, decode HD and AHV
data = data_nonsinglelist[:,rangeuse]
rates = data.T
rates = rates-rates.mean()
rates = rates/rates.std()
angle = angle_list_rnnout[rangeuse]/180*np.pi
print('angle_min_max',np.min(angle), np.max(angle))
infeature = np.shape(rates)
if infeature[0]==0:
    infeature = 0
else:
    infeature = infeature[1]
temp  = int(np.floor(len(rates)/5))
dl_a = make_dataloader(rates[temp:], angle[temp:])
a_net = AngleNet(in_features=infeature)
optim_a = make_optim(a_net, weight_dacay=5e-2)
loss_a = make_aloss()
epochs = 10
epoch_loss = []
for epoch in range(epochs):
    for batch, (r, a) in enumerate(dl_a):
        with autograd.set_detect_anomaly(False):
            a_cos_hat, a_sin_hat, a_kap = a_net(r)
            batch_loss = loss_a(a_cos_hat, a_sin_hat, a, a_kap)
            optim_a.zero_grad()
            batch_loss.backward()
            optim_a.step()
            epoch_loss.append(batch_loss)


angular_velocity_use = angular_velocity[rangeuse]
dl_v = make_dataloader(rates[temp:], angular_velocity_use[temp:])
v_net = VelocityNet(in_features=infeature)
optim_v = make_optim(v_net, 5e-3)
loss_v = make_vloss()
epochs = 20
epoch_loss = []
for epoch in range(epochs):
    for batch, (r, v) in enumerate(dl_v):
        with autograd.set_detect_anomaly(False):
            v_hat, v_var = v_net(r)
            batch_loss = loss_v(v_hat, v, v_var)
            optim_v.zero_grad()
            batch_loss.backward()
            optim_v.step()
            epoch_loss.append(batch_loss)

angle_sample = angle[:temp]
angle_sample2 = angle[temp:2*temp]
rates_sample = torch.Tensor(rates[:temp])
rates_sample2 = torch.Tensor(rates[temp:2*temp])
a_cos_hat, a_sin_hat, _ = a_net(rates_sample)
a_cos_hat2, a_sin_hat2, _ = a_net(rates_sample2)
a_cos_hat, a_sin_hat = a_cos_hat.detach().numpy(), a_sin_hat.detach().numpy()
a_cos_hat2, a_sin_hat2 = a_cos_hat2.detach().numpy(), a_sin_hat2.detach().numpy()

x = angle_sample
y = np.arctan2(a_sin_hat, a_cos_hat) % (2*np.pi)

data = {'angle_sample':x, 'angle_predicted':y}
_ = g = sns.jointplot(x='angle_sample', y='angle_predicted', data=data, marginal_kws=dict(bins=20, color='b', alpha=0.5))
_ = g.ax_joint.clear()
_ = g.plot_joint(sns.scatterplot, color=color2, s=5, alpha=0.5)   # 绘制散点图
_ = sns.regplot(x='angle_sample', y='angle_predicted', data=data, ax=g.ax_joint, scatter=False, color='b', line_kws={'linestyle': '--'})  # 绘制回归曲线
ax = plt.gca()
_ = ax.set_title('multimodal')
_ = ax.set_ylabel('HD_predicted_on_testing_dta')
_ = ax.set_xlabel('HD')
_ = ax.title.set_position([1.2, 1.2])
# fig.savefig('../../figures/representative_scaatter_plot_angle_r_unimodal.pdf', bbox_inches='tight', transparent=True, format='pdf')


x = angle_sample2
y = np.arctan2(a_sin_hat2, a_cos_hat2) % (2*np.pi)
residual = y - x
indx = np.where(residual>np.pi)[0]
if len(indx) > 0:
    indx1 = np.random.choice(indx, int(len(indx)*0.5), replace=False)
    indx2 = np.setdiff1d(indx, indx1)
    x[indx1] = x[indx1] + 2*np.pi
    y[indx2] = y[indx2] - 2*np.pi

indx = np.where(residual<-np.pi)[0]
if len(indx) > 0:
    indx1 = np.random.choice(indx, int(len(indx)*0.5), replace=False)
    indx2 = np.setdiff1d(indx, indx1)
    x[indx1] = x[indx1] - 2*np.pi
    y[indx2] = y[indx2] + 2*np.pi
data = {'angle_sample':x, 'angle_predicted':y}
_ = g = sns.jointplot(x='angle_sample', y='angle_predicted', data=data, marginal_kws=dict(bins=20, color='b', alpha=0.5))
_ = g.ax_joint.clear()
_ = g.plot_joint(sns.scatterplot, color='b', s=5, alpha=0.5)   # 绘制散点图
_ = sns.regplot(x='angle_sample', y='angle_predicted', data=data, ax=g.ax_joint, scatter=False, color='b', line_kws={'linestyle': '--'})  # 绘制回归曲线
ax = plt.gca()
_ = ax.set_title('multimodal')
_ = ax.set_ylabel('HD_predicted_on_training_data')
_ = ax.set_xlabel('HD')
_ = ax.title.set_position([1.2, 1.2])
velocity_sample = angular_velocity_use[:temp]
velocity_sample2 = angular_velocity_use[temp:2*temp]
rates_sample = torch.Tensor(rates[:temp])
rates_sample2 = torch.Tensor(rates[temp:2*temp])
velocity_hat, _ = v_net(rates_sample)
velocity_hat2, _ = v_net(rates_sample2)
velocity_hat = velocity_hat.detach().numpy()
velocity_hat2 = velocity_hat2.detach().numpy()
r2 = np.corrcoef(velocity_sample, velocity_hat)
fig = plt.figure()
data = {'velocity_sample':velocity_sample, 'velocity_predicted':velocity_hat}
g = sns.jointplot(x='velocity_sample', y='velocity_predicted', data=data, marginal_kws=dict(bins=20, color='b', alpha=0.5))
g.ax_joint.clear()
g.plot_joint(sns.scatterplot, color='b', s=5, alpha=0.5)   # 绘制散点图
sns.regplot(x='velocity_sample', y='velocity_predicted', data=data, ax=g.ax_joint, scatter=False, color='b', line_kws={'linestyle': '--'})  # 绘制回归曲线
ax = plt.gca()
_ = ax.set_title('multimodal')
_ = ax.title.set_position([1.2, 1.2])
_ = ax.set_ylabel('AHV_predicted_on_testing_data')
_ = ax.set_xlabel('AHV')
# fig.savefig('../../figures/representative_scaatter_plot_AHV_r_unimodal.pdf', bbox_inches='tight', transparent=True, format='pdf')#
fig = plt.figure()
data = {'velocity_sample':velocity_sample2, 'velocity_predicted':velocity_hat2}
g = sns.jointplot(x='velocity_sample', y='velocity_predicted', data=data, marginal_kws=dict(bins=20, color='b', alpha=0.5))
g.ax_joint.clear()
g.plot_joint(sns.scatterplot, color='b', s=5, alpha=0.5)   # 绘制散点图
_ = sns.regplot(x='velocity_sample', y='velocity_predicted', data=data, ax=g.ax_joint, scatter=False, color='b', line_kws={'linestyle': '--'})  # 绘制回归曲线
ax = plt.gca()
_ = ax.set_title('multimodal')
_ = ax.title.set_position([1.2, 1.2])
_ = ax.set_ylabel('AHV_predicted_on_training_data')
_ = ax.set_xlabel('AHV')


# 5-fold correlation calculation and average
for iterationn in [1,2]:   
    if iterationn == 1:
        data = data_single[:,rangeuse]
    else:
        data = data_nonsinglelist[:,rangeuse]    
    rates = data.T
    rates = rates-rates.mean()
    rates = rates/rates.std()
    angle = angle_list_rnnout[rangeuse]/180*np.pi
    r_angle_5_fold, r_v_5_fold = get_cor(rates,angle,angular_velocity)
    if iterationn == 1:
        print('r for single_peak cells are',[r_angle_5_fold, r_v_5_fold])
        r_single = [r_angle_5_fold, r_v_5_fold]
    else:
        print('r for non_single_peak cells are',[r_angle_5_fold, r_v_5_fold])
        r_non_single = [r_angle_5_fold, r_v_5_fold]
# get mean and std
r_single = np.array(r_single)
r_non_single = np.array(r_non_single)

r_angle_single = np.mean(r_single[0,:])
r_v_single = np.mean(r_single[1,:])
r_angle_non_single = np.mean(r_non_single[0,:])
r_v_non_single = np.mean(r_non_single[1,:])

print('r for single_peak cells are',[r_angle_single, r_v_single])
print('r for non_single_peak cells are',[r_angle_non_single, r_v_non_single])

r_angle_single_std = np.std(r_single[0,:])
r_v_single_std = np.std(r_single[1,:])
r_angle_non_single_std = np.std(r_non_single[0,:])
r_v_non_single_std = np.std(r_non_single[1,:])

print('r-std for single_peak cells are',[r_angle_single_std, r_v_single_std])
print('r-std for non_single_peak cells are',[r_angle_non_single_std, r_v_non_single_std])

In [None]:
# decoding correlations across 50 simulated sessions (Fig. 3C)
trials_all = np.arange(50)
r_a_single = np.zeros_like(trials_all,dtype=float)
r_v_single = np.zeros_like(trials_all,dtype=float)
r_a_non_single = np.zeros_like(trials_all,dtype=float)
r_v_non_single = np.zeros_like(trials_all,dtype=float)
r_a_single_std = np.zeros_like(trials_all,dtype=float)
r_v_single_std = np.zeros_like(trials_all,dtype=float)
r_a_non_single_std = np.zeros_like(trials_all,dtype=float)
r_v_non_single_std = np.zeros_like(trials_all,dtype=float)

for i,trialnum in enumerate(trials_all):
    print(i)
    varname = '../simulated_data/trials_test_1000/output_simulated_200trials_1000t_epoch_500_testtrial_'+str(trialnum+1)+'_0609.mat'
    data = sio.loadmat(varname)
    data1 = data['data']
    data1 = np.array(data1)
    neuraldata = data1[0:100,:]
    angle_list_target = data1[-3,:]
    angle_list_rnnout = data1[-2,:]
    angular_velocity = data1[-1,:]
    data_single = neuraldata[singlelist,:]
    non_singlelist = non_singlelist.astype(int)
    data_nonsinglelist = neuraldata[non_singlelist,:]

    rangeuse = np.arange(100000)
    for iterationn in [1,2]:   
        if iterationn == 1:
            data = data_single[:,rangeuse]
        else:
            data = data_nonsinglelist[:,rangeuse]    
        rates = data.T
        rates = rates-rates.mean()
        rates = rates/rates.std()
        angle = angle_list_rnnout[rangeuse]/180*np.pi
        r_angle_5_fold, r_v_5_fold = get_cor(rates,angle,angular_velocity)
        if iterationn == 1:
            r_a_single[i] = np.mean(r_angle_5_fold)
            r_v_single[i] = np.mean(r_v_5_fold)
            # print('r_v_single[i]',r_v_5_fold)
            r_a_single_std[i] = np.std(r_angle_5_fold)
            r_v_single_std[i] = np.std(r_v_5_fold)
        else:
            r_a_non_single[i] = np.mean(r_angle_5_fold)
            # print('r_v_non_single[i]',r_v_5_fold)
            r_v_non_single[i] = np.mean(r_v_5_fold)
            r_a_non_single_std[i] = np.std(r_angle_5_fold)
            r_v_non_single_std[i] = np.std(r_v_5_fold)


import pickle
data = {'r_a_single':r_a_single,'r_v_single':r_v_single,'r_a_non_single':r_a_non_single,
        'r_v_non_single':r_v_non_single,'r_a_single_std':r_a_single_std,'r_v_single_std':r_v_single_std,
        'r_a_non_single_std':r_a_non_single_std,'r_v_non_single_std':r_v_non_single_std}

with open('../../results/trials_test_50_traindatasize_100000_all_r_save.pkl', 'wb') as f:
    pickle.dump(data, f)

import pickle
with open('../../results/trials_test_50_traindatasize_100000_all_r_save.pkl', 'rb') as f:
    data = pickle.load(f)

r_a_single = data['r_a_single']
r_v_single = data['r_v_single']
r_a_non_single = data['r_a_non_single']
r_v_non_single = data['r_v_non_single']
r_a_single_std = data['r_a_single_std']
r_v_single_std = data['r_v_single_std']
r_a_non_single_std = data['r_a_non_single_std']
r_v_non_single_std = data['r_v_non_single_std']
r_a_single = r_a_single
r_v_single = r_v_single
r_a_non_single = r_a_non_single
r_v_non_single = r_v_non_single
r_a_single_std = r_a_single_std
r_v_single_std = r_v_single_std
r_a_non_single_std = r_a_non_single_std
r_v_non_single_std = r_v_non_single_std

plt.figure()
_ = plt.boxplot([r_a_single, r_a_non_single,r_v_single, r_v_non_single], positions = [1,2,3,4], widths = 0.3, showfliers=False)
_ = plt.scatter(np.ones(len(r_a_single))+ np.random.randn(len(r_a_single)) * 0.1, r_a_single, s = 30, marker='o', c=color2,alpha=0.5)
_ = plt.scatter(np.ones(len(r_a_non_single))*2+ np.random.randn(len(r_a_non_single)) * 0.1, r_a_non_single, s = 30,marker='o', c=color1,alpha=0.5)
_ = plt.scatter(np.ones(len(r_v_single))*3+ np.random.randn(len(r_v_single)) * 0.1, r_v_single, color = color2, s = 30, marker='o',facecolor='none', edgecolor=color2, alpha=0.5)
_ = plt.scatter(np.ones(len(r_v_non_single))*4+ np.random.randn(len(r_v_non_single)) * 0.1, r_v_non_single, color = color1, s = 30,marker='o', facecolor='none', edgecolor=color1, alpha=0.5)
ax = plt.gca()
fig = plt.gcf()
ax.set_xticks([1.5,3.5])
ax.set_xticklabels(['Angle decoding', 'Velocity decoding'])
ax.set_ylabel('Pearson''s Correlation coefficient')
from scipy.stats import mannwhitneyu
stat, p = mannwhitneyu(r_a_single, r_a_non_single)
print('r_a_single vs r_a_non_single: p = ', p)
stat, p = mannwhitneyu(r_v_single, r_v_non_single)
print('r_v_single vs r_v_non_single: p = ', p)
print('mean of r_a_single: ', np.mean(r_a_single))
print('mean of r_a_non_single: ', np.mean(r_a_non_single))
print('mean of r_v_single: ', np.mean(r_v_single))
print('mean of r_v_non_single: ', np.mean(r_v_non_single))
fig.savefig('../../figures/unimodal_vs_multimodl_correlations_results_50_trials_mlv.pdf', bbox_inches='tight', transparent=True, format='pdf')


# ANOVA Variance analysis
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
a = len(r_a_single)
b = len(r_v_single)
c = len(r_a_non_single)
d = len(r_v_non_single)

angles = np.array(['angle']*a) 
velocities = np.array(['velocity']*b)
m = np.concatenate((angles, velocities))

angles = np.array(['angle']*c) 
velocities = np.array(['velocity']*d)
n = np.concatenate((angles, velocities))

modals = np.array(['unimodal']*(a+b))
non_modals = np.array(['multimodal']*(c+d))

df = pd.DataFrame({'A': np.concatenate((m,n)),
                    'B': np.concatenate((modals, non_modals)),
                    'Y': np.concatenate((r_a_single, r_v_single, r_a_non_single, r_v_non_single))})

formula = 'Y ~ C(A) + C(B) + C(A):C(B)'  
model = ols(formula, df).fit()
aov_table = anova_lm(model)
print(aov_table)


In [None]:
# 