# Figure 6 : Temporal generalization with 96 labels

# Params

In [1]:
%run -n boillerplate.ipynb

In [2]:
sys.path.append('../experimental')

# Temporal transpositions with 96 labels

# Not Kfolded

In [3]:
try:
    transpo_cms = np.load('./data/fig_6_tempo_transpo96_cms.npy')
    transpo_accs = np.load('./data/fig_6_tempo_transpo96_accs.npy')
    # debug
except:
    transpo_cms = np.zeros((len(timesteps), len(timesteps), N_thetas*N_B_thetas, N_thetas*N_B_thetas))
    transpo_accs = np.zeros((len(timesteps), len(timesteps)))
    
    # Data
    try:
        data = np.load('./data/data_all_t_bt.npy')
        labels = np.load('./data/labels_all_t_bt.npy')
    except:    
        # Data
        data, labels, le = par_load_temporal_data(timesteps = timesteps, target_btheta = None,
                                                  target_theta = None, data_type = 'all_t_bt',
                                                  cluster_list = cluster_list)
        np.save('./data/data_all_t_bt.npy', data)
        np.save('./data/labels_all_t_bt.npy', labels)

    # Classifying
    
    accs = np.zeros((data.shape[0],data.shape[0]))
    for ibin_train in tqdm(range(data.shape[0]), desc = 'Training and testing') :
        xtrain, _, ytrain, __ = train_test_split(data[ibin_train,:,:], labels, test_size =test_size, random_state = 42)
        logreg = LogisticRegression(**opts_LR)
        logreg.fit(xtrain, ytrain)
        for ibin_test in range(data.shape[0]):
            _, xtest, __, ytest = train_test_split(data[ibin_test,:,:], labels, test_size =test_size, random_state = 42)

            cm = metrics.confusion_matrix(ytest, logreg.predict(xtest), normalize = 'all')
            cm *= len(le.classes_)

            transpo_cms[ibin_train, ibin_test, :, :] = cm
            transpo_accs[ibin_train, ibin_test] = metrics.balanced_accuracy_score(ytest, logreg.predict(xtest))
        
    np.save('./data/fig_6_tempo_transpo96_cms.npy', transpo_cms)
    np.save('./data/fig_6_tempo_transpo96_accs.npy', transpo_accs)
    


In [4]:
data.shape

NameError: name 'data' is not defined

In [None]:
fig, ax = plt.subplots(nrows = 1, 
                        ncols = 1,
                        figsize = (8,8))

xticks = np.linspace(0, len(timesteps)-1, 7, dtype = np.int16, endpoint = True)
xticklabs = np.round(timesteps[xticks]+win_size, 2)
colors = plt.cm.inferno(np.linspace(.8, .3, len(B_thetas))) #tc colormap


im = ax.imshow(transpo_accs, origin = 'lower', interpolation = 'None',
                vmin = 1/96)

ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)

ax.set_ylabel('Training time (s)', fontsize = 18)
ax.set_xlabel('Generalization time (s)', fontsize = 18)
ax.tick_params(axis='both', which='major', labelsize=14)

'''ax.set_title(r'B$_\theta$ = %.2f°' % (B_thetas[::-1][i] * 180/np.pi),
            color = colors[::-1][i], fontsize = 18, x = .55, y = 1.01)'''
cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
cbar.set_ticks(np.linspace(1/96, np.max(transpo_accs), 5))
cbar.set_ticklabels(np.round(np.linspace(1/96, np.max(transpo_accs), 5), 2))
cbar.ax.tick_params(labelsize = 12)

cbar.ax.set_ylabel('Classification accuracy', rotation = 270, labelpad = 20,
              fontsize = 16)

ax.plot([0,60],[0,60], c = 'w', linestyle = '--', alpha = .8, linewidth = 1) 
ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
ax.plot([10.4,40],[10,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1)
    
fig.tight_layout()
#fig.savefig('./output/fig_5_tempo_gen_96.pdf', bbox_inches='tight', dpi=200, transparent=True)
plt.show()

# Unused, pvals

# K-folded

In [None]:
# Trying a new kfold
try:
    kf_cms = np.load('./data/fig_6_tempo_transpo96_cms_kf.npy')
    kf_accs = np.load('./data/fig_6_tempo_transpo96_accs_kf.npy')

except:  
    
    # Data
    data, labels, le = par_load_temporal_data(timesteps = timesteps, target_btheta = None,
                                            target_theta = None, data_type = 'all_t_bt',
                                            cluster_list = cluster_list)
    
    kf_accs = np.zeros((data.shape[0], data.shape[0], n_splits))
    kf_cms = np.zeros((data.shape[0], data.shape[0], n_splits, 96, 96))

    # Classifying
    logreg = LogisticRegression(**opts_LR)
    
    kf = KFold(n_splits = n_splits)
    accs = np.zeros((data.shape[0],data.shape[0]))
    for ibin_train in tqdm(range(data.shape[0]), desc = 'Training and testing') :
        for i_kf, (train_index, test_index) in enumerate(kf.split(data[ibin_train,:,:])) :
            xtrain, ytrain = data[ibin_train,train_index,:], labels[train_index] #is train on axis 1 or 2
            logreg.fit(xtrain, ytrain)
            for ibin_test in range(data.shape[0]):
                xtest, ytest = data[ibin_test,test_index,:], labels[test_index]

                cm = metrics.confusion_matrix(ytest, logreg.predict(xtest), normalize = 'all')
                cm *= len(le.classes_)
                
                kf_accs[ibin_train, ibin_test, i_kf] = metrics.balanced_accuracy_score(ytest, logreg.predict(xtest))
                kf_cms[ibin_train, ibin_test, i_kf, :, :] = cm
        
    np.save('./data/fig_6_tempo_transpo96_cms_kf.npy', kf_cms)
    np.save('./data/fig_6_tempo_transpo96_accs_kf.npy', kf_accs)
    


In [None]:
try:
    pval_map = np.load('./data/fig_6_transpo96pvals.npy')

except:
    arr = np.zeros((61,61), dtype = object)
    for x in range(61) :
        for y in range(61) :
            arr[x,y] = list(kf_accs[x,y,:6])

    triu = np.triu(arr)
    itriu = np.triu_indices(arr.shape[-1])
    tril = np.tril(arr).T

    triu = triu[triu!=0]
    tril = tril[tril!=0]

    pval_map = np.zeros((arr.shape[-1], arr.shape[-1]))
    for i0, el in tqdm(enumerate(triu), total = len(triu)) :
        p = permutation_test(triu[i0], tril[i0], num_rounds = num_rounds)
        pval_map[itriu[0][i0], itriu[1][i0]] = p
        
    np.save('./data/fig_6_transpo96pvals.npy', pval_map)
    


In [None]:
fig, ax = plt.subplots(nrows = 1, 
                        ncols = 1,
                        figsize = (8,8))

colors = plt.cm.inferno(np.linspace(.8, .3, len(B_thetas))) #tc colormap

matacc = np.mean(kf_accs, axis = -1)
im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)

itril = np.tril_indices(61)
pvals = pval_map < .05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False
        
pval_edges(pvals.T , ax = ax, lw = 1)
ax.set_xlim(0,60)
ax.set_ylim(0, 60)
    
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)

ax.set_ylabel('Training time (s)', fontsize = 18)
ax.set_xlabel('Generalization time (s)', fontsize = 18)
ax.tick_params(axis='both', which='major', labelsize=14)

'''ax.set_title(r'B$_\theta$ = %.2f°' % (B_thetas[::-1][i] * 180/np.pi),
            color = colors[::-1][i], fontsize = 18, x = .55, y = 1.01)'''
cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
cbar.set_ticks(np.linspace(1/96, np.max(matacc), 5))
cbar.set_ticklabels(np.round(np.linspace(1/96, np.max(matacc), 5), 2) )
cbar.ax.tick_params(labelsize = 12)

cbar.ax.set_ylabel('Classification accuracy', rotation = 270, labelpad = 20,
              fontsize = 16)

ax.plot([0,60],[0,60], c = 'w', linestyle = '--', alpha = .8, linewidth = 1) 
ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
ax.plot([10.4,40],[10,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1)
    
fig.tight_layout()
fig.savefig('./output/fig_6_tempo_gen_96.pdf', bbox_inches='tight', dpi=200, transparent=True)
plt.show()

In [None]:
fig, ax = plt.subplots(nrows = 1, 
                        ncols = 1,
                        figsize = (8,8))

xticks = np.linspace(0, len(timesteps)-1, 7, dtype = np.int16, endpoint = True)
xticklabs = np.round(timesteps[xticks]+win_size, 2)
colors = plt.cm.inferno(np.linspace(.8, .3, len(B_thetas))) #tc colormap


matacc = np.mean(kf_accs, axis = -1)
#im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)

itril = np.tril_indices(61)
pvals = pval_map < .05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False
    
mat = matacc
triu = np.triu(mat)
itriu = np.triu_indices(mat.shape[-1])
tril = np.tril(mat).T
diff = (triu-tril)/np.max(mat)
diff*=100
tril2 = np.tril_indices(diff.shape[-1])
diff[tril2]= None
    

im = ax.imshow(diff, origin = 'lower', interpolation = 'None',
               cmap = 'RdBu_r',
              norm = mcols.TwoSlopeNorm(
                                  vmin = -30,
                                   vmax = 30,
                                      vcenter = 0))

# Diagonal plotting
#ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
ax.plot([10.4,40],[10,10], c = 'k', linestyle = '-', alpha = .8, linewidth = 1)

# Pvals
'''itril = np.tril_indices(61)
pvals = pvals_automaps[i] < 0.05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False'''

pval_edges(pvals.T , ax = ax, lw = 1, c = 'k')
ax.set_xlim(0,60)
ax.set_ylim(0, 60)

xtick_idxs = [int(x) for x in ax.get_xticks() if x >=0][:-1]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)

ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xlim(0,60)
ax.set_ylim(0, 60)
    
ax.set_xticks([])
ax.set_yticks([])

ax.tick_params(axis='both', which='major', labelsize=14)

'''ax.set_title(r'B$_\theta$ = %.2f°' % (B_thetas[::-1][i] * 180/np.pi),
            color = colors[::-1][i], fontsize = 18, x = .55, y = 1.01)'''
cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
cbar.ax.tick_params(labelsize = 12)
cbar.ax.set_ylabel(r'$\Delta$ accuracy (% of max)', rotation = 270, labelpad = 20,
              fontsize = 16)
'''cticks = np.concatenate((np.linspace(np.nanmin(diff), 0, 3),
                np.linspace(0, np.nanmax(diff), 3)))
cticks = np.delete(cticks, 2)
cticks = np.linspace(-40, 40, 5)
cbar.set_ticks(cticks)
cbar.set_ticklabels(np.round(cticks, 2))
cbar.set_ticklabels(cticks)'''

    
fig.tight_layout()
fig.savefig('./output/fig_6_asym.pdf', bbox_inches='tight', dpi=200, transparent=True)
plt.show()

In [None]:
itril = np.tril_indices(61)
pvals = pval_map< 0.05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False

# hide FPs, for plot sake, careful not to count them later on
pvals[0:15,:] = False
pvals[:,0:15] = False
print(len(np.where(pvals == True)[0]) / (31*31))

# For the theta decoder only

In [None]:

try:
    kf_cms_t_only = np.load('./data/tempo_transpo96_cms_kf_t_only.npy')
    kf_accs_t_only = np.load('./data/fig_6_tempo_transpo96_accs_kf_t_only.npy')

except:
    # Loading data
    all_cms_theta_only = np.zeros((len(timesteps), n_splits), dtype = object)
    data, labels, le = par_load_temporal_data(timesteps = timesteps, target_btheta = None,
                                            target_theta = None, data_type = 'all_bt',
                                            cluster_list = cluster_list)
    
    kf_accs_t_only = np.zeros((data.shape[0], data.shape[0], n_splits))
    kf_cms_t_only = np.zeros((data.shape[0], data.shape[0], n_splits, 12, 12))

    # Classifying
    logreg = LogisticRegression(**opts_LR)
    
    kf = KFold(n_splits = n_splits)
    accs = np.zeros((data.shape[0],data.shape[0]))
    for ibin_train in tqdm(range(data.shape[0]), desc = 'Training and testing') :
        for i_kf, (train_index, test_index) in enumerate(kf.split(data[ibin_train,:,:])) :
            xtrain, ytrain = data[ibin_train,train_index,:], labels[train_index] #is train on axis 1 or 2
            logreg.fit(xtrain, ytrain)
            for ibin_test in range(data.shape[0]):
                xtest, ytest = data[ibin_test,test_index,:], labels[test_index]

                cm = metrics.confusion_matrix(ytest, logreg.predict(xtest), normalize = 'all')
                cm *= len(le.classes_)
                
                kf_accs_t_only[ibin_train, ibin_test, i_kf] = metrics.balanced_accuracy_score(ytest, logreg.predict(xtest))
                kf_cms_t_only[ibin_train, ibin_test, i_kf, :, :] = cm
        
    np.save('./data/fig_6_tempo_transpo96_cms_kf_t_only.npy', kf_cms_t_only)
    np.save('./data/fig_6_tempo_transpo96_accs_kf_t_only.npy', kf_accs_t_only)
    


In [None]:
try:
    pval_map_t_only = np.load('./data/fig_6_transpo96pvals_t_only.npy')

except:
    arr = np.zeros((61,61), dtype = object)
    for x in range(61) :
        for y in range(61) :
            arr[x,y] = list(kf_accs_t_only[x,y,:6])

    triu = np.triu(arr)
    itriu = np.triu_indices(arr.shape[-1])
    tril = np.tril(arr).T

    triu = triu[triu!=0]
    tril = tril[tril!=0]

    pval_map_t_only = np.zeros((arr.shape[-1], arr.shape[-1]))
    for i0, el in tqdm(enumerate(triu), total = len(triu)) :
        p = permutation_test(triu[i0], tril[i0], num_rounds = num_rounds)
        pval_map_t_only[itriu[0][i0], itriu[1][i0]] = p
        
    np.save('./data/fig_6_transpo96pvals_t_only.npy', pval_map_t_only)
    


In [None]:
fig, ax = plt.subplots(nrows = 1, 
                        ncols = 1,
                        figsize = (8,8))

colors = plt.cm.inferno(np.linspace(.8, .3, len(B_thetas))) #tc colormap

matacc = np.mean(kf_accs_t_only, axis = -1)
im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/12)

itril = np.tril_indices(61)
pvals = pval_map_t_only < .05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False
        
pval_edges(pvals.T , ax = ax, lw = 1)
ax.set_xlim(0,60)
ax.set_ylim(0, 60)
    
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)

ax.set_ylabel('Training time (s)', fontsize = 18)
ax.set_xlabel('Generalization time (s)', fontsize = 18)
ax.tick_params(axis='both', which='major', labelsize=14)

'''ax.set_title(r'B$_\theta$ = %.2f°' % (B_thetas[::-1][i] * 180/np.pi),
            color = colors[::-1][i], fontsize = 18, x = .55, y = 1.01)'''
cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
cbar.set_ticks(np.linspace(1/12, np.max(matacc), 5))
cbar.set_ticklabels(np.round(np.linspace(1/12, np.max(matacc), 5), 2) )
cbar.ax.tick_params(labelsize = 12)

cbar.ax.set_ylabel('Classification accuracy', rotation = 270, labelpad = 20,
              fontsize = 16)

ax.plot([0,60],[0,60], c = 'w', linestyle = '--', alpha = .8, linewidth = 1) 
ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
ax.plot([10.4,40],[10,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1)
    
fig.tight_layout()
fig.savefig('./output/fig_6_tempo_gen_96_theta_only.pdf', bbox_inches='tight', dpi=200, transparent=True)
plt.show()

In [None]:
fig, ax = plt.subplots(nrows = 1, 
                        ncols = 1,
                        figsize = (8,8))

xticks = np.linspace(0, len(timesteps)-1, 7, dtype = np.int16, endpoint = True)
xticklabs = np.round(timesteps[xticks]+win_size, 2)
colors = plt.cm.inferno(np.linspace(.8, .3, len(B_thetas))) #tc colormap


matacc = np.mean(kf_accs_t_only, axis = -1)
#im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)

itril = np.tril_indices(61)
pvals = pval_map_t_only < .05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False
    
mat = matacc
triu = np.triu(mat)
itriu = np.triu_indices(mat.shape[-1])
tril = np.tril(mat).T
diff = (triu-tril)/np.max(mat)
diff*=100
tril2 = np.tril_indices(diff.shape[-1])
diff[tril2]= None
    

im = ax.imshow(diff, origin = 'lower', interpolation = 'None',
               cmap = 'RdBu_r',
              norm = mcols.TwoSlopeNorm(
                                  vmin = -30,
                                   vmax = 30,
                                      vcenter = 0))

# Diagonal plotting
#ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
ax.plot([10.4,40],[10,10], c = 'k', linestyle = '-', alpha = .8, linewidth = 1)

# Pvals
'''itril = np.tril_indices(61)
pvals = pvals_automaps[i] < 0.05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False'''

pval_edges(pvals.T , ax = ax, lw = 1, c = 'k')
ax.set_xlim(0,60)
ax.set_ylim(0, 60)

xtick_idxs = [int(x) for x in ax.get_xticks() if x >=0][:-1]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)

ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xlim(0,60)
ax.set_ylim(0, 60)
    
ax.set_xticks([])
ax.set_yticks([])

ax.tick_params(axis='both', which='major', labelsize=14)

'''ax.set_title(r'B$_\theta$ = %.2f°' % (B_thetas[::-1][i] * 180/np.pi),
            color = colors[::-1][i], fontsize = 18, x = .55, y = 1.01)'''
cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
cbar.ax.tick_params(labelsize = 12)
cbar.ax.set_ylabel(r'$\Delta$ accuracy (% of max)', rotation = 270, labelpad = 20,
              fontsize = 16)
'''cticks = np.concatenate((np.linspace(np.nanmin(diff), 0, 3),
                np.linspace(0, np.nanmax(diff), 3)))
cticks = np.delete(cticks, 2)
cticks = np.linspace(-40, 40, 5)
cbar.set_ticks(cticks)
cbar.set_ticklabels(np.round(cticks, 2))
cbar.set_ticklabels(cticks)'''

    
fig.tight_layout()
fig.savefig('./output/fig_6_asym_theta_only.pdf', bbox_inches='tight', dpi=200, transparent=True)
plt.show()

# Diff between the two ?

In [None]:
# Theta only
matacc = np.mean(kf_accs_t_only, axis = -1)
#im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)

itril = np.tril_indices(61)
pvals_theta = pval_map_t_only < .05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals_theta[itril[0][i1], itril[1][i1]] = False
    
mat = matacc
triu = np.triu(mat)
itriu = np.triu_indices(mat.shape[-1])
tril = np.tril(mat).T
diff_theta = (triu-tril)/np.max(mat)
diff_theta*=100
tril2 = np.tril_indices(diff_theta.shape[-1])
diff_theta[tril2]= None

In [None]:
# Btheta
matacc = np.mean(kf_accs, axis = -1)
#im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)

itril = np.tril_indices(61)
pvals = pval_map < .05
for i1, _ in enumerate(range(len(itril[0]))) :
    pvals[itril[0][i1], itril[1][i1]] = False
    
mat = matacc
triu = np.triu(mat)
itriu = np.triu_indices(mat.shape[-1])
tril = np.tril(mat).T
diff = (triu-tril)/np.max(mat)
diff*=100
tril2 = np.tril_indices(diff.shape[-1])
diff[tril2]= None

In [None]:
fig, ax = plt.subplots(nrows = 1, 
                        ncols = 1,
                        figsize = (8,8))

imtest = diff-diff_theta
im = ax.imshow(imtest, origin = 'lower', interpolation = 'None',
               cmap = 'RdBu_r',
              norm = mcols.TwoSlopeNorm(
                                  vmin = -20,
                                   vmax = 20,
                                      vcenter = 0))

xtick_idxs = [int(x) for x in ax.get_xticks() if x >=0][:-1]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)

ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xlim(0,60)
ax.set_ylim(0, 60)

# Making the polar plots from the kfold

In [None]:
def make_ranged_histo(a, valrange, N_thetas=N_thetas, N_B_thetas=N_B_thetas):
    # makes a polar histogram on the pref range passed as an arg
    # a is a given coordinate on the CM transpo map
    # valrange is a btheta range
    out_histo = np.zeros((len(valrange), N_thetas, N_B_thetas))
    
    for i0, val in enumerate(valrange) : # iterate through the ground truth
        histo = np.zeros((N_thetas,len(B_thetas)))
        for bt in range(len(B_thetas)) : #iterate through the preds
            submat = a[(val*N_thetas):((val+1)*N_thetas), (bt*N_thetas) : ((bt+1)*N_thetas)]
            histo[:,bt] = np.diag(submat)
        out_histo[i0,:,:] = histo
        
    out_histo = np.mean(out_histo, axis = 0) # marginalize over orientations
    return out_histo

In [None]:
def make_ranged_histo(a, valrange) :
    for k in range(96) :
        a[k,:] = np.roll(a[k,:], -(k%12)+5)
    a2 = a.reshape((12,8,96)).mean(axis = 0)

    list_histo = np.zeros((len(valrange), 12, 8))
    for i0, val in enumerate(valrange) :
        list_histo[i0,:,:] =  np.swapaxes(a2[val].reshape(8,12), 0, -1)
        
    return np.mean(list_histo, axis = 0)

In [None]:
def plot_polar(a, ax, vmin, vmax) :
    abins = np.linspace(0.1,B_thetas[0], 9)
    rbins = np.linspace(0, np.pi, 12)
    A, R = np.meshgrid(abins, rbins)
    
    pc = ax.pcolormesh(R, A, a,
                       cmap="RdBu_r", edgecolors='k', linewidth=.5,
                       antialiased=True,
                       norm = mcols.TwoSlopeNorm(
                                      vmin = vmin,
                                       vmax = vmax,
                                          vcenter = 0))
                          
          
    return pc

In [None]:
coordinates = [(40,20), (40,40), (20,40)]
print(timesteps[40], timesteps[35], timesteps[20])

In [None]:
timesteps[20]+win_size

In [None]:
a = transpo_accs.copy()
for c in coordinates :
    a[c] = .7
plt.imshow(a, origin = 'lower')

In [None]:
lowrange = np.arange(0,4)
highrange = np.arange(4,8)
cor = 10e-5

low_histos, high_histos = [], []
for c in coordinates :
    histol = make_ranged_histo(a = np.mean(kf_cms, axis = 2)[c],
                                  valrange = highrange) + cor
    histoh = make_ranged_histo(a = np.mean(kf_cms, axis = 2)[c],
                                  valrange = lowrange) + cor
    
    low_histos.append(histol)
    high_histos.append(histoh)

In [None]:
def rainbow_histo(a1, a2):
    def func(x) :
        return x+x^2+x^3
    #out = func(norm_data(a1) / norm_data(a2))
    out = np.log(a1/a2)
    out = a1 - a2
    return out

In [None]:
fig, axs = plt.subplots(ncols = 2, subplot_kw = dict(projection = 'polar', aspect = 'equal'))
vmin, vmax = 0.38, 5.1

a1, a2 = rainbow_histo(low_histos[1], low_histos[0]), rainbow_histo(high_histos[1], high_histos[0])
vmin, vmax  = np.nanmin([a1[:,:], a2[:,:]]), np.nanmax([a1[:,:], a2[:,:]])
#vmin, vmax = 0.38, 1.84
ax = axs[0]
a = rainbow_histo(low_histos[1], low_histos[0])
print(np.min(a), np.max(a))
pl = plot_polar(a, ax = ax, vmin = vmin, vmax = vmax)
ax.set_title('Middle to upper, low')

ax = axs[1]
a = rainbow_histo(high_histos[1], high_histos[0])
print(np.min(a), np.max(a))
plot_polar(a, ax = ax, vmin = vmin, vmax = vmax)
ax.set_title('Middle to upper, high')

for i, ax in enumerate(axs) : 
    ax.set_thetamin(0)
    ax.set_thetamax(180)

    ax.spines['polar'].set_visible(False)
    ax.grid(False)
    if i == 0 : 
        ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
        ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
        ax.tick_params(labelsize=14)
    else :
        ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    if i == 0 :
        ps = ax.get_position().get_points().flatten()
        cax = fig.add_axes([(1/4)+.05, 0, (1/2), 0.025])

        cticks = np.concatenate((np.linspace(vmin, 0, 3),
                        np.linspace(0, vmax, 4)))
        cticks = np.delete(cticks, 3)
        cticks = np.delete(cticks, 1)

        cb = fig.colorbar(
            pl, cax=cax, orientation='horizontal', ticks=cticks)
        cb.ax.tick_params(labelsize=14)
        cb.ax.set_xticklabels(np.round(cticks, 2))
        #cax.set_xlabel(r'$\Delta$ accuracy', labelpad=5, fontsize=14) #this is a log OR not a delta acc
        
fig.savefig('./output/fig_6_topfans.pdf', bbox_inches='tight', dpi=200, transparent=True)    
plt.show()

In [None]:
fig, axs = plt.subplots(ncols = 2, subplot_kw = dict(projection = 'polar', aspect = 'equal'))
vmin, vmax = -7, 7.78

a1, a2 = rainbow_histo(low_histos[1], low_histos[2]), rainbow_histo(high_histos[1], high_histos[2])
vmin, vmax  = np.nanmin([a1[:,:], a2[:,:]]), np.nanmax([a1[:,:], a2[:,:]])
ax = axs[0]
a = rainbow_histo(low_histos[1], low_histos[2])
print(np.min(a), np.max(a))
pl = plot_polar(a, ax = ax, vmin = vmin, vmax = vmax)
ax.set_title('Middle to lower, low')

ax = axs[1]
a = rainbow_histo(high_histos[1], high_histos[2])
print(np.min(a), np.max(a))
plot_polar(a, ax = ax, vmin = vmin, vmax = vmax)
ax.set_title('Middle to lower, high')

for i, ax in enumerate(axs) : 
    ax.set_thetamin(0)
    ax.set_thetamax(180)

    ax.spines['polar'].set_visible(False)
    ax.grid(False)
    if i == 0 : 
        ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
        ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
        ax.tick_params(labelsize=14)
    else :
        ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    if i == 0 :
        ps = ax.get_position().get_points().flatten()

        cax = fig.add_axes([(1/4)+.05, 0, (1/2), 0.025])

        cticks = np.concatenate((np.linspace(vmin, 0, 3),
                        np.linspace(0, vmax, 3)))
        cticks = np.delete(cticks, 3)

        cb = fig.colorbar(
            pl, cax=cax, orientation='horizontal', ticks=cticks)
        cb.ax.tick_params(labelsize=14)
        cb.ax.set_xticklabels(np.round(cticks, 2))
        cax.set_xlabel(r'$\Delta$ accuracy', labelpad=5, fontsize=14)
        
fig.savefig('./output/fig_6_botfans.pdf', bbox_inches='tight', dpi=200, transparent=True)    
plt.show()

# Log ODD

In [None]:
def make_LOR_ranged_histo(a, valrange):
    # makes a polar histogram on the pref range passed as an arg
    # a is a given coordinate on the CM transpo map
    # valrange is a btheta range
    out_histo = np.zeros((len(valrange),len(thetas),len(B_thetas)))
    
    for i0, val in enumerate(valrange) : # iterate through the ground truth
        histo = np.zeros((len(thetas),len(B_thetas)))
        for bt in range(len(B_thetas)) : #iterate through the preds
            submat = a[val*12:(val+1)*12, bt*12 : (bt+1)*12]
            histo[:,bt] = np.diag(submat)
        out_histo[i0,:,:] = histo
        
    out_histo = np.mean(out_histo, axis = 0)
    return out_histo

In [None]:
logoddmap = np.zeros((transpo_cms.shape[0],transpo_cms.shape[1]))

ll = np.zeros((transpo_cms.shape[0],transpo_cms.shape[1]))
lh = np.zeros((transpo_cms.shape[0],transpo_cms.shape[1]))
hl = np.zeros((transpo_cms.shape[0],transpo_cms.shape[1]))
hh = np.zeros((transpo_cms.shape[0],transpo_cms.shape[1]))

for itrain in tqdm(range(transpo_cms.shape[0])) :
    for itest in range(transpo_cms.shape[1]) :
        # Polar histograms
        histo_low = make_LOR_ranged_histo(a = np.mean(kf_cms[itrain, itest,:], axis = 0),
                                      valrange = np.arange(0,4))
        histo_high = make_LOR_ranged_histo(a = np.mean(kf_cms[itrain, itest,:], axis = 0),
                                      valrange = np.arange(4,8))
        
        # Average accuracy on both slices of the fan for each ground truth
        # These arrays are used to normalize everything
        ll[itrain, itest] = np.mean(histo_low[:,:4])
        lh[itrain, itest] = np.mean(histo_low[:,4:])
        hl[itrain, itest] = np.mean(histo_high[:,:4])
        hh[itrain, itest] = np.mean(histo_high[:4:])
        
        # TODO : stop naming variables at 11pm
        # post-11pm disclaimer : fromX mean ground truth is X, the right part is the predicted acc
        # While these do not include normalization
        fromlow_mean_low = np.mean(histo_low[:,:4])
        fromlow_mean_high = np.mean(histo_low[:,4:])
        fromhigh_mean_low = np.mean(histo_high[:,:4])
        fromhigh_mean_high = np.mean(histo_high[:,4:])
        
        # Log odd ratio, computed without norm
        oddratio = (fromlow_mean_low / fromlow_mean_high) / (fromhigh_mean_low / fromhigh_mean_high)
        logoddmap[itrain, itest] = np.log(oddratio)
        

In [None]:
norm_logodd = np.log((ll/lh) / (hh/hl)) #not normed now
norm_logodd[norm_logodd == -np.inf] = 0
norm_logodd[norm_logodd == np.inf] = 0
fig, ax = plt.subplots(figsize = (8,8))

a = norm_logodd
im = ax.imshow(a,  cmap = 'RdBu_r', origin = 'lower', interpolation = 'None',
          norm = mcols.TwoSlopeNorm(
                                      vmin = np.min(a),
                                       vmax = np.max(a),
                                          vcenter = 0))

'''cols = plt.cm.gray(np.linspace(.1, .75, len(xs)))[::-1]
for y in ys :
    for i, x in enumerate(xs) :
        ax.scatter(x,y, color = cols[i], zorder = 5)'''

ax.plot([0,60],[0,60], c = 'k', linestyle = '--', alpha = .8, linewidth = 1) 
ax.plot([10,10],[40,10], c = 'k', linestyle = '-', alpha = .8, linewidth = 1) 
ax.plot([10.4,40],[10,10], c = 'k', linestyle = '-', alpha = .8, linewidth = 1)

cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
cticks = np.concatenate((np.linspace(np.min(a), 0, 4),
                        np.linspace(0, np.max(a), 4)))
cticks = np.delete(cticks, 3)
cbar.set_ticks(cticks)
cbar.set_ticklabels(np.round(cticks, 2))
cbar.ax.tick_params(labelsize = 12)
cbar.ax.set_ylabel(r'Log OR', rotation = 270, labelpad = 40,
              fontsize = 16)

xtick_idxs = [int(x) for x in ax.get_xticks() if x >=0][:-1]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabs)
ax.set_yticks(xticks)
ax.set_yticklabels(xticklabs)
ax.set_ylabel('Training time (s)', fontsize = 18)
ax.set_xlabel('Generalization time (s)', fontsize = 18)
ax.tick_params(axis='both', which='major', labelsize=14)

fig.tight_layout()
fig.savefig('./output/fig_6_nobelmap.pdf', bbox_inches='tight', dpi=200, transparent=True)
plt.show()

# Pred proba realignement - video
How does each fan evolves in time when moving in a line ? We'll consider two scenarios : evolution in the X axis and evolution in the Y axis. 

The decoder will be trained at a time point on the identity line, then tested at multiple time points on an orthogonal axis. Each test set will serve to predict the probas, which will be rotated on their theta to form a nice delta theta fan, and then we'll average across folds and preds

In [None]:
def pred_proba_time_gen(evo_idxs, N_thetas=N_thetas, N_B_thetas=N_B_thetas) :
    ibin_train = evo_idxs[0][1]
    
    proba_preds = np.zeros((len(evo_idxs), n_splits, data.shape[1]//n_splits, N_thetas*N_B_thetas))
    for i_kf, (train_index, test_index) in tqdm(enumerate(kf.split(data[ibin_train,:,:])), 
                                                desc = 'Kfolding', total=n_splits) :
        xtrain, ytrain = data[ibin_train, train_index,:], labels[train_index]
        logreg.fit(xtrain, ytrain)

        for itest, ibin_test in enumerate(evo_idxs[:,0]) :
            xtest, ytest = data[ibin_test, test_index,:], labels[test_index]
            probas = logreg.predict_proba(xtest)

            rolled_proba = []
            for i_proba in range(len(probas)) :
                lab = labels[i_proba]
                proba = probas[i_proba]
                rolled_proba.append(np.roll(proba, -lab%N_thetas+N_B_thetas//2+1))
            rolled_proba = np.asarray(rolled_proba)
            proba_preds[itest, i_kf, :,:] = rolled_proba
            
    return proba_preds

In [None]:

def plot_all_probas(probas, title, N_thetas=N_thetas, N_B_thetas=N_B_thetas) :
    n_ax = probas.shape[0]
    fig, axs = plt.subplots(figsize = (n_ax*3, n_ax), ncols = n_ax,
                        subplot_kw = dict(projection = 'polar', aspect = 'equal'))

    for i, ax in enumerate(fig.axes):
        a = np.swapaxes(probas[i].reshape(N_B_thetas, N_thetas), 0, 1)
        
        abins = np.linspace(0.1,B_thetas[0], 9)
        rbins = np.linspace(0, np.pi, 12)
        EL, AZ = np.meshgrid(np.linspace(0.1, B_thetas[0], N_B_thetas+1), np.linspace(0, np.pi, N_thetas+1))

        pc = ax.pcolormesh(AZ*2, EL, coeffs.T,
                           cmap="viridis", edgecolors='k', linewidth=.5,
                           antialiased=True,
                           vmin = np.min(probas),
                           vmax = np.max(probas))  

        ax.set_thetamin(0)
        ax.set_thetamax(180)

        ax.spines['polar'].set_visible(False)
        ax.grid(False)

        if i == 0 : 
            ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
            ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
            ax.tick_params(labelsize=14)
            
            ps = ax.get_position().get_points().flatten()
            cax = fig.add_axes([(1/4)+.05, 0, (1/2), 0.025])
            cticks = np.concatenate((np.linspace(vmin, 0, 3),
                            np.linspace(0, vmax, 3)))
            cticks = np.delete(cticks, 3)
            cticks = np.linspace(np.min(probas), np.max(probas), 6)

            cb = fig.colorbar(
                pc, cax=cax, orientation='horizontal', ticks=cticks)
            cb.ax.tick_params(labelsize=14)
            cb.ax.set_xticklabels(np.round(cticks, 2))
            cax.set_xlabel(r'Mean pred proba', labelpad=5, fontsize=14)
            
        else :
            ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_title('%.2f s' % (timesteps[gen_evolution[i][0]]))


    fig.suptitle(title, fontsize = 14)
    fig.tight_layout()
    return fig, axs

In [None]:
logreg = LogisticRegression(**opts_LR)
kf = KFold(n_splits = n_splits)

In [None]:
# Loading data
data, labels, le = par_load_temporal_data(timesteps = timesteps, target_btheta = None,
                                        target_theta = None, data_type = 'all_t_bt',
                                        cluster_list = cluster_list)

In [None]:
def plot_ani_proba(a, ax, title,
                  vmin, vmax, init = False) :

    a = np.swapaxes(a.reshape(8,12), 0, 1)

    abins = np.linspace(0.1,B_thetas[0], 9)
    rbins = np.linspace(0, np.pi, 12)
    A, R = np.meshgrid(abins, rbins)

    pc = ax.pcolormesh(R, A, a,
                       cmap="RdBu_r", edgecolors='k', linewidth=.5,
                       antialiased=True,
                       norm = mcols.TwoSlopeNorm(
                                      vmin = vmin,
                                       vmax = vmax,
                                          vcenter = 1/96))
    ax.set_thetamin(0)
    ax.set_thetamax(180)
    ax.spines['polar'].set_visible(False)
    ax.grid(False)

    ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
    ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
    ax.tick_params(labelsize=14)
    
    ps = ax.get_position().get_points().flatten()
    cax = fig.add_axes([.62, .25, .3, 0.025])
    cticks = np.linspace(vmin, vmax, 4)

    cb = fig.colorbar(
        pc, cax=cax, orientation='horizontal', ticks=cticks)
    cb.ax.tick_params(labelsize=14)
    cb.ax.set_xticklabels(np.round(cticks, 3))
    cax.set_xlabel(r'Pred. proba.', labelpad=5, fontsize=14)

    ax.set_yticklabels([])
    ax.set_title(r'$t =%.2f $s' % title, fontsize = 15)
    
def init():
    # Plot the transpo map on the left
    
    
    matacc = np.mean(kf_accs, axis = -1)
    im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)
    ax.plot([0,60],[0,60], c = 'w', linestyle = '--', alpha = .8, linewidth = 1) 
    ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
    ax.plot([10.4,40],[10,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1)
    ax.plot([gen_evolution[0][0], gen_evolution[-1][0]],
           [gen_evolution[0][1], gen_evolution[0][1]], color = 'k', linewidth = 1)
    
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabs)
    ax.set_yticks(xticks)
    ax.set_yticklabels(xticklabs)
    ax.set_ylabel('Training time (s)', fontsize = 14)
    ax.set_xlabel('Generalization time (s)', fontsize = 14)
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
    cbar.set_ticks(np.linspace(1/96, np.max(matacc), 5))
    cbar.set_ticklabels(np.round(np.linspace(1/96, np.max(matacc), 5), 3) )
    cbar.ax.tick_params(labelsize = 12)
    cbar.ax.set_ylabel('Classification accuracy', rotation = 270, labelpad = 20,
              fontsize = 14)
    
    
    # Plot the fan on the right
    plot_ani_proba(a = avg_across_tests_gen[0],
                  ax = ax_p, title = timesteps[tests[0]]+win_size,
                  vmin = avg_across_tests_gen.min(),
                  vmax = avg_across_tests_gen.max())
    return ax, ax_p

def animate(i) :
    # Scatter on the left
    global sc
    if i>0 : sc.remove()
    sc = ax.scatter(gen_evolution[i][0], gen_evolution[0][1],
              marker = '|', c = 'k')
    
    # Update the fan
    plot_ani_proba(a = avg_across_tests_gen[i],
                  ax = ax_p, title = timesteps[tests[i]]+win_size,
                  vmin = avg_across_tests_gen.min(),
                  vmax = avg_across_tests_gen.max())
    fig.tight_layout()
    return ax, ax_p

In [None]:
do_video = False
do_video = True


if do_video :
    for traintime in np.arange(10, 60, 5) :
        tests = np.arange(10, 60, 1) # evolution along either axis
        gen_evolution = np.asarray([(x, traintime) for x in tests]) # Generalization on the X axis

        proba_preds_gen = pred_proba_time_gen(gen_evolution)
        avg_across_folds_gen = np.mean(proba_preds_gen, axis = 1)
        avg_across_tests_gen = np.mean(avg_across_folds_gen, axis = 1)

        fig = plt.figure(figsize = (12,6))
        ax = fig.add_axes([.1, .1, .4, 1.],polar = False)
        ax_p = fig.add_axes([.6, 0, .35, 1.],polar = True)

        anim = animation.FuncAnimation(fig, animate, init_func = init,
                                      frames = np.arange(len(tests)),
                                      interval = 150)
        anim.save(filename = './output/video_%s.mp4'%gen_evolution[0][1],
                  writer = 'ffmpeg')

# Static

In [None]:
ys, xs = [20, 35, 50], [20, 35, 50]

In [None]:
def plot_static_proba(a, ax, title,
                  vmin, vmax, init = False, do_title = False,
                     do_bar = False, do_label = False) :

    a = np.swapaxes(a.reshape(8,12), 0, 1)

    abins = np.linspace(0.1,B_thetas[0], 9)
    rbins = np.linspace(0, np.pi, 12)
    A, R = np.meshgrid(abins, rbins)

    pc = ax.pcolormesh(R, A, a,
                       cmap="RdBu_r", edgecolors='k', linewidth=.5,
                       antialiased=True,
                       norm = mcols.TwoSlopeNorm(
                                      vmin = vmin,
                                       vmax = vmax,
                                          vcenter = 1/96))
    
    
    
    if do_label :
        ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
        ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
        ax.tick_params(labelsize=14)
    else :
        ax.set_xticks([])
    ax.set_yticklabels([])
    ax.set_thetamin(0)
    ax.set_thetamax(180)
    ax.spines['polar'].set_visible(False)
    ax.grid(False)
    
    if do_bar :
        ps = ax.get_position().get_points().flatten()
        cax = fig.add_axes([.62, .25, .3, 0.025])
        cticks = np.linspace(vmin, vmax, 4)

        cb = fig.colorbar(
            pc, cax=cax, orientation='horizontal', ticks=cticks)
        cb.ax.tick_params(labelsize=14)
        cb.ax.set_xticklabels(np.round(cticks, 2))
        cax.set_xlabel(r'Pred. proba.', labelpad=5, fontsize=14)

    
    if do_title :
        ax.set_title(r'$t =%.2f $s' % title, fontsize = 15)

In [None]:
do_static = False
do_static = True
if do_static :
    for i0, y in enumerate(ys[::-1]) : #train times
        print(y)
        tests = xs #generalizatio ntimes
        gen_evolution = np.asarray([(x, y) for x in tests]) # Generalization on the X axis

        proba_preds_gen = pred_proba_time_gen(gen_evolution)
        avg_across_folds_gen = np.mean(proba_preds_gen, axis = 1)
        avg_across_tests_gen = np.mean(avg_across_folds_gen, axis = 1)

        fig, axs = plt.subplots(figsize = (9, 3), ncols = 3, nrows = 1,
                               subplot_kw = dict(projection = 'polar', aspect = 'equal'))
        for i1, ax in enumerate(axs):
            plot_static_proba(a = avg_across_tests_gen[i1],
                             ax = ax,
                             title =  timesteps[xs[i1]]+win_size,
                             do_title = True if i0 == 0 else False,
                             vmin = 0.0, vmax = .03,
                             do_bar = True if i1 == 0 and i0 == 0 else False,
                             do_label = True if i0 == 0 and i1 == 0 else False)
        plt.show()

# Static, with multiple bthetas

In [None]:
def pred_proba_time_gen(evo_idxs) :
    ibin_train = evo_idxs[0][1]
    
    proba_preds = np.zeros((len(evo_idxs), n_splits, 3), dtype = object)
    for i_kf, (train_index, test_index) in tqdm(enumerate(kf.split(data[ibin_train,:,:])), desc = 'Kfolding',
                                               total = n_splits) :
        xtrain, ytrain = data[ibin_train, train_index,:], labels[train_index]
        logreg.fit(xtrain, ytrain)

        for itest, ibin_test in enumerate(evo_idxs[:,0]) :
            xtest, ytest = data[ibin_test, test_index,:], labels[test_index]
            probas = logreg.predict_proba(xtest)

            rolled_proba_0, rolled_proba_15, rolled_proba_36 = [],[],[]
            for i_proba in range(len(probas)) :
                lab = labels[i_proba]
                proba = probas[i_proba]
                
                if lab <=12 :
                    rolled_proba_0.append(np.roll(proba, -lab%12+5))
                elif lab >=36 and lab <=48 :
                    rolled_proba_15.append(np.roll(proba, -lab%12+5))
                elif lab >= 84 :
                    rolled_proba_36.append(np.roll(proba, -lab%12+5))
                    
            #rolled_proba = np.asarray(rolled_proba)
            proba_preds[itest, i_kf, 2] = np.asarray(rolled_proba_0)
            proba_preds[itest, i_kf, 1] = np.asarray(rolled_proba_15)
            proba_preds[itest, i_kf, 0] = np.asarray(rolled_proba_36)
            
    return proba_preds

In [None]:
def plot_static_proba(a, ax, title, title_bt,
                  vmin, vmax, init = False, do_title = False,
                     do_bar = False, do_label = False) :

    a = np.swapaxes(a.reshape(8,12), 0, 1)

    abins = np.linspace(0.1,B_thetas[0], 9)
    rbins = np.linspace(0, np.pi, 12)
    A, R = np.meshgrid(abins, rbins)

    pc = ax.pcolormesh(R, A, a,
                       cmap="RdBu_r", edgecolors='k', linewidth=.5,
                       antialiased=True,
                       norm = mcols.TwoSlopeNorm(
                                      vmin = vmin,
                                       vmax = vmax,
                                          vcenter = 1/96))
    
    
    
    if do_label :
        ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
        ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
        ax.tick_params(labelsize=14)
    else :
        ax.set_xticks([])
    ax.set_yticklabels([])
    ax.set_thetamin(0)
    ax.set_thetamax(180)
    ax.spines['polar'].set_visible(False)
    ax.grid(False)
    
    if do_bar :
        ps = ax.get_position().get_points().flatten()
        cax = fig.add_axes([.62, .25, .3, 0.025])
        cticks = np.linspace(vmin, vmax, 4)

        cb = fig.colorbar(
            pc, cax=cax, orientation='horizontal', ticks=cticks)
        cb.ax.tick_params(labelsize=14)
        cb.ax.set_xticklabels(np.round(cticks, 2))
        cax.set_xlabel(r'Pred. proba.', labelpad=5, fontsize=14)

    
    if do_title :
        ax.set_title(r'$t =%.2f $s, bt = %.2f' % (title, title_bt), fontsize = 15)

In [None]:
do_static = False
do_static = True
if do_static :
    for i0, y in enumerate(ys[::-1]) : #train times
        print(y)
        tests = xs #generalizatio ntimes
        gen_evolution = np.asarray([(x, y) for x in tests]) # Generalization on the X axis
        
        proba_preds_gen = pred_proba_time_gen(gen_evolution)
        
        
        for i1 in range(3) :
            proba_bthetas = proba_preds_gen[:,:,i1]
            
            fig, axs = plt.subplots(figsize = (9, 3), ncols = 3, nrows = 1,
                                   subplot_kw = dict(projection = 'polar', aspect = 'equal'))
            
            for i2, ax in enumerate(axs):
                proba_t = proba_bthetas[i2,:]
                
                a = np.mean(proba_t, axis = 0)
                a2 = np.mean(a, axis = 0)
                
                plot_static_proba(a = a2,
                                 ax = ax,
                                 title =  timesteps[xs[i2]]+win_size,
                                 do_title = True if i0 == 0 else False,
                                  title_bt = (B_thetas[[-1, 4, 0]]*180/np.pi)[i1],
                                 vmin = 0.0, vmax = .03,
                                 do_bar = True if i1 == 0 and i0 == 0 else False,
                                 do_label = True if i0 == 0 and i1 == 0 else False)

            
            
        plt.show()

# Animated, with multiple Bthetas

In [None]:
logreg = LogisticRegression(**opts_LR)
kf = KFold(n_splits = n_splits)

In [None]:
# Loading data
data, labels, le = par_load_temporal_data(timesteps = timesteps, target_btheta = None,
                                        target_theta = None, data_type = 'all_t_bt',
                                        cluster_list = cluster_list)

In [None]:
def pred_proba_time_gen(evo_idxs) :
    ibin_train = evo_idxs[0][1]
    
    proba_preds = np.zeros((len(evo_idxs), n_splits, 3), dtype = object)
    for i_kf, (train_index, test_index) in tqdm(enumerate(kf.split(data[ibin_train,:,:])), desc = 'Kfolding',
                                               total = n_splits) :
        xtrain, ytrain = data[ibin_train, train_index,:], labels[train_index]
        logreg.fit(xtrain, ytrain)

        for itest, ibin_test in enumerate(evo_idxs[:,0]) :
            xtest, ytest = data[ibin_test, test_index,:], labels[test_index]
            probas = logreg.predict_proba(xtest)

            rolled_proba_0, rolled_proba_15, rolled_proba_36 = [],[],[]
            for i_proba in range(len(probas)) :
                lab = labels[i_proba]
                proba = probas[i_proba]
                
                if lab <=12 :
                    rolled_proba_0.append(np.roll(proba, -lab%12+5))
                elif lab >=36 and lab <=48 :
                    rolled_proba_15.append(np.roll(proba, -lab%12+5))
                elif lab >= 84 :
                    rolled_proba_36.append(np.roll(proba, -lab%12+5))
                    
            #rolled_proba = np.asarray(rolled_proba)
            proba_preds[itest, i_kf, 2] = np.asarray(rolled_proba_0)
            proba_preds[itest, i_kf, 1] = np.asarray(rolled_proba_15)
            proba_preds[itest, i_kf, 0] = np.asarray(rolled_proba_36)
            
    return proba_preds

In [None]:
def plot_ani_proba(a, ax, title, title_bt,
                  vmin, vmax, init = False,
                  do_title = False, do_bar = False, n_bar = 0,
                  do_ticks  = False) :

    a = np.swapaxes(a.reshape(8,12), 0, 1)

    abins = np.linspace(0.1,B_thetas[0], 9)
    rbins = np.linspace(0, np.pi, 12)
    A, R = np.meshgrid(abins, rbins)

    pc = ax.pcolormesh(R, A, a,
                       cmap="RdBu_r", edgecolors='k', linewidth=.5,
                       antialiased=True,
                       norm = mcols.TwoSlopeNorm(
                                      vmin = vmin,
                                       vmax = vmax,
                                          vcenter = 1/96))
    ax.set_thetamin(0)
    ax.set_thetamax(180)
    ax.spines['polar'].set_visible(False)
    ax.grid(False)
    ax.set_yticklabels([])
    
    if do_ticks :
        ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
        ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
        ax.tick_params(labelsize=14)
    else :
        ax.set_xticks([])
    
    if do_bar :
        cax = fig.add_axes([.4 + (.2*n_bar), 0.3, .15, 0.025])
        cticks = np.linspace(vmin, vmax, 4)
        cb = fig.colorbar(
            pc, cax=cax, orientation='horizontal', ticks=cticks)
        cb.ax.tick_params(labelsize=14)
        cb.ax.set_xticklabels(np.round(cticks, 3))
        cax.set_xlabel(r'Pred. proba.', labelpad=5, fontsize=14)
    else :
        pass
    
    
    if do_title  :
        ax.set_title(r'$t =%.2fs ; B_\theta = %.2f$' % (title, title_bt), fontsize = 15)
    else :
        ax.set_title(r'$B_\theta = %.2f°$' % title_bt, fontsize = 15)
    
def init():
    # Plot the transpo map on the left
    matacc = np.mean(kf_accs, axis = -1)
    im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)
    ax.plot([0,60],[0,60], c = 'w', linestyle = '--', alpha = .8, linewidth = 1) 
    ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
    ax.plot([10.4,40],[10,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1)
    ax.plot([gen_evolution[0][0], gen_evolution[-1][0]],
           [gen_evolution[0][1], gen_evolution[0][1]], color = 'k', linewidth = 1)
    
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabs)
    ax.set_yticks(xticks)
    ax.set_yticklabels(xticklabs)
    ax.set_ylabel('Training time (s)', fontsize = 14)
    ax.set_xlabel('Generalization time (s)', fontsize = 14)
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
    cbar.set_ticks(np.linspace(1/96, np.max(matacc), 5))
    cbar.set_ticklabels(np.round(np.linspace(1/96, np.max(matacc), 5), 3) )
    cbar.ax.tick_params(labelsize = 12)
    cbar.ax.set_ylabel('Classification accuracy', rotation = 270, labelpad = 20,
              fontsize = 13)
    
    
    # Plot the fans on the right
    plot_ani_proba(a = array_movie[0,0,:],
                  ax = ax_p0, title = timesteps[tests[0]]+win_size,
                  vmin = array_movie[:,0,:].min(),
                  vmax = array_movie[:,0,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[0],
                  do_title = True, do_bar = True, do_ticks = True, n_bar = 0)
    plot_ani_proba(a = array_movie[0,1,:],
                  ax = ax_p1, title = timesteps[tests[0]]+win_size,
                  vmin = array_movie[:,1,:].min(),
                  vmax = array_movie[:,1,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[1],
                  do_title = False, do_bar = True, do_ticks = False, n_bar = 1)
    plot_ani_proba(a = array_movie[0,2,:],
                  ax = ax_p2, title = timesteps[tests[0]]+win_size,
                  vmin = array_movie[:,2,:].min(),
                  vmax = array_movie[:,2,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[2],
                  do_title = False, do_bar = True, do_ticks = False, n_bar = 2)
    return ax, ax_p0, ax_p1, ax_p2

def animate(i) :
    # Scatter on the left
    global sc
    if i>0 : sc.remove()
    sc = ax.scatter(gen_evolution[i][0], gen_evolution[0][1],
              marker = '|', c = 'k')
    
    # Update the fan
    plot_ani_proba(a = array_movie[i,0,:],
                  ax = ax_p0, title = timesteps[tests[i]]+win_size,
                  vmin = array_movie[:,0,:].min(),
                  vmax = array_movie[:,0,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[0],
                  do_title = True, do_bar = False, do_ticks = True, n_bar = 0)
    plot_ani_proba(a = array_movie[i,1,:],
                  ax = ax_p1, title = timesteps[tests[i]]+win_size,
                  vmin = array_movie[:,1,:].min(),
                  vmax = array_movie[:,1,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[1],
                  do_title = False, do_bar = False, do_ticks = False, n_bar = 1)
    plot_ani_proba(a = array_movie[i,2,:],
                  ax = ax_p2, title = timesteps[tests[i]]+win_size,
                  vmin = array_movie[:,2,:].min(),
                  vmax = array_movie[:,2,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[2],
                  do_title = False, do_bar = False, do_ticks = False, n_bar= 2)
    fig.tight_layout()
    return ax, ax_p0, ax_p1, ax_p2

In [None]:
do_video = False
do_video = True
if do_video :
    for traintime in np.arange(10, 60, 5) :
        tests = np.arange(10, 60, 1) # evolution along either axis
        gen_evolution = np.asarray([(x, traintime) for x in tests]) # Generalization on the X axis

        proba_preds_gen = pred_proba_time_gen(gen_evolution)
        
        array_movie = np.zeros((len(tests), 3, 96)) #test, bt, proba
        for i1 in range(3) :
            proba_bthetas = proba_preds_gen[:,:,i1]
            for i2 in range(len(tests)) :
                proba_t = proba_bthetas[i2,:]
                a = np.mean(proba_t, axis = 0)
                a2 = np.mean(a, axis = 0)
                array_movie[i2,i1,:] = a2
            


        fig = plt.figure(figsize = (12,6))
        ax = fig.add_axes([.1, .1, .2, 1.],polar = False)
        ax_p0 = fig.add_axes([.4, 0, .15, 1.],polar = True)
        ax_p1 = fig.add_axes([.6, 0, .15, 1.],polar = True)
        ax_p2 = fig.add_axes([.8, 0, .15, 1.],polar = True)

        anim = animation.FuncAnimation(fig, animate, init_func = init,
                                      frames = np.arange(len(tests)),
                                      interval = 150)
        anim.save(filename = './output/video_%s.mp4'%gen_evolution[0][1],
                  writer = 'ffmpeg')

# Increase alongside the diagonal

In [None]:
def pred_proba_time_gen(evo_idxs) :

    
    proba_preds = np.zeros((len(evo_idxs), n_splits, 3), dtype = object)
    for i_kf, (train_index, test_index) in tqdm(enumerate(kf.split(data[0,:,:])), desc = 'Kfolding',
                                               total = n_splits) :
        
        for _, ibin_train in enumerate(evo_idxs[:,0]) :
            xtrain, ytrain = data[ibin_train, train_index,:], labels[train_index]
            logreg.fit(xtrain, ytrain)

            for itest, ibin_test in enumerate(evo_idxs[:,0]) :
                xtest, ytest = data[ibin_test, test_index,:], labels[test_index]
                probas = logreg.predict_proba(xtest)

                rolled_proba_0, rolled_proba_15, rolled_proba_36 = [],[],[]
                for i_proba in range(len(probas)) :
                    lab = labels[i_proba]
                    proba = probas[i_proba]

                    if lab <=12 :
                        rolled_proba_0.append(np.roll(proba, -lab%12+5))
                    elif lab >=36 and lab <=48 :
                        rolled_proba_15.append(np.roll(proba, -lab%12+5))
                    elif lab >= 84 :
                        rolled_proba_36.append(np.roll(proba, -lab%12+5))

                #rolled_proba = np.asarray(rolled_proba)
                proba_preds[itest, i_kf, 2] = np.asarray(rolled_proba_0)
                proba_preds[itest, i_kf, 1] = np.asarray(rolled_proba_15)
                proba_preds[itest, i_kf, 0] = np.asarray(rolled_proba_36)
            
    return proba_preds

In [None]:
def plot_ani_proba(a, ax, title, title_bt,
                  vmin, vmax, init = False,
                  do_title = False, do_bar = False, n_bar = 0,
                  do_ticks  = False) :

    a = np.swapaxes(a.reshape(8,12), 0, 1)

    abins = np.linspace(0.1,B_thetas[0], 9)
    rbins = np.linspace(0, np.pi, 12)
    A, R = np.meshgrid(abins, rbins)

    pc = ax.pcolormesh(R, A, a,
                       cmap="RdBu_r", edgecolors='k', linewidth=.5,
                       antialiased=True,
                       norm = mcols.TwoSlopeNorm(
                                      vmin = vmin,
                                       vmax = vmax,
                                          vcenter = 1/96))
    ax.set_thetamin(0)
    ax.set_thetamax(180)
    ax.spines['polar'].set_visible(False)
    ax.grid(False)
    ax.set_yticklabels([])
    
    if do_ticks :
        ax.set_xticks(np.radians([0, 45, 90, 135, 180]))
        ax.set_xticklabels(['+90°', '+45°', '0°', '-45°', '-90°'])
        ax.tick_params(labelsize=14)
    else :
        ax.set_xticks([])
    
    if do_bar :
        cax = fig.add_axes([.4 + (.2*n_bar), 0.3, .15, 0.025])
        cticks = np.linspace(vmin, vmax, 4)
        cb = fig.colorbar(
            pc, cax=cax, orientation='horizontal', ticks=cticks)
        cb.ax.tick_params(labelsize=14)
        cb.ax.set_xticklabels(np.round(cticks, 3))
        cax.set_xlabel(r'Pred. proba.', labelpad=5, fontsize=14)
    else :
        pass
    
    
    if do_title  :
        ax.set_title(r'$t =%.2fs ; B_\theta = %.2f°$' % (title, title_bt), fontsize = 15)
    else :
        ax.set_title(r'$B_\theta = %.2f°$' % title_bt, fontsize = 15)
    
def init():
    # Plot the transpo map on the left
    matacc = np.mean(kf_accs, axis = -1)
    im = ax.imshow(matacc, origin = 'lower', interpolation = 'None', vmin = 1/96)
    ax.plot([0,60],[0,60], c = 'w', linestyle = '--', alpha = .8, linewidth = 1) 
    ax.plot([10,10],[40,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1) 
    ax.plot([10.4,40],[10,10], c = 'w', linestyle = '-', alpha = .8, linewidth = 1)
    
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabs)
    ax.set_yticks(xticks)
    ax.set_yticklabels(xticklabs)
    ax.set_ylabel('Training time (s)', fontsize = 14)
    ax.set_xlabel('Generalization time (s)', fontsize = 14)
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    cbar = fig.colorbar(im, ax = ax,fraction=0.046, pad=0.04)
    cbar.set_ticks(np.linspace(1/96, np.max(matacc), 5))
    cbar.set_ticklabels(np.round(np.linspace(1/96, np.max(matacc), 5), 3) )
    cbar.ax.tick_params(labelsize = 12)
    cbar.ax.set_ylabel('Classification accuracy', rotation = 270, labelpad = 20,
              fontsize = 13)
    
    
    # Plot the fans on the right
    plot_ani_proba(a = array_movie[0,0,:],
                  ax = ax_p0, title = timesteps[tests[0]]+win_size,
                  vmin = array_movie[:,0,:].min(),
                  vmax = array_movie[:,0,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[0],
                  do_title = True, do_bar = True, do_ticks = True, n_bar = 0)
    plot_ani_proba(a = array_movie[0,1,:],
                  ax = ax_p1, title = timesteps[tests[0]]+win_size,
                  vmin = array_movie[:,1,:].min(),
                  vmax = array_movie[:,1,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[1],
                  do_title = False, do_bar = True, do_ticks = False, n_bar = 1)
    plot_ani_proba(a = array_movie[0,2,:],
                  ax = ax_p2, title = timesteps[tests[0]]+win_size,
                  vmin = array_movie[:,2,:].min(),
                  vmax = array_movie[:,2,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[2],
                  do_title = False, do_bar = True, do_ticks = False, n_bar = 2)
    return ax, ax_p0, ax_p1, ax_p2

def animate(i) :
    # Scatter on the left
    global sc
    if i>0 : sc.remove()
    sc = ax.scatter(gen_evolution[i][0], gen_evolution[i][0],
              marker = 'x', c = 'k')
    
    # Update the fan
    plot_ani_proba(a = array_movie[i,0,:],
                  ax = ax_p0, title = timesteps[tests[i]]+win_size,
                  vmin = array_movie[:,0,:].min(),
                  vmax = array_movie[:,0,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[0],
                  do_title = True, do_bar = False, do_ticks = True, n_bar = 0)
    plot_ani_proba(a = array_movie[i,1,:],
                  ax = ax_p1, title = timesteps[tests[i]]+win_size,
                  vmin = array_movie[:,1,:].min(),
                  vmax = array_movie[:,1,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[1],
                  do_title = False, do_bar = False, do_ticks = False, n_bar = 1)
    plot_ani_proba(a = array_movie[i,2,:],
                  ax = ax_p2, title = timesteps[tests[i]]+win_size,
                  vmin = array_movie[:,2,:].min(),
                  vmax = array_movie[:,2,:].max(),
                   title_bt = (B_thetas[::-1][[0,4,-1]]*180/np.pi)[2],
                  do_title = False, do_bar = False, do_ticks = False, n_bar = 2)
    fig.tight_layout()
    return ax, ax_p0, ax_p1, ax_p2

In [None]:
do_video = False
do_video = True
if do_video :

    tests = np.arange(10, 60, 1) # evolution along either axis
    gen_evolution = np.asarray([(x, 0) for x in tests]) # Generalization on the X axis

    proba_preds_gen = pred_proba_time_gen(gen_evolution)

    array_movie = np.zeros((len(tests), 3, 96)) #test, bt, proba
    for i1 in range(3) :
        proba_bthetas = proba_preds_gen[:,:,i1]
        for i2 in range(len(tests)) :
            proba_t = proba_bthetas[i2,:]
            a = np.mean(proba_t, axis = 0)
            a2 = np.mean(a, axis = 0)
            array_movie[i2,i1,:] = a2


In [None]:
if do_video :
    fig = plt.figure(figsize = (12,6))
    ax = fig.add_axes([.1, .1, .2, 1.],polar = False)
    ax_p0 = fig.add_axes([.4, 0, .15, 1.],polar = True)
    ax_p1 = fig.add_axes([.6, 0, .15, 1.],polar = True)
    ax_p2 = fig.add_axes([.8, 0, .15, 1.],polar = True)

    anim = animation.FuncAnimation(fig, animate, init_func = init,
                                  frames = np.arange(len(tests)),
                                  interval = 150)
    anim.save(filename = './output/video_diag.mp4',
              writer = 'ffmpeg')

# Non kfolded

In [None]:
def pred_proba_time_gen(evo_idxs) :
    
    proba_preds = np.zeros((len(evo_idxs), 3), dtype = object)
    for itest, ibin_train in tqdm(enumerate(evo_idxs[:,0]), total = len(evo_idxs[:,0]), desc='Decoding') :
        xtrain, xtest, ytrain, ytest = train_test_split(data[ibin_train,:,:], labels, test_size =test_size, random_state = 42,
                                                       stratify = labels)
        logreg.fit(xtrain, ytrain)
        
        probas = logreg.predict_proba(xtest)
        rolled_proba_0, rolled_proba_15, rolled_proba_36 = [],[],[]
        for i_proba in range(len(probas)) :
            lab = labels[i_proba]
            proba = probas[i_proba]

            if lab <=12 :
                rolled_proba_0.append(np.roll(proba, -lab%12+5))
            elif lab >=36 and lab <=48 :
                rolled_proba_15.append(np.roll(proba, -lab%12+5))
            elif lab >= 84 :
                rolled_proba_36.append(np.roll(proba, -lab%12+5))
        proba_preds[itest, 0] = np.asarray(rolled_proba_0)
        proba_preds[itest, 1] = np.asarray(rolled_proba_15)
        proba_preds[itest, 2] = np.asarray(rolled_proba_36)
            
    return proba_preds

In [None]:
do_video = False
do_video = True

if do_video :

    tests = np.arange(10, 60, 1) # evolution along either axis
    gen_evolution = np.asarray([(x, 0) for x in tests]) # Generalization on the X axis

    proba_preds_gen = pred_proba_time_gen(gen_evolution)

    array_movie = np.zeros((len(tests), 3, 96)) #test, bt, proba
    for i1 in range(3) :
        proba_bthetas = proba_preds_gen[:,i1]
        for i2 in range(len(tests)) :
            array_movie[i2,i1,:] = np.mean(proba_bthetas[i2], axis = 0)

In [None]:
if do_video :
    fig = plt.figure(figsize = (12,6))
    ax = fig.add_axes([.1, .1, .2, 1.],polar = False)
    ax_p0 = fig.add_axes([.4, 0, .15, 1.],polar = True)
    ax_p1 = fig.add_axes([.6, 0, .15, 1.],polar = True)
    ax_p2 = fig.add_axes([.8, 0, .15, 1.],polar = True)

    anim = animation.FuncAnimation(fig, animate, init_func = init,
                                  frames = np.arange(len(tests)),
                                  interval = 150)
    anim.save(filename = './output/video_diag.mp4',
              writer = 'ffmpeg')