In [None]:
# Linear versus quadratic model estimation
# hollander et al., 2021
# test for decoding accuracy
# increase linearly or show a clear peak at a specific cortical depth
# perform bayesian model comparison

# fit two hierarhical linear models using pymc
# NUTS sampler

# dependent variable (decoding accuracy) was first modeled as a linear function of cortical depth
# y_n = theta_n,0 + theta_n,1 * d_n

# and then also as a quadratic function of cortical depth
# y_n = theta_n,0 + theta_n,1 * d_n * theta_n,2 * d^2

# where y_n were all observationf for subjects n
# d is a corresponding vector of cortical depths
# theta_n is the parameter vector for subject n that was estimated

# the subject-wise parameters were modeled as coming from Gaussian group distirbution mean mu and standard deviation sigma
# theta ~ N(mu, sigma)
# mu ~ N(0, 1)
# sigma ~ HalfCauchy(5) -> 5 subjects?

# (1) use the state-of-the-art Watanabe-Akaike informatoin criterion to do Bayesian model copmarison between the linear and quadratic models
# (2) estimate the posterior of the peak of the quadratic function using the formula 

In [1]:
from pathlib import Path
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
from pymc.model_graph import model_to_graphviz
from src.config import SUBJECTS, SESSION, N_LAYER
from palettable.matplotlib import Inferno_3 as ColMap
plt.style.use(os.path.join(module_path, "src", "default.mplstyle"))


SESS = "VASO"
DIR_DATA = "/data/pt_01880/Experiment1_ODC/paper/decoding"



In [2]:
class LinearModel:
    def __init__(self, x, y, group):
        self.x = x
        self.y = y
        self.group = group
        self.model = pm.Model()
        self.trace = None
        self.init = False

    @property
    def mean_data(self):
        x_data = np.zeros(N_LAYER)
        y_data = np.zeros(N_LAYER)
        for i in range(N_LAYER):
            x_data[i] = i
            y_data[i] = np.mean(self.y[np.where(self.x==i)[0]])
        return x_data, y_data

    def init_model(self):
        with self.model:
            b0 = pm.Normal("b0", mu=0, sigma=100)
            b1 = pm.Normal("b1", mu=0, sigma=100)

            ## define Linear model
            yest = b0 + b1 * self.x
            #yest = b1*pm.math.cos(x) + b0

            ## define Normal likelihood with HalfCauchy noise (fat tails, equiv to HalfT 1DoF)
            y_sigma = pm.HalfCauchy("y_sigma", beta=10)
            likelihood = pm.Normal("likelihood", mu=yest, sigma=y_sigma, observed=self.y)
        
        self.init = True

    def sample(self):
        if not self.init:
            self.init_model()
        with self.model:
            self.trace = pm.sample(
                draws=10000, 
                tune=500,
                chains=4,
                target_accept=0.95,
                progressbar=True,
                return_inferencedata=True, 
                idata_kwargs={'log_likelihood': True}
                )

    def plot_prior(self):
        if not self.init:
            self.init_model()
        with self.model:
            prior_checks = pm.sample_prior_predictive(samples=500)
        b0 = prior_checks.prior.b0.to_numpy()[0]
        b1 = prior_checks.prior.b1.to_numpy()[0]
        self._plot_helper(b0, b1)

    def plot_fit(self):
        if self.trace is None:
            raise ValueError("No sampling done!")
        b1 = np.mean(self.trace.posterior.b1.to_numpy(), axis=0)
        b0 = np.mean(self.trace.posterior.b0.to_numpy(), axis=0)
        self._plot_helper(b0, b1)

    def plot_traces(self):
        if self.trace is None:
            raise ValueError("No sampling done!")
        # Plot traces with overlaid means and values
        summary = az.summary(self.trace, stat_funcs={"mean": np.mean}, extend=False)
        ax = az.plot_trace(
            self.trace,
            lines=tuple([(k, {}, v["mean"]) for k, v in summary.iterrows()]),
            )

        for i, mn in enumerate(summary["mean"].values):
            ax[i, 0].annotate(
                f"{mn:.2f}",
                xy=(mn, 0),
                xycoords="data",
                xytext=(5, 10),
                textcoords="offset points",
                rotation=90,
                va="bottom",
                fontsize=16,
                color="C0",
                )

    def evaluate_model(self, method="waic"):
        if self.trace is None:
            raise ValueError("No sampling done!")
        if method == "waic":
            #waic = az.waic(self.trace, scale="deviance")
            print(az.waic(self.trace, scale="deviance"))
        elif method == "loo":
            print(az.loo(self.trace, scale="deviance"))
        else:
            raise ValueError("Unknown method!")

    def visualize_model(self):
        if not self.init:
            self.init_model()
        graph = model_to_graphviz(self.model)
        graph.render("graphname", format="png")

    def _plot_helper(self, *args):
        x_data, y_data = self.mean_data
        fig, ax = plt.subplots()
        color = ColMap.hex_colors
        _x = np.linspace(0, 10, 1000)
        for i in range(len(args[1])):
            y_prior = args[1][i] * _x + args[0][i]
            _ = ax.plot(_x, y_prior, color="gray", alpha=0.05)
        _ = ax.plot(_x, np.mean(args[1]) * _x + np.mean(args[0]), color=color[1], linestyle="-")
        _ = ax.plot(x_data, y_data, color="black", linestyle="--")
        _ = ax.set_xlabel(r"GM/WM $\rightarrow$ GM/CSF")
        _ = ax.set_ylabel("Accuracy in %")
        plt.show()

    @classmethod
    def from_data(cls, sess):
        x = []
        y = []
        group = []
        for i, subj in enumerate(SUBJECTS):
            for day in [0, 1]:
                path = Path(DIR_DATA) / subj / f"{sess}{SESSION[subj][sess][day]}"
                file = path / "bandpass_none" / "accuracy.csv"
                data = np.genfromtxt(file, delimiter=',')
                data = np.mean(data, axis=1)

                x.extend(np.arange(N_LAYER))
                y.extend(data * 100)
                group.extend(i * np.ones_like(data, dtype=np.int64))
        x = np.array(x)
        y = np.array(y)
        group = np.array(group, dtype=np.int64)

        return cls(x, y, group)

In [None]:
class CosModel(LinearModel):
    def init_model(self):
        pass

    def plot_prior(self):
        pass

    def plot_fit(self):
        pass