## 降维

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

In [None]:
def plot_dim_reduction(resp, label, ax=None):
    """Plot dimensionality-reduced population responses (using tSNE)
    first PCA, then tSNE

    Args:
      resp (numpy array): n_stimuli x n_neurons matrix with population responses
      label (numpy array): 1D array of length stimuli with labels
      ax (matplotlib axes): axes onto which to plot

    """
    if ax is None:
        ax = plt.gca()

    # First do PCA to reduce dimensionality to 200 dimensions so that tSNE is faster
    resp_lowd = PCA(n_components=min(300, resp.shape[1])).fit_transform(resp)

    # Then do tSNE to reduce dimensionality to 2 dimensions
    resp_lowd = TSNE(n_components=2, perplexity=40, n_jobs=-1).fit_transform(resp_lowd)

    # Plot dimensionality-reduced population responses
    # on 2D axes, with each point colored by stimulus orientation
    scat = ax.scatter(resp_lowd[:, 0], resp_lowd[:, 1], c=stim, cmap='twilight')

    cbar = plt.colorbar(scat, ax=ax, label='stimulus labels')
    ax.set_xlabel('dimension 1')
    ax.set_ylabel('dimension 2')
    ax.set_xticks([])
    ax.set_yticks([])

plot_dim_reduction(V1_response, training_labels)

## RDM
RDM 我写了两种方法，一种是用相关算的，一种用距离算的。

In [None]:
def plot_corr_matrix(rdm, ax=None, vmax=2.0):
    """Plot dissimilarity matrix

    Args:
      rdm (numpy array): n_stimuli x n_stimuli representational dissimilarity 
        matrix
      ax (matplotlib axes): axes onto which to plot

    Returns:
      nothing
    
    """
    if ax is None:
      ax = plt.gca()

    cax = ax.imshow(rdm, vmin=0.0, vmax=vmax)
    ax.set_aspect('auto')

    if rdm.shape[0] == 4:
        ax.set_xticks([i for i in range(4)])
        ax.set_yticks([i for i in list(range(4))[::-1]])
        ax_label = list(labels2num.keys())
        ax.set_xticklabels([ax_label[i] for i in range(4)], rotation=30)
        ax.set_yticklabels([ax_label[i] for i in list(range(4))[::-1]], rotation=40)
    else:
        ax.set_xticks([])
        ax.set_yticks([])
    cbar = plt.colorbar(cax, ax=ax, label='correlation coefficient')

def RDM(resp, methods='corr'):
    """Compute the representational dissimilarity matrix (RDM)

    Args:
      resp: S x N matrix with population responses to
        each stimulus in each row
    
    Returns:
      np.ndarray: S x S representational dissimilarity matrix
    """
    if methods == 'corr':
        zresp = zscore(resp, axis=1)  # z-score responses to each stimulus
        result = 1 - (zresp @ zresp.T) / zresp.shape[1]
    elif methods == 'dist':
        # pdist just returns the upper triangle of the pairwise distance matrix
        pair_dist = pdist(resp)
        # to get the whole (20, 20) array we can use squareform
        result = squareform(pair_dist)
    return result

In [None]:
distance = RDM(V1_avg, 'corr')
plot_corr_matrix(distance, vmax=2)

## decoding

1.1 划分训练集和测试集

In [None]:
## 

# violinplot

In [None]:
plt.figure(figsize=(2,2))

a = sns.violinplot(x = "...", # 指定x轴的数据
               y = "...", # 指定y轴的数据
               data = open_delta, # 指定绘图的数据集
               linewidth = 0.8,
               hue = "...",   #基于什么变量分类
               width = 1,
               scale = 'area'
               #order = ['ds_forward','ds_backward','ds'] # 指定x轴刻度标签的顺序
            
              )
plt.xlabel('')
plt.xticks(ticks = [])
plt.ylabel('')
plt.ylabel('...', fontdict={'weight': 'normal', 'size': 20})
#plt.yticks(ticks = [])
a.set_ylim(-1,1)
# 添加图形标题
plt.title('...')
# 设置图例
#plt.legend(loc = 'lower center', ncol = 3)
#plt.legend()
#plt.legend(loc=[3, 0])
#plt.legend( bbox_to_anchor=(1,1), loc='center left',ncol = 3)
a.legend_.remove()
#fig.tight_layout()#调整整体空白
# 显示图形
plt.show()
fig = a.get_figure()
fig.savefig("violinplot.png",quality=95, dpi=800)

# pointplot

In [None]:
i = 3
sns.set_style("white") 
plt.figure(figsize=(3,3))
ax=sns.pointplot(x=x, y=y, data=data,alpha = 0.3, capsize=0.1)
ax.spines['top'].set_visible(False)  #去掉上边框
ax.spines['right'].set_visible(False) #去掉右边框
ax.set_title('...'.format(i),fontdict={'weight':'normal','size': 20})
ax.set(xlabel='...', ylabel='relative power')
plt.xlabel('...', fontdict={'weight': 'normal', 'size': 15})
plt.ylabel('...', fontdict={'weight': 'normal', 'size': 15})
ax.set_ylim(8, 11)
#ax.legend_.remove()
plt.show()
fig = ax.get_figure()
fig.savefig("pointplot.png",quality=95, dpi=800)

# scatterplot

In [None]:
plt.figure(figsize=(3,2))
a = sns.regplot(x, y,scatter_kws={'s':8,"alpha":0.9})
a.spines['top'].set_visible(False)  #去掉上边框
a.spines['right'].set_visible(False) #去掉右边框
a.spines['bottom'].set_linewidth(1.5)#x轴变粗
a.spines['left'].set_linewidth(1.5)#y轴变粗
a.set_xlabel("...",labelpad = 3, fontdict={'weight': 'normal', 'size': 13}) #设置x轴名称
a.set_ylabel('') #设置y轴名称
a.set_xlim(-3, 3)
a.set_ylim(-2.5, 2.5)
a.tick_params(axis='both',pad=5,direction='in',width=1.5,labelsize=12)#坐标轴上刻度的粗细
plt.show()
#fig.tight_layout()
fig = a.get_figure()
fig.savefig("....png",quality=95, dpi=800)

# heatmap

In [None]:
f, ax = plt.subplots(figsize=(15, 15))
electrodes = ["F3","F4","FZ","F7","F8","C3","CZ","C4","P3","P4","PZ","P7","P8","O1","O2"]
frequency = ["open_delta","open_theta","open_alpha","open_beta","close_delta","close_theta","close_alpha","close_beta","alpha_peak"]
sns.heatmap(cor_vocabulary,vmin = -0.5,vmax=0.5,fmt= '.3f',linewidths = 0.05,annot=True,cmap = plt.cm.RdYlBu,square=True)
ax.set_ylim([8, 0])
ax.set_xlabel('eletrodes')
ax.set_ylabel('frequency')
ax.set_xticklabels(["F3","F4","FZ","F7","F8","C3","CZ","C4","P3","P4","PZ","P7","P8","O1","O2"])
ax.set_yticklabels(["open_delta","open_theta","open_alpha","open_beta","close_delta","close_theta","close_alpha","close_beta","alpha_peak"])
plt.yticks( rotation=0)
plt.show()
plt.savefig('%s.png'%(ax.get_title()),quality=95,dpi = 800)

# Barplot with scatter and error bar

In [None]:
i =3
sns.set(style="white")
plt.figure(figsize=(3,3))
ax = sns.barplot(x="...", y="...", data=data, capsize=.12, ci="sd",  errcolor='k')
sns.swarmplot(x="...", y="...", data=data, color="0", alpha=.3)
ax.set_ylim(-20, 50)
#ax.set_title('classification accuracy in open & close& all'.format(i))
#ax.set(xlabel='state', ylabel='accuracy')
ax.spines['bottom'].set_linewidth(1.5)#x轴变粗
ax.spines['left'].set_linewidth(1.5)#y轴变粗
ax.spines['top'].set_visible(False)  #去掉上边框
ax.spines['right'].set_visible(False) #去掉右边框
ax.set_xlabel('...',labelpad = 5,fontdict={'weight': 'normal', 'size': 13}) #设置x轴名称
ax.set_ylabel('...',labelpad = 5,fontdict={'weight': 'normal', 'size': 13}) #设置y轴名称

fig = ax.get_figure()
fig.savefig("....png",quality=95, dpi=800)