In [1]:
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
from tqdm import tqdm

import librosa
import numpy as np
import torch

from panns_inference import AudioTagging, SoundEventDetection, labels
import whisper


def find_audios(parent_dir, exts=['.wav', '.mp3', '.flac', '.webm', '.mp4', '.m4a']):
    audio_files = []
    for root, dirs, files in os.walk(parent_dir):
        for file in files:
            if os.path.splitext(file)[1] in exts:
                audio_files.append(os.path.join(root, file))
    return audio_files


#################### PANNs ####################

def load_panns(device='cuda'):
    model = AudioTagging(checkpoint_path=None, device=device)
    return model

@torch.no_grad()
def tag_audio(model, audio_path):
    (audio, _) = librosa.core.load(audio_path, sr=32000, mono=True)
    # only use the first 30 seconds
    audio = audio[None, :30*32000]
    (clipwise_output, embedding) = model.inference(audio)
    tags, probs = get_audio_tagging_result(clipwise_output[0])
    return tags, probs


def get_audio_tagging_result(clipwise_output):
    """Visualization of audio tagging result.
    Args:
      clipwise_output: (classes_num,)
    """
    sorted_indexes = np.argsort(clipwise_output)[::-1]

    tags = []
    probs = []
    for k in range(10):
        tag = np.array(labels)[sorted_indexes[k]]
        prob = clipwise_output[sorted_indexes[k]]
        tags.append(tag)
        probs.append(float(prob))

    return tags, probs 


def is_vocal(tags, probs, threshold=0.08):
    pos_tags = {'Speech', 'Singing', 'Rapping'}
    for tag, prob in zip(tags, probs):
        if tag in pos_tags and prob > threshold:
            return True
    return False


#################### Whisper ####################


def load_whisper(model="large"):
    model = whisper.load_model(model, in_memory=True)
    return model


def transcribe_and_save(whisper_model, panns_model, args):
    """transcribe the audio, and save the result with the same relative path in the output_dir
    """
    audio_files = find_audios(args.input_dir)

    if args.n_shard > 1:
        print(f'processing shard {args.shard_rank} of {args.n_shard}')
        audio_files.sort() # make sure no intersetction
        audio_files = audio_files[args.shard_rank * len(audio_files) // args.n_shard : (args.shard_rank + 1) * len(audio_files) // args.n_shard] 

    for file in tqdm(audio_files):
        output_file = os.path.join(args.output_dir, os.path.relpath(file, args.input_dir))
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        results = []
        try:
            tags, probs = tag_audio(panns_model, file)

            if args.threshold == 0. or is_vocal(tags, probs, threshold=args.threshold):
                if args.debug:
                    print(file)
                    for tag, prob in zip(tags, probs):
                        print(f'{tag}: {prob}')
                    continue

                ## generate 5 different transcription by varying the temperature
                for i in range(args.top_n_sample):
                    result = whisper.transcribe(whisper_model, file, language=args.language, initial_prompt=args.prompt,
                                               temperature=(0.5 + 0.1 * i))
                    result['tags_with_probs'] = [{'tag': tag, 'prob': prob} for tag, prob in zip(tags, probs)]
                    results.append(result)
                with open(output_file + '.json', 'w') as f:
                    json.dump(results, f, indent=4, ensure_ascii=False)
            else:
                print(f'no vocal in {file}')
                if args.debug:
                    for tag, prob in zip(tags, probs):
                            print(f'{tag}: {prob}')
        except Exception as e:
            print(e)
            continue

In [2]:
class args:
    model = 'large-v3'
    prompt = 'lyrics: '
    language = 'vi'
    input_dir = './sample'
    output_dir = './results'
    n_shard = 1
    shard_rank = 0
    threshold = 0
    debug = False
    top_n_sample = 2

In [3]:
whisper_model = load_whisper(args.model)
panns_model = load_panns()

Checkpoint path: /home/kelvinsoh/panns_data/Cnn14_mAP=0.431.pth
GPU number: 1


In [4]:
transcribe_and_save(whisper_model, panns_model, args)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:02<00:00, 31.40s/it]


In [5]:
import json

with open('results/en_sample1.mp3.json', 'r') as file:
    # Load the JSON data
    data = json.load(file)

predictions_dict = {}
for i in range(len(data)):
    predictions_dict[f'prediction_{i}'] = data[i]['text']
    print(f"sample #{i}:{data[i]['text']}")

sample #0: Shall we go for a walk?
sample #1: Shall we go for a walk?


### OpenAI

In [6]:
from openai import OpenAI
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
import re
import os

from dotenv import load_dotenv
load_dotenv('.env')
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)

In [7]:
instruction_prompt = """
Task: As a GPT-4 based lyrics transcription post-processor,
your task is to analyze multiple ASR model-generated versions
of a Vietnamese song’s lyrics and determine the most accurate version
closest to the true lyrics. Also filter out invalid lyrics
when all predictions are nonsense.
Input: The input is in JSON format:

{“prediction_1”: “line1;line2;...”, ...}
Output: Your output must be strictly in readable JSON format
without any extra text:
{
“reasons”: “reason1;reason2;...”,
“closest_prediction”: <key_of_prediction>
“output”: “line1;line2...”
}
Requirements: For the "reasons" field, you have to provide
a reason for the choice of the "closest_prediction" field. For
the "closest_prediction" field, choose the prediction key that
is closest to the true lyrics. Only when all predictions greatly
differ from each other or are completely nonsense or meaningless,
which means that none of the predictions is valid,
fill in "None" in this field. For the "output" field, you need
to output the final lyrics of closest_prediction. If the "closest_
prediction" field is "None", you should also output "None"
in this field. The language of the input lyrics is English.
"""

In [8]:
context = f'''{instruction_prompt}

{predictions_dict}
'''.strip()

In [9]:
print(context)

Task: As a GPT-4 based lyrics transcription post-processor,
your task is to analyze multiple ASR model-generated versions
of a Vietnamese song’s lyrics and determine the most accurate version
closest to the true lyrics. Also filter out invalid lyrics
when all predictions are nonsense.
Input: The input is in JSON format:

{“prediction_1”: “line1;line2;...”, ...}
Output: Your output must be strictly in readable JSON format
without any extra text:
{
“reasons”: “reason1;reason2;...”,
“closest_prediction”: <key_of_prediction>
“output”: “line1;line2...”
}
Requirements: For the "reasons" field, you have to provide
a reason for the choice of the "closest_prediction" field. For
the "closest_prediction" field, choose the prediction key that
is closest to the true lyrics. Only when all predictions greatly
differ from each other or are completely nonsense or meaningless,
which means that none of the predictions is valid,
fill in "None" in this field. For the "output" field, you need
to output the 

In [10]:
chat_completion = client.chat.completions.create(
    model="gpt-4o",#"gpt-4-0613",
    messages=[{"role": "user", "content": context}],
    stream=False)
response = chat_completion.choices[0].message.content

In [20]:
json.loads(response.replace("```json", "").replace("```", "").strip())

{'reasons': 'All predictions are identical and make sense.',
 'closest_prediction': 'prediction_0',
 'output': 'Shall we go for a walk?'}