Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference ESC-50 fine-tuned model #12

Closed
myatmyintzuthin opened this issue Apr 1, 2022 · 2 comments
Closed

Inference ESC-50 fine-tuned model #12

myatmyintzuthin opened this issue Apr 1, 2022 · 2 comments

Comments

@myatmyintzuthin
Copy link

Hello, authors.
Thank you for sharing the great work.

I tried to fine-tuned AudioSet pretrained model passt-s-f128-p16-s10-ap.476-swa.pt on ESC-50 dataset by using ex_esc50.py.
I got checkpoints saved in output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt.
I want to load the checkpoint and inference with audio file. I am trying to load the checkpoint model and tried to used passt_hear21 for inference but kinda lost track of the process.

Could you please share how to inference with the saved checkpoints on audio file?

@kkoutini
Copy link
Owner

kkoutini commented Apr 5, 2022

Hi thanks!
you can use passt_hear21 like this:

# Loading the weights
p ="output/esc50/_None/checkpoints/epoch=4-step=2669.ckpt"
ckpt = torch.load(p)
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

# getting the model
from hear21passt.base import load_model, get_scene_embeddings, get_timestamp_embeddings

model = load_model(mode="logits").cuda()
model.net.load_state_dict(net_statedict) # loading the fine-tuned weights

# example
wave_example = torch.ones((3, 32000 * 5))*0.5 
logits = model(wave_example)

@myatmyintzuthin
Copy link
Author

myatmyintzuthin commented Apr 6, 2022

Thank you so much for the reply.
This is my first time of creating inference script in PyTorch. It was a great help to me.
I am gonna share my inference script here in case someone wants to use.

# References 
# 1) https://github.com/YuanGongND/ast/blob/master/egs/audioset/inference.py
# 2) https://github.com/kkoutini/passt_hear21

import csv
import argparse
import numpy as np
import torch
import torchaudio
from pytorch_lightning import Trainer as plTrainer
from hear21passt.base import load_model

def load_label(label_csv):
    with open(label_csv, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        lines = list(reader)
    labels = []
    ids = []  # Each label has a unique id such as "/m/068hy"
    for i1 in range(1, len(lines)):
        id = lines[i1][1]
        label = lines[i1][2]
        ids.append(id)
        labels.append(label)
    return labels


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Example of parser:'
                                                 'python inference --audio_path ESC-50-master/audio_32k/1-5996-A-6.wav '
                                                 '--model_path checkpoints/epoch=2-step=4799.ckpt')

    parser.add_argument("--model_path", required= True,type=str,
                        help="the trained model you want to test")
    parser.add_argument("--audio_path", required= True,
                        help='the audio you want to predict, sample rate 32k.',
                        type=str)

    args = parser.parse_args()

    label_csv = './esc50/esc_class_labels_indices.csv'       # label and indices for ESC-50 data

    # 1. load audio file
    audio_path = args.audio_path
    waveform, _ = torchaudio.load(audio_path)

    # 2. load checkpoint
    checkpoint_path = args.model_path
   
    ckpt = torch.load(checkpoint_path)
    net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
    net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

    # 3. loading the fine-tuned weights
    passt_model = load_model(mode="logits").cuda()
    passt_model.net.load_state_dict(net_statedict)
    
    trainer = plTrainer(gpus=1)
    print(f'[*INFO] load checkpoint: {checkpoint_path}')
    
    passt_model = passt_model.to(torch.device("cuda:0"))
    waveform = waveform.to(torch.device("cuda:0"))
    
    with torch.no_grad():
        output = passt_model(waveform)
        output = torch.sigmoid(output)
    result_output = output.data.cpu().numpy()[0]

    # 4. map the post-prob to label
    labels = load_label(label_csv)

    sorted_indexes = np.argsort(result_output)[::-1]

    # Print audio tagging top probabilities
    print('[*INFO] predict results:')
    for k in range(5):
        print('{}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]],
                                  result_output[sorted_indexes[k]]))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants