In [1]:
class_idx2MIDIClass_path = "hw1/class_idx2MIDIClass.json"

In [24]:
inference_audio_path = "hw1/test_track/"

In [108]:
import json
import numpy as np
import os

# huggingface
# from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
from datasets import load_dataset
import torchaudio
import nnAudio

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

import torch
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import classification_report, accuracy_score, f1_score

from glob import glob

import pickle

In [4]:
device = torch.device("mps" if torch.cuda.is_available() else "cpu")

In [5]:
SAMPLE_RATE = 24000

In [6]:
class EmbeddingDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label, embedding = self.data[idx]
        embedding = torch.tensor(embedding, dtype=torch.float32)  # Shape: [13, 768]
        return embedding, torch.tensor(label, dtype=torch.float32)

class MultiClassClassifier(nn.Module):
    def __init__(self, input_size, num_classes, thresholds=None):
        super(MultiClassClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size[0] * input_size[1], 512)  # Adjusted input size
        self.fc2 = nn.Linear(512, num_classes)
        self.sigmoid = nn.Sigmoid()
        self.thresholds = thresholds if thresholds is not None else [0.5] * num_classes

    def forward(self, x):
        # x shape: [batch_size, 13, 768]
        x = x.view(x.size(0), -1)  # Flatten, shape: [batch_size, 13 * 768]
        x = torch.relu(self.fc1(x))  # Fully connected layer with ReLU activation
        x = self.fc2(x)  # Output layer
        return self.sigmoid(x)  # Apply sigmoid to get probabilities

    def predict(self, x, thresholds=None):
        if thresholds is None:
            thresholds = self.thresholds
        with torch.no_grad():
            probabilities = self.forward(x)
            return (probabilities >= torch.tensor(thresholds).to(probabilities.device)).float()  # Apply thresholds to get binary output

In [7]:
def train(model, criterion, optimizer, train_loader, num_epochs):
    model.train()
    total_batches = len(train_loader)
    for epoch in range(num_epochs):
        for batch_num, (inputs, labels) in enumerate(train_loader, 1):
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to the device
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_num}/{total_batches}], Loss: {loss.item():.4f}')
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [8]:
def evaluate(model, val_loader, thresholds):
    num_classes = len(thresholds)
    model.eval()
    best_thresholds = [0.5] * num_classes
    best_scores = [0] * num_classes
    
    for i in range(num_classes):
        for threshold in thresholds:
            all_labels = []
            all_preds = []
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                preds = model.predict(inputs, thresholds=[threshold] * num_classes)
                all_labels.append(labels.cpu().numpy()[:, i])
                all_preds.append(preds.cpu().numpy()[:, i])
            all_labels = np.concatenate(all_labels, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
            score = f1_score(all_labels, all_preds)  # Use F1-score for evaluation
            if score > best_scores[i]:
                best_scores[i] = score
                best_thresholds[i] = threshold
    
    return best_thresholds, best_scores

In [9]:
def test(model, test_loader, thresholds, class_idx2MIDIClass):
    model.eval()
    all_labels = []
    all_preds = []
    
    # Collect predictions and true labels
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        preds = model.predict(inputs, thresholds)
        all_labels.append(labels.cpu().numpy())
        all_preds.append(preds.cpu().numpy())
    
    # Concatenate all predictions and labels
    all_labels = np.concatenate(all_labels, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)

    
    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    
    # Generate classification report
    target_names = [class_idx2MIDIClass[str(i)] for i in range(len(class_idx2MIDIClass))]
    report = classification_report(all_labels, all_preds, target_names=target_names, zero_division=0)
    
    # Print accuracy and classification report
    print(f'Accuracy: {accuracy:.4f}')
    print(report)
    
    return accuracy, report

In [11]:
batch_size = 16
num_epochs = 10
input_size = (13, 768)
num_classes = 9

In [21]:
# read the best thresholds
with open('best_threshold.json', 'r') as f:
    thresholds = json.load(f)

# Instantiate the model architecture
model = MultiClassClassifier(input_size, num_classes, thresholds=thresholds).to(device)

# Load the model
model.load_state_dict(torch.load('MERT_model_different_threshold.pth'))

  model.load_state_dict(torch.load('MERT_model_different_threshold.pth'))


<All keys matched successfully>

In [22]:
print(model.thresholds)

[0.1, 0.5, 0.5, 0.1, 0.1, 0.1, 0.5, 0.1, 0.5]


In [57]:
midi_path_list = glob(os.path.join(inference_audio_path, '*.flac'))
print(midi_path_list)

# sort the midi_path_list
midi_path_list.sort()

['hw1/test_track/Track01937.flac', 'hw1/test_track/Track02024.flac', 'hw1/test_track/Track02100.flac', 'hw1/test_track/Track02078.flac', 'hw1/test_track/Track01876.flac']


In [84]:
# for each song, cut the song into 5 seconds and embed it

# Load the audio file

SAMPLE_RATE = 44100

wav_file = []
for file in midi_path_list:
    waveform, sr = torchaudio.load(file)
    song = []
    # cut the song into 5 seconds， ignore the last clip that is less than 5 seconds
    for i in range(waveform.shape[1]//SAMPLE_RATE//5):
        clip = waveform[:, i*SAMPLE_RATE*5:(i+1)*SAMPLE_RATE*5]
        # cast clip type back fro tensor
        clip = clip.numpy()
        # flatten the clip
        clip = clip.flatten()
        song.append(clip)
    wav_file.append(song)

print(wav_file[0][0])

(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)
(220500,)


In [86]:
# if GPU is available, use it, otherwise use CPU
device = torch.device("mps" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
# torch.cuda.empty_cache()

# loading our model weights
# model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
MERT_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True).to(device)
# loading the corresponding preprocessor config
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True)

# happen to be 24kHz, the same as the dataset
resample_rate = processor.sampling_rate

resampler = T.Resample(SAMPLE_RATE, resample_rate)

# (label, embedding)
inference_embedding = []

# use tqdm to show the progress
# process the data in batches, or the kernel will die
for i in range(len(wav_file)):
    song_embedding = []
    for clip in tqdm(wav_file[i]):
        input_audio = resampler(torch.from_numpy(clip)).float().to(device)
        # input_audio = torch.tensor(audio).float()
        inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt").to(device)

        
        with torch.no_grad():
            outputs = MERT_model(**inputs, output_hidden_states=True)
        all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() # (13, 374, 768)
        time_reduced_hidden_states = all_layer_hidden_states.mean(-2) # (13, 768)

        song_embedding.append(time_reduced_hidden_states)
    inference_embedding.append(song_embedding)

Some weights of the model checkpoint at m-a-p/MERT-v1-95M were not used when initializing MERTModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing MERTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MERTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MERTModel were not initialized from the model checkpoint at m-a-p/MERT-v1-95M and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
10

In [90]:
# Inference with the best threshold
sample_input = torch.randn(1, *input_size).to(device)  # Ensure the input is on the same device as the model
print(sample_input.shape)
binary_output = model.predict(sample_input)
print(binary_output)

torch.Size([1, 13, 768])
tensor([[1., 0., 0., 1., 1., 0., 0., 1., 0.]])


In [102]:
embedding = []

for song in inference_embedding:
    song_embedding = []
    for clip in song:
        # inference 
        # make the input shape from (13, 768) to (1, 13, 768)
        clip = clip.unsqueeze(0)
        output = model.predict(clip)
        output = output.cpu().numpy()
        # turn from (1, 9) to (9)
        output = output[0]
        song_embedding.append(output)
    embedding.append(song_embedding)

In [109]:
# check if hw1/embedding folder exists
if not os.path.exists('hw1/embedding'):
    os.makedirs('hw1/embedding')

for i in range(len(embedding)):
    # dump the embedding into pickle
    filename = midi_path_list[i].split('/')[-1].split('.')[0]
    with open(f'hw1/embedding/{filename}.pkl', 'wb') as f:
        pickle.dump(embedding[i], f)