In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
from sklearn.metrics import confusion_matrix, accuracy_score
from einops import rearrange
from icecream import ic
from tqdm.autonotebook import tqdm

from CLAPWrapper import CLAPWrapper
from utils.dataset import *
from utils.interventions import Intervention

In [3]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [4]:
seed_everything(42)

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

module_activation_dict = {
    # Conv blocks
    'audio_encoder.base.conv_block1': nn.Identity(),    # 0
    'audio_encoder.base.conv_block2': nn.Identity(),    # 1
    'audio_encoder.base.conv_block3': nn.Identity(),    # 2
    'audio_encoder.base.conv_block4': nn.Identity(),    # 3
    'audio_encoder.base.conv_block5': nn.Identity(),    # 4
    'audio_encoder.base.conv_block6': nn.Identity(),    # 5
    'audio_encoder.base.fc1': F.relu,                   # 6
    'audio_encoder.projection.linear1': F.gelu,         # 7
    'audio_encoder.projection.linear2': nn.Identity(),  # 8
}

module_list = list(module_activation_dict.keys())

In [6]:
weights_path = "/scratch/pratyaksh.g/clap/CLAP_weights_2022_microsoft.pth"
clap_model = CLAPWrapper(weights_path, use_cuda=True if DEVICE == "cuda" else False)

In [7]:
clap_model.clap.eval()

CLAP(
  (audio_encoder): AudioEncoder(
    (base): Cnn14(
      (spectrogram_extractor): Spectrogram(
        (stft): STFT(
          (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
          (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
        )
      )
      (logmel_extractor): LogmelFilterBank()
      (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_block1): ConvBlock(
        (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_block2): ConvBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [8]:
probing_dataset = ESC50Dataset()
testing_dataset = ESC50Dataset()

Using downloaded and verified file: /scratch/pratyaksh.g/esc50/ESC-50-master.zip


2000it [00:00, 11199.04it/s]

Loading audio files





Using downloaded and verified file: /scratch/pratyaksh.g/esc50/ESC-50-master.zip


2000it [00:00, 12301.74it/s]

Loading audio files





In [9]:
layer_idx = 6
layer_name = module_list[layer_idx]
partition_idx = 2
cluster_idx = 1
treatment = 'random' # 'intervened' or 'random' (control using group of random neurons)
invert_mask = False # False will *remove* the activations of the neurons in the cluster

In [10]:
activations = torch.load(f'/scratch/pratyaksh.g/{testing_dataset.path_name}/activations/{layer_name}.pt')
clusters = torch.load(f'/scratch/pratyaksh.g/{probing_dataset.path_name}/clusters/{layer_name}.pt')
n_clusters = torch.load(f'/scratch/pratyaksh.g/{probing_dataset.path_name}/clusters/{layer_name}_n.pt')

In [11]:
if len(activations.shape) > 2:
    n, c, w, h = activations.shape

In [12]:
pdf_path = f'/scratch/pratyaksh.g/{probing_dataset.path_name}/cluster-plots/{layer_name}/partition-{partition_idx}/'

for file in os.listdir(pdf_path):
    if file.endswith(f'cluster-{cluster_idx}.pdf'):
        file_path = os.path.join(pdf_path, file)
        break

print('Cluster embedding pdf: sftp://gnode060' + file_path)

Cluster embedding pdf: sftp://gnode060/scratch/pratyaksh.g/esc50/cluster-plots/audio_encoder.base.fc1/partition-2/1-cluster-1.pdf


In [13]:
# Computing text embeddings
prompt = 'this is a sound of '
y = [prompt + x for x in testing_dataset.classes]
text_embeddings = clap_model.get_text_embeddings(y)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.




In [14]:
performance = pd.DataFrame(
    columns=['accuracy', 'treatment', 'label']
)

In [15]:
intervention = Intervention(clap_model, module_activation_dict)

In [30]:
intervention.clear_handles()

cluster_mask = (clusters[:, partition_idx] == cluster_idx).bool()
if invert_mask:
    cluster_mask = ~cluster_mask

if treatment == 'random':
    num_neurons = clusters.shape[0]
    num_neurons_in_cluster = cluster_mask.sum().item()
    random_neurons = np.random.randint(0, num_neurons, num_neurons_in_cluster)
    random_cluster_mask = torch.zeros_like(cluster_mask.flatten())
    random_cluster_mask[random_neurons] = 1
    random_cluster_mask = random_cluster_mask.bool()
    cluster_mask = random_cluster_mask

if len(cluster_mask.shape) > 2:
    cluster_mask = rearrange(cluster_mask, '(c w h) -> c w h', c=c, w=w, h=h)
intervention.set_intervention(activations, cluster_mask, layer_name, replace_with='random')
# replace_with='random' here means that the activations of random instances are used to 'remove information',
# and has nothing to do with intervention_type. The alternative type is replace_with='zero', which would
# zero out the activations to 'remove information'

In [25]:
audio_tensor_path = f"/scratch/pratyaksh.g/{testing_dataset.path_name}/audio-tensors"
audio_tensors = torch.load(f"{audio_tensor_path}/{testing_dataset.path_name}-audio-tensors.pt").to(DEVICE)

In [16]:
clap_model.clap.to(DEVICE)

CLAP(
  (audio_encoder): AudioEncoder(
    (base): Cnn14(
      (spectrogram_extractor): Spectrogram(
        (stft): STFT(
          (conv_real): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
          (conv_imag): Conv1d(1, 513, kernel_size=(1024,), stride=(320,), bias=False)
        )
      )
      (logmel_extractor): LogmelFilterBank()
      (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_block1): ConvBlock(
        (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_block2): ConvBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [17]:
loader = torch.utils.data.DataLoader(list(zip(audio_tensors, testing_dataset.one_hot)), batch_size=16, shuffle=False)

In [18]:
next(clap_model.clap.parameters()).device

device(type='cuda', index=0)

In [19]:
# Computing audio embeddings
y_preds, y_labels = [], []
for batch in tqdm(loader):
    audio_tensor, one_hot_target = batch

    audio_embeddings = clap_model.clap.audio_encoder(audio_tensor.to(DEVICE))[0]
    audio_embeddings = audio_embeddings / torch.norm(audio_embeddings, dim=-1, keepdim=True)

    similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings.to(DEVICE))

    y_pred = F.softmax(similarity.detach().cpu(), dim=-1).numpy()
    y_pred = np.argmax(y_pred, axis=-1)
    y_preds.append(y_pred)

    y_label = np.argmax(one_hot_target.detach().cpu().numpy(), axis=-1)
    y_labels.append(y_label)

  0%|          | 0/125 [00:00<?, ?it/s]

In [20]:
score = accuracy_score(np.concatenate(y_labels), np.concatenate(y_preds))

In [21]:
score

0.827

In [18]:
treatment

'random'

In [23]:
confusion = confusion_matrix(np.concatenate(y_labels), np.concatenate(y_preds))
class_wise_acc = confusion.diagonal() / confusion.sum(axis=1)

for label_idx, label in enumerate(testing_dataset.classes):
    performance = performance.append({
        'accuracy': class_wise_acc[label_idx],
        'treatment': treatment,
        'label': label,
    }, ignore_index=True)

In [24]:
performance

Unnamed: 0,accuracy,treatment,label
0,0.925,random,dog
1,1.0,random,rooster
2,0.9,random,pig
3,0.625,random,cow
4,0.975,random,frog
5,0.95,random,cat
6,0.975,random,hen
7,0.725,random,insects
8,0.925,random,sheep
9,0.35,random,crow


In [20]:
csv_path = f"/scratch/pratyaksh.g/{testing_dataset.path_name}/intervened-performance/{layer_name}/"
os.makedirs(csv_path, exist_ok=True)
expt_id = f"partition-{partition_idx}-cluster-{cluster_idx}-{treatment}-invert_mask={str(invert_mask)}.csv"
performance.to_csv(csv_path + expt_id, index=False)

In [21]:
performance

Unnamed: 0,accuracy,treatment,label
0,0.05,random,dog
1,0.2,random,rooster
2,0.375,random,pig
3,0.175,random,cow
4,0.3,random,frog
5,0.225,random,cat
6,0.05,random,hen
7,0.05,random,insects
8,0.225,random,sheep
9,0.025,random,crow


: 