In [1]:
from notebook_utils import load_tinymodel, load_tinydataset, load_saes, load_module_names
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
llm = load_tinymodel()
dataset = load_tinydataset(batch_size=32, max_seq_length=128, num_datapoints=1000)
all_saes = load_saes(k=30)
module_names = load_module_names(llm)

100%|██████████| 1000/1000 [00:01<00:00, 533.81it/s]
  state_dict = torch.load(path)


Failed to tokenize 0 tokens
Number of datapoints w/ 129 tokens: 952
Total Tokens: 0.122808M


In [15]:
def calculate_gpu_memory(b, i, j, dtype=torch.float32):
    bytes_per_element = torch.tensor([], dtype=dtype).element_size()
    
    input_features_memory = b * i * bytes_per_element
    virtual_weights_memory = i * j * bytes_per_element
    attribution_memory = b * i * j * bytes_per_element
    
    total_memory = input_features_memory + virtual_weights_memory + attribution_memory
    
    return {
        "input_features_memory": input_features_memory,
        "virtual_weights_memory": virtual_weights_memory,
        "attribution_memory": attribution_memory,
        "total_memory": total_memory,
        "total_memory_gb": total_memory / (1024**3)
    }

# Example usage
b = 1*128
i, j =  6500, 6500  # Your dimensions
memory_info = calculate_gpu_memory(b, i, j)

print(f"Input Features Memory: {memory_info['input_features_memory'] / (1024**2):.2f} MB")
print(f"Virtual Weights Memory: {memory_info['virtual_weights_memory'] / (1024**2):.2f} MB")
print(f"Attribution Memory: {memory_info['attribution_memory'] / (1024**2):.2f} MB")
print(f"Total Memory: {memory_info['total_memory_gb']:.2f} GB")

Input Features Memory: 3.17 MB
Virtual Weights Memory: 161.17 MB
Attribution Memory: 20629.88 MB
Total Memory: 20.31 GB


In [37]:
from einops import rearrange, einsum


target_sae_names = ['torso_1_mlp_out_transcoder', 'torso_1_res_final']
saes = [all_saes[name].to(device) for name in target_sae_names]
resid_mid = llm.torso[1].res_mlp
resid_final = llm.torso[1].res_final
mlp_out = llm.torso[1].mlp
for batch_ind, batch in enumerate(dataset):
    batch = batch.to(device)
    with torch.no_grad():
        with llm.trace(batch) as tracr:
            act_res_mid = resid_mid.output.save()
            act_res_final = resid_final.output.save()
            act_mlp_out = mlp_out.output.save()
        # Now we want to run through the saes
        transcoder = saes[0].to(device)
        sae_final = saes[1].to(device)
        # mlp_out_hat = transcoder(act_res_mid)

        # sae_final_features_hat = sae_final.encode(mlp_out_hat+act_res_mid)
        # maybe figure out a way to fold in the decoder bias?
        tr_dec = transcoder.decoder.weight
        #TODO: we removed the last weight to help w/ knowing .T and shape. 
        final_enc = sae_final.encoder.weight
        virtual_weights = tr_dec.T @ final_enc.T

        act_res_mid = act_res_mid.to(device)
        act_res_mid = rearrange(act_res_mid, 'b s d_model -> (b s) d_model')
        input_features, input_acts, input_indices = transcoder.encode(act_res_mid, return_topk=True)
        # input_features = rearrange(input_features, 'b s f -> (b s) f')
        # input_acts = rearrange(input_acts, 'b s f-> (b s) f')
        # input_indices = rearrange(input_indices, 'b s f -> (b s) f')
        mlp_out_hat = transcoder.decoder(input_features)

        output_features, output_acts, output_indices = sae_final.encode(mlp_out_hat + act_res_mid, return_topk=True)

        # Gradient equals the weights
        # attribution = torch.einsum('bi,ij->bij', input_features, virtual_weights)

        break

In [34]:
act_res_mid.shape

torch.Size([32, 129, 768])

In [39]:
output_acts[output_indices==0]
input_acts[output_indices==0]
(output_indices==0).sum(-1) != 0, input_acts.shape

(tensor([False, False, False,  ..., False, False, False], device='cuda:0'),
 torch.Size([4128, 30]))

In [40]:
((output_indices==0).sum(-1) != 0).nonzero()

tensor([[2369],
        [3193],
        [3194],
        [3204],
        [3207],
        [3798]], device='cuda:0')

In [90]:
current_output_feature = 0
num_input_features = input_features.shape[-1]
num_output_features = output_features.shape[-1]
feature_by_feature_attribution = torch.zeros(num_input_features, num_output_features)
features_set_yet = torch.zeros(num_output_features, dtype=torch.bool)
for current_output_feature in range(num_output_features):
    # Get the batch indices where the output feature is non-zero
    nz_batch_indices = (output_indices==current_output_feature).sum(-1).nonzero()[:, 0]
    output_virtual_weights = virtual_weights[:, current_output_feature]

    # Index into the virtual weights & input indices ie find the inputs that activated the output feature
    nz_input_ind = input_indices[nz_batch_indices]
    batched_virtual_weights = output_virtual_weights[nz_input_ind]
    nz_input_acts = input_acts[nz_batch_indices]

    # Calculate the attribution ie act*gradient
    current_output_attribution = nz_input_acts * batched_virtual_weights 

    # Normalize the attributions (by abs value cause negative gradients)
    total_abs_value = current_output_attribution.abs().sum(dim=-1)
    normed_current_output_attribution = current_output_attribution / total_abs_value[:, None]

    # Set the feature by feature attribution (average w/ existing attributions)    

In [106]:
normed_current_output_attribution.abs().sum(dim=-1), normed_current_output_attribution.mean(dim=0).abs().sum(dim=-1)

(tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0'),
 tensor(0.7414, device='cuda:0'))

In [110]:
normed_current_output_attribution.abs().mean(dim=0).abs().sum()

tensor(1., device='cuda:0')

In [105]:
normed_current_output_attribution.shape

torch.Size([6, 30])

In [96]:
total_abs_value = current_output_attribution.abs().sum(dim=-1)
# Normalize: 
normed_current_output_attribution = current_output_attribution / total_abs_value[:, None]

AssertionError: 

In [100]:
normed_current_output_attribution

tensor([[-3.6227e-02, -2.6679e-03, -2.5494e-02, -1.2283e-03,  1.7044e-01,
         -4.9746e-02, -1.8890e-02, -6.3689e-02, -3.3854e-02,  1.7237e-01,
         -6.0978e-03,  2.1763e-02, -4.3416e-02,  7.5457e-02,  2.9024e-02,
         -6.1067e-03,  5.2058e-03, -3.8862e-02, -6.2616e-03, -1.0429e-02,
         -2.1213e-02, -3.2377e-02, -2.1631e-02,  1.1801e-02, -4.8583e-02,
          7.9568e-03,  3.5445e-03, -9.7144e-03, -1.8767e-02, -7.1831e-03],
        [-1.3719e-02,  2.5239e-03,  4.5559e-01,  6.0995e-03,  1.5561e-01,
         -4.0269e-02, -1.2523e-02, -4.5647e-02,  6.8993e-03, -1.0656e-02,
         -5.3215e-02, -1.6520e-02, -9.7398e-03,  2.4028e-03, -1.1964e-02,
         -5.8181e-03, -1.9840e-02, -1.1012e-03, -1.8136e-02,  2.4025e-03,
          2.2144e-04,  1.3957e-02, -2.2930e-02,  9.7942e-03, -1.6925e-02,
         -5.8936e-03, -6.2075e-03,  1.3838e-02, -1.9537e-02,  2.1156e-05],
        [-9.3610e-02,  2.0581e-03,  2.8441e-03, -1.7609e-02,  3.2889e-01,
          2.9626e-03, -5.8919e-03, -

In [97]:
normed_current_output_attribution.norm(dim=-1)

tensor([0.2906, 0.4918, 0.3760, 0.3835, 0.3477, 0.5571], device='cuda:0')

In [95]:
current_output_attribution[0].abs().sum()

tensor(4.4699, device='cuda:0')

In [7]:
unique_input_indices = torch.unique(input_indices)
input_features[:, unique_input_indices].isnan().any()
unique_output_indices = torch.unique(output_indices)

# output_features.shape, unique_output_indices.shape
# virtual_weights[unique_input_indices][:, unique_output_indices]

In [9]:
virtual_weights.shape, unique_output_indices.max()

(torch.Size([6144, 6143]), tensor(6143, device='cuda:0'))

In [12]:
import torch
from einops import rearrange, einsum


def sparse_attribution(input_features, virtual_weights, input_indices, output_indices):
    # Find unique input and output indices across the batch
    unique_input_indices = torch.unique(input_indices)
    unique_output_indices = torch.unique(output_indices)

    # combine batch and sequence dimensions
    input_features = rearrange(input_features, 'b s i -> (b s) i')

    # Extract relevant slices of input_features and virtual_weights
    sparse_input_features = input_features[:, unique_input_indices]
    sparse_virtual_weights = virtual_weights[unique_input_indices][:, unique_output_indices]

    # Perform the sparse matrix multiplication
    sparse_attribution = torch.einsum('bi,ij->bij', sparse_input_features, sparse_virtual_weights)

    # Create a tensor to hold the full attribution
    full_attribution = torch.zeros(input_features.shape[0], input_features.shape[1], virtual_weights.shape[1], device=input_features.device)

    # Place the sparse attribution results in the correct positions in the full attribution tensor
    full_attribution[:, unique_input_indices[:, None], unique_output_indices] = sparse_attribution

    return full_attribution

# Usage
# Assuming input_features, virtual_weights, input_indices, and output_indices are defined
# attribution = sparse_attribution(input_features, virtual_weights, input_indices, output_indices)

unique_input_indices = torch.unique(input_indices)
unique_output_indices = torch.unique(output_indices)

# Extract relevant slices of input_features and virtual_weights
sparse_input_features = input_features[:, unique_input_indices]
sparse_virtual_weights = virtual_weights[unique_input_indices][:, unique_output_indices]

# Perform the sparse matrix multiplication
spar_attr = torch.einsum('bi,ij->bij', sparse_input_features, sparse_virtual_weights)

# # Create a tensor to hold the full attribution
# full_attribution = torch.zeros(input_features.shape[0], input_features.shape[1], virtual_weights.shape[1], device=input_features.device)

# # Place the sparse attribution results in the correct positions in the full attribution tensor
# full_attribution[:, unique_input_indices[:, None], unique_output_indices] = spar_attr


OutOfMemoryError: CUDA out of memory. Tried to allocate 795.99 GiB. GPU 0 has a total capacity of 15.73 GiB of which 14.37 GiB is free. Process 2007258 has 1.35 GiB memory in use. Of the allocated memory 1.07 GiB is allocated by PyTorch, and 93.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [5]:
input_features.shape, unique_input_indices.shape, unique_output_indices.shape

(torch.Size([8256, 6144]), torch.Size([4292]), torch.Size([6030]))

In [6]:
sparse_input_features.shape, sparse_virtual_weights.shape

(torch.Size([8256, 4292]), torch.Size([4292, 6030]))

In [4]:
print("input_features shape:", input_features.shape)
unique_input_indices = torch.unique(input_indices)
print("unique_input_indices shape:", unique_input_indices.shape)
print("Max value in unique_input_indices:", unique_input_indices.max().item())
print("Min value in unique_input_indices:", unique_input_indices.min().item())

input_features shape: torch.Size([64, 129, 6144])
unique_input_indices shape: torch.Size([4292])
Max value in unique_input_indices: 6142
Min value in unique_input_indices: 0


In [11]:

unique_input_indices = input_indices.unique()
unique_output_indices = output_indices.unique()

# # Extract relevant slices of input_features and virtual_weights
sparse_input_features = input_features[:, unique_input_indices]
sparse_virtual_weights = virtual_weights[unique_input_indices][:, unique_output_indices]
# sparse_input_features

In [7]:
uniq = input_indices.unique().cpu().numpy()
uniq

(4292,)

In [8]:
input_features.index([0,1], dim=-1)

AttributeError: 'Tensor' object has no attribute 'index'

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [38]:
from einops import rearrange, einsum
input_features = rearrange(input_features, 'b s f -> (b s) f')
attribution = einsum(input_features, virtual_weights, "b f1, f1 f2 -> b f1 f2")

OutOfMemoryError: CUDA out of memory. Tried to allocate 1160.81 GiB. GPU 0 has a total capacity of 15.73 GiB of which 13.95 GiB is free. Process 1969018 has 1.78 GiB memory in use. Of the allocated memory 1.55 GiB is allocated by PyTorch, and 37.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [31]:
attribution = torch.dot(input_features, virtual_weights)
# RuntimeError: self must be a matrix
# attribution = input_features * virtual_weights

RuntimeError: 1D tensors expected, but got 2D and 2D tensors

In [32]:
input_features.shape, virtual_weights.shape
attribution = torch.einsum('bi,ij->bij', input_features, virtual_weights)


(torch.Size([8256, 6144]), torch.Size([6144, 6143]))

In [5]:
attribution = rearrange(attribution, 'b s f -> (b s) f')

In [15]:
attribution.shape

torch.Size([64, 129, 6143])

In [6]:
output_indices = rearrange(output_indices, 'b s f -> (b s) f')

In [13]:
output_indices[:, 0].shape

torch.Size([64, 30])

In [8]:
with torch.no_grad():
    attribution = input_features @ virtual_weights

In [9]:
input_features.shape, virtual_weights.shape

(torch.Size([64, 129, 6144]), torch.Size([6144, 6143]))

In [10]:
# We want to calculate attribution = act*gradient

# I believe this is equivalent to the weights of the activations (ignore biases)
# It'd be good to actually verify this is the case



with torch.no_grad():
    tr_dec = transcoder.decoder.weight
    #TODO: we removed the last weight to help w/ knowing .T and shape. 
    final_enc = sae_final.encoder.weight[:-1]
    virtual_weights = tr_dec.T @ final_enc.T

    act_res_mid = act_res_mid.to(device)
    input_features, input_acts, input_indices = transcoder.encode(act_res_mid, return_topk=True)
    mlp_out_hat = transcoder.decoder(input_features)

    output_features, output_acts, output_indices = sae_final.encode(mlp_out_hat + act_res_mid, return_topk=True)

    # For efficient gradient calculation, we can get the nonzero_indices of both input & output feature

    # W_input = transcoder.decoder.weight[input_indices]    

In [12]:
transcoder.decoder.weight.shape, input_indices.shape

(torch.Size([768, 6144]), torch.Size([64, 129, 30]))

In [None]:
# def encode(self, x: torch.Tensor, return_topk: bool = False):
#     post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))
#     post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)

#     # We can't split immediately due to nnsight
#     tops_acts_BK = post_topk.values
#     top_indices_BK = post_topk.indices

#     buffer_BF = torch.zeros_like(post_relu_feat_acts_BF)
#     encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK)

#     if return_topk:
#         return encoded_acts_BF, tops_acts_BK, top_indices_BK
#     else:
#         return encoded_acts_BF

# def decode(self, x: torch.Tensor) -> torch.Tensor:
#     return self.decoder(x) + self.b_dec