# Initial plan for training an SAE on spikes and trying to find interpretable SAE features

- Get neural + behavioral data
  - Get Allen visual coding neuropixels spikes data
  - Get Allen visual stim data + metadata
- Train SAEs 
  - Break spikes down into time windows
  - Train SAEs on all spikes for a particular region for all sessions for one animal
    - Hyperparameter sweeps: 
      - Time window size
      - Number of SAE features
      - L1_coeff values
      - Add a seq_len of time windows
      - Second layer to decoder (to capture nonlinear features)
      - Different optimizers
      - Different sparsity penalties / loss functions
  - Repeat this training but for particular regions for multiple animals
- Interpret SAE features
  - After training, feed in spikes for particular time window(s) and see which SAE feature(s) fire, and see if they correspond to the visual stim


In [None]:
%load_ext autoreload
%autoreload 2
# %flow mode reactive

In [1]:
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union

import einops
import jax
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import torch
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from rich import print as rprint
from torch import Tensor, nn
from torch.nn import functional as F
from tqdm.notebook import tqdm

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

## Get Allen visual coding neuropixels data

In [2]:
out_dir = Path(r"C:\Users\jai\mini\data")
out_dir.mkdir(exist_ok=True)
print(f"{out_dir.exists()=}")

manifest_path = out_dir / "manifest.json"
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

out_dir.exists()=True


In [3]:
units = cache.get_units()
probes = cache.get_probes()
channels = cache.get_channels()
sessions = cache.get_session_table()

In [22]:
display(sessions.head(10))

Unnamed: 0_level_0,published_at,specimen_id,session_type,age_in_days,sex,full_genotype,unit_count,channel_count,probe_count,ecephys_structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
715093703,2019-10-03T00:00:00Z,699733581,brain_observatory_1.1,118.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,884,2219,6,"[CA1, VISrl, nan, PO, LP, LGd, CA3, DG, VISl, ..."
719161530,2019-10-03T00:00:00Z,703279284,brain_observatory_1.1,122.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,755,2214,6,"[TH, Eth, APN, POL, LP, DG, CA1, VISpm, nan, N..."
721123822,2019-10-03T00:00:00Z,707296982,brain_observatory_1.1,125.0,M,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,444,2229,6,"[MB, SCig, PPT, NOT, DG, CA1, VISam, nan, LP, ..."
732592105,2019-10-03T00:00:00Z,717038288,brain_observatory_1.1,100.0,M,wt/wt,824,1847,5,"[grey, VISpm, nan, VISp, VISl, VISal, VISrl]"
737581020,2019-10-03T00:00:00Z,718643567,brain_observatory_1.1,108.0,M,wt/wt,568,2218,6,"[grey, VISmma, nan, VISpm, VISp, VISl, VISrl]"
739448407,2019-10-03T00:00:00Z,716813543,brain_observatory_1.1,112.0,M,wt/wt,625,2221,6,"[grey, VISam, nan, VIS, VISp, VISl, VISrl]"
742951821,2019-10-03T00:00:00Z,723627604,brain_observatory_1.1,120.0,M,wt/wt,893,2219,6,"[VISal, nan, grey, VISl, VISrl, VISp, VISpm, VIS]"
743475441,2019-10-03T00:00:00Z,722882755,brain_observatory_1.1,121.0,M,wt/wt,553,2225,6,"[LP, LGd, HPF, DG, CA3, CA1, VISrl, nan, PP, P..."
744228101,2019-10-03T00:00:00Z,719817805,brain_observatory_1.1,122.0,M,wt/wt,659,2226,6,"[Eth, TH, LP, POL, APN, DG, CA1, VIS, nan, CA3..."
746083955,2019-10-03T00:00:00Z,726170935,brain_observatory_1.1,98.0,F,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,582,2216,6,"[VPM, TH, LGd, CA3, CA2, CA1, VISal, nan, grey..."


In [4]:
bo_df = sessions[sessions["session_type"] == "brain_observatory_1.1"]  # brain observatory data
id_counts = bo_df['specimen_id'].value_counts()
repeated_ids = id_counts[id_counts > 1]
print(f"{repeated_ids=}")

repeated_ids=Series([], Name: specimen_id, dtype: int64)


No animals with more than one session, so we'll just train on single animal-session for now.

In [5]:
# session = sessions[sessions["specimen_id"] == 742951821]
# session = sessions[sessions["specimen_id"] == 750332458]
session = sessions[sessions["specimen_id"] == 717038288]

In [6]:
display(session)
session_id = session.index.values[0]
print(f"{session_id=}")

Unnamed: 0_level_0,published_at,specimen_id,session_type,age_in_days,sex,full_genotype,unit_count,channel_count,probe_count,ecephys_structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
732592105,2019-10-03T00:00:00Z,717038288,brain_observatory_1.1,100.0,M,wt/wt,824,1847,5,"[grey, VISpm, nan, VISp, VISl, VISal, VISrl]"


session_id=732592105


In [7]:
session_data = cache.get_session_data(session_id)

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


some useful (not exhaustive) `session_data` attributes and methods

- attributes:
  - metadata
  - channels
  - stimulus_conditions
  - stimulus_presentations
  - spike_times
  - spike_amplitudes

- methods:
  - channel_structure_intervals
  - conditionwise_spike_statistics
  - get_pupil_data 
  - get_stimulus_epochs
  - get_stimulus_parameter_values
  - get_stimulus_table
  - get_lfp
  - get_screen_graze_data
  - get_invalid_times
  - presentationwise_spike_times
  - presentationwise_spike_counts
  - running_speed

In [8]:
session_data.metadata

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


{'specimen_name': 'C57BL/6J-404553',
 'session_type': 'brain_observatory_1.1',
 'full_genotype': 'wt/wt',
 'sex': 'M',
 'age_in_days': 100.0,
 'rig_equipment_name': 'NP.1',
 'num_units': 824,
 'num_channels': 1847,
 'num_probes': 5,
 'num_stimulus_presentations': 70388,
 'session_start_time': datetime.datetime(2019, 1, 8, 16, 26, 20, tzinfo=tzoffset(None, -28800)),
 'ecephys_session_id': 732592105,
 'structure_acronyms': ['VISpm',
  'grey',
  nan,
  'VISp',
  'VISl',
  'VISal',
  'VISrl'],
 'stimulus_names': ['spontaneous',
  'gabors',
  'flashes',
  'drifting_gratings',
  'natural_movie_three',
  'natural_movie_one',
  'static_gratings',
  'natural_scenes']}

In [8]:
region = "VISp"

session_channels = session_data.channels.drop(columns=["filtering"])
session_channels_visp = session_channels[session_channels["structure_acronym"] == region]

In [10]:
display(session_channels_visp)

Unnamed: 0_level_0,probe_channel_number,probe_horizontal_position,probe_id,probe_vertical_position,structure_acronym,ecephys_structure_id,ecephys_structure_acronym,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
850231405,307,27,733744649,3080,VISp,385.0,VISp,,,
850231359,284,43,733744649,2860,VISp,385.0,VISp,,,
850231385,297,11,733744649,2980,VISp,385.0,VISp,,,
850231343,276,43,733744649,2780,VISp,385.0,VISp,,,
850231389,299,27,733744649,3000,VISp,385.0,VISp,,,
...,...,...,...,...,...,...,...,...,...,...
850231349,279,27,733744649,2800,VISp,385.0,VISp,,,
850231323,266,59,733744649,2680,VISp,385.0,VISp,,,
850231281,245,11,733744649,2460,VISp,385.0,VISp,,,
850231297,253,11,733744649,2540,VISp,385.0,VISp,,,


In [9]:
units_df = session_data.units
region_units = units_df[units_df["structure_acronym"] == region]

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


In [14]:
display(region_units)

Unnamed: 0_level_0,waveform_PT_ratio,waveform_amplitude,amplitude_cutoff,cluster_id,cumulative_drift,d_prime,firing_rate,isi_violations,isolation_distance,L_ratio,...,ecephys_structure_id,ecephys_structure_acronym,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate,probe_description,location,probe_sampling_rate,probe_lfp_sampling_rate,probe_has_lfp_data
unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
915960297,0.410785,117.866385,0.001235,335,214.70,3.013238,3.051031,0.089814,50.575739,0.006127,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960294,0.210040,106.159365,0.000761,334,159.52,3.250551,13.198225,0.006399,74.417854,0.002733,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960290,0.248721,131.387490,0.053858,333,491.48,2.292449,5.642346,0.038768,58.246321,0.008143,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960287,0.560337,153.604815,0.034125,332,309.50,3.832008,14.034699,0.005457,67.692919,0.012434,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960283,0.401756,141.322935,0.046545,331,144.15,2.651925,4.370163,0.062538,49.513412,0.009947,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
915960975,0.374368,262.344420,0.002301,542,85.86,5.000994,0.543841,0.269217,59.157642,0.000969,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960972,0.487083,126.358635,0.027390,541,358.85,3.163878,0.980188,0.290065,47.377000,0.006576,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960968,0.372298,345.604740,0.001666,539,176.93,8.156430,0.634339,0.000000,95.795791,0.000010,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True
915960996,0.401461,227.728995,0.004779,550,191.13,6.107058,0.511019,0.000000,66.139560,0.000161,...,385.0,VISp,,,,probeC,See electrode locations,29999.991665,1249.999653,True


In [10]:
spike_times = session_data.spike_times
display(spike_times)

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


{915957951: array([  58.43389839,   68.84436109,   69.12766189, ..., 9369.23959097,
        9369.37909136, 9369.74325905]),
 915957946: array([1.02166981e+00, 1.17723691e+00, 2.50844067e+00, ...,
        9.41225195e+03, 9.41295741e+03, 9.41401995e+03]),
 915957685: array([1.24900378e+00, 1.25403713e+00, 1.83147209e+00, ...,
        9.41134111e+03, 9.41239791e+03, 9.41420385e+03]),
 915956508: array([  14.71984177,   22.41733015,   30.70085351, ..., 9414.77161938,
        9414.82191952, 9414.90375309]),
 915956502: array([1.46030438e+00, 1.46980440e+00, 1.48103777e+00, ...,
        9.41494015e+03, 9.41500052e+03, 9.41503145e+03]),
 915956668: array([7.26235640e-01, 1.11377007e+00, 1.37823748e+00, ...,
        9.41484879e+03, 9.41486355e+03, 9.41494369e+03]),
 915956662: array([7.46502364e-01, 8.28235927e-01, 9.13002833e-01, ...,
        9.41272578e+03, 9.41280891e+03, 9.41329988e+03]),
 915957581: array([3.27084282e+00, 7.83262235e+00, 7.83682236e+00, ...,
        9.41496709e+03, 9.4149

In [11]:
region_unit_ids = set(region_units.index)
region_spike_times = {
    unit_id: spikes for unit_id, spikes in spike_times.items() if unit_id in region_unit_ids
}

In [12]:
display(region_spike_times)

{915960921: array([  82.15947037,   82.20857039,   83.81880417, ..., 9414.41789664,
        9414.51213   , 9414.63093003]),
 915960683: array([7.43947753e-01, 7.61647758e-01, 1.01491450e+00, ...,
        9.41484863e+03, 9.41497826e+03, 9.41499916e+03]),
 915960678: array([6.36547724e-01, 7.61414425e-01, 8.23247775e-01, ...,
        9.41474240e+03, 9.41482016e+03, 9.41499213e+03]),
 915960674: array([1.77788137e+00, 2.66708162e+00, 2.80251499e+00, ...,
        9.41284170e+03, 9.41297856e+03, 9.41406206e+03]),
 915960812: array([3.83351528e+00, 3.95428198e+00, 7.57101632e+00, ...,
        9.41401743e+03, 9.41470813e+03, 9.41477036e+03]),
 915960810: array([5.40881030e-01, 7.20304955e+00, 7.32741625e+00, ...,
        9.36201422e+03, 9.39398782e+03, 9.40106503e+03]),
 915960586: array([6.35847723e-01, 7.76647762e-01, 8.40514447e-01, ...,
        9.17326470e+03, 9.17328940e+03, 9.17335493e+03]),
 915960835: array([6.30918263e+00, 2.68179550e+01, 2.69210884e+01, ...,
        9.41119550e+03, 

In [13]:
region_unit_spike_counts = {len(ts) for ts in region_spike_times.values()}

In [14]:
# See total number of spikes, and the maximum number of spikes in a single unit
sum(region_unit_spike_counts), max(region_unit_spike_counts)

(6910581, 458062)

- Options for feeding in spikes:
  - total spike counts per unit in a small time window

In [15]:
# Compute binned spike counts dataframe

win = 0.01  # 10 ms
min_max_pairs = [(ts_arr.min(), ts_arr.max()) for ts_arr in region_spike_times.values()]
first_spike_ts, last_spike_ts = (
    min(pair[0] for pair in min_max_pairs), max(pair[1] for pair in min_max_pairs)
)
stop_time = session_data.stimulus_presentations.iloc[-1]["stop_time"]

all_spike_ts = np.concatenate(list(region_spike_times.values()))
unit_indices = np.concatenate(
    [np.full(len(spikes), i) for i, spikes in enumerate(region_spike_times.values())]
)
bins = np.arange(0, last_spike_ts, win)
counts, _, _ = np.histogram2d(
    all_spike_ts, unit_indices, bins=[bins, range(len(region_spike_times) + 1)]
)
counts = torch.from_numpy(counts.astype(np.uint8))

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


In [16]:
# See what percentage of elements are non-zero,
# and what percentage of examples have at least one non-zero
frac_nonzero_bins = torch.sum(counts > 0) / counts.numel()
spike_counts_per_bin = torch.sum(counts > 0, axis=1)
frac_nonzero_examples = sum(spike_counts_per_bin > 0) / spike_counts_per_bin.size()[0]
print(f"{frac_nonzero_bins=}\n{frac_nonzero_examples=}")

frac_nonzero_bins=tensor(0.0614)
frac_nonzero_examples=tensor(0.9324)


In [17]:
indxs = torch.nonzero(counts).t()  # get nonzero indxs in appropriate format for `sparse_coo_tensor`
vals = counts[indxs[0], indxs[1]]
counts = torch.sparse_coo_tensor(indxs, vals, torch.Size(counts.shape))

In [18]:
print(counts)

tensor(indices=tensor([[    54,     54,     54,  ..., 941505, 941505, 941505],
                       [     5,     10,     12,  ...,     58,     72,     95]]),
       values=tensor([1, 1, 1,  ..., 1, 1, 1]),
       size=(941506, 110), nnz=6355897, dtype=torch.uint8,
       layout=torch.sparse_coo)


In [None]:
# counts_df = pd.DataFrame(
#     counts.to_dense(),
#     columns=region_spike_times.keys(),  # unit IDs as column names
#     index=bins[:-1],  # bin start times as index
# )

In [93]:
# display(counts_df)

Unnamed: 0,915960921,915960683,915960678,915960674,915960812,915960810,915960586,915960835,915960832,915960825,...,915960615,915960297,915960294,915960290,915960287,915960283,915960275,915960262,915960382,915960947
0.00,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
0.01,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
0.02,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
0.03,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
0.04,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9415.01,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,1,0,0,0
9415.02,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9415.03,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9415.04,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0


## Train the SAE

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))

device=device(type='cuda')
NVIDIA GeForce RTX 3090


In [20]:
@dataclass
class SaeConfig:
    """Config to set some params for (Toy) `AutoEncoder`."""

    n_input_ae: int  # number of input units to the autoencoder
    n_hidden_ae: int = 10_000  # number of hidden units in the autoencoder
    n_instances: int = 2  # number of model instances to optimize in parallel
    l1_coeff: float = 1.1  # relative weight of sparsity_loss : activations_reconstruction_loss
    

In [46]:
# Estimate gb required for largest tensor op

# batch_sz, n_instances, n_units, seq_len, n_hidden_ae = 8, 2, 110, 1, 10000
# [8, 2, 110] @ [2, 110, 10000] -> [8, 2, 10000]
8 * 2 * 110 * 2 * 110 * 10000 * 4 / 1e9

15.488

In [45]:
"""Create SAE for taking in Toy Model's activations and finding features."""

class Sae(nn.Module):
    """SAE for finding features in the activations of the Toy Model."""

    # Shapes of weights and biases for the encoder and decoder in the SAE.
    W_enc: Float[Tensor, "n_instances n_input_ae n_hidden_ae"]
    W_dec: Float[Tensor, "n_instances n_hidden_ae n_input_ae"]
    b_enc: Float[Tensor, "n_instances n_hidden_ae"]
    b_dec: Float[Tensor, "n_instances n_input_ae"]

    def __init__(self, cfg: SaeConfig):
        """Initializes model parameters."""
        super().__init__()
        self.cfg = cfg
        self.W_enc = nn.Parameter(
            nn.init.xavier_normal_(
                torch.empty(
                    (cfg.n_instances, cfg.n_input_ae, cfg.n_hidden_ae), dtype=torch.bfloat16
                )
            )
        )
        self.W_dec = nn.Parameter(
            nn.init.xavier_normal_(
                torch.empty(
                    (cfg.n_instances, cfg.n_hidden_ae, cfg.n_input_ae), dtype=torch.bfloat16
                )
            )
        )
        self.b_enc = nn.Parameter(
            torch.zeros((cfg.n_instances, cfg.n_hidden_ae), dtype=torch.bfloat16)
        )
        self.b_dec = nn.Parameter(
            torch.zeros((cfg.n_instances, cfg.n_input_ae), dtype=torch.bfloat16)
        )

    def forward(self, h: Float[Tensor, "n_instances batch_sz n_input_ae"]):
        """Computes loss as a function of SAE feature sparsity and spike_count reconstructions."""
        # Compute encoder hidden activations.
        z = F.relu(
            einops.einsum(
                h,
                self.W_enc,
                "n_instances batch_sz n_input_ae, n_instances n_input_ae n_hidden_ae "
                "-> n_instances batch_sz n_hidden_ae",
            )
            + self.b_enc.unsqueeze(1)
        )

        # Compute reconstructed input.
        h_prime = (
            einops.einsum(
                z,
                self.W_dec,
                "n_instances batch_sz n_hidden_ae, n_instances n_hidden_ae n_input_ae "
                "-> n_instances batch_sz n_input_ae",
            )
            + self.b_dec.unsqueeze(1)
        )

        # Compute loss (l1_loss and l2_loss shapes: [batch_sz n_instances]) and return values.
        l1_loss = z.abs().sum(-1)  # sparsity component of loss, over n_hidden_ae
        l2_loss = (h_prime - h).pow(2).mean(-1)  # activations reconstruction loss, over n_input_ae
        loss = (self.cfg.l1_coeff * l1_loss + l2_loss).mean(0).sum()  # scalar
        return l1_loss, l2_loss, loss, z, h_prime

    @torch.no_grad()
    def normalize_decoder(self) -> None:
        """Normalizes the decoder weights to have unit norm."""
        self.W_dec.data = self.W_dec.data / self.W_dec.data.norm(dim=2, keepdim=True)

In [46]:
@torch.no_grad()
def resample_neurons(
    self: Sae,
    frac_active_in_window: Float[Tensor, "window n_instances n_hidden_ae"],
) -> Tuple[List[List[str]], str]:  # -> (colors_for_neurons, title_with_resampling_info)
    """Resamples neurons dead for `dead_neuron_window` steps, according to `frac_active`."""
    # Get a tensor of dead neurons.
    dead_features_mask = frac_active_in_window.sum(0) < 1e-8  # -> [n_instances n_hidden_ae]
    n_dead = dead_features_mask.int().sum().item()

    # Get our random replacement values (-> [n_dead n_input_ae]).
    replacement_vals = torch.randn((n_dead, self.cfg.n_input_ae), device=self.W_enc.device)
    replacement_vals_norm = replacement_vals / replacement_vals.norm(dim=-1, keepdim=True) + 1e-8

    # Reset W_enc, W_dec, and b_enc (we transpose W_enc to return a view with correct shape).
    self.W_enc.data.transpose(-1, -2)[dead_features_mask] = replacement_vals_norm
    self.W_dec.data[dead_features_mask] = replacement_vals_norm
    self.b_enc.data[dead_features_mask] = 0.0

    # Return data for visualising the resampling process.
    colors = [
        ["red" if dead else "black" for dead in dead_neuron_mask_inst]
        for dead_neuron_mask_inst in dead_features_mask
    ]
    title = f"resampling {n_dead}/{dead_features_mask.numel()} neurons (shown in red)"
    return colors, title

# Add method to sae class.
Sae.resample_neurons = resample_neurons

In [47]:
def lr_schedule(*_):
    """Simple learning rate schedule."""
    return 1.0

def optimize(
    self: Sae,
    spike_counts: Int[Tensor, "n_timebins n_units"],
    seq_len: int = 1,  # number of timebins to use in each spike_count_seq
    batch_sz: int = 8,
    steps: int = 500_000,
    log_freq: int = 1000,
    lr: float = 1e-3,
    lr_scale: Callable[[int, int], float] = lr_schedule,
    neuron_resample_window: Optional[int] = None,  # in optimization steps
):
    """Optimizes the autoencoder using the given hyperparameters."""
    optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)

    # Create lists to store data we'll eventually be plotting.
    frac_active_list = []  # fraction of non-zero activations for each neuron (feature)
    data_log = {"frac_active": [], "W_enc": [], "W_dec": [], "titles": [], "colors": []}
    colors = None
    title = "No resampling"

    # Define valid samples for `spike_counts`.
    n_timebins, _n_units = spike_counts.shape
    valid_starts = n_timebins - seq_len + 1

    progress_bar = tqdm(range(steps))
    for step in progress_bar:
        # Normalize decoder weights at each step to prevent artificially small / sparse
        # features from large decoder weights.
        self.normalize_decoder()

        # # Check for dead neurons, and resample them if found.
        # if (neuron_resample_window is not None) and (
        #     (step + 1) % neuron_resample_window == 0
        # ):
        #     frac_active_in_window = torch.stack(
        #         frac_active_list[-neuron_resample_window:], dim=0
        #     )
        #     colors, title = self.resample_neurons(frac_active_in_window)

        # Update learning rate.
        step_lr = lr * lr_scale(step, steps)
        for group in optimizer.param_groups:
            group["lr"] = step_lr

        # Get batch of spikes.
        start_indxs = torch.randint(0, valid_starts, (self.cfg.n_instances, batch_sz))
        seq_indxs = start_indxs.unsqueeze(-1) + torch.arange(seq_len)
        spike_count_seqs = spike_counts[seq_indxs]  # -> [n_instances batch_sz seq_len n_units]
        spike_count_seqs = einops.rearrange(
            spike_count_seqs,
            'n_instances batch_sz seq_len n_units -> n_instances (batch_sz seq_len) n_units'
        )

        # Optimize.
        optimizer.zero_grad()
        l1_loss, l2_loss, loss, z, _ = self.forward(spike_count_seqs)
        loss.backward()
        optimizer.step()

        # Calculate the sparsities and them to the list.
        frac_active = einops.reduce(
            (z.abs() > 1e-8).float(),
            "batch_sz n_instances hidden_ae -> n_instances hidden_ae",
            "mean",
        )
        frac_active_list.append(frac_active)

        # Display progress bar, and append new values for plotting.
        if step % log_freq == 0 or (step + 1 == steps):
            progress_bar.set_postfix(
                l1_loss=self.cfg.l1_coeff * l1_loss.mean(0).sum().item(),
                l2_loss=l2_loss.mean(0).sum().item(),
                lr=step_lr,
            )
            data_log["W_enc"].append(self.W_enc.detach().cpu())
            data_log["W_dec"].append(self.W_dec.detach().cpu())
            data_log["titles"].append(f"Step {step}/{steps}: {title}")
            data_log["frac_active"].append(frac_active.detach().cpu())
            data_log["colors"].append(colors)

    return data_log


# Add method to SAE class.
Sae.optimize = optimize

In [48]:
"""Train the SAE."""

ae_cfg = SaeConfig(
    n_input_ae=counts.shape[1],
    n_instances=2,
    n_hidden_ae=10_000,
    l1_coeff=1.1,
)
sae = Sae(ae_cfg).to(device)

data_log = sae.optimize(spike_counts=counts.to_dense().bfloat16().to(device))

  0%|          | 0/10000 [00:00<?, ?it/s]

## Interpret the SAE features

In [27]:
session_stim_data = session_data.get_stimulus_table()
display(session_stim_data)

Unnamed: 0_level_0,stimulus_block,start_time,stop_time,spatial_frequency,temporal_frequency,x_position,stimulus_name,orientation,y_position,frame,size,phase,contrast,color,duration,stimulus_condition_id
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
0,,21.579074,81.645874,,,,spontaneous,,,,,,,,60.066800,0
1,0.0,81.645874,81.879397,0.08,4.0,-30.0,gabors,45.0,-20.0,,"[20.0, 20.0]","[3644.93333333, 3644.93333333]",0.8,,0.233523,1
2,0.0,81.879397,82.129600,0.08,4.0,20.0,gabors,90.0,-10.0,,"[20.0, 20.0]","[3644.93333333, 3644.93333333]",0.8,,0.250203,2
3,0.0,82.129600,82.379803,0.08,4.0,40.0,gabors,90.0,30.0,,"[20.0, 20.0]","[3644.93333333, 3644.93333333]",0.8,,0.250203,3
4,0.0,82.379803,82.630006,0.08,4.0,40.0,gabors,45.0,-40.0,,"[20.0, 20.0]","[3644.93333333, 3644.93333333]",0.8,,0.250203,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
70383,14.0,9142.164805,9142.415016,0.04,,,static_gratings,60.0,,,"[250.0, 250.0]",0.75,0.8,,0.250210,4806
70384,14.0,9142.415016,9142.665223,0.08,,,static_gratings,30.0,,,"[250.0, 250.0]",0.0,0.8,,0.250207,4857
70385,14.0,9142.665223,9142.915430,0.32,,,static_gratings,60.0,,,"[250.0, 250.0]",0.75,0.8,,0.250207,4876
70386,14.0,9142.915430,9143.165637,0.16,,,static_gratings,90.0,,,"[250.0, 250.0]",0.5,0.8,,0.250207,4790


## Scratchpad below here

In [None]:
def rotated_binary_search(arr, target, left, right):
    # Base case: If the range is invalid
    if left > right:
        return -1  # Target not found

    # Find the middle index
    mid = (left + right) // 2

    # If the middle element is the target
    if arr[mid] == target:
        return mid

    # Determine if the left half is sorted
    if arr[left] <= arr[mid]:
        # Check if the target is in the left sorted half
        if arr[left] <= target < arr[mid]:
            return rotated_binary_search(arr, target, left, mid - 1)
        else:  # Otherwise, search in the right half
            return rotated_binary_search(arr, target, mid + 1, right)
    else:
        # Right half must be sorted
        if arr[mid] < target <= arr[right]:
            return rotated_binary_search(arr, target, mid + 1, right)
        else:  # Otherwise, search in the left half
            return rotated_binary_search(arr, target, left, mid - 1)


# Wrapper function to simplify the call
def search_in_rotated_sorted_list(arr, target):
    return rotated_binary_search(arr, target, 0, len(arr) - 1)


# Example
arr = [7, 8, 9, 2, 3, 4]
target = 3
index = search_in_rotated_sorted_list(arr, target)
print(f"Index of {target} is: {index}")