In [1]:
import torch
from utils import *
# Example usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
# model_name = "gpt2"
# model_name = "HuggingFaceTB/SmolLM-360M"
model_name = "HuggingFaceTB/SmolLM-135M"
m_name_save= model_name.replace("/", "_")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_bos_token = True
# batch_size = 512
# batch_size = 32
batch_size =32
max_length = 128
learning_rate = 1e-3
if(model_name == "gpt2"):
    target_layer = 'transformer.h.5'
    d_name = None
else: 
    target_layer = "model.layers.18"
    d_name = "cosmopedia-v2"
debug = True
if(debug):
    if(model_name == "gpt2"):
        dataset_name = "Elriggs/openwebtext-100k"
    else: 
        dataset_name = "HuggingFaceTB/smollm-corpus"
    # num_datapoints = 1_000_000
    num_datapoints = 20_000
    # num_datapoints = 15_000
    # num_datapoints = 2_000
    total_batches = num_datapoints // batch_size
    print(f"total amount of tokens in dataset: {num_datapoints * max_length / 1e6}M")
else:    
    if(model_name == "gpt2"):
        dataset_name = "prithivMLmods/OpenWeb888K"
        num_datapoints = 888_000 # 880_000
        total_batches = 888_000 // batch_size
    else: 
        dataset_name = "HuggingFaceTB/smollm-corpus"
        num_datapoints = 2_000_000
        total_batches = num_datapoints // batch_size
        print(f"total amount of tokens in dataset: {num_datapoints * max_length / 1e6}M")

data_generator = TokenizedDataset(dataset_name, tokenizer, d_name, batch_size=batch_size, max_length=max_length, total_batches=total_batches)

val_data_generator = TokenizedDataset(dataset_name, tokenizer, d_name, batch_size=batch_size, max_length=max_length, total_batches=total_batches, shuffle_seed=98)

total amount of tokens in dataset: 2.56M


Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

In [2]:
# Now we want to download all our SAEs
import json
from utils import DotDict, AutoEncoderTopK
from huggingface_hub import hf_hub_download
huggingface_name = "Elriggs/seq_concat_HuggingFaceTB_SmolLM-135M_model.layers.18"
name_prefix = f"sae_k=30_tokBias=True"
sae_name_style = name_prefix + ".pt"
cfg_name_style = name_prefix + "_cfg.json"

model_path = hf_hub_download(
    repo_id=huggingface_name,
    filename=sae_name_style
)

# Download config file
config_path = hf_hub_download(
    repo_id=huggingface_name,
    filename=cfg_name_style
)
cfg = DotDict(json.load(open(config_path)))

sae = AutoEncoderTopK.from_pretrained(model_path, k=cfg.k, device = None, embedding=True)
print(f"Memory usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

sae.to(device)
# print memory  in GB
print(f"Memory usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

Memory usage: 0.51 GB
Memory usage: 0.65 GB


# Train SAE offset by 1 batch 
so we don't need to recompute the LLM's activations

# Logan's Attempt

In [3]:
class TemperatureScheduler:
    def __init__(self, 
                 high_temp=1e-3,
                 low_temp=0.0,
                 start_high_fraction=0.2,  # First 20% at high temp
                 end_high_fraction=0.6,    # High temp until 60% of training
                 end_transition_fraction=0.8):  # Low temp from 80% to end
        """
        Temperature scheduler that controls noise injection during training.
        
        Args:
            high_temp: The higher temperature value
            low_temp: The lower temperature value (typically 0)
            start_high_fraction: Fraction of training where temp reaches high_temp
            end_high_fraction: Fraction of training where temp starts to decrease
            end_transition_fraction: Fraction of training where temp reaches low_temp
        """
        self.high_temp = high_temp
        self.low_temp = low_temp
        self.start_high_fraction = start_high_fraction
        self.end_high_fraction = end_high_fraction
        self.end_transition_fraction = end_transition_fraction
        self.current_state = None
        
    def get_temperature(self, current_step, total_steps):
        """
        Calculate the temperature for the current training step.
        
        Args:
            current_step: Current training step (0-indexed)
            total_steps: Total number of training steps
            
        Returns:
            The temperature value for the current step
        """
        # Convert to fraction of total training
        progress = current_step / total_steps
        
        # Phase 1: Ramp up from low_temp to high_temp
        if progress < self.start_high_fraction:
            self.current_state = "ramping_up"
            # Linear interpolation from low_temp to high_temp
            alpha = progress / self.start_high_fraction
            return self.low_temp + (self.high_temp - self.low_temp) * alpha
        
        # Phase 2: Stay at high_temp
        elif progress < self.end_high_fraction:
            self.current_state = "high_temp"
            return self.high_temp
        
        # Phase 3: Ramp down from high_temp to low_temp
        elif progress < self.end_transition_fraction:
            self.current_state = "ramping_down"
            # Linear interpolation from high_temp back to low_temp
            alpha = (progress - self.end_high_fraction) / (self.end_transition_fraction - self.end_high_fraction)
            return self.high_temp - (self.high_temp - self.low_temp) * alpha
        
        # Phase 4: Stay at low_temp until the end
        else:
            self.current_state = "low_temp"
            return self.low_temp

In [4]:
import os
def calc_ce_loss(logits, batch, reduction="none"):
    # first shift each by one. No ground truth for last token, no prediction for first token
    shifted_logits = logits[:, :-1]  # Remove prediction for last token
    shifted_targets = batch[:, 1:]   # Remove ground truth for first token
    
    batch_size, seq_len, vocab_size = shifted_logits.shape
    #print batch size
    

    ce_loss = F.cross_entropy(
        shifted_logits.contiguous().view(-1, shifted_logits.size(-1)),
        shifted_targets.contiguous().view(-1),
        reduction=reduction
    )
    ce_loss = ce_loss.view(batch_size, seq_len)
    return ce_loss

def update_feature_sim(feature_sim, per_feature_MSE, local_ids, num_features, batch_idx):
    # find out how many times each feature index shows up in local_ids
    num_features = post_relu_feat_acts_BF.shape[-1]
    # Get unique indices and their counts
    unique_indices, counts = torch.unique(local_ids, return_counts=True)

    unique_counts = torch.zeros(num_features, device=local_ids.device)
    unique_counts.index_put_((unique_indices,), counts.float())

    feature_mse_sum = torch.zeros(num_features, device=local_ids.device)
    feature_mse_sum.scatter_add_(0, local_ids.flatten(), per_feature_MSE.flatten())

    # Compute the average MSE for each feature by dividing by its count
    # (avoiding division by zero for features that don't appear)
    feature_mask = unique_counts > 0
    feature_mse_avg = torch.zeros_like(feature_mse_sum)
    feature_mse_avg[feature_mask] = feature_mse_sum[feature_mask] / unique_counts[feature_mask]

    feature_sim[batch_idx, unique_indices] = feature_mse_avg[unique_indices].to("cpu")
    return feature_sim

def sub_bias_and_normalize(x, sae, batch, normalize=True):
    if normalize:
        x = (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-8)
    if sae.per_token_bias:
        x = x - sae.per_token_bias(batch)
    return x

import math
# lots of code used from https://github.com/bfpill/devinterp-1/blob/main/modular_addition/calc_lambda.ipynb
# feature_sim = torch.ones(total_batches, 9216)*-0.1 # num of features
num_features = sae.encoder.weight.shape[0]
dataset_size = num_datapoints* max_length
# beta_star_scalar = dataset_size / math.log(dataset_size) # 1/Temp (From App A)
beta_star_scalar = 1
lr = 1e-4  # "epsilon"
gamma = 0  # Elasticity Regularization (diff from original weights w*)

normalize = True


#temp 5.0e-4 went t. 2.0 & never recovered (also gamm=10)
# Temperature parameters
# max_temp = 2.0e-4
max_temp = 4.0e-4
lowest_temp = 0.0
peak_noise_fraction = 0.1
start_decreasing_fraction = 0.5
stay_at_zero_fraction = 0.6
tempScheduler = TemperatureScheduler(
    max_temp,
    lowest_temp,
    peak_noise_fraction,
    start_decreasing_fraction,
    stay_at_zero_fraction
)
# save_checkpoints_every = 10
# batches_to_run = 10 #for validation
save_checkpoints_every = 50
batches_to_run = 20 #for validation


from copy import deepcopy
orig_model = deepcopy(model)
orig_sae = deepcopy(sae)
all_ce_losses = []
# all_elasticity_losses = []
all_ce_diffs = []
fvus = []
sgd_opt = torch.optim.SGD(model.parameters(), lr=1.0)
adam_opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
sae_opt = torch.optim.Adam(sae.parameters(), lr=learning_rate)
restarted_adam=False
for batch_idx in tqdm(range(total_batches)):
    current_opt = adam_opt if tempScheduler.current_state == "low_temp" else sgd_opt
    current_opt.zero_grad() 
    sae_opt.zero_grad()
    if(cfg.norm_decoder):
        sae.set_decoder_norm_to_unit_norm()
    batch = data_generator.next().to(device)
    with Trace(model, target_layer) as original_trace:
        ce_loss = model(batch, labels=batch).loss
        new_x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
        new_x = new_x.detach()

    # SAE
    # w = torch.nn.utils.parameters_to_vector(model.parameters())
    # elasticity_loss = gamma* torch.mean(((w_0 - w) ** 2))
    elasticity_loss = 0.0

    full_loss = lr/2 * (beta_star_scalar*ce_loss + elasticity_loss)
    full_loss.backward()

    if(tempScheduler.current_state == "low_temp" and not restarted_adam):
        # Restart Adam optimizer
        print("Restarting Adam optimizer")
        for state in adam_opt.state.values():
            state.clear()
        adam_opt = torch.optim.Adam(model.parameters(), lr=1e-5)
        restarted_adam = True
    # Choose optimizer based on temperature state
    current_opt.step()
    # Store losses
    all_ce_losses.append(ce_loss.item())
    # all_elasticity_losses.append(elasticity_loss.item())

    # SAE
    new_x = (new_x - new_x.mean(dim=-1, keepdim=True)) / (new_x.std(dim=-1, keepdim=True) + 1e-8)
    bias = sae.per_token_bias(batch)
    x_hat = sae(new_x - bias) + bias
    with torch.no_grad():
        fvu = calculate_fvu(new_x, x_hat)
    mse = torch.mean((x_hat - new_x) ** 2)
    mse.backward()
    if(cfg.norm_decoder):
        sae.remove_gradient_parallel_to_decoder_directions()

    sae_opt.step()

    # Add noise to the parameters
    with torch.no_grad():
        if tempScheduler.current_state != "low_temp":
            temperature = tempScheduler.get_temperature(batch_idx, total_batches)
            new_params = torch.nn.utils.parameters_to_vector(model.parameters()) 
            noise = torch.randn_like(new_params) * temperature
            torch.nn.utils.vector_to_parameters(new_params + noise, model.parameters())
        # if batch_idx % 1 ==0:
            # Check CE diff w/ original model 
    

        
    # SAE sims
    with torch.no_grad():
        # Get original model's acts & SAE features
        with Trace(orig_model, target_layer) as original_trace:
            original_ce_loss = orig_model(batch, labels=batch).loss
        #     old_x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
        #     old_x = old_x.detach()
        # # x = x.normalize - bias (so ready for SAE.encode)
        # old_x = sub_bias_and_normalize(old_x, orig_sae, batch, normalize=normalize)

        # _, local_mags, local_ids = orig_sae.encode(old_x, return_topk=True)

             
        # # New x is already normalized & -bias
        # new_x = sub_bias_and_normalize(new_x, sae, batch, normalize=normalize)
        
        # post_relu_feat_acts_BF = nn.functional.relu(sae.encoder(new_x - sae.b_dec))
        # # get the activations at the feature_ids (feature_ids is = to top_indices_BK)
        # custom_feature_acts = post_relu_feat_acts_BF.gather(dim=-1, index=local_ids)
        # per_feature_MSE = ((custom_feature_acts - local_mags)**2)
        # feature_sim = update_feature_sim(feature_sim, per_feature_MSE, local_ids, num_features = post_relu_feat_acts_BF.shape[-1], batch_idx=batch_idx)
        #TODO above should be squared


        # Get CE diff
        ce_diff = ce_loss - original_ce_loss
        all_ce_diffs.append(ce_diff.item())
    if batch_idx % 25 == 0:
        # calculate fvu of the x & x_hat
        # fvu = calculate_fvu(new_x, x_hat)
        # fvus.append(fvu)
    # if batch_idx % 1 == 0:
        # do temp in einsteins notation
        # print(f"Batch {batch_idx}/{total_batches}, Loss: {ce_loss.item():.3f} + {elasticity_loss.item():.5f} | temperature: {temperature:.1e}")
        print(f"Batch {batch_idx}/{total_batches}, CE diff: {ce_diff.item():.3f}, fvu: {fvu:.3f} | temperature: {temperature:.1e}")
    with torch.no_grad():
        # save first batch info:
        d_model = model.config.hidden_size
        os.makedirs("checkpoints", exist_ok=True)
        # First batch
        if batch_idx == 0:
            all_tokens = torch.zeros(batches_to_run*batch_size, max_length-1, dtype=torch.long)
            val_original_ce_loss = torch.zeros(batches_to_run*batch_size, max_length-1, dtype=torch.float16)
            all_activations = torch.zeros(batches_to_run*batch_size, max_length-1, d_model, dtype=torch.float16)
            val_data_generator.reset_iterator()
            for i in range(batches_to_run):
                batch = val_data_generator.next().to(device)
                with Trace(model, target_layer) as original_trace:
                    original_logits = model(batch, labels=batch, reduction="none").logits
                    new_x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
                    new_x = new_x.detach()
                ce_loss = calc_ce_loss(original_logits, batch)

                val_original_ce_loss[i*batch_size:(i+1)*batch_size] = ce_loss.cpu() # already max_len - 1
                all_activations[i*batch_size:(i+1)*batch_size] = new_x[:, 1:].cpu()
                all_tokens[i*batch_size:(i+1)*batch_size] = batch[:, 1:].cpu()
            torch.save({
                "ce_loss": val_original_ce_loss,
                "all_activations": all_activations,
                "all_tokens": all_tokens,
            }, f"checkpoints/first_batch_info_skip_first_token.pt")
            # DELETE all tokens
            del all_tokens
            del all_activations
            del val_original_ce_loss
            del original_logits
            import gc
            gc.collect()
            # clear cache
            torch.cuda.empty_cache()
            
        elif (batch_idx) % save_checkpoints_every == 0:
            # if we shoot for 1m tokens then,
            # batches_to_run = 1_000_000 // (batch_size*max_length)
            num_features = orig_sae.encoder.out_features

            original_feature_acts = torch.zeros(batches_to_run*batch_size, max_length-1, num_features, dtype=torch.float16)
            new_feature_acts = torch.zeros(batches_to_run*batch_size, max_length-1, num_features, dtype=torch.float16)
            val_new_ce_loss = torch.zeros(batches_to_run*batch_size, max_length-1, dtype=torch.float16)
            all_activations = torch.zeros(batches_to_run*batch_size, max_length-1, d_model, dtype=torch.float16)

            print("Saving feature activations")
            for i in tqdm(range(batches_to_run),  total=batches_to_run):
                batch = val_data_generator.next().to(device)
                with Trace(model, target_layer) as original_trace:
                    logits = model(batch, labels=batch).logits
                    new_x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
                    new_x = new_x.detach()
                new_ce_loss = calc_ce_loss(logits, batch)

                with Trace(orig_model, target_layer) as original_trace:
                    # We could just get the SAE features in first batch, BUT we do need the local_ids of original-SAE for the new one
                    _ = orig_model(batch, labels=batch) # Original CE calculated in first batch
                    old_x = original_trace.output[0] if isinstance(original_trace.output, tuple) else original_trace.output
                    old_x = old_x.detach()

                # x = x.normalize - bias (so ready for SAE.encode)
                old_x = sub_bias_and_normalize(old_x, orig_sae, batch, normalize=normalize)
                new_normalized_x = sub_bias_and_normalize(new_x, sae, batch, normalize=normalize)

                original_features, local_mags, local_ids = orig_sae.encode(old_x, return_topk=True)

                post_relu_feat_acts_BF = nn.functional.relu(sae.encoder(new_normalized_x - sae.b_dec))
                # get the activations at the feature_ids (feature_ids is = to top_indices_BK)
                custom_feature_acts = post_relu_feat_acts_BF.gather(dim=-1, index=local_ids)
                buffer_BF = torch.zeros_like(post_relu_feat_acts_BF)
                custom_features = buffer_BF.scatter_(dim=-1, index=local_ids, src=custom_feature_acts)

                # feature_sim = update_feature_sim(feature_sim, per_feature_MSE, local_ids, num_features = post_relu_feat_acts_BF.shape[-1], batch_idx=batch_idx)
                start_idx = i*(batch_size)
                end_idx = start_idx + (batch_size)  
                # Let's ignore the first token for all features in general
                val_new_ce_loss[start_idx:end_idx] = new_ce_loss # first token ignored by default
                original_feature_acts[start_idx:end_idx] = original_features[:, 1:].cpu() 
                new_feature_acts[start_idx:end_idx] = custom_features[:, 1:].cpu()
                all_activations[start_idx:end_idx] = new_x[:, 1:].cpu()

            # save the feature_sim, original_feature_acts, new_feature_acts
            # make  dir if not exists
            torch.save({
                # "feature_sim": feature_sim,
                "ce_loss": val_new_ce_loss,
                "all_activations": all_activations,
                "original_feature_acts": original_feature_acts,
                "new_feature_acts": new_feature_acts}, f"checkpoints/feature_sim_{batch_idx}.pt")

  0%|          | 0/625 [00:00<?, ?it/s]

Batch 0/625, CE diff: 0.000, fvu: 0.117 | temperature: 0.0e+00


  4%|▍         | 26/625 [00:33<09:04,  1.10it/s] 

Batch 25/625, CE diff: -0.001, fvu: 0.113 | temperature: 1.6e-04


  8%|▊         | 50/625 [00:55<08:53,  1.08it/s]

Batch 50/625, CE diff: -0.000, fvu: 0.114 | temperature: 3.2e-04
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 12%|█▏        | 76/625 [01:42<08:31,  1.07it/s]  

Batch 75/625, CE diff: 0.011, fvu: 0.129 | temperature: 4.0e-04


 16%|█▌        | 100/625 [02:05<08:15,  1.06it/s]

Batch 100/625, CE diff: 0.018, fvu: 0.118 | temperature: 4.0e-04
Saving feature activations


100%|██████████| 20/20 [00:17<00:00,  1.16it/s]
 20%|██        | 126/625 [02:52<07:50,  1.06it/s]  

Batch 125/625, CE diff: 0.029, fvu: 0.123 | temperature: 4.0e-04


 24%|██▍       | 150/625 [03:15<07:30,  1.06it/s]

Batch 150/625, CE diff: 0.034, fvu: 0.116 | temperature: 4.0e-04
Saving feature activations


100%|██████████| 20/20 [00:17<00:00,  1.18it/s]
 28%|██▊       | 176/625 [04:02<07:05,  1.06it/s]  

Batch 175/625, CE diff: 0.043, fvu: 0.121 | temperature: 4.0e-04


 32%|███▏      | 200/625 [04:25<06:46,  1.05it/s]

Batch 200/625, CE diff: 0.058, fvu: 0.121 | temperature: 4.0e-04
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 36%|███▌      | 226/625 [05:11<06:15,  1.06it/s]

Batch 225/625, CE diff: 0.038, fvu: 0.117 | temperature: 4.0e-04


 40%|████      | 250/625 [05:34<05:59,  1.04it/s]

Batch 250/625, CE diff: 0.050, fvu: 0.125 | temperature: 4.0e-04
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 44%|████▍     | 276/625 [06:21<05:30,  1.05it/s]

Batch 275/625, CE diff: 0.073, fvu: 0.124 | temperature: 4.0e-04


 48%|████▊     | 300/625 [06:43<05:11,  1.04it/s]

Batch 300/625, CE diff: 0.071, fvu: 0.124 | temperature: 4.0e-04
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 52%|█████▏    | 326/625 [07:30<04:42,  1.06it/s]

Batch 325/625, CE diff: 0.062, fvu: 0.125 | temperature: 3.2e-04


 56%|█████▌    | 350/625 [07:53<04:22,  1.05it/s]

Batch 350/625, CE diff: 0.066, fvu: 0.127 | temperature: 1.6e-04
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 60%|██████    | 376/625 [08:40<03:54,  1.06it/s]

Batch 375/625, CE diff: 0.064, fvu: 0.126 | temperature: 0.0e+00
Restarting Adam optimizer


 64%|██████▍   | 400/625 [09:03<03:37,  1.03it/s]

Batch 400/625, CE diff: 0.067, fvu: 0.134 | temperature: 0.0e+00
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 68%|██████▊   | 426/625 [09:50<03:10,  1.05it/s]

Batch 425/625, CE diff: 0.043, fvu: 0.129 | temperature: 0.0e+00


 72%|███████▏  | 450/625 [10:13<02:48,  1.04it/s]

Batch 450/625, CE diff: 0.016, fvu: 0.130 | temperature: 0.0e+00
Saving feature activations


100%|██████████| 20/20 [00:17<00:00,  1.18it/s]
 76%|███████▌  | 476/625 [11:00<02:22,  1.05it/s]

Batch 475/625, CE diff: 0.029, fvu: 0.120 | temperature: 0.0e+00


 80%|████████  | 500/625 [11:23<02:00,  1.03it/s]

Batch 500/625, CE diff: 0.024, fvu: 0.128 | temperature: 0.0e+00
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
 84%|████████▍ | 526/625 [12:10<01:34,  1.05it/s]

Batch 525/625, CE diff: 0.008, fvu: 0.126 | temperature: 0.0e+00


 88%|████████▊ | 550/625 [12:34<01:12,  1.03it/s]

Batch 550/625, CE diff: 0.002, fvu: 0.119 | temperature: 0.0e+00
Saving feature activations


100%|██████████| 20/20 [00:17<00:00,  1.17it/s]
 92%|█████████▏| 576/625 [13:22<00:46,  1.05it/s]

Batch 575/625, CE diff: 0.008, fvu: 0.118 | temperature: 0.0e+00


 96%|█████████▌| 600/625 [13:45<00:24,  1.03it/s]

Batch 600/625, CE diff: -0.002, fvu: 0.131 | temperature: 0.0e+00
Saving feature activations


100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
100%|██████████| 625/625 [14:31<00:00,  1.39s/it]


576

In [1]:
(ce_loss - original_ce_loss).shape, start_idx - end_idx

NameError: name 'ce_loss' is not defined

In [10]:
all_ce_diffs[start_idx:end_idx] = ce_loss - original_ce_loss

RuntimeError: The expanded size of the tensor (14) must match the existing size (32) at non-singleton dimension 0.  Target sizes: [14, 127].  Tensor sizes: [32, 127]

In [12]:
all_ce_diffs.shape, end_idx

(torch.Size([1550, 127]), 1568)

In [None]:
per_feature_MSE.min()
# find out how many times each feature index shows up in local_ids
num_features = post_relu_feat_acts_BF.shape[-1]
# Get unique indices and their counts
unique_indices, counts = torch.unique(local_ids, return_counts=True)

unique_counts = torch.zeros(num_features, device=local_ids.device)
unique_counts.index_put_((unique_indices,), counts.float())

feature_mse_sum = torch.zeros(num_features, device=local_ids.device)
feature_mse_sum.scatter_add_(0, local_ids.flatten(), per_feature_MSE.flatten())
# set where local_mags=0 to -1 for feature_mse_sum

# Compute the average MSE for each feature by dividing by its count
# (avoiding division by zero for features that don't appear)
feature_mask = unique_counts > 0
feature_mse_avg = torch.zeros_like(feature_mse_sum)
feature_mse_avg[feature_mask] = feature_mse_sum[feature_mask] / unique_counts[feature_mask]

feature_sim[batch_idx, unique_indices] = feature_mse_avg[unique_indices].to("cpu")
feature_mse_avg[unique_indices].shape, feature_sim[batch_idx, :, unique_indices].shape

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)