In [120]:
import numpy as np
from sklearn.linear_model import LogisticRegression
from pathlib import Path
import pandas as pd
from numpy.lib.stride_tricks import sliding_window_view
import matplotlib.pyplot as plt
from neuropy import plotting
from scipy import stats


basepath = Path("D:\\Data")
# files = ["gronckle.csv", "grump.csv"]
files = sorted(basepath.glob("*.csv"))

fig = plotting.Fig(6, 3, size=(12, 5), num=1)

npast = 10
params_pooled = []
task_type_bool = []

for i, file in enumerate(files):
    data_df = pd.read_csv(basepath / file)
    prob_corr = np.abs(
        stats.pearsonr(data_df["rewprobfull1"], data_df["rewprobfull2"])[0]
    )

    task_type = "unstructured" if prob_corr < 0.2 else "structured"
    task_type_bool.append(prob_corr)

    choices = data_df["port"].to_numpy()
    choices[choices == 2] = -1
    outcomes = data_df["reward"].to_numpy()
    outcomes[outcomes == 0] = -1
    n_trials = choices.size

    past_choices = sliding_window_view(choices, npast)[:-1, :]
    past_outcomes = sliding_window_view(outcomes, npast)[:-1, :]
    actual_choices = choices[npast:]

    x = np.hstack(
        (
            past_choices * past_outcomes,
            past_choices,
            past_outcomes,
        )
    )
    clf = LogisticRegression(random_state=0).fit(x, actual_choices)

    params = np.fliplr(clf.coef_.squeeze().reshape(3, npast))
    params_pooled.append(params)

    subfig = fig.add_subfigure(fig.gs[i])
    subfig.suptitle(f"{files[i].name[:-4]}, {task_type}")
    sub_axs = subfig.subplots(1, 3, width_ratios=[1, 1, 1], sharey=True, sharex=True)

    colors = ["orange", "purple", "blue"]
    titles = ["Reward Seeking", "Choice Preservation", "Main effect of Outcome"]
    for _, ax in enumerate(sub_axs):

        ax.plot(np.arange(1, 11), params[_], ".-", color=colors[_], zorder=1)
        ax.set_title(titles[_])
        ax.axhline(0, color="gray", zorder=0, lw=0.8)
        ax.set_xticks([1, 5, 10])

    if i == 0:
        sub_axs[0].set_xlabel("Trials in the past")
        sub_axs[0].set_ylabel("Influence on current choice")

task_type_bool = np.array(task_type_bool)
params_pooled = np.array(params_pooled)
mean_struc = params_pooled[task_type_bool < 0.2, :, :].mean(axis=0)
mean_unstruc = params_pooled[task_type_bool > 0.2, :, :].mean(axis=0)

subfig = fig.add_subfigure(fig.gs[4:, 0:2])
subfig.suptitle(f"Mean across animals by task type")
sub_axs = subfig.subplots(1, 3, width_ratios=[1, 1, 1], sharey=True, sharex=True)

colors = ["orange", "purple", "blue"]
titles = ["Reward Seeking", "Choice Preservation", "Main effect of Outcome"]
for _, ax in enumerate(sub_axs):

    ax.plot(np.arange(1, 11), mean_struc[_], ".-", color=colors[_], zorder=1)
    ax.plot(
        np.arange(1, 11), mean_unstruc[_], ".-", color=colors[_], alpha=0.5, zorder=1
    )
    ax.set_title(titles[_])
    ax.axhline(0, color="gray", zorder=0, lw=0.8)
    ax.set_xticks([1, 5, 10])

In [114]:
np.array(params_pooled).mean(axis=0).shape

(3, 10)

In [None]:
clf.coef_.reshape(3, 10)

array([[ 1.42566825e-01,  1.36219690e-01,  1.06658575e-01,
         1.64560054e-01,  1.54191088e-01,  1.82923636e-01,
         3.25072185e-01,  4.06793141e-01,  2.84806566e-01,
         1.22912753e+00],
       [-2.78163829e-04,  4.49484669e-03, -1.65199260e-02,
        -2.08843506e-02,  1.11621426e-02,  6.72401248e-03,
         3.04537672e-04, -2.29415835e-02, -5.01471543e-02,
        -6.26810326e-02],
       [ 2.94850012e-02,  2.89742085e-03,  3.21116694e-02,
         5.38890822e-02,  6.67156585e-02,  9.83085270e-02,
         1.05871632e-01,  1.37174049e-01,  2.22335994e-01,
         4.92329831e-01]])

In [66]:
clf.coef_.shape

(1, 30)

In [64]:
_, ax = plt.subplots()

ax.plot(clf.coef_.squeeze())

[<matplotlib.lines.Line2D at 0x1d517801450>]