In [None]:

import matplotlib.pyplot as plt
import torch
from muutils.dbg import dbg_auto

from spd.clustering.activations import component_activations, process_activations
from spd.clustering.merge import compute_merge_costs, merge_iteration
from spd.clustering.merge_matrix import GroupMerge
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset
from spd.models.component_model import ComponentModel
from spd.utils.data_utils import DatasetGeneratedDataLoader

DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
component_model, cfg, path = ComponentModel.from_pretrained("wandb:goodfire/spd/runs/dcjm9g2n")
component_model.to(DEVICE);
# dbg_auto(component_model)
# dbg_auto(cfg)
# dbg_auto(path)
# dir(component_model)

In [None]:

# grep_repr((component_model, cfg, path, dir(component_model)), "_features")
# cfg.task_config
# grep_repr(, "_features")

In [None]:
N_SAMPLES: int = 1000

dataset = ResidualMLPDataset(
    n_features=component_model.model.config.n_features,
    feature_probability=cfg.task_config.feature_probability,
    device=DEVICE,
    calc_labels=False,  # Our labels will be the output of the target model
    label_type=None,
    act_fn_name=None,
    label_fn_seed=None,
    label_coeffs=None,
    data_generation_type=cfg.task_config.data_generation_type,
    # synced_inputs=synced_inputs,
)

dataloader = DatasetGeneratedDataLoader(dataset, batch_size=N_SAMPLES, shuffle=False)


In [None]:
ci = component_activations(
	component_model,
	dataloader,
	device=DEVICE,
	# threshold=0.1,
)

dbg_auto(ci);

In [None]:
coa = process_activations(
	ci,
	filter_dead_threshold=0.1,
	plots=True,
);

In [None]:
gm_ident = GroupMerge.identity(n_components=coa["n_components_alive"])
			
gm_ident.plot(figsize=(10, 2))

$$
	F_g := \frac{\alpha}{n}
	\Bigg[
		d(A(g)) \cdot Q^T 
		+ Q \cdot d(A(g))^T
		- \Big(
			R \mathbf{1}^T
			+ \mathbf{1} R^T + \alpha^{-1}
		\Big) 
		\odot A(g)
	\Bigg]
$$

In [None]:
costs = compute_merge_costs(
	coact=coa['coactivations'],
	merges=gm_ident,
)
plt.matshow(costs.cpu(), cmap='viridis')
plt.colorbar()

In [None]:

coact_bool = coa['coactivations'] > 0.002
merge_iteration(
	coact=coact_bool.float().T @ coact_bool.float(),
	activation_mask=coact_bool,
	check_threshold=0.1,
	# initial_merge=?,
	# alpha=0.001,
	rank_cost=lambda _: 1e-1,
	alpha=1e-1,
	iters=100,
	plot_every=10,
	plot_every_min=0,
	# plot_every=None,
)