In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
from sklearn.metrics import confusion_matrix, accuracy_score
from torch.utils.data import DataLoader
from einops import rearrange
from icecream import ic
from tqdm.autonotebook import tqdm

from CLAPWrapper import CLAPWrapper
from utils.dataset import *
from utils.interventions import Intervention

In [2]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [3]:
seed_everything(42)

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

module_activation_dict = {
    # conv blocks
    'audio_encoder.base.conv_block1': nn.Identity(),    # 0
    'audio_encoder.base.conv_block2': nn.Identity(),    # 1
    'audio_encoder.base.conv_block3': nn.Identity(),    # 2
    'audio_encoder.base.conv_block4': nn.Identity(),    # 3
    'audio_encoder.base.conv_block5': nn.Identity(),    # 4
    'audio_encoder.base.conv_block6': nn.Identity(),    # 5
    'audio_encoder.base.fc1': F.relu,                   # 6
    # 'audio_encoder.projection.linear1': F.gelu,         # 7
    # 'audio_encoder.projection.linear2': nn.Identity(),  # 8
}

module_list = list(module_activation_dict.keys())

In [5]:
weights_path = "/scratch/pratyaksh.g/clap/CLAP_weights_2022_microsoft.pth"
clap_model = CLAPWrapper(weights_path, use_cuda=True if DEVICE == "cuda" else False)

In [6]:
clap_model.clap.eval()

CLAP(
  (audio_encoder): AudioEncoder(
    (base): Cnn14(
      (spectrogram_extractor): Spectrogram(
        (stft): STFT(
          (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
          (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
        )
      )
      (logmel_extractor): LogmelFilterBank()
      (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_block1): ConvBlock(
        (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_block2): ConvBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [7]:
probing_dataset = ESC50Dataset()
testing_dataset = ESC50Dataset()

Using downloaded and verified file: /scratch/pratyaksh.g/esc50/ESC-50-master.zip


927it [00:00, 9260.34it/s]

Loading audio files


2000it [00:00, 9408.61it/s]


Using downloaded and verified file: /scratch/pratyaksh.g/esc50/ESC-50-master.zip


2000it [00:00, 12379.28it/s]

Loading audio files





In [8]:
# Computing text embeddings
prompt = 'this is a sound of '
y = [prompt + x for x in testing_dataset.classes]
text_embeddings = clap_model.get_text_embeddings(y)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [9]:
layer_name = module_list[5]
layer_name

'audio_encoder.base.conv_block6'

In [10]:
activations = torch.load(f'/scratch/pratyaksh.g/{testing_dataset.path_name}/activations/{layer_name}.pt')
clusters = torch.load(f'/scratch/pratyaksh.g/{probing_dataset.path_name}/clusters/{layer_name}.pt')
n_clusters = torch.load(f'/scratch/pratyaksh.g/{probing_dataset.path_name}/clusters/{layer_name}_n.pt')
audio_tensors = torch.load(f'/scratch/pratyaksh.g/{testing_dataset.path_name}/audio-tensors/{testing_dataset.path_name}-audio-tensors.pt')

if len(activations.shape) > 2:
    n, c, w, h = activations.shape

In [11]:
clap_model.clap.to(DEVICE)
audio_tensors.to(DEVICE)

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.2059e-02, -1.0316e-01, -1.4158e-01,  ...,  6.9915e-02,
          4.0415e-02,  2.8286e-03],
        [-6.9483e-03, -1.2520e-02, -1.1253e-02,  ...,  2.1651e-01,
         -1.0526e-02, -2.8688e-01],
        ...,
        [-8.8183e-02, -7.4015e-02, -6.9245e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-4.9859e-01, -2.2502e-01,  5.4859e-02,  ...,  6.0001e-02,
          6.0360e-02,  5.7563e-02],
        [-3.9407e-04, -1.2344e-04, -4.5801e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')

In [12]:
df = pd.read_csv(f'/scratch/pratyaksh.g/esc50/cluster-stats/{layer_name}.csv')
threshold = 5.26
df = df[df['entropy'] < threshold]
df = df.head(100)
df

Unnamed: 0,layer_name,partition_idx,cluster_idx,entropy,neuron_count,top_class_idx,top_class
0,audio_encoder.base.conv_block6,5,1,5.145218,2342,8,sheep
1,audio_encoder.base.conv_block6,5,4,5.217937,3206,19,thunderstorm
5,audio_encoder.base.conv_block6,4,37,4.878267,274,25,footsteps
6,audio_encoder.base.conv_block6,4,18,4.959192,1096,8,sheep
7,audio_encoder.base.conv_block6,4,32,4.988885,1109,38,clock tick
...,...,...,...,...,...,...,...
115,audio_encoder.base.conv_block6,3,73,4.978487,1239,19,thunderstorm
116,audio_encoder.base.conv_block6,3,104,4.979952,499,46,church bells
117,audio_encoder.base.conv_block6,3,18,4.980650,1360,0,dog
118,audio_encoder.base.conv_block6,3,95,4.981133,65,20,crying baby


In [13]:
treatment = 'intervened'
invert_mask = False

In [14]:
record = next(df.iterrows())
record

(0,
 layer_name       audio_encoder.base.conv_block6
 partition_idx                                 5
 cluster_idx                                   1
 entropy                                5.145218
 neuron_count                               2342
 top_class_idx                                 8
 top_class                                 sheep
 Name: 0, dtype: object)

In [15]:
seed_everything(42)

In [16]:
partition_idx = record[1]['partition_idx']
cluster_idx = record[1]['cluster_idx']

In [17]:
performance = pd.DataFrame(
    columns=['accuracy', 'treatment', 'label']
)

In [18]:
intervention = Intervention(clap_model, module_activation_dict)

In [19]:
intervention.clear_handles()

In [20]:
cluster_mask = (clusters[:, partition_idx] == cluster_idx).bool()
if invert_mask:
    cluster_mask = ~cluster_mask

if treatment == 'random':
    num_neurons = clusters.shape[0]
    num_neurons_in_cluster = cluster_mask.sum().item()
    random_state = np.random.RandomState(42)
    random_neurons = random_state.choice(num_neurons, num_neurons_in_cluster, replace=False)
    random_cluster_mask = torch.zeros_like(cluster_mask.flatten())
    random_cluster_mask[random_neurons] = 1
    random_cluster_mask = random_cluster_mask.bool()
    cluster_mask = random_cluster_mask

if len(cluster_mask.shape) < 2:
    cluster_mask = rearrange(cluster_mask, '(c w h) -> c w h', c=c, w=w, h=h)
intervention.set_intervention(activations, cluster_mask, layer_name, replace_with='zero')

In [21]:
cluster_mask.shape

torch.Size([2048, 21, 2])

In [22]:
loader = DataLoader(list(zip(audio_tensors, testing_dataset.one_hot)), batch_size=4, shuffle=False)

In [23]:
repeats = 1
for run in range(repeats):
    y_preds, y_labels = [], []
    for batch in tqdm(loader, desc=f"Run {run + 1}/{repeats}"):
        audio_tensor, one_hot_target = batch

        audio_embeddings = clap_model.clap.audio_encoder(audio_tensor.to(DEVICE))[0]
        audio_embeddings = audio_embeddings / torch.norm(audio_embeddings, dim=-1, keepdim=True)

        similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings.to(DEVICE))

        y_pred = F.softmax(similarity.detach().cpu(), dim=-1).numpy()
        y_pred = np.argmax(y_pred, axis=-1)
        y_preds.append(y_pred)

        y_label = np.argmax(one_hot_target.detach().cpu().numpy(), axis=-1)
        y_labels.append(y_label)

    # %%
    confusion = confusion_matrix(np.concatenate(y_labels), np.concatenate(y_preds))
    class_wise_acc = confusion.diagonal() / confusion.sum(axis=1)

    for label_idx, label in enumerate(testing_dataset.classes):
        performance = performance.append({
            'accuracy': class_wise_acc[label_idx],
            'treatment': treatment,
            'label': label,
            'run': run,
        }, ignore_index=True)

intervention.clear_handles()
del intervention
torch.cuda.empty_cache()

Run 1/1:   0%|          | 0/500 [00:00<?, ?it/s]

In [24]:
performance.mean()

accuracy    0.828
run         0.000
dtype: float64

In [None]:
# %%
csv_path = f"/scratch/pratyaksh.g/{testing_dataset.path_name}/intervened-performance/{layer_name}/"
os.makedirs(csv_path, exist_ok=True)
expt_id = f"partition-{partition_idx}-cluster-{cluster_idx}-{treatment}-invert_mask={str(invert_mask)}.csv"
performance.to_csv(csv_path + expt_id, index=False)