In [1]:
from utils.datasets import ASVSpoof21Dataset
import torch
import numpy as np

In [39]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import os
import librosa
import yaml
from transformers import AutoTokenizer, DebertaV2Tokenizer

SUPPORTED_FORMATS = ["wav", "mp3", "flac"]


def load_audio(filename, sampling_rate=None):
    # Load audio file
    assert os.path.exists(filename), f"File {filename} does not exist"
    assert (
        filename.split(".")[-1] in SUPPORTED_FORMATS
    ), f"File {filename} is not supported"

    # Load audio file
    audio, sr = librosa.load(filename, sr=None)
    return audio, sr


def pad(x, max_len):
    """
    From src/baselines/asvspoof2021/DF/Baseline-RawNet2/data_utils.py
    """
    x_len = x.shape[0]
    if x_len >= max_len:
        return x[:max_len]
    # need to pad
    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
    return padded_x


def get_spoof_list(meta_dir, is_train=False, is_eval=False):
    d_meta = {}
    file_list = []
    with open(meta_dir, "r") as f:
        l_meta = f.readlines()

    if is_train:
        for line in l_meta:
            # _, key, _, _, label = line.strip().split(" ")

            key, label = line.split(" ")[1], line.split(" ")[5]
            file_list.append(key)
            d_meta[key] = 1 if label == "bonafide" else 0
        return d_meta, file_list

    elif is_eval:
        for line in l_meta:
            key = line.strip()
            file_list.append(key)
        return None, file_list
    else:
        for line in l_meta:
            key, label = line.split(" ")[1], line.split(" ")[5]
            file_list.append(key)
            d_meta[key] = 1 if label == "bonafide" else 0
        return d_meta, file_list


class ASVSpoof21Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dir,
        meta_dir,
        is_train=False,
        is_eval=False,
        sampling_rate=16000,
        max_duration=4,
        get_transcription=False,
    ):
        self.sampling_rate = sampling_rate
        self.max_duration = max_duration
        self.cut = self.sampling_rate * self.max_duration  # padding
        self.meta, self.list_IDs = get_spoof_list(meta_dir, is_train, is_eval)
        self.root_dir = root_dir

    def __len__(self):
        return len(self.list_IDs)

    def load_audio_tensor(self, key):
        filename = os.path.join(self.root_dir, f"flac/{key}.flac")
        audio_arr, _ = load_audio(filename, self.sampling_rate)
        audio_tensor = torch.tensor(pad(audio_arr, self.cut)).float()
        return audio_tensor

    def __getitem__(self, idx):
        f = self.list_IDs[idx]
        y = self.meta[f]
        x = self.load_audio_tensor(f)
        meta = {"key": f, "label": y}
        return x, meta

In [40]:
ds = ASVSpoof21Dataset(
    "/data/amathur-23/DADA/ASVspoof2021_DF_eval",
    "/data/amathur-23/DADA/ASVspoof2021_DF_eval/keys/DF/CM/trial_metadata.txt",
    is_train=True,
)

In [53]:
loader = torch.utils.data.DataLoader(ds, batch_size=16, shuffle=False, num_workers=0)

In [54]:
x, y = next(iter(loader))

In [None]:
x.shape

In [None]:
y

In [57]:
def get_files(meta):
    return [os.path.join(ds.root_dir, f"flac/{key}.flac") for key in meta["key"]]

In [None]:
get_files(y)

In [None]:
import warnings

warnings.simplefilter("ignore", category=FutureWarning)

import torch
from transformers import pipeline

whisper = pipeline(
    "automatic-speech-recognition",
    "openai/whisper-large-v3",
    torch_dtype=torch.float16,
    device="cuda:2",
)

transcription = whisper(get_files(y))

print(transcription)

In [None]:
from tqdm import tqdm

keys = []
transcriptions = []

for _, meta in tqdm(loader):
    keys.extend(meta["key"])
    transcription_list_dict = whisper(get_files(meta))
    transcriptions.extend([x["text"] for x in transcription_list_dict])

In [None]:
import pandas as pd

df = pd.DataFrame({"key": keys, "transcription": transcriptions})
df.to_csv("asvspoof21_df_eval_transcriptions.csv", index=False)