# Experiment: NAME HERE

**Goal**:

Additional Notes

Related notebooks:

* `XX`

> Written by Jess Breda 

## Imports

In [1]:
import seaborn as sns
import matplotlib.pyplot as plt

from multiglm.experiments.experiment import load_experiment

from multiglm.experiments.experiment_tau_sweep import ExperimentTauSweep

from multiglm.features.design_matrix_generator import *
from multiglm.features.design_matrix_generator_PWM import *

from multiglm.models.binary_logistic_regression import BinaryLogisticRegression
from multiglm.models.multiclass_logistic_regression import MultiClassLogisticRegression

from multiglm.visualizations.model_visualizer import ModelVisualizerTauSweep

from multiglm.data import ANIMAL_IDS
from multiglm.experiments import STANDARD_SIGMAS


sns.set_context("talk")
%load_ext autoreload
%autoreload 2

## Model Configs

In [4]:
standard_cols = {
    "session": lambda df: (copy(df.session)),
    "bias": lambda df: (add_bias_column(df)),
    "s_a": lambda df: (standardize(df.s_a)),
    "s_b": lambda df: (standardize(df.s_b)),
    "prev_avg_stim": lambda df: (prev_avg_stim(df, mask_prev_violation=True)),
    "prev_correct": lambda df: (prev_correct_side(df)),
    "prev_choice": lambda df: (prev_choice(df)),
    "labels": {"column_name": "choice"},
}

sweep_col = {
    "tau_sweep": {
        "taus": [1, 2, 3, 4, 5],
        "col_name": "filt_prev_viol",
        "col_func": lambda df: (prev_violation(df)),
        "current_idx": 0,
    },
}

model_config = {
    "example_model_name": {
        "model_class": MultiClassLogisticRegression,
        "dmg_config": {**standard_cols, **sweep_col},
    }
}


params = {
    "animals": ANIMAL_IDS,  # all animals
    "data_type": "new_trained",
    "sigmas": STANDARD_SIGMAS,
    "random_state": 47,
    "eval_train": True,
    "model_config": model_config,
}
save_name = "DATE_tau_sweep_SWEEPCOL.pkl"

order = [
    "bias",
    "filt_prev_viol",
    "s_a",
    "s_b",
    "prev_avg_stim",
    "prev_correct",
    "prev_choice",
]

## Run

In [None]:
experiment = ExperimentTauSweep(params)
experiment.run()
experiment.save(save_name)

## Visualize

In [None]:
experiment = load_experiment(save_name)
mvt = ModelVisualizerTauSweep(experiment)

In [None]:
mvt.plot_sigma_summary()

In [None]:
mvt.plot_tau_summary()

In [None]:
mvt.plot_nll_over_taus_by_animal(group="tau", color="gray")

In [None]:

mvt.plot_weights_summary(palette="Set2", order=order)

In [None]:
mvt.plot_weights_by_animal(palette="Set2", order=order)

In [None]:
mvt.plot_tau_histogram(column="prev_violation_tau", binwidth=1)

In [None]:
mvt.save_best_fit_tau()

## Special Fxs

In [None]:
plot_weight_by_tau("feature_colume_name", "weight_class")