In [2]:
from torch.utils.data import Dataset
import numpy as np
import torch
from dataclasses import dataclass
import pytorch_lightning as pl
from s3prl.nn import S3PRLUpstream
import librosa
import pandas as pd

class DrumDataset(Dataset):
    def __init__(self, metadata):
        super().__init__()
        self.metadata = metadata

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        x, fs = librosa.core.load(row['filename'], sr=16000)
        return {'wav': x.astype(np.float32),
                'emotion': row['num_kicks'].astype(int)}
    def __len__(self):
        return len(self.metadata)
    
class UpstreamDownstreamModel(pl.LightningModule):
    def __init__(self, upstream, num_layers=13, hidden_sizes=[128], num_classes=4, downstream='mlp', lstm_size=None):
        super().__init__()
        self.upstream = S3PRLUpstream(upstream)
        self.num_features = self.upstream.hidden_sizes[-1]
        self.upstream.eval()
        self.layer_weights = torch.nn.Parameter(torch.randn(num_layers))
        self.downstream_type = downstream
        self.num_classes = num_classes
        if downstream == 'mlp':
            mlp_ch = [self.num_features] + hidden_sizes + [num_classes]
            mlp_layers = [torch.nn.Sequential(torch.nn.Linear(chi, cho), torch.nn.ReLU()) for chi, cho in zip(mlp_ch[:-2],mlp_ch[1:-1])]
            mlp_layers += [torch.nn.Linear(mlp_ch[-2], mlp_ch[-1])]
            self.downstream = torch.nn.Sequential(*mlp_layers)
        elif downstream == 'lstm':
            self.downstream_lstm = torch.nn.LSTM(self.num_features, lstm_size, batch_first=True)
            mlp_ch = [lstm_size] + hidden_sizes + [num_classes]
            mlp_layers = [torch.nn.Sequential(torch.nn.Linear(chi, cho), torch.nn.ReLU()) for chi, cho in zip(mlp_ch[:-2],mlp_ch[1:-1])]
            mlp_layers += [torch.nn.Linear(mlp_ch[-2], mlp_ch[-1])]
            self.downstream_mlp = torch.nn.Sequential(*mlp_layers)
            

    def forward(self, x):
        with torch.no_grad():
            x['upstream_outs'], x['upstream_lens'] = self.upstream(x['wav'], x['wav_lens'])
        upstream_outs = torch.stack(x['upstream_outs'])
        layer_w = torch.softmax(self.layer_weights, dim=0)[:, None, None, None]
        upstream_pooled = torch.sum(layer_w*upstream_outs, dim=0)
        if self.downstream_type == 'mlp':
            padding_mask = torch.arange(0, upstream_pooled.shape[1], device=upstream_outs.device)[None,:] < x['upstream_lens'][0][:, None]
            upstream_pooled = torch.sum(upstream_pooled * padding_mask[:,:,None], dim=1)/torch.sum(padding_mask[:,:,None], dim=1)
            x['pooled_upstream'] = upstream_pooled
            x['padding_mask'] = padding_mask
            x['y_pred'] = self.downstream(x['pooled_upstream'])
        elif self.downstream_type == 'lstm':
            lstm_in = torch.nn.utils.rnn.pack_padded_sequence(upstream_pooled, 
                                                              x['upstream_lens'][0].to(device='cpu', dtype=torch.int64),
                                                              batch_first=True,
                                                              enforce_sorted=False)
            lstm_out, (hn, cn) = self.downstream_lstm(lstm_in)
            x['y_pred'] = self.downstream_mlp(hn[0])
        else:
            raise ValueError('Unknown downstream_type {}'.format(self.downstream_type))

In [3]:
METADATA_PATH = '/mnt/data/drum_dataset/metadata.csv'

df_metadata = pd.read_csv(METADATA_PATH)

def partition(x):
    if x[-7] == '0':
        return 'test'
    elif x[-7] == '1':
        return 'validation'
    else:
        return 'train'
df_metadata['partition'] = df_metadata['filename'].apply(partition)

df_test = df_metadata.loc[df_metadata['partition']=='test']
dataset_test = DrumDataset(df_test)

ud_model = UpstreamDownstreamModel(upstream='wav2vec2',
                                   downstream='lstm',
                                   num_layers=13, 
                                   num_classes=6,
                                   hidden_sizes=[128],
                                   lstm_size=256)

ud_model.load_state_dict(torch.load('/home/cbolanos/interpretability-benchmarks/checkpoints/drums-step15000.ckpt')['state_dict'])



<All keys matched successfully>

In [20]:
ud_model.eval()
dataframe = []

for i, xin in enumerate(dataset_test):
    xin['wav_lens'] = torch.tensor([xin['wav'].shape[0]])
    xin['wav'] = torch.from_numpy(xin['wav'])[None,:]
    
    if xin['emotion'] != 0:
        ud_model(xin)
        dataframe.append(dataset_test.metadata.iloc[i])

In [23]:
drums_dataset = pd.DataFrame(dataframe)
drums_dataset = drums_dataset.sample(frac=1, random_state=None).reset_index(drop=True)
drums_dataset

Unnamed: 0.1,Unnamed: 0,filename,hit_filenames,durations,counts,num_kicks,pattern,partition
0,5057,/mnt/data/drum_dataset/5057.wav,"[PosixPath('drumkit/kick/Bass Sample 12.wav'),...","[4080, 2464, 2704, 13920, 13728, 2080, 8960, 9...","{'K': 5, 'S': 0, 'C': 5, 'T': 0}",5,KKCKKCCCKC,test
1,1049,/mnt/data/drum_dataset/1049.wav,[PosixPath('drumkit/overheads/Overhead Sample ...,"[8512, 4416, 8256, 9760, 12224, 15408, 6944, 6...","{'K': 1, 'S': 4, 'C': 3, 'T': 2}",1,CKSSCSCTST,test
2,5059,/mnt/data/drum_dataset/5059.wav,[PosixPath('drumkit/overheads/Overhead Sample ...,"[13616, 1632, 12304, 14800, 1856, 5184, 7824, ...","{'K': 5, 'S': 1, 'C': 2, 'T': 2}",5,CSTCKKKTKK,test
3,3033,/mnt/data/drum_dataset/3033.wav,"[PosixPath('drumkit/kick/Bass Sample 29.wav'),...","[9088, 14496, 12816, 9712, 7296, 12976, 9488, ...","{'K': 3, 'S': 3, 'C': 2, 'T': 2}",3,KSKSCTSCTK,test
4,5079,/mnt/data/drum_dataset/5079.wav,"[PosixPath('drumkit/kick/Bass Sample 10.wav'),...","[8176, 4608, 13120, 9392, 1952, 9600, 14208, 1...","{'K': 5, 'S': 2, 'C': 1, 'T': 2}",5,KKTKSKTSKC,test
...,...,...,...,...,...,...,...,...
495,4024,/mnt/data/drum_dataset/4024.wav,[PosixPath('drumkit/overheads/Overhead Sample ...,"[2528, 5632, 4512, 8096, 6000, 13552, 13920, 1...","{'K': 4, 'S': 2, 'C': 2, 'T': 2}",4,CSTCKKKTKS,test
496,3072,/mnt/data/drum_dataset/3072.wav,[PosixPath('drumkit/overheads/Overhead Sample ...,"[6208, 5488, 8416, 3248, 12800, 2480, 12640, 1...","{'K': 3, 'S': 5, 'C': 1, 'T': 1}",3,CKSKSTSKSS,test
497,2036,/mnt/data/drum_dataset/2036.wav,[PosixPath('drumkit/overheads/Overhead Sample ...,"[14000, 5952, 15280, 6816, 9568, 15776, 6144, ...","{'K': 2, 'S': 3, 'C': 4, 'T': 1}",2,CKCKSCTSSC,test
498,3012,/mnt/data/drum_dataset/3012.wav,[PosixPath('drumkit/overheads/Overhead Sample ...,"[13024, 7664, 7440, 14256, 8656, 2672, 14048, ...","{'K': 3, 'S': 3, 'C': 3, 'T': 1}",3,CKSSKCSCKT,test


In [24]:
drums_dataset.to_csv('/home/cbolanos/explain_where/models/drums/drums_dataset.csv', index=False)