In [3]:
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
from omegaconf import OmegaConf
from tqdm import tqdm
import importlib
import numpy as np
import pickle 
def load_dataset(configs):
    # call dataset, build the set by config
    dataset = getattr(importlib.import_module('dataset'), f'{configs.name}')(**configs)
    return dataset

# loading our model weights
device = torch.device('cuda')
model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True).to(device)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M" ,trust_remote_code=True)
configs = OmegaConf.load('configs/MERT.yaml')
# load demo audio and set processor
configs.data['batch_size'] = 1
configs.data['name'] = 'GTZAN'

dataset = load_dataset(configs.data)
train_loader, valid_loader, test_loader = dataset.train_loader, dataset.valid_loader, dataset.test_loader
sampling_rate = 22050
outdir = '../data/MERT_extracted/'
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
for batch in tqdm(test_loader):
    input_values, attn_mask, label, fns = batch
    input_values, attn_mask, label = input_values.to(device), attn_mask.to(device), label.to(device)
    with torch.no_grad():
        # batch['inputs']['input_values'] = batch['inputs']['input_values'].squeeze(1)
        # outputs = model(**batch['inputs'], output_hidden_states=True)
        input_values = input_values.squeeze(1)
        outputs = model(input_values= input_values, attention_mask=attn_mask, output_hidden_states=True)
        outputs['last_hidden_state'] = outputs['last_hidden_state'].mean(-2).squeeze().detach().cpu().numpy()
        outputs['hidden_states'] = np.array([h.mean(-2).detach().cpu().numpy() for h in  outputs['hidden_states']])
        outputs['filename'] = fns[0].split('/')[-1]
        outputs['label'] = label.detach().cpu().numpy()
        
        with open(f"{outdir}/{outputs['filename'].strip('.wav')}.pkl", 'wb') as f:
            pickle.dump(dict(outputs), f)
        # torch.save(outputs, f"{outdir}/{outputs['filename'].strip('.wav')}.pkl")
 

Some weights of the model checkpoint at m-a-p/MERT-v1-330M were not used when initializing MERTModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing MERTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MERTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MERTModel were not initialized from the model checkpoint at m-a-p/MERT-v1-330M and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


FileNotFoundError: [Errno 2] No such file or directory: '../data/MERT_extracted/blues.00005.pkl'

In [3]:
dict(outputs)

{'last_hidden_state': array([[-0.14504881,  0.22710714, -0.24867591, ...,  0.16390459,
          0.28312564,  0.21599707],
        [-0.19075   ,  0.27931428, -0.16414267, ...,  0.07543774,
         -0.02058431,  0.00545884],
        [-0.12118477,  0.4595118 , -0.06940861, ...,  0.0622366 ,
         -0.12402634, -0.0389233 ],
        ...,
        [-0.13360098, -0.0475815 , -0.19085309, ...,  0.03405258,
         -0.14633861, -0.13172099],
        [ 0.11046387, -0.00421953, -0.15599903, ...,  0.16261217,
         -0.17315088,  0.06167391],
        [ 0.05564832,  0.09329486, -0.01084341, ...,  0.22233868,
          0.069498  ,  0.0477774 ]], dtype=float32),
 'hidden_states': array([[[[-7.95065939e-01, -3.71167541e-01,  8.14275682e-01, ...,
            3.53135252e+00, -2.85495853e+00,  6.79931545e+00],
          [-1.29813686e-01,  3.25349689e+00,  2.81486225e+00, ...,
            1.94283485e+00,  3.86267811e-01,  7.45594311e+00],
          [ 1.84039307e+00,  3.69022441e+00, -8.72681797e-01

In [2]:
outputs

NameError: name 'outputs' is not defined