In [1]:
import os
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

import torch

import neurox.interpretation.utils as utils
import neurox.interpretation.ablation as ablation
import neurox.interpretation.linear_probe as linear_probe

from ttsxai.utils.utils import read_ljs_metadata


log_dir = '/nas/users/dahye/kw/tts/ttsxai/logs/probe_tacotron2_articulatory_features'
data_activation_dir = "/nas/users/dahye/kw/tts/ttsxai/data_activation/LJSpeech/tacotron2_waveglow"


In [2]:
data_dict = np.load(os.path.join(log_dir, 'data', 'train_data.npz'), allow_pickle=True)
X = data_dict['X']
y = data_dict['y']
label2idx = data_dict['label2idx'].item()
idx2label = data_dict['idx2label'].item()
src2idx = data_dict['src2idx'].item()
idx2src = data_dict['idx2src'].item()
neuronidx2name = data_dict['neuronidx2name'].item()

In [3]:
# load pre-trained probe
probe = linear_probe.train_logistic_regression_probe(
    X, y, lambda_l1=0.001, lambda_l2=0.001,
    num_epochs=0)
probe.load_state_dict(
    torch.load(os.path.join(log_dir, 'models', 'probe.pth')))

Training classification probe
Creating model...
Number of training instances: 859266
Number of classes: 10


<All keys matched successfully>

In [4]:
ordering, cutoffs = linear_probe.get_neuron_ordering(probe, label2idx)

  0%|          | 0/101 [00:00<?, ?it/s]

In [10]:
label2idx

{'Bilabial': 0,
 'Glottal': 1,
 'Labio-Velar': 2,
 'Dental': 3,
 'Velar': 4,
 'Vowel': 5,
 'Alveolar': 6,
 'Labiodental': 7,
 'Palatal': 8,
 'Post-Alveolar': 9}

In [15]:
np.array(ordering[:10])

array([1186, 1158, 1678, 1166, 1459, 1235, 1142, 1240, 1531, 1340])

In [20]:
for k in label2idx.keys():    
    best_10 = get_neuron_ordering_for_class(probe, k, label2idx)[:10]
    print(k, best_10)
    # print([neuronidx2name[i] for i in best_10])

Bilabial [1531 1236 1347 1427 1353 1222 1035 1029 1472 1525]
Glottal [1459 1033 1188 1445 1270 1421 1183 1226 1480 1489]
Labio-Velar [1340 1207 1145 1440 1077 1159 1257 1526 1477 1308]
Dental [1166 1116 1322 1389 1484 1227 1711 1645   81  737]
Velar [1240 1024 1412 1426 1280 1330 1037 1385 1364 1043]
Vowel [1678 1681 1710 1605 2031 1608 1567 1722 1517 1733]
Alveolar [1235 1517 1424 1300 2043 1204 1176 1290 1605 1651]
Labiodental [1158 1283 1358 1428 1268 1309 1472 1369 1075 1467]
Palatal [1142 1110 1521 1042 1355 1122 1511 1429 1281 1115]
Post-Alveolar [1186 1342 1495 1493 1075 1417 1138 1031 1089 1444]


In [11]:
get_neuron_ordering_for_class(probe, 'Bilabial', label2idx)

array([1531, 1236, 1347, ..., 1314,  551, 1740])

In [9]:
def get_neuron_ordering_for_class(probe, class_name, class_to_idx, search_stride=100):
    """
    Get neuron ordering for a specific class from a trained probe.

    Parameters
    ----------
    probe : interpretation.linear_probe.LinearProbe
        Trained probe model
    class_name : str
        The name of the class for which to find the neuron ordering
    class_to_idx : dict
        Class to class index mapping
    search_stride : int, optional
        Number of steps to divide the weight mass percentage

    Returns
    -------
    neuron_ordering : numpy.ndarray
        Array of neurons ordered by their importance for the specified class
    """
    class_idx = class_to_idx[class_name]
    weights = list(probe.parameters())[0].data.cpu().numpy()
    abs_weights = np.abs(weights[class_idx])

    neuron_orderings = []
    for p in range(1, search_stride + 1):
        percentage = p / search_stride
        total_mass = np.sum(abs_weights)
        sorted_idx = np.argsort(abs_weights)[::-1]  # Sort in descending order
        cum_sums = np.cumsum(abs_weights[sorted_idx])
        selected_neurons = sorted_idx[cum_sums <= total_mass * percentage]
        neuron_orderings.extend(selected_neurons)

    # Remove duplicates while preserving order
    neuron_ordering = list(dict.fromkeys(neuron_orderings))

    return np.array(neuron_ordering)

In [8]:
cutoffs

[10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 11,
 11,
 11,
 11,
 12,
 12,
 12,
 12,
 12,
 13,
 13,
 15,
 15,
 15,
 17,
 17,
 17,
 18,
 18,
 18,
 18,
 19,
 21,
 21,
 21,
 21,
 23,
 25,
 26,
 26,
 29,
 30,
 30,
 30,
 32,
 32,
 32,
 32,
 33,
 33,
 34,
 35,
 35,
 35,
 35,
 36,
 37,
 37,
 39,
 41,
 41,
 44,
 44,
 44,
 46,
 46,
 47,
 50,
 50,
 51,
 53,
 54,
 56,
 56,
 58,
 60,
 63,
 63,
 65,
 67,
 70,
 72,
 77,
 83,
 89,
 100,
 119,
 189,
 346,
 555,
 832,
 1167,
 1457,
 1686,
 1856,
 1965,
 2018,
 2039,
 2048,
 2048]

In [7]:
len(ordering)

2048