In [2]:
# setup
import os
import json
import gc
from typing import Optional
from datetime import datetime
from tqdm import tqdm
import numpy as np
import soundfile as sf
from glob import glob
import librosa
from sklearn.metrics import precision_score, f1_score, classification_report
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import warnings

from model import CNNClassifier, ShortChunkCNN_Res

# TODO: change the file path
TRAIN_FILE_DIR = "./hw1/slakh/train"
VALID_FILE_DIR = "./hw1/slakh/validation"
TEST_FILE_DIR = "./hw1/slakh/test"
TRAIN_LABEL_PATH = "./hw1/slakh/train_labels.json"
VALID_LABEL_PATH = "./hw1/slakh/validation_labels.json"
TEST_LABEL_PATH = "./hw1/slakh/test_labels.json"
RANDOM_SEED = 0

LABELS = ['Piano', 'Percussion', 'Organ', 'Guitar', 'Bass', 'Strings', 'Voice', 'Wind Instruments', 'Synth']
warnings.filterwarnings("ignore", category=UserWarning)



  from .autonotebook import tqdm as notebook_tqdm




  return self.fget.__get__(instance, owner)()


In [3]:
# DEVICE: GPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-params
EPOCHS = 30
PATIENCE = 10
BATCH_SIZE = 32  # 64
LR = 1e-3  # 1e-5
THRESHOLD = 0.7543 # current best

# for model
N_CHANNELS = 128  # 256


In [4]:
# Load pretrained pre-processor and model
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
MERT_model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)

# Freeze the pretrained model's parameters
for param in MERT_model.parameters():
    param.requires_grad = False




In [5]:
class AudioDataset(Dataset):
    def __init__(self, wav_directory: str, label_directory: str):
        """
        Args:
            directory (string): Path to the directory with all the .npy files.
        """
        self.directory = wav_directory
        self.files = os.listdir(wav_directory)  # List of all .npy files in the directory
        with open(label_directory, "r") as f:
            self.labels = json.load(f)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.directory, self.files[idx])
        audio_wave = np.load(file_path)
        label = np.array(self.labels[self.files[idx]], dtype=np.float32)
        return audio_wave, label


train_dataset = AudioDataset(TRAIN_FILE_DIR, TRAIN_LABEL_PATH)
valid_dataset = AudioDataset(VALID_FILE_DIR, VALID_LABEL_PATH)
test_dataset = AudioDataset(TEST_FILE_DIR, TEST_LABEL_PATH)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [6]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [7]:
def save_checkpoint(model, path, verbose = False):
    postfix = datetime.now().strftime("%m%d-%H-%M")
    if path is None:
        path = f"DL_model_{postfix}.pt"
    torch.save(model.state_dict(), path)
    if verbose:
        print(f"model successfully saved to {path}")

### Train DL model (ShortChunkCNN_Res)

In [8]:
def train(
        train_dataloader: DataLoader, 
        valid_dataloader: DataLoader,
        MERT_model: AutoModel = MERT_model,
        processor: Wav2Vec2FeatureExtractor = processor,
        threshold: float = THRESHOLD,
        model_path: Optional[str] = None,
        verbose: bool = True
    ):

    # Training and Validation Record
    train_loss_list = []
    valid_loss_list = []

    # Early Stopping
    best_loss = np.Inf
    best_score = 0
    cnt = 0

    # Model
    model = ShortChunkCNN_Res(n_channels=N_CHANNELS)

    # Optimizer
    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Training Start!!!
    print("Training Start!!!")
    for epoch in tqdm(range(EPOCHS), disable=(not verbose)):
        MERT_model = MERT_model.to(DEVICE)
        model = model.to(DEVICE)
        model.train()

        total_loss = 0
        train_true = torch.tensor([])
        train_pred_p = torch.tensor([])

        # Training with batches
        for i, (train_wavs, train_label) in enumerate(train_dataloader):
            train_wavs = train_wavs.cpu().numpy()
            inputs = processor(train_wavs, sampling_rate=24000, return_tensors="pt")
            inputs = inputs.to(DEVICE)
            train_label = train_label.to(DEVICE)

            # pre-trained model
            outputs = MERT_model(**inputs)
            pretrained_output = outputs.last_hidden_state # [batch_size, 374 time, 1024 feature_dim]

            # Trainable classifier
            optimizer.zero_grad()
            output = model(pretrained_output)
            loss = loss_fn(output, train_label)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Calculate Score
            train_label = train_label.cpu()
            output = output.cpu()
            train_true = torch.cat([train_true, train_label])
            train_pred_p = torch.cat([train_pred_p, output])

            # Delete Var
            del train_wavs, train_label, inputs, outputs, pretrained_output, output
            gc.collect()
        
        train_pred = (train_pred_p > threshold).float()
        train_score = f1_score(train_true, train_pred, average="macro")
        train_loss_list.append(total_loss)

        # Validation
        model.eval()
        valid_total_loss = 0
        valid_true = torch.tensor([])
        valid_pred_p = torch.tensor([])

        with torch.no_grad():
            for j, (valid_wavs, valid_label) in enumerate(valid_dataloader):
                valid_wavs = valid_wavs.cpu().numpy()
                inputs = processor(valid_wavs, sampling_rate=24000, return_tensors="pt")
                inputs = inputs.to(DEVICE)
                valid_label = valid_label.to(DEVICE)

                # pre-trained model      
                outputs = MERT_model(**inputs)
                pretrained_output = outputs.last_hidden_state # [batch_size, 374 time, 1024 feature_dim]

                output = model(pretrained_output)
                loss = loss_fn(output, valid_label)
                valid_total_loss += loss.item()

                # Calculate Score
                valid_label = valid_label.cpu()
                output = output.cpu()
                valid_true = torch.cat([valid_true, valid_label])
                valid_pred_p = torch.cat([valid_pred_p, output])

                # Delete Var
                # del valid_wavs, valid_label, inputs, outputs, output, all_layer_hidden_states, time_reduced_hidden_states
                del valid_wavs, valid_label, inputs, output
                gc.collect()

        valid_pred = (valid_pred_p > threshold).float()
        valid_score = f1_score(valid_true, valid_pred, average="macro")
        valid_loss_list.append(valid_total_loss)

        print(f"Epoch {epoch+1}: train loss: {total_loss:.4f}, train score: {train_score:.4f} || valid loss: {valid_total_loss:.4f}, valid score: {valid_score:.4f}. Threshold: {threshold}")

        # Delete Var
        del train_true, train_pred, train_pred_p, valid_true, valid_pred, valid_pred_p
        gc.collect()

        # for early stopping
        if valid_score <= best_score:    #valid_total_loss >= best_loss:
            cnt += 1
            if cnt >= PATIENCE:
                # print(f"Early Stopping at epoch: {epoch+1}, the best valid loss = {best_loss:.4f}")
                print(f"Early Stopping at epoch: {epoch+1}, the best scores = {best_score:.4f}")
                break
        else:
            # best_loss = valid_total_loss
            best_score = valid_score
            cnt = 0
            save_checkpoint(model, model_path)
    
    print(f"Training complete!")
    best_model = ShortChunkCNN_Res(n_channels=N_CHANNELS)
    best_model.load_state_dict(torch.load(model_path))

    # Clear GPU memory
    torch.cuda.empty_cache()

    return best_model, train_loss_list, valid_loss_list


def test(
        test_dataloader: DataLoader, 
        model: ShortChunkCNN_Res,
        processor: Wav2Vec2FeatureExtractor = processor,
        MERT_model: AutoModel = MERT_model,
        threshold: float = THRESHOLD,
        verbose: bool = True
    ):

    # Test start
    MERT_model = MERT_model.to(DEVICE)
    model = model.to(DEVICE)
    model.eval()
    loss_fn = nn.BCELoss()
    test_total_loss = 0
    test_true = torch.tensor([])
    test_pred_p = torch.tensor([])

    with torch.no_grad():
        for test_wavs, test_label in tqdm(test_dataloader, disable=(not verbose)):
            test_wavs = test_wavs.cpu().numpy()
            inputs = processor(test_wavs, sampling_rate=24000, return_tensors="pt")
            inputs = inputs.to(DEVICE)
            test_label = test_label.to(DEVICE)

            # pre-trained model      
            outputs = MERT_model(**inputs)
            pretrained_output = outputs.last_hidden_state # [batch_size, time, 1024 feature_dim]

            output = model(pretrained_output)
            loss = loss_fn(output, test_label)
            test_total_loss += loss.item()

            # Calculate Score
            test_label = test_label.cpu()
            output = output.cpu()
            test_true = torch.cat([test_true, test_label])
            test_pred_p = torch.cat([test_pred_p, output])

            # Delete Var
            del test_wavs, test_label, inputs, outputs, pretrained_output, output
            gc.collect()

    test_pred = (test_pred_p > threshold).float()
    test_score = precision_score(test_true, test_pred, average="macro")
    test_score_f1 = f1_score(test_true, test_pred, average="macro")

    print(f"Macro Precision: {test_score:.4f}")
    print(f"Macro F1-score: {test_score_f1:.4f}")
    if verbose:
        report = classification_report(test_true, test_pred, target_names=LABELS)
        print("Classification Report:\n", report)

    return test_true, test_pred_p


### Train Model

In [8]:
model_path = "DL_model_f1.pt"
model, train_loss, valid_loss = train(train_dataloader, valid_dataloader, model_path=model_path)

Training Start!!!


  3%|▎         | 1/30 [09:01<4:21:38, 541.33s/it]

Epoch 1: train loss: 229.6959, train score: 0.3252 || valid loss: 52.2525, valid score: 0.3816. Threshold: 0.7543
Epoch 2: train loss: 206.0148, train score: 0.3925 || valid loss: 50.3212, valid score: 0.4344. Threshold: 0.7543


  7%|▋         | 2/30 [18:01<4:12:24, 540.88s/it]

Epoch 3: train loss: 189.4072, train score: 0.4620 || valid loss: 52.9554, valid score: 0.4928. Threshold: 0.7543


 13%|█▎        | 4/30 [36:05<3:54:34, 541.31s/it]

Epoch 4: train loss: 173.1609, train score: 0.5460 || valid loss: 87.2296, valid score: 0.4624. Threshold: 0.7543
Epoch 5: train loss: 155.0965, train score: 0.6243 || valid loss: 52.8863, valid score: 0.5505. Threshold: 0.7543


 20%|██        | 6/30 [54:09<3:36:43, 541.81s/it]

Epoch 6: train loss: 133.4534, train score: 0.7077 || valid loss: 61.2099, valid score: 0.5437. Threshold: 0.7543
Epoch 7: train loss: 108.6969, train score: 0.7813 || valid loss: 60.6404, valid score: 0.5890. Threshold: 0.7543


 23%|██▎       | 7/30 [1:03:10<3:27:33, 541.46s/it]

Epoch 8: train loss: 86.1249, train score: 0.8418 || valid loss: 68.5253, valid score: 0.5963. Threshold: 0.7543


 27%|██▋       | 8/30 [1:12:12<3:18:36, 541.67s/it]

Epoch 9: train loss: 65.6533, train score: 0.8845 || valid loss: 78.1392, valid score: 0.5970. Threshold: 0.7543


 30%|███       | 9/30 [1:21:13<3:09:32, 541.54s/it]

Epoch 10: train loss: 52.7591, train score: 0.9109 || valid loss: 106.6895, valid score: 0.6073. Threshold: 0.7543


 37%|███▋      | 11/30 [1:39:18<2:51:36, 541.90s/it]

Epoch 11: train loss: 42.7240, train score: 0.9309 || valid loss: 99.7268, valid score: 0.5822. Threshold: 0.7543


 40%|████      | 12/30 [1:48:19<2:42:31, 541.77s/it]

Epoch 12: train loss: 37.5315, train score: 0.9417 || valid loss: 106.7952, valid score: 0.5855. Threshold: 0.7543
Epoch 13: train loss: 33.5149, train score: 0.9493 || valid loss: 167.3739, valid score: 0.6302. Threshold: 0.7543


 47%|████▋     | 14/30 [2:06:21<2:24:21, 541.37s/it]

Epoch 14: train loss: 30.1994, train score: 0.9546 || valid loss: 135.1796, valid score: 0.6265. Threshold: 0.7543


 50%|█████     | 15/30 [2:15:23<2:15:20, 541.37s/it]

Epoch 15: train loss: 26.3084, train score: 0.9620 || valid loss: 127.6240, valid score: 0.5916. Threshold: 0.7543
Epoch 16: train loss: 24.1216, train score: 0.9645 || valid loss: 125.7990, valid score: 0.6387. Threshold: 0.7543


 57%|█████▋    | 17/30 [2:33:25<1:57:15, 541.19s/it]

Epoch 17: train loss: 23.1786, train score: 0.9661 || valid loss: 146.8373, valid score: 0.6151. Threshold: 0.7543


 60%|██████    | 18/30 [2:42:26<1:48:14, 541.23s/it]

Epoch 18: train loss: 22.4957, train score: 0.9679 || valid loss: 163.7165, valid score: 0.6163. Threshold: 0.7543
Epoch 19: train loss: 20.1240, train score: 0.9720 || valid loss: 131.0082, valid score: 0.6388. Threshold: 0.7543


 67%|██████▋   | 20/30 [3:00:29<1:30:12, 541.26s/it]

Epoch 20: train loss: 19.0185, train score: 0.9733 || valid loss: 211.1457, valid score: 0.6373. Threshold: 0.7543
Epoch 21: train loss: 18.2295, train score: 0.9742 || valid loss: 145.0899, valid score: 0.6436. Threshold: 0.7543


 73%|███████▎  | 22/30 [3:18:31<1:12:08, 541.07s/it]

Epoch 22: train loss: 17.4674, train score: 0.9755 || valid loss: 151.3711, valid score: 0.6432. Threshold: 0.7543


In [9]:
model_path = "DL_model_f1.pt"
model = ShortChunkCNN_Res(n_channels=N_CHANNELS)
model.load_state_dict(torch.load(model_path))

y_true, y_pred_p = test(valid_dataloader, model)


100%|██████████| 118/118 [01:38<00:00,  1.20it/s]


Macro Precision: 0.6771
Macro F1-score: 0.6496
Classification Report:
                   precision    recall  f1-score   support

           Piano       0.90      0.97      0.93      3237
      Percussion       0.44      0.23      0.30       383
           Organ       0.47      0.48      0.47       671
          Guitar       0.92      0.84      0.88      3194
            Bass       0.96      0.99      0.97      3471
         Strings       0.63      0.91      0.74      1930
           Voice       0.66      0.33      0.44       939
Wind Instruments       0.56      0.67      0.61      1599
           Synth       0.57      0.45      0.50      1074

       micro avg       0.78      0.80      0.79     16498
       macro avg       0.68      0.65      0.65     16498
    weighted avg       0.79      0.80      0.79     16498
     samples avg       0.78      0.80      0.78     16498



In [10]:
# Search best threshold on valid set
thresholds = np.logspace(-1, 0, 50)
best_score = 0
best_threshold = None
for t in thresholds:
    y_pred = (y_pred_p > t).float()
    test_score = precision_score(y_true, y_pred, average="macro")
    if test_score > best_score:
        best_threshold = t
        best_score = test_score

y_pred = (y_pred_p > best_threshold).float()
print(f"Threshold: {best_threshold:.4f}, best score: {best_score:.4f}")
report = classification_report(y_true, y_pred, target_names=LABELS)
print("\nClassification Report:\n", report)


Threshold: 0.9541, best score: 0.7207

Classification Report:
                   precision    recall  f1-score   support

           Piano       0.90      0.95      0.92      3237
      Percussion       0.56      0.18      0.27       383
           Organ       0.54      0.39      0.45       671
          Guitar       0.94      0.77      0.84      3194
            Bass       0.96      0.98      0.97      3471
         Strings       0.66      0.88      0.76      1930
           Voice       0.72      0.25      0.37       939
Wind Instruments       0.58      0.57      0.58      1599
           Synth       0.63      0.38      0.47      1074

       micro avg       0.82      0.76      0.79     16498
       macro avg       0.72      0.60      0.63     16498
    weighted avg       0.81      0.76      0.77     16498
     samples avg       0.81      0.76      0.77     16498



In [11]:
model_path = "DL_model_f1.pt"
model = ShortChunkCNN_Res(n_channels=N_CHANNELS)
model.load_state_dict(torch.load(model_path))

test_true, test_pred_p = test(test_dataloader, model, threshold=0.9541)


100%|██████████| 71/71 [00:58<00:00,  1.21it/s]

Macro Precision: 0.7178
Macro F1-score: 0.6202
Classification Report:
                   precision    recall  f1-score   support

           Piano       0.88      0.96      0.92      1889
      Percussion       0.59      0.16      0.25       243
           Organ       0.56      0.31      0.40       461
          Guitar       0.93      0.80      0.86      1943
            Bass       0.96      0.99      0.97      2076
         Strings       0.72      0.87      0.79      1235
           Voice       0.69      0.30      0.42       485
Wind Instruments       0.55      0.55      0.55       889
           Synth       0.58      0.34      0.43       647

       micro avg       0.82      0.76      0.79      9868
       macro avg       0.72      0.59      0.62      9868
    weighted avg       0.81      0.76      0.77      9868
     samples avg       0.81      0.76      0.77      9868






In [11]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1299 MiB |   5925 MiB | 540431 GiB | 540430 GiB |
|       from large pool |   1293 MiB |   5909 MiB | 539032 GiB | 539031 GiB |
|       from small pool |      5 MiB |     24 MiB |   1399 GiB |   1399 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   1299 MiB |   5925 MiB | 540431 GiB | 540430 GiB |
|       from large pool |   1293 MiB |   5909 MiB | 539032 GiB | 539031 GiB |
|       from small pool |      5 MiB |     24 MiB |   1399 GiB |   1399 GiB |
|---------------------------------------------------------------

### For Evaluation

In [18]:
def get_prediction(
        flac_file_path: str,
        model: ShortChunkCNN_Res,
        processor: Wav2Vec2FeatureExtractor = processor,
        MERT_model: AutoModel = MERT_model,
        threshold: float = THRESHOLD,
        save_file: bool = True,
    ):
    name = flac_file_path.split('/')[-1].split('.')[0]
    a, sr = sf.read(flac_file_path)
    n = librosa.resample(a, orig_sr=sr, target_sr=24000)
    n = n[:-(n.shape[0]%120000)]  # remove trailing
    n = n.reshape(((n.shape[0]//120000), 120000))  # reshape into 5 second
    inputs = processor(n, sampling_rate=24000, return_tensors="pt")
    inputs = inputs.to(DEVICE)
    MERT_model = MERT_model.to(DEVICE)
    model = model.to(DEVICE)

    # pre-trained model
    with torch.no_grad():
        outputs = MERT_model(**inputs)
        pretrained_output = outputs.last_hidden_state # [batch_size, time, 1024 feature_dim]
        output = model(pretrained_output)

    output = output.cpu()
    output = (output > threshold).float()
    output = output.numpy().T

    if save_file:
        np.save(f"./hw1/test_track/{name}.npy", output)
        print(f"File {name}.npy successfully saved. dim={output.shape}")

    return output


In [19]:
model_path = "DL_model_f1.pt"
model = ShortChunkCNN_Res(n_channels=N_CHANNELS)
model.load_state_dict(torch.load(model_path))

audio_path_list = glob(os.path.join("./hw1/test_track", "*.flac"))
for file in audio_path_list:
    o = get_prediction(file, model=model, threshold=0.9541)


File Track01937.npy successfully saved. dim=(9, 40)
File Track01876.npy successfully saved. dim=(9, 51)
File Track02100.npy successfully saved. dim=(9, 45)
File Track02078.npy successfully saved. dim=(9, 43)
File Track02024.npy successfully saved. dim=(9, 49)
