In [1]:
# Import your existing types and functions
from typing import Optional, Union, Any, Callable
from dataclasses import dataclass
import random
from pathlib import Path
import math

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch

import pandas as pd
import torch
from torch import Tensor, einsum
from jaxtyping import Float, Int, Bool

from js_embedding_vis import write_inlined_config
from muutils.collect_warnings import CollateWarnings
from muutils.dbg import dbg, dbg_auto, dbg_tensor

# from spd.clustering.embed_vis import AnalysisConfig, coactivation_analysis, plot_embedding_label_grid
from spd.clustering.grouping import (
    CoactivationResults,
    CoactivationResultsGroup,
    get_coactivations,
)
from spd.utils.data_utils import SparseFeatureDataset
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset
from spd.clustering.merge_matrix import GroupMerge, BatchedGroupMerge

ModuleNotFoundError: No module named 'spd.data_utils'

In [None]:
# magic autoreload
%load_ext autoreload
%autoreload 2

In [None]:
DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
print(f"Using device: {DEVICE = }")

In [None]:
COACTIVATIONS = get_coactivations(
    model_path=Path("../data/mlp-decomp/model_30000.pth"),
    dataset_cls=ResidualMLPDataset,
    dataset_kwargs=dict(
        calc_labels=False,  # Our labels will be the output of the target model
        label_type=None,
        act_fn_name=None,
        label_fn_seed=None,
        label_coeffs=None,
        synced_inputs=None,
    ),
	dataloader_kwargs=dict(
		batch_size=999,
    ),
    coactivations_kwargs=dict(
        module_groups=[["layers.0.mlp_in", "layers.0.mlp_out"]],
		n_samples=999,
    ),
    device=DEVICE,
)["group_0"]



dbg_auto(COACTIVATIONS);

the above actually makes sense. the last 100 are the superimposed components which we expect to merge into one super-component.	

In [None]:
# gm = GroupMerge.random(
#     n_components=200,
#     k_groups=50,
#     ensure_groups_nonempty=True,
# )
gm_ident = GroupMerge.identity(n_components=200)

gm_ident.plot(figsize=(10, 2))
gm_downstream: BatchedGroupMerge = gm_ident.all_downstream_merged()
dbg_tensor(gm_downstream.group_idxs);
dbg("d")
gmd_m = gm_downstream.group_idxs
dbg("e")
dbg_tensor(gmd_m);
gm_downstream[0].plot(figsize=(10, 2))


$$
	F_g := \frac{\alpha}{n}
	\Bigg[
		d(A(g)) \cdot Q^T 
		+ Q \cdot d(A(g))^T
		- \Big(
			R \mathbf{1}^T
			+ \mathbf{1} R^T + \alpha^{-1}
		\Big) 
		\odot A(g)
	\Bigg]
$$

In [None]:
def compute_merge_costs(
    coact: Bool[Tensor, "k_groups k_groups"],
    merges: GroupMerge,
    alpha: float = 1.0,
    # rank_cost: Callable[[float], float] = lambda c: math.log(c),
    rank_cost: Callable[[float], float] = lambda _: 1.0,
) -> Float[Tensor, "k_groups k_groups"]:
    """Compute MDL costs for merge matrices"""
    device: torch.device = coact.device
    ranks: Float[Tensor, " k_groups"] = merges.components_per_group.to(device=device).float()
    diag: Float[Tensor, "k_groups"] = torch.diag(coact).to(device=device)

    # dbg_tensor(coact)
    # dbg_tensor(ranks)
    # dbg_tensor(diag)

    return alpha * (
        diag @ ranks.T
        + ranks @ diag.T
        - (
            ranks.unsqueeze(0) 
            + ranks.unsqueeze(1)
            + (rank_cost(merges.k_groups) / alpha)
        ) * coact
    )

In [None]:
costs = compute_merge_costs(
	coact=COACTIVATIONS['co_occurrence_matrix'],
	merges=gm_ident,
)
plt.matshow(costs.cpu(), cmap='viridis')

In [None]:

@dataclass
class MergeData:
	coact: Float[Tensor, "k_groups k_groups"]
	activation_mask: Bool[Tensor, "samples k_groups"]
	merges: GroupMerge



	@classmethod
	def init(
		
	) -> "MergeData":
		coact_results: CoactivationResultsGroup = get_coactivations(
			model_path=Path("../data/mlp-decomp/model_30000.pth"),
			dataset_cls=ResidualMLPDataset,
			dataset_kwargs=dict(
				calc_labels=False,  # Our labels will be the output of the target model
				label_type=None,
				act_fn_name=None,
				label_fn_seed=None,
				label_coeffs=None,
				synced_inputs=None,
			),
			dataloader_kwargs=dict(
				batch_size=999,
			),
			coactivations_kwargs=dict(
				module_groups=[["layers.0.mlp_in", "layers.0.mlp_out"]],
				n_samples=999,
			),
			device=DEVICE,
		)["group_0"]


		return MergeData(
			coact=coact_results["co_occurrence_matrix"],
			activation_mask=coact_results["active_mask"],
			merges=GroupMerge.identity(

	def recompute_coacts(
		coact: Float[Tensor, "k_groups k_groups"],
		merges: GroupMerge,
		merge_pair: tuple[int, int],
		activation_mask: Bool[Tensor, "samples k_groups"],
	) -> tuple[
			GroupMerge,
			Float[Tensor, "k_groups-1 k_groups-1"],
			Bool[Tensor, "samples k_groups"],
		]:
		# check shape
		k_groups: int = coact.shape[0]
		assert coact.shape[1] == k_groups, "Coactivation matrix must be square"

		# activations of the new merged group
		activation_mask_grp: Bool[Tensor, " samples"] = activation_mask[:, merge_pair[0]] + activation_mask[:, merge_pair[1]]

		# coactivations with the new merged group
		# dbg_tensor(activation_mask_grp)
		# dbg_tensor(activation_mask)
		coact_with_merge: Bool[Tensor, " k_groups"] = (activation_mask_grp.float() @ activation_mask.float()).bool()
		new_group_idx: int = min(merge_pair)
		remove_idx: int = max(merge_pair)
		new_group_self_coact: float = activation_mask_grp.float().sum().item()
		# dbg_tensor(coact_with_merge)

		# assemble the merge pair
		merge_new: GroupMerge = merges.merge_groups(
			merge_pair[0],
			merge_pair[1],
		)
		old_to_new_idx: dict[int|None, int| None] = merge_new.old_to_new_idx # type: ignore
		assert old_to_new_idx[None] == new_group_idx, "New group index should be the minimum of the merge pair"
		assert old_to_new_idx[new_group_idx] is None
		assert old_to_new_idx[remove_idx] is None
		# TODO: check that the rest are in order? probably not necessary

		# reindex coactivations
		coact_temp: Float[Tensor, "k_groups k_groups"] = coact.clone()
		# add in the similarities with the new group
		coact_temp[new_group_idx, :] = coact_with_merge
		coact_temp[:, new_group_idx] = coact_with_merge
		# delete the old group
		mask: Bool[Tensor, " k_groups"] = torch.ones(coact_temp.shape[0], dtype=torch.bool, device=coact_temp.device)
		mask[remove_idx] = False
		coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact_temp[mask, :][:, mask]
		# add in the self-coactivation of the new group
		coact_new[new_group_idx, new_group_idx] = new_group_self_coact
		# dbg_tensor(coact_new)

		# reindex mask
		activation_mask_new: Float[Tensor, "samples ..."] = activation_mask.clone()
		# add in the new group
		activation_mask_new[:, new_group_idx] = activation_mask_grp
		# remove the old group
		activation_mask_new = activation_mask_new[:, mask]
		
		# dbg_tensor(activation_mask_new)

		return (
			merge_new,
			coact_new,
			activation_mask_new,
		)


rc_test = recompute_coacts(
	coact=COACTIVATIONS['co_occurrence_matrix'],
	merges=gm_ident,
	merge_pair=(0, 1),
	activation_mask=COACTIVATIONS['active_mask'],
)
rc_test[0].plot()
plt.show()
plt.matshow(rc_test[1].cpu(), cmap='viridis')
plt.show()
plt.matshow(rc_test[2].T.cpu(), cmap='viridis')
plt.show()


In [None]:
def merge_dead_components(
	coact: Float[Tensor, "k_groups k_groups"],
	activation_mask: Bool[Tensor, "samples k_groups"],
	threshold: float = 1e-6,
) -> tuple[
	GroupMerge,
	Float[Tensor, "k_new k_new"],
	Bool[Tensor, "samples k_new"],
]:
	"""Merge dead components into a single group."""
	merge: GroupMerge = GroupMerge.identity(n_components=activation_mask.shape[1])

	# find dead components
	act_probs: Float[Tensor, "k_groups"] = activation_mask.float().mean(dim=0)
	dbg_tensor(act_probs)
	dead_components: Bool[Tensor, "k_groups"] = act_probs < threshold
	dbg_tensor(dead_components)
	if not dead_components.any():
		return merge


	dead_idxs = torch.where(dead_components)[0]
	dbg_tensor(dead_idxs)

	first_dead_idx: int = dead_idxs[0].item()
	dead_idxs_remove = dead_idxs[1:]  # all but the first
	# merge dead components into a single group
	for i in dead_idxs_remove.tolist()[::-1]: # reverse order to avoid index issues
		merge = merge.merge_groups(
			first_dead_idx,
			i,
		)

	dbg_tensor(merge.group_idxs)
	
	coact_new: Float[Tensor, "k_new k_new"] = coact[~dead_idxs_remove]
	
	return (
		merge,
		
	)

gm_dead = merge_dead_components(
	activation_mask=COACTIVATIONS['active_mask'],
)
gm_dead.plot()

In [None]:
def merge_iteration(
	coact: Bool[Tensor, "c_components c_components"],
	activation_mask: Bool[Tensor, "samples c_components"],
	initial_merge: GroupMerge|None = None,
    alpha: float = 1.0,
	iters: int = 100,
	check_threshold: float = 0.05,
	pop_component_prob: float = 0.0,
):
	# check shapes
	c_components: int = coact.shape[0]
	assert coact.shape[1] == c_components, "Coactivation matrix must be square"
	assert activation_mask.shape[1] == c_components, "Activation mask must match coactivation matrix shape"


	do_pop: bool = pop_component_prob > 0.0
	if do_pop:
		iter_pop: Bool[Tensor, " iters"] = torch.rand(iters, device=coact.device) < pop_component_prob

	# start with an identity merge
	current_merge: GroupMerge
	if initial_merge is not None:
		current_merge = initial_merge
	else:
		current_merge = GroupMerge.identity(n_components=c_components)

	k_groups: int = c_components
	current_coact: Float[Tensor, "k_groups k_groups"] = coact.clone()
	current_act_mask: Bool[Tensor, "samples k_groups"] = activation_mask.clone()

	# iteration counter
	i: int = 0
	while i < iters:
		# pop a component if needed
		# if do_pop and iter_pop[i]:
		# 	# randomly select a component to pop
		# 	pop_idx: int = random.randint(0, k_groups - 1)
		# 	dbg(f"Popping component {pop_idx}")
		# 	# remove the component from the merge
			




		# compute costs
		costs: Float[Tensor, "c_components c_components"] = compute_merge_costs(
			coact=current_coact,
			merges=current_merge,
			alpha=alpha,
		)

		# find the maximum cost among non-diagonal elements we should consider
		non_diag_costs: Float[Tensor, ""] = costs[~torch.eye(k_groups, dtype=torch.bool)]
		non_diag_costs_range: tuple[float, float] = (non_diag_costs.min().item(), non_diag_costs.max().item())
		max_considered_cost: float = (non_diag_costs_range[1] - non_diag_costs_range[0]) * check_threshold + non_diag_costs_range[0]

		# consider pairs with costs below the threshold
		considered_idxs = torch.where(costs <= max_considered_cost)
		considered_idxs = torch.stack(considered_idxs, dim=1)
		# remove from considered_idxs where i == j
		considered_idxs = considered_idxs[considered_idxs[:, 0] != considered_idxs[:, 1]]		

		# randomly select one of the considered pairs
		min_pair: tuple[int, int] = tuple(considered_idxs[random.randint(0, considered_idxs.shape[0] - 1)].tolist())
		pair_cost: float = costs[min_pair[0], min_pair[1]].item()

		# merge the pair
		current_merge, current_coact, current_act_mask = recompute_coacts(
			coact=current_coact,
			merges=current_merge,
			merge_pair=min_pair,
			activation_mask=current_act_mask,
		)

		# dbg_tensor(costs)
		# dbg_tensor(non_diag_costs)
		# dbg(non_diag_costs_range)		
		# dbg(max_considered_cost)
		# dbg_tensor(considered_idxs)
		# dbg(f"Iteration {i}: merging pair {min_pair=} {pair_cost=} {non_diag_costs_range[0]=} {max_considered_cost=}")

		k_groups -= 1
		assert current_coact.shape[0] == k_groups, "Coactivation matrix shape should match number of groups"
		assert current_coact.shape[1] == k_groups, "Coactivation matrix shape should match number of groups"
		assert current_act_mask.shape[1] == k_groups, "Activation mask shape should match number of groups"


		if i % 50 == 0:
			current_merge.plot()
			plt.show()

		i += 1

merge_iteration(
	coact=COACTIVATIONS['co_occurrence_matrix'],
	activation_mask=COACTIVATIONS['active_mask'],
	initial_merge=gm_dead,
	alpha=1.0,
	iters=160,
)

In [None]:


def merge_iteration(
    coact: Float[Tensor, "k_groups k_groups"],
    active_mask: Bool[Tensor, "samples n_components"],
    merge: GroupMerge,
    alpha: float = 1.0,
) -> dict[str, Any]:
    n_samples: int = active_mask.shape[0]
    n_components: int = active_mask.shape[1]
    assert n_components == coact.shape[0] == coact.shape[1]




    return {}

    


In [None]:
dbg_tensor(COACTIVATIONS['active_mask'])
merge_costs = compute_merge_costs(
    coact=COACTIVATIONS['active_mask'],
    bgm=gm_downstream,
    alpha=1.0,
)
dbg_tensor(merge_costs);

In [None]:
n_downstream: int = gm_ident.k_groups
merge_costs_grid = torch.full((n_downstream, n_downstream), torch.nan)
for g in range(gm_downstream.batch_size):
    mp = gm_downstream.meta[g]['merge_pair']
    merge_costs_grid[mp] = merge_costs[g].cpu()
    merge_costs_grid[mp[1], mp[0]] = merge_costs[g].cpu()

plt.matshow(merge_costs_grid)
plt.colorbar(label='Merge Cost')

In [None]:
def greedy_merge(
	activation_mask: Bool[Tensor, "n_samples n_components"],
	target_k_groups: int,
    alpha: float = 1.0,
	initial_guess: GroupMerge|None = None,
):
	n_samples: int; n_components: int
	n_samples, n_components = activation_mask.shape

	current_merge: GroupMerge
	if initial_guess is None:
		current_merge = GroupMerge.identity(n_components)
	else:
		current_merge = initial_guess


	dbg_tensor(current_merge.to_matrix())

	while current_merge.k_groups > target_k_groups:
		# Compute merge costs for all pairs of groups
		adm: BatchedGroupMerge = current_merge.all_downstream_merged()
		dbg_tensor(adm.to_matrix())
		merge_costs: Tensor = compute_merge_costs(
			coact=activation_mask,
			bgm=adm,
			alpha=alpha,
		)

		# Find the pair with the lowest merge cost
		min_cost, min_pair = torch.min(merge_costs, dim=0)
		dbg(min_cost)
		dbg(min_pair)

		# Merge the pair with the lowest cost
		current_merge = current_merge.merge_groups(min_pair)
		dbg(f"Merging {min_pair} with cost {min_cost.item()}")
		dbg_tensor(current_merge.to_matrix())



greedy_merge(COACTIVATIONS['active_mask'], 10)

In [None]:


def generate_all_merge_matrices(n_components: int, k: int, device=None) -> Bool[Tensor, "n_matrices k n_components"]:
    """Generate all possible merge matrices as single tensor"""
    if device is None:
        device = torch.device('cpu')
    
    # Total possibilities: k^n_components
    n_total = k ** n_components
    
    # Generate all base-k assignments
    assignments = torch.zeros(n_total, n_components, dtype=torch.long, device=device)
    
    for i in range(n_total):
        temp = i
        for j in range(n_components):
            assignments[i, j] = temp % k
            temp //= k
    
    # Convert to one-hot merge matrices
    merge_matrices = torch.zeros(n_total, k, n_components, dtype=torch.bool, device=device)
    batch_indices = torch.arange(n_total, device=device).unsqueeze(1)  # [n_total, 1]
    component_indices = torch.arange(n_components, device=device).unsqueeze(0)  # [1, n_components]
    
    merge_matrices[batch_indices, assignments, component_indices] = True
    
    # Filter out matrices with empty groups
    group_counts = merge_matrices.sum(dim=2)  # [n_total, k]
    valid_mask = (group_counts > 0).all(dim=1)  # [n_total]
    
    return merge_matrices[valid_mask]


def find_all_merge_costs(
    co_occurrence_matrix: Float[Tensor, "n_components n_components"],
    marginal_counts: Float[Tensor, "n_components"],
    k: int,
    alpha: float = 1.0,
) -> Float[Tensor, "n_matrices"]:
    """Return costs for all possible merge matrices"""
    all_matrices = generate_all_merge_matrices(marginal_counts.shape[0], k, marginal_counts.device)
    return compute_merge_costs(co_occurrence_matrix, marginal_counts, all_matrices, alpha)


def greedy_merge_search(
    co_occurrence_matrix: Float[Tensor, "n_components n_components"],
    marginal_counts: Float[Tensor, "n_components"],
    k: int,
    alpha: float = 1.0,
    temperature: float = 0.0,
    seed: Optional[int] = None,
) -> tuple[Bool[Tensor, "k n_components"], float]:
    """Greedy search with optional temperature sampling"""
    if seed is not None:
        torch.manual_seed(seed)
    
    n_components = marginal_counts.shape[0]
    device = marginal_counts.device
    
    # Start with identity
    merge_matrix = torch.eye(n_components, dtype=torch.bool, device=device)
    current_k = n_components
    
    while current_k > k:
        # Generate all possible merge candidates
        candidates = []
        
        for i in range(current_k):
            for j in range(i + 1, current_k):
                # Create candidate by merging groups i and j
                candidate = merge_matrix.clone()
                candidate[i] = candidate[i] | candidate[j]
                
                # Remove group j by shifting
                if j < current_k - 1:
                    candidate[j:current_k-1] = candidate[j+1:current_k]
                candidate = candidate[:current_k-1]
                
                candidates.append(candidate)
        
        # Compute costs for all candidates
        if len(candidates) > 0:
            candidate_stack = torch.stack(candidates)  # [n_candidates, current_k-1, n_components]
            
            # Compute current cost
            current_cost = compute_merge_costs(co_occurrence_matrix, marginal_counts, merge_matrix, alpha)
            
            # Compute candidate costs
            candidate_costs = compute_merge_costs(co_occurrence_matrix, marginal_counts, candidate_stack, alpha)
            cost_deltas = candidate_costs - current_cost
            
            # Select based on temperature
            if temperature == 0.0:
                best_idx = cost_deltas.argmin().item()
            else:
                probs = torch.softmax(-cost_deltas / temperature, dim=0)
                best_idx = torch.multinomial(probs, 1).item()
            
            merge_matrix = candidates[best_idx]
            current_k -= 1
        else:
            break
    
    final_cost = compute_merge_costs(co_occurrence_matrix, marginal_counts, merge_matrix, alpha)
    return merge_matrix, final_cost