# Figure 4
The code for plotting figure 4. The data is generated in `FORCE oscillations.ipynb` and `NEF oscillations.ipynb`.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
import seaborn
import pandas as pd
%matplotlib inline

In [None]:
filename = "../generatedData/oscillations.h5"

In [None]:
def calcCorrelations(framework):
    if framework=="efficient":
        return calcCorrelationsEfficientCoding()
    res = []
    for f in range(1,26):
        originalEta = pd.read_hdf(filename, "/%s/%d/original/K" % (framework, f)).values
        originalPhi = pd.read_hdf(filename, "/%s/%d/original/phi" % (framework, f)).values

        columnsEta = pd.read_hdf(filename, "/%s/%d/insideManifold/K" % (framework, f)).values
        columnsPhi = pd.read_hdf(filename, "/%s/%d/insideManifold/phi" % (framework, f)).values

        rowsEta = pd.read_hdf(filename, "/%s/%d/outsideManifold/K" % (framework, f)).values
        rowsPhi = pd.read_hdf(filename, "/%s/%d/outsideManifold/phi" % (framework, f)).values

        etaPhiCorr = scipy.stats.pearsonr(originalEta.flat, originalPhi.flat)[0]

        origW = np.dot(originalEta, originalPhi.T)
        colsW = np.dot(columnsEta, columnsPhi.T)
        rowsW = np.dot(rowsEta, rowsPhi.T)

        colsWcorr = scipy.stats.pearsonr(origW.flat, colsW.flat)[0]
        rowsWcorr = scipy.stats.pearsonr(origW.flat, rowsW.flat)[0]
        res.append((f, etaPhiCorr, colsWcorr, rowsWcorr))
    return pd.DataFrame(res, columns=["frequency", "etaVsPhiT", "wPermCols", "wPermRows"])

In [None]:
def calcCorrelationsEfficientCoding():
    res = []
    for f in range(1,26):
        #Reuse encoders from figure 2:
        originalEta = pd.read_hdf("../generatedData/fig2.h5", "/ec/original/K").values
        omega = 2*np.pi*f
        A = np.array([[0, -omega],[omega, 0]])
        originalPhi = np.dot(originalEta, 0.01*A + np.eye(2))
        
        columnsEta = originalEta[:,(1,0)]
        columnsPhi = np.dot(columnsEta, 0.01*A + np.eye(2))

        permRowIndex = np.hstack((np.arange(500,1000), np.arange(500)))
        rowsEta = originalEta[permRowIndex, :]
        rowsPhi = np.dot(rowsEta, 0.01*A + np.eye(2))

        etaPhiCorr = scipy.stats.pearsonr(originalEta.flat, originalPhi.flat)[0]

        origW = np.dot(originalEta, originalPhi.T)
        colsW = np.dot(columnsEta, columnsPhi.T)
        rowsW = np.dot(rowsEta, rowsPhi.T)

        colsWcorr = scipy.stats.pearsonr(origW.flat, colsW.flat)[0]
        rowsWcorr = scipy.stats.pearsonr(origW.flat, rowsW.flat)[0]
        res.append((f, etaPhiCorr, colsWcorr, rowsWcorr))
    return pd.DataFrame(res, columns=["frequency", "etaVsPhiT", "wPermCols", "wPermRows"])

In [None]:
correlations = {k: calcCorrelations(k) for k in ["forceLong", "nef", "efficient"]}

In [None]:
seaborn.set_context("paper")
seaborn.set_style("white")

In [None]:
fig, axs = plt.subplots(2,1,figsize=(2.5, 2.5), sharex=True)
for c in ["forceLong", "nef", "efficient"]:
    axs[0].plot(correlations[c].etaVsPhiT)
axs[0].set_ylabel("Correlation")
axs[0].set_ylim(0,1)
axs[0].legend(["FORCE", "NEF", "Efficient\ncoding"], loc="upper right", frameon=True, ncol=1, labelspacing=1)
for i,c in enumerate(["forceLong", "nef", "efficient"]):
    axs[1].plot(correlations[c].set_index("frequency").wPermCols, c=seaborn.color_palette()[i])
    axs[1].plot(correlations[c].set_index("frequency").wPermRows, c=seaborn.color_palette()[i], ls="dotted")
axs[1].set_xlabel("Frequency [Hz]")
axs[1].set_ylabel("Correlation")
axs[1].set_ylim(-1,1)
axs[1].set_xlim(1,25)
#axs[1].legend(["", "", "", "", "Inside\nmanifold", "Outside\nmanifold"], frameon=True, ncol=3)
seaborn.despine()
plt.savefig("fig4_plots.svg", dpi=300)