# Fake Audio Detection
### we use sota model from https://github.com/piotrkawa/deepfake-whisper-features based on the interspeech paper https://www.isca-speech.org/archive/interspeech_2023/kawa23b_interspeech.html 

In [None]:
!git clone https://github.com/piotrkawa/deepfake-whisper-features.git

In [None]:
!mv deepfake-whisper-features/* .

In [None]:
!apt-get -y install libsox-dev

In [None]:
!pip install git+https://github.com/openai/whisper.git@7858aa9c08d98f75575035ecd6481f462d66ca27
!pip install asteroid-filterbanks==0.4.0
!pip install librosa==0.9.2

In [None]:
!python download_whisper.py

In [None]:
import argparse
import numpy as np
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
import sys
import os
import fnmatch
import pandas as pd
import csv

import torch
import torchaudio
import yaml
#from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from torch.utils.data import DataLoader
from torch.utils.data.dataset import T_co

from src import metrics, commons
from src.models import models
from src.datasets.base_dataset import SimpleAudioFakeDataset,APPLY_NORMALIZATION,apply_preprocessing


In [None]:
class SwissHacksDataset(SimpleAudioFakeDataset):

    real_labels = ["PASOTPBNLM",
                   "H01FH3KEY8",
                   "BJ53WB0WQB",
                   "VWEQSW9GQY",
                   "R18W797V9Q",
                   "6W361V5VV9",
                   "ZBKI0P43EK",
                   "ROHDD0Z6CG",
                   "2IM42LTT5R",
                   "Y6K2JU2H4B",
                   "CBDX295MEZ",
                   "ZCB53KC2PC",
                   "3162VQ31V7",
                   "YDDKTODLGX",
                   "XKG8C7QFXT",
                   "1Z2W0U9OU8",
                   "ZGZHPG1TS8",
                   "244F8XZK0E",
                   "7FPDGERPRV",
                   "IIBWPCAJFZ"]
    fake_labels = ["NSIOUFFN5C",
                   "MIRV2AHSDH",
                   "GE90UVYAIC",
                   "GINYUH6NU7",
                   "ASLJ66JRJL",
                   "2TT75RT0RO",
                   "FG1GU97VU5",
                   "FI3U0S0S6X",
                   "RLL2WGXJRT",
                   "IU50O8RY55",
                   "0D8CAOL7XN",
                   "541T0I3AUW",
                   "B5PN7WKKMI",
                   "F68PGID9TU",
                   "TC1N3OMAN3",
                   "H6XNGJ7SCM",
                   "9PS130EZ8T",
                   "X8L6WJ0NDN",
                   "7A8PVRXFLV",
                   "12MINIG2V7",
                   ]

    def __init__(
            self,
            path,
            subset="train",
            transform=None,
            seed=None,
            partition_ratio=(1., 0.),
            split_strategy="random"
    ):
        super().__init__(subset=subset, transform=transform)
        self.path = path
        self.read_samples()
        self.partition_ratio = partition_ratio
        self.seed = seed


    def read_samples(self):
        path = Path(self.path)

        # 0= real, 1 = fake , 2 = unknown
        self.samples = []
        #self.samples.append(("/kaggle/input/swisshacks/12MINIG2V7.wav",0,0))
        # Pattern to match .wav files
        pattern = '*.wav'

        # Iterate over files in the directory
        for file in os.listdir("/kaggle/input/swisshacks/"):
            if fnmatch.fnmatch(file, pattern):
                # Do something with the file
                #print(file)  # Example: Print the filename
                label = 2
                if file[:-4] in self.real_labels:
                    label = 0
                    print("real:", file)
                elif file[:-4] in self.fake_labels:
                    label = 1
                    print("fake:", file)
                self.samples.append([os.path.join("/kaggle/input/swisshacks/",file), label, file[:-4]])

    def __getitem__(self, index) -> T_co:
        if isinstance(self.samples, pd.DataFrame):
            sample = self.samples.iloc[index]

            path = str(sample["path"])
            label = sample["label"]
            attack_type = sample["attack_type"]
            if type(attack_type) != str and math.isnan(attack_type):
                attack_type = "N/A"
        else:
            path, label, attack_type = self.samples[index]

        waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION)
        real_sec_length = len(waveform[0]) / sample_rate

        waveform, sample_rate = apply_preprocessing(waveform, sample_rate)

        return_data = [waveform, sample_rate]
        if self.return_label:
            return_data.append(label)
            return_data.append(attack_type)

        if self.return_meta:
            return_data.append(
                (
                    attack_type,
                    path,
                    self.subset,
                    real_sec_length,
                )
            )
        return return_data

In [None]:
test_ds = SwissHacksDataset("")
test_ds[0]

# Load Model

In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
config_file = "/kaggle/working/configs/finetuning/whisper_frontend_mesonet.yaml"
with open(config_file, "r") as f:
    config = yaml.safe_load(f)

seed = config["data"].get("seed", 42)
# fix all seeds - this should not actually change anything
commons.set_seed(seed)

In [None]:
batch_size = 1
model_config=config["model"]
model_name, model_parameters = model_config["name"], model_config["parameters"]
# Load model architecture
model = models.get_model(
    model_name=model_name,
    config=model_parameters,
    device=device,
)
model_paths="/kaggle/input/mesonet_whisper_mfcc_finetuned/pytorch/mesonet_whisper_mfcc_finetuned/1/mesonet_whisper_mfcc_finetuned/mesonet_whisper_mfcc_finetuned.pth" #config["checkpoint"].get("path", [])
model.load_state_dict(torch.load(model_paths,map_location=device))
model = model.to(device)

# Export CSV

In [None]:
model.eval()
same_results = []
result_matrix = np.zeros([40,2])
selected_result = 0
csv_file = "fake_audio_classification_results.csv"
with open(csv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter=',')
    writer.writerow(["ID","PROVIDED_LABEL","PREDICTION","FINAL_LABEL"])

    for sample_idx in range(len(test_ds)):
        batch_x,_, label, recording_id = test_ds[sample_idx]
        #print(u)
        #print(label)
        batch_x = batch_x.unsqueeze(0)
        batch_pred = model(batch_x).squeeze(1)
        batch_pred = torch.sigmoid(batch_pred)
        batch_pred_label = (batch_pred + 0.5).int()
        pred_label = batch_pred_label[0].item()
        #same_results.append(batch_pred_label[0].item()==label)
        #result_matrix[selected_result,label]=int(batch_pred_label[0].item()==label)
        #selected_result+=1
        final_label = label if label!=2 else pred_label
        writer.writerow([recording_id,label,pred_label,final_label])