In [1]:
from torch.utils.data import Dataset
import numpy as np
import torch
from dataclasses import dataclass
import pytorch_lightning as pl
from s3prl.nn import S3PRLUpstream
import librosa
import pandas as pd

class LibrispeechDataset(Dataset):
    def __init__(self, metadata):
        super().__init__()
        self.metadata = metadata

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        x, fs = librosa.core.load(row['filename'], sr=16000)
        return {'wav': x.astype(np.float32),
                'emotion': row['has_word'].astype(int)}
        
    def __len__(self):
        return len(self.metadata)
    
class UpstreamDownstreamModel(pl.LightningModule):
    def __init__(self, upstream, num_layers=13, hidden_sizes=[128], num_classes=4, downstream='mlp', lstm_size=None):
        super().__init__()
        self.upstream = S3PRLUpstream(upstream)
        self.num_features = self.upstream.hidden_sizes[-1]
        self.upstream.eval()
        self.layer_weights = torch.nn.Parameter(torch.randn(num_layers))
        self.downstream_type = downstream
        self.num_classes = num_classes
        if downstream == 'mlp':
            mlp_ch = [self.num_features] + hidden_sizes + [num_classes]
            mlp_layers = [torch.nn.Sequential(torch.nn.Linear(chi, cho), torch.nn.ReLU()) for chi, cho in zip(mlp_ch[:-2],mlp_ch[1:-1])]
            mlp_layers += [torch.nn.Linear(mlp_ch[-2], mlp_ch[-1])]
            self.downstream = torch.nn.Sequential(*mlp_layers)
        elif downstream == 'lstm':
            self.downstream_lstm = torch.nn.LSTM(self.num_features, lstm_size, batch_first=True)
            mlp_ch = [lstm_size] + hidden_sizes + [num_classes]
            mlp_layers = [torch.nn.Sequential(torch.nn.Linear(chi, cho), torch.nn.ReLU()) for chi, cho in zip(mlp_ch[:-2],mlp_ch[1:-1])]
            mlp_layers += [torch.nn.Linear(mlp_ch[-2], mlp_ch[-1])]
            self.downstream_mlp = torch.nn.Sequential(*mlp_layers)
            

    def forward(self, x):
        with torch.no_grad():
            x['upstream_outs'], x['upstream_lens'] = self.upstream(x['wav'], x['wav_lens'])
        upstream_outs = torch.stack(x['upstream_outs'])
        layer_w = torch.softmax(self.layer_weights, dim=0)[:, None, None, None]
        upstream_pooled = torch.sum(layer_w*upstream_outs, dim=0)
        if self.downstream_type == 'mlp':
            padding_mask = torch.arange(0, upstream_pooled.shape[1], device=upstream_outs.device)[None,:] < x['upstream_lens'][0][:, None]
            upstream_pooled = torch.sum(upstream_pooled * padding_mask[:,:,None], dim=1)/torch.sum(padding_mask[:,:,None], dim=1)
            x['pooled_upstream'] = upstream_pooled
            x['padding_mask'] = padding_mask
            x['y_pred'] = self.downstream(x['pooled_upstream'])
        elif self.downstream_type == 'lstm':
            lstm_in = torch.nn.utils.rnn.pack_padded_sequence(upstream_pooled, 
                                                              x['upstream_lens'][0].to(device='cpu', dtype=torch.int64),
                                                              batch_first=True,
                                                              enforce_sorted=False)
            lstm_out, (hn, cn) = self.downstream_lstm(lstm_in)
            x['y_pred'] = self.downstream_mlp(hn[0])
        else:
            raise ValueError('Unknown downstream_type {}'.format(self.downstream_type))

2025-02-08 02:03:13.922122: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-08 02:03:13.929851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-08 02:03:13.939311: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-08 02:03:13.941747: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-08 02:03:13.948720: I tensorflow/core/platform/cpu_feature_guar

In [2]:
METADATA_PATH = '/mnt/data/librilight_med_24k/little_metadata.csv'

df_metadata = pd.read_csv(METADATA_PATH)
df_test = df_metadata.loc[df_metadata['split']=='test-clean']
dataset_test = LibrispeechDataset(df_test)

ud_model = UpstreamDownstreamModel(upstream='wav2vec2', 
                                   num_layers=13, 
                                   num_classes=1,
                                   hidden_sizes=[256])


checkpoint = torch.load('/home/cbolanos/interpretability-benchmarks/checkpoints/librispeech-kws-step2200.ckpt')['state_dict']
new_state_dict = {k.replace("mlp", "downstream"): v for k, v in checkpoint.items()}
ud_model.load_state_dict(new_state_dict)



<All keys matched successfully>

In [18]:
ud_model.eval()
dataframe = []

for i, xin in enumerate(dataset_test):
    xin['wav_lens'] = torch.tensor([xin['wav'].shape[0]])
    xin['wav'] = torch.from_numpy(xin['wav'])[None,:]
    ud_model(xin)
    if xin['emotion'] == 1 and (xin['y_pred'][0].item() > 0):
        dataframe.append(dataset_test.metadata.iloc[i])

In [19]:
kws_dataset = pd.DataFrame(dataframe)
kws_dataset = kws_dataset.sample(frac=1, random_state=None).reset_index(drop=True)
kws_dataset

Unnamed: 0.1,Unnamed: 0,id,transcription,has_word,trans_file,filename,alignment_filename,word_start,word_end,split,duration
0,1651,121-127105-0024,POOR DOUGLAS BEFORE HIS DEATH WHEN IT WAS IN S...,True,/mnt/data/LibriSpeech24K/test-clean/121/127105...,/mnt/data/LibriSpeech24K/test-clean/121/127105...,/mnt/data/LibriSpeech24K/test-clean/121/127105...,11.95,12.16,test-clean,14.450
1,2380,5683-32866-0030,A LITTLE BIT OF PLASTER TUMBLED DOWN THE CHIMN...,True,/mnt/data/LibriSpeech24K/test-clean/5683/32866...,/mnt/data/LibriSpeech24K/test-clean/5683/32866...,/mnt/data/LibriSpeech24K/test-clean/5683/32866...,0.36,0.62,test-clean,4.845
2,1418,260-123286-0025,FLIGHT WAS OUT OF THE QUESTION NOW THE REPTILE...,True,/mnt/data/LibriSpeech24K/test-clean/260/123286...,/mnt/data/LibriSpeech24K/test-clean/260/123286...,/mnt/data/LibriSpeech24K/test-clean/260/123286...,5.26,5.56,test-clean,9.205
3,2118,1995-1826-0026,NOW FOR ONE LITTLE HALF HOUR SHE HAD BEEN A WO...,True,/mnt/data/LibriSpeech24K/test-clean/1995/1826/...,/mnt/data/LibriSpeech24K/test-clean/1995/1826/...,/mnt/data/LibriSpeech24K/test-clean/1995/1826/...,1.25,1.51,test-clean,15.450
4,1709,1221-135767-0009,BUT PEARL WHO WAS A DAUNTLESS CHILD AFTER FROW...,True,/mnt/data/LibriSpeech24K/test-clean/1221/13576...,/mnt/data/LibriSpeech24K/test-clean/1221/13576...,/mnt/data/LibriSpeech24K/test-clean/1221/13576...,5.90,6.12,test-clean,13.340
...,...,...,...,...,...,...,...,...,...,...,...
80,2324,4970-29095-0031,HE DOESN'T SAY EXACTLY WHAT IT IS SAID RUTH A ...,True,/mnt/data/LibriSpeech24K/test-clean/4970/29095...,/mnt/data/LibriSpeech24K/test-clean/4970/29095...,/mnt/data/LibriSpeech24K/test-clean/4970/29095...,3.90,4.18,test-clean,15.050
81,1156,7021-85628-0026,NO MY LITTLE SON SHE SAID,True,/mnt/data/LibriSpeech24K/test-clean/7021/85628...,/mnt/data/LibriSpeech24K/test-clean/7021/85628...,/mnt/data/LibriSpeech24K/test-clean/7021/85628...,1.12,1.41,test-clean,2.740
82,1108,237-134500-0036,I CAN'T PLAY WITH YOU LIKE A LITTLE BOY ANY MO...,True,/mnt/data/LibriSpeech24K/test-clean/237/134500...,/mnt/data/LibriSpeech24K/test-clean/237/134500...,/mnt/data/LibriSpeech24K/test-clean/237/134500...,1.90,2.12,test-clean,6.600
83,2143,4446-2273-0024,BARTLEY STARTED WHEN HILDA RANG THE LITTLE BEL...,True,/mnt/data/LibriSpeech24K/test-clean/4446/2273/...,/mnt/data/LibriSpeech24K/test-clean/4446/2273/...,/mnt/data/LibriSpeech24K/test-clean/4446/2273/...,1.83,2.03,test-clean,4.825


In [20]:
kws_dataset.to_csv('/home/cbolanos/explain_where/models/kws-librispeech/kws_dataset.csv', index=False)