## 0. Preparation

### Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import statsmodels.api as sm
from sklearn.metrics import mean_squared_error as MSE

import sys
sys.path.append('/home/rachel/Documents/lfp2spikes/modularized_code') 
from models.fit_model import MODEL_FITTING
from models.testing import EVALUATION

### Data

In [2]:
np.random.seed(42)
theta = np.random.random((4,))
print(f"Ground truth theta: \n{theta}")

Ground truth theta: 
[0.37454012 0.95071431 0.73199394 0.59865848]


In [3]:
X = np.random.random((1000,4))
y = np.exp(X@theta) + np.random.random((1000,))*1e-7

In [4]:
def check_model_pipeline(fitter, **fit_params): 
    # MODEL FITTING
    model, params, _ = fitter.fit_model(X, y, **fit_params)
    nll = fitter.neg_log_lik(params, X, y)
    y_hat = fitter.predict_spike_rate(model, X, params)
    
    # EVALUATION
    print(f"\nModel with regularization alpha {fit_params['alpha']}: ")
    print(f"  - Negative Log Likelihood: {nll}")
    print(f"  - theta: {params[1:]}")
    print(f"  - intercept: {params[0]}")
    print(f"  - theta MSE: {MSE(theta, params[1:])}")
    print(f"  - predction MSE: {MSE(y, y_hat)}")

## 1. Sklearn

In [5]:
fitter_sklearn = MODEL_FITTING("sklearn")
print(">> SKLEARN <<")
for a in [0.0, 0.3, 0.9]: 
    params = {"alpha" : a}
    theta_hat = check_model_pipeline(fitter_sklearn, **params)

>> SKLEARN <<

Model with regularization alpha 0.0: 
  - Negative Log Likelihood: 1.6019886259382192
  - theta: [0.37453926 0.95071613 0.73198913 0.59866142]
  - intercept: 5.432684976499334e-07
  - theta MSE: 8.972578303461632e-12
  - predction MSE: 5.2574857847723706e-11

Model with regularization alpha 0.3: 
  - Negative Log Likelihood: 1.674230474666313
  - theta: [0.19562545 0.49666452 0.39828789 0.3121517 ]
  - intercept: 0.6803977983537364
  - theta MSE: 0.10790438498585722
  - predction MSE: 0.6761286331266279

Model with regularization alpha 0.9: 
  - Negative Log Likelihood: 1.7743079356529874
  - theta: [0.09964016 0.25500097 0.20849153 0.15986081]
  - intercept: 1.0343100505068084
  - theta MSE: 0.2565463043007203
  - predction MSE: 1.5477557708572312


### 2. Statsmodels

In [6]:
fitter_stats = MODEL_FITTING("stats")
print(">> STATSMODELS (L2 Reg) <<")
for a in [0.0, 0.3, 0.9]: 
    params = {"alpha" : a, "L1_wt" : 0.0}
    theta_hat = check_model_pipeline(fitter_stats, **params)

>> STATSMODELS (L2 Reg) <<

Model with regularization alpha 0.0: 
  - Negative Log Likelihood: 1.6019886259382188
  - theta: [0.37453926 0.95071613 0.73198913 0.59866142]
  - intercept: 5.432684977739136e-07
  - theta MSE: 8.972578303565896e-12
  - predction MSE: 5.257485784879797e-11

Model with regularization alpha 0.3: 
  - Negative Log Likelihood: 1.6389523665108008
  - theta: [0.30063792 0.60721738 0.49984864 0.42085291]
  - intercept: 0.4223543603713865
  - theta MSE: 0.05223948318476876
  - predction MSE: 0.3890894025848563

Model with regularization alpha 0.9: 
  - Negative Log Likelihood: 1.7009122431875203
  - theta: [0.29500259 0.4538631  0.39700484 0.35811657]
  - intercept: 0.5078601861802435
  - theta MSE: 0.10581636082464548
  - predction MSE: 1.0322417602425813


In [7]:
fitter_stats = MODEL_FITTING("stats")
print(">> STATSMODELS (L1 Reg) <<")
for a in [0.0, 0.3, 0.9]: 
    params = {"alpha" : a, "L1_wt" : 1.0}
    theta_hat = check_model_pipeline(fitter_stats, **params)

>> STATSMODELS (L1 Reg) <<

Model with regularization alpha 0.0: 
  - Negative Log Likelihood: 1.6019891994466962
  - theta: [0.37447307 0.95051983 0.73152031 0.59705982]
  - intercept: 0.0015231045515916848
  - theta MSE: 7.055907789944597e-07
  - predction MSE: 4.42712223406949e-06

Model with regularization alpha 0.3: 
  - Negative Log Likelihood: 1.8209762396277083
  - theta: [0.         0.52706477 0.         0.        ]
  - intercept: 1.0481984902410366
  - theta MSE: 0.3034915859174859
  - predction MSE: 1.988469129868459

Model with regularization alpha 0.9: 
  - Negative Log Likelihood: 2.0461514705401376
  - theta: [0. 0. 0. 0.]
  - intercept: 1.147730553291428
  - theta MSE: 0.4845862761470941
  - predction MSE: 3.6325564747662193


