In [1]:
from torch.utils.data import Dataset, DataLoader
import soundfile as sf
import numpy as np
from typing import Dict, Any, List, Callable, Optional, Union
import torch
from dataclasses import dataclass
import pytorch_lightning as pl
from s3prl.nn import S3PRLUpstream
import librosa
import torchmetrics

class IEMOCAPDataset(Dataset):
    def __init__(self, metadata, class_map=['neutral','happiness','anger','sadness']):
        super().__init__()
        self.metadata = metadata
        self.class_map = {k:i for i,k in enumerate(class_map)}

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        x, fs = sf.read(row['filename'])
        return {'wav': x.astype(np.float32),
                'emotion': self.class_map[row['emotion']]}
        
    def __len__(self):
        return len(self.metadata)
    
class UpstreamDownstreamModel(pl.LightningModule):
    def __init__(self, upstream, num_layers=13, hidden_sizes=[128], num_classes=4, downstream='mlp', lstm_size=None):
        super().__init__()
        self.upstream = S3PRLUpstream(upstream)
        self.num_features = self.upstream.hidden_sizes[-1]
        self.upstream.eval()
        self.layer_weights = torch.nn.Parameter(torch.randn(num_layers))
        self.downstream_type = downstream
        self.num_classes = num_classes
        if downstream == 'mlp':
            mlp_ch = [self.num_features] + hidden_sizes + [num_classes]
            mlp_layers = [torch.nn.Sequential(torch.nn.Linear(chi, cho), torch.nn.ReLU()) for chi, cho in zip(mlp_ch[:-2],mlp_ch[1:-1])]
            mlp_layers += [torch.nn.Linear(mlp_ch[-2], mlp_ch[-1])]
            self.downstream = torch.nn.Sequential(*mlp_layers)
        elif downstream == 'lstm':
            self.downstream_lstm = torch.nn.LSTM(self.num_features, lstm_size, batch_first=True)
            mlp_ch = [lstm_size] + hidden_sizes + [num_classes]
            mlp_layers = [torch.nn.Sequential(torch.nn.Linear(chi, cho), torch.nn.ReLU()) for chi, cho in zip(mlp_ch[:-2],mlp_ch[1:-1])]
            mlp_layers += [torch.nn.Linear(mlp_ch[-2], mlp_ch[-1])]
            self.downstream_mlp = torch.nn.Sequential(*mlp_layers)
            
        if num_classes>1:
            self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
            self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        else:
            self.train_acc = torchmetrics.Accuracy(task='binary')
            self.val_acc = torchmetrics.Accuracy(task='binary')

    def forward(self, x):
        with torch.no_grad():
            x['upstream_outs'], x['upstream_lens'] = self.upstream(x['wav'], x['wav_lens'])
        upstream_outs = torch.stack(x['upstream_outs'])
        layer_w = torch.softmax(self.layer_weights, dim=0)[:, None, None, None]
        upstream_pooled = torch.sum(layer_w*upstream_outs, dim=0)
        if self.downstream_type == 'mlp':
            padding_mask = torch.arange(0, upstream_pooled.shape[1], device=upstream_outs.device)[None,:] < x['upstream_lens'][0][:, None]
            upstream_pooled = torch.sum(upstream_pooled * padding_mask[:,:,None], dim=1)/torch.sum(padding_mask[:,:,None], dim=1)
            x['pooled_upstream'] = upstream_pooled
            x['padding_mask'] = padding_mask
            x['y_pred'] = self.downstream(x['pooled_upstream'])
        elif self.downstream_type == 'lstm':
            lstm_in = torch.nn.utils.rnn.pack_padded_sequence(upstream_pooled, 
                                                              x['upstream_lens'][0].to(device='cpu', dtype=torch.int64),
                                                              batch_first=True,
                                                              enforce_sorted=False)
            lstm_out, (hn, cn) = self.downstream_lstm(lstm_in)
            x['y_pred'] = self.downstream_mlp(hn[0])
        else:
            raise ValueError('Unknown downstream_type {}'.format(self.downstream_type))

2025-02-08 00:07:27.774465: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-08 00:07:27.782209: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-08 00:07:27.790388: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-08 00:07:27.792812: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-08 00:07:27.799771: I tensorflow/core/platform/cpu_feature_guar

In [2]:
import torch
import pandas as pd

METADATA_PATH = '/mnt/data/IEMOCAP-happy-cough/metadata.csv'

df_metadata = pd.read_csv(METADATA_PATH)
df_metadata['session'] = df_metadata['filename'].apply(lambda x: int(x.split('/')[-5][-1]))
df_metadata['partition'] = df_metadata['session'].apply(lambda x: 'Train' if x<5 else 'Test')
df_test = df_metadata.loc[df_metadata['partition']=='Test']
dataset_test = IEMOCAPDataset(df_test)

ud_model_cough = UpstreamDownstreamModel(upstream='wav2vec2', 
                                   num_layers=13, 
                                   num_classes=4,
                                   hidden_sizes=[256])
ud_model_nocough = UpstreamDownstreamModel(upstream='wav2vec2', 
                                   num_layers=13, 
                                   num_classes=4,
                                   hidden_sizes=[256])

checkpoint = torch.load('/home/cbolanos/interpretability-benchmarks/checkpoints/iemocap-cough-1340.ckpt')['state_dict']
new_state_dict = {k.replace("mlp", "downstream"): v for k, v in checkpoint.items()}
ud_model_cough.load_state_dict(new_state_dict)


checkpoint1 = torch.load('/home/cbolanos/interpretability-benchmarks/checkpoints/iemocap-nocough-1340.ckpt')['state_dict']
new_state_dict1 = {k.replace("mlp", "downstream"): v for k, v in checkpoint1.items()}
ud_model_nocough.load_state_dict(new_state_dict1)




<All keys matched successfully>

In [7]:
dataframe = []

for i, xin in enumerate(dataset_test):
    if xin['emotion'] == 1:
        xtin = xin.copy()
        xtin['wav_lens'] = torch.tensor([xtin['wav'].shape[0]])
        xtin['wav'] = torch.from_numpy(xtin['wav'])[None,:]
        xin['wav_lens'] = torch.tensor([xin['wav'].shape[0]])
        xin['wav'] = torch.from_numpy(xin['wav'])[None,:]
        
        ud_model_cough(xin)
        ud_model_nocough(xtin)
        if (xin['y_pred'].detach().cpu().numpy()[0].argmax() == 1) and (xtin['y_pred'].detach().cpu().numpy()[0].argmax() != 1):
            dataframe.append(dataset_test.metadata.iloc[i])

In [10]:
cough_happy = pd.DataFrame(dataframe)
cough_happy 

Unnamed: 0.1,Unnamed: 0,filename,speech_filename,emotion,duration,cough_filename,cough_start,cough_end,session,partition
4335,4335,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,1.738125,3-145487-A-24.wav,1289.0,6851.0,5,Test
4340,4340,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,1.910000,1-52266-A-24.wav,9840.0,15952.0,5,Test
4353,4353,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,2.939938,3-151213-A-24.wav,32316.0,41723.0,5,Test
4355,4355,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,2.039938,4-171396-A-24.wav,6105.0,12632.0,5,Test
4361,4361,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,2.300000,1-19118-A-24.wav,11129.0,18489.0,5,Test
...,...,...,...,...,...,...,...,...,...,...
5468,5468,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,5.430000,3-145487-A-24.wav,23574.0,36374.0,5,Test
5469,5469,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,4.259938,3-145487-A-24.wav,37286.0,50086.0,5,Test
5472,5472,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,6.899938,1-19111-A-24.wav,75429.0,90277.0,5,Test
5474,5474,/mnt/data/IEMOCAP-happy-cough/Session5/sentenc...,/mnt/data/IEMOCAP/Session5/sentences/wav/Ses05...,happiness,2.699937,1-58792-A-24.wav,6269.0,14908.0,5,Test


In [11]:
cough_happy.to_csv('/home/cbolanos/explain_where/models/cough/cough_happy.csv', index=False)