## Compare Binomial & Multinomial Logistic Regression

> Written by Jess Breda September 2023 post lab meeting

A question that came up (from Carlos) was if the mutli-class model would do better on L/R trials if it was trained on L,R,V. The goal of this notebook is to implement this comparison.

**Initial Steps**:

[ ] working with simulated data, figure out the dimensions of the multi-class cost

* follow up questions here if needed

[ ] working with simulated data, create a binomial class for fitting (from prev code)

[ ] validate binomial class finds athena/nick like results with base regressors

*  start with single animal, then expand

[ ] see what prev_violation regressor does for binomial model

[ ] figure out how to make train/test split for the models
	
* follow up if diff number of training trials might be an issue

* probably don't want to make this perfect now, but long term good to think about & have this information easily stored
		
[ ] determine what the null model comparison would be (if any?)


In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import pathlib
import sys
from scipy.optimize import minimize

[
    sys.path.append(str(folder))
    for folder in pathlib.Path("../src/").iterdir()
    if folder.is_dir()
]
from get_rat_data import *

### Dimensions of multi-class cost

In the test eval of the cost function, I want to squash out the third dimension so the model performs better. Eg [1/3 1/3 1/3] should become [1/2 1/2] (or [1/2 1/2 0], not sure yet). Need to focus on figuring out the cost code.

In [7]:
class MultiClassLogisticRegressionComp:
    def __init__(self, sigma=None, method="BFGS", disp=True):
        self.W = None
        self.sigma = sigma
        self.method = method
        self.disp = disp
        self.stored_fits = []

    def fit(self, X: pd.DataFrame, Y: np.ndarray):
        N, D_w_bias = X.shape
        _, C = Y.shape
        initial_W_flat = np.zeros(D_w_bias * C)

        result = minimize(
            fun=self.cost,
            x0=initial_W_flat,
            args=(X.to_numpy(), Y, self.sigma),
            method=self.method,
            jac=self._gradient,
            options={"disp": self.disp},
        )

        self.W = result.x.reshape(D_w_bias, C)
        return self.W

    def eval(self, X: pd.DataFrame, Y: np.ndarray):
        return self.cost(self.W, X.to_numpy(), Y, sigma=None)

    def cost(self, W, X, Y, sigma):
        """
        Compute the negative log-likelihood for multi-class
        logistic regression with L2 regularization (or MAP).

        params
        ------
        W : np.ndarray, shape (D + 1, C) or flattened (D+1 * C)
        weight matrix, will be in flattened form if in use
            for minimize() function
        X : pd.DataFrame, shape (N, D + 1)
            design matrix with bias column
        Y : np.ndarray, shape (N, C), where C = 3
            one-hot encoded choice labels for each trial as left,
            right or violation
        sigma : float (default=None)
            standard deviation of Gaussian prior, if None no
            regularization is applied

        returns
        -------
        - nll : float
            negative log-likelihood
        """
        if len(W.shape) == 1:
            W = W.reshape(X.shape[1], Y.shape[1])

        logits = X @ W
        penalty = (
            (1 / (2 * (sigma**2))) * np.trace(W[1:, :].T @ W[1:, :]) if sigma else 0
        )
        nll = (-np.sum(Y * logits) + np.sum(self.log_sum_exp(logits))) + penalty
        return nll

    def _gradient(self, W, X, Y, sigma):
        """
        Compute the gradient of the negative log-likelihood for
        multi-class logistic regression with L2 regularization (or MAP).

        params
        ------
        W : np.ndarray, shape (D + 1, C) or flattened (D+1 * C)
        weight matrix, will be in flattened form if in use
        for minimize() function
        X : pd.DataFrame, shape (N, D + 1)
            design matrix with bias column
        Y : np.ndarray, shape (N, C), where C = 3
            one-hot encoded choice labels for each trial as left,
            right or violation
        sigma : float (default=None)
            standard deviation of Gaussian prior, if None no
            regularization is applied

        returns
        -------
        gradient :  np.ndarray, shape (D+1 * C)
            gradient of the negative log-likelihood

        """
        if len(W.shape) == 1:
            W = W.reshape(X.shape[1], Y.shape[1])

        logits = X @ W
        P = self._stable_softmax(logits)

        if sigma:
            penalty_gradient = W / (sigma**2)
        else:
            penalty_gradient = np.zeros_like(W)

        penalty_gradient[0, :] = 0  # No penalty for bias

        gradient = X.T @ (P - Y) + penalty_gradient
        return gradient.flatten()

    @staticmethod
    def log_sum_exp(logits):
        max_logits = np.max(logits, axis=1, keepdims=True)
        return (
            np.log(np.sum(np.exp(logits - max_logits), axis=1, keepdims=True))
            + max_logits
        )

    @staticmethod
    def stable_softmax(logits):
        exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
        sum_exp = np.sum(exp_logits, axis=1, keepdims=True)
        return exp_logits / sum_exp

In [15]:
model = MultiClassLogisticRegressionComp(sigma=1)

In [16]:
N = 1000  # Number of samples
D = 4  # Number of features
C = 3  # Number of classes

# Generate random feature values
X = np.random.normal(size=(N, D))
X_with_bias = np.c_[np.ones(N), X]  # bias column

# Generate random true weights (including the bias coefficient)
true_W = np.random.normal(loc=0, scale=1, size=(D + 1, C))

# Generate target labels (on hot encoded) using multinomial logistic function
A = X_with_bias @ true_W
P = model.stable_softmax(A)
Y = np.array([np.random.multinomial(1, n) for n in P])

print(f"Generated {N} samples with {D} features and {C} classes")
print(f"W is {true_W.shape} \nX is {X_with_bias.shape} \nY is {Y.shape}")
print(f"W has mean {np.mean(true_W):.3f} and std {np.std(true_W):.3f}")

Generated 1000 samples with 4 features and 3 classes
W is (5, 3) 
X is (1000, 5) 
Y is (1000, 3)
W has mean -0.046 and std 1.146


In [18]:
P[0]

array([0.36529901, 0.04899202, 0.58570898])