In [None]:
import torch
import torch.nn.functional as F

In [None]:
languages = ['c','c++', 'c#', 'java', 'rust', 'python', 'javascript', 'php', 'html', 'go', 'ruby', 'wiki']

In [None]:
n, over_zero = [], []
for lang in languages:
    data = torch.load(f'data/activation.{lang}.train.llama')
    n.append(data['n'])
    over_zero.append(data['over_zero'])

n = torch.tensor(n)
over_zero = torch.stack(over_zero, dim=-1)

num_layers, intermediate_size, lang_num = over_zero.size()

In [None]:
import os
os.makedirs(f"./activation_mask", exist_ok=True)  # Create the directory if it doesn't exist

In [None]:
def original_activation(top_rate = 0.01, filter_rate = 0.95, activation_bar_ratio = 0.95):

    activation_probs = over_zero / n # layer x inter x lang_num
    normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True)
    normed_activation_probs[torch.isnan(normed_activation_probs)] = 0
    log_probs = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), 0)
    entropy = - torch.sum(normed_activation_probs * log_probs, dim=-1)
    largest = False
    
    if torch.isnan(entropy).sum():
        print(torch.isnan(entropy).sum())
        raise ValueError
    
    flattened_probs = activation_probs.flatten()
    top_prob_value = flattened_probs.kthvalue(round(len(flattened_probs) * filter_rate)).values.item()
    print(top_prob_value)
    # dismiss the neruon if no language has an activation value over top 90%
    top_position = (activation_probs > top_prob_value).sum(dim=-1)
    entropy[top_position == 0] = -torch.inf if largest else torch.inf

    flattened_entropy = entropy.flatten()
    top_entropy_value = round(len(flattened_entropy) * top_rate)
    print(flattened_entropy)
    _, index = flattened_entropy.topk(top_entropy_value, largest=largest)
    row_index = index // entropy.size(1)
    col_index = index % entropy.size(1)
    selected_probs = activation_probs[row_index, col_index] # n x lang
    print(selected_probs)
    # for r, c in zip(row_index, col_index):
    #     print(r, c, activation_probs[r][c])

    print(selected_probs.size(0), torch.bincount(selected_probs.argmax(dim=-1)))
    selected_probs = selected_probs.transpose(0, 1)
    activation_bar = flattened_probs.kthvalue(round(len(flattened_probs) * activation_bar_ratio)).values.item()
    print((selected_probs > activation_bar).sum(dim=1).tolist())
    lang, indice = torch.where(selected_probs > activation_bar)
    print(lang, indice)
    merged_index = torch.stack((row_index, col_index), dim=-1)
    final_indice = []
    for _, index in enumerate(indice.split(torch.bincount(lang).tolist())):
        lang_index = [tuple(row.tolist()) for row in merged_index[index]]
        lang_index.sort()
        layer_index = [[] for _ in range(num_layers)]
        for l, h in lang_index:
            layer_index[l].append(h)
        for l, h in enumerate(layer_index):
            layer_index[l] = torch.tensor(h).long()
        final_indice.append(layer_index)
    torch.save(final_indice, f"activation_mask/top-{top_rate}-filter-{filter_rate}-activation-{activation_bar_ratio}")  


In [None]:
def custom_activation(filter_rate = 0.95, activation_bar_ratio = 0.95, num_neurons_per_lang = 400):
    activation_probs = over_zero / n # layer x inter x lang_num
    normed_activation_probs = activation_probs / activation_probs.sum(dim=-1, keepdim=True)
    normed_activation_probs[torch.isnan(normed_activation_probs)] = 0
    log_probs = torch.where(normed_activation_probs > 0, normed_activation_probs.log(), 0)
    entropy = - torch.sum(normed_activation_probs * log_probs, dim=-1)
    largest = False
    
    if torch.isnan(entropy).sum():
        print(torch.isnan(entropy).sum())
        raise ValueError
    
    flattened_probs = activation_probs.flatten()
    top_prob_value = flattened_probs.kthvalue(round(len(flattened_probs) * filter_rate)).values.item()
    print(top_prob_value)
    # dismiss the neruon if no language has an activation value over top 90%
    top_position = (activation_probs > top_prob_value).sum(dim=-1)
    entropy[top_position == 0] = -torch.inf if largest else torch.inf

    flattened_entropy = entropy.flatten()
#     top_entropy_value = round(len(flattened_entropy) * top_rate)
    print(flattened_entropy)
#     _, index = flattened_entropy.topk(top_entropy_value, largest=largest)
#     row_index = index // entropy.size(1)
#     col_index = index % entropy.size(1)
#     selected_probs = activation_probs[row_index, col_index] # n x lang
#     print(selected_probs)
    # for r, c in zip(row_index, col_index):
    #     print(r, c, activation_probs[r][c])

#     print(selected_probs.size(0), torch.bincount(selected_probs.argmax(dim=-1)))
#     selected_probs = selected_probs.transpose(0, 1)
    activation_bar = flattened_probs.kthvalue(round(len(flattened_probs) * activation_bar_ratio)).values.item()

    # Corrected torch.where usage
    layer_idx, neuron_idx, lang_idx = torch.where(activation_probs > activation_bar)
    
    # Merge row and column indices
    merged_index = torch.stack((layer_idx, neuron_idx), dim=-1)  # (num_selected, 2)

    # Organize neurons per language
    final_indice = []
    unique_languages = torch.unique(lang_idx)

    for l in unique_languages:
        # Get all indices for the current language
        lang_indices = merged_index[lang_idx == l]
        
        # Sort neurons by entropy (low to high)
        lang_entropies = entropy[lang_indices[:, 0], lang_indices[:, 1]]
        sorted_indices = torch.argsort(lang_entropies)

        # Select the top `num_neurons_per_lang` neurons
        selected_indices = lang_indices[sorted_indices[:num_neurons_per_lang]]

        # Convert to (layer, neuron) tuples
        lang_index = [tuple(row.tolist()) for row in selected_indices]
        
        # Organize indices per layer
        layer_index = [[] for _ in range(num_layers)]
        for layer, neuron in lang_index:
            layer_index[layer].append(neuron)
        
        # Convert lists to tensors
        for layer, neurons in enumerate(layer_index):
            layer_index[layer] = torch.tensor(neurons).long()
        
        final_indice.append(layer_index)

    # Save results (commented out for now)
    torch.save(final_indice, f"activation_mask/neurons-{num_neurons_per_lang}-filter-{filter_rate}-activation-{activation_bar_ratio}")  

In [None]:
original_activation()

In [None]:
custom_activation()