In [1]:
from datasets import Dataset, DatasetDict, Audio, ClassLabel, Features
import numpy as np
import soundfile as sf
import os
import torch
import pandas as pd
from tqdm import tqdm
from transformers import ASTFeatureExtractor, ASTForAudioClassification, Trainer, TrainingArguments
from torch.utils.data import DataLoader
from datasets import Audio
import torchaudio

In [None]:
os.environ["WANDB_API_KEY"] = 'your key'
os.environ["WANDB_PROJECT"] = "AST AIRI 228"
os.environ["WANDB_NOTES"] = "ASTSPOOF AIRI 228"
os.environ["WANDB_NAME"] = "astspoof-airi-228"

In [3]:
def get_ast_dataset(
    test_dir='/kaggle/input/safe-speak-2024-audio-spoof-detection-hackathon/wavs'
):
    valid_audio_files = []
    bad_names = []
    
    for file in tqdm(os.listdir(test_dir)):
        if file.endswith('.wav'):
            file_path = os.path.join(test_dir, file)
            waveform, sample_rate = torchaudio.load(file_path)
            duration = waveform.shape[1] / sample_rate
            if duration > 0.5:
                valid_audio_files.append(file_path)
            else:
                print(f"File: {file}, Duration: {duration:.2f} seconds")
                bad_names.append(file_path)
                
    test_dataset = Dataset.from_dict({
        'audio': valid_audio_files
    }, features=Features({
        'audio': Audio()
    }))
    
    return test_dataset, bad_names

In [4]:
#test_dataset = get_ast_dataset()

In [5]:
feature_extractor = ASTFeatureExtractor.from_pretrained('MIT/ast-finetuned-audioset-10-10-0.4593')
model = ASTForAudioClassification.from_pretrained('/kaggle/input/ast-airi-train/runs/ast_classifier/checkpoint-6834')
INPUT_NAME = feature_extractor.model_input_names[0]

def preprocess_audio(batch):
    wavs = [audio['array'] for audio in batch[INPUT_NAME]]
    inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate, return_tensors='pt')
    return {INPUT_NAME: inputs[INPUT_NAME]}

test_dataset, bad_names = get_ast_dataset()

test_dataset = test_dataset.cast_column('audio', Audio(sampling_rate=feature_extractor.sampling_rate))

test_dataset = test_dataset.rename_column('audio', INPUT_NAME)
test_dataset.set_transform(preprocess_audio, output_all_columns=False)

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir='./results',
        per_device_eval_batch_size=128
    )
)

test_predictions = trainer.predict(test_dataset)
logits = test_predictions.predictions
probabilities = torch.softmax(torch.tensor(logits), dim=-1).numpy()
class_0 = probabilities[:, 0]
class_1 = probabilities[:, 1]

preprocessor_config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

  1%|          | 1301/144693 [00:21<37:57, 62.97it/s]

File: 69402.wav, Duration: 0.44 seconds


  5%|▍         | 7216/144693 [01:58<35:42, 64.16it/s]

File: 28216.wav, Duration: 0.35 seconds


 10%|█         | 14669/144693 [04:02<32:22, 66.95it/s]

File: 42427.wav, Duration: 0.44 seconds


 13%|█▎        | 18399/144693 [05:04<36:00, 58.47it/s]

File: 36657.wav, Duration: 0.37 seconds


 32%|███▏      | 45699/144693 [12:40<28:35, 57.71it/s]

File: 48776.wav, Duration: 0.27 seconds


 36%|███▌      | 51413/144693 [14:14<25:58, 59.86it/s]

File: 81539.wav, Duration: 0.47 seconds


 37%|███▋      | 53446/144693 [14:49<24:39, 61.68it/s]

File: 29774.wav, Duration: 0.45 seconds


 46%|████▌     | 66216/144693 [18:19<22:53, 57.14it/s]

File: 38063.wav, Duration: 0.38 seconds


 55%|█████▌    | 79617/144693 [22:01<16:40, 65.07it/s]

File: 58122.wav, Duration: 0.31 seconds


 67%|██████▋   | 96954/144693 [26:48<12:39, 62.87it/s]

File: 73102.wav, Duration: 0.47 seconds


 69%|██████▉   | 100351/144693 [27:44<11:29, 64.32it/s]

File: 23368.wav, Duration: 0.01 seconds


 70%|██████▉   | 101060/144693 [27:55<10:52, 66.83it/s]

File: 63228.wav, Duration: 0.40 seconds


 70%|██████▉   | 101088/144693 [27:56<11:31, 63.02it/s]

File: 66857.wav, Duration: 0.46 seconds


 77%|███████▋  | 111651/144693 [30:47<08:31, 64.63it/s]

File: 30206.wav, Duration: 0.43 seconds


 83%|████████▎ | 120407/144693 [33:10<06:25, 62.93it/s]

File: 26036.wav, Duration: 0.48 seconds


 89%|████████▉ | 129162/144693 [35:40<03:57, 65.42it/s]

File: 45656.wav, Duration: 0.01 seconds


 94%|█████████▍| 135958/144693 [37:32<02:33, 56.79it/s]

File: 52907.wav, Duration: 0.47 seconds


 95%|█████████▌| 137980/144693 [38:05<01:36, 69.31it/s]

File: 28268.wav, Duration: 0.11 seconds


100%|██████████| 144693/144693 [39:59<00:00, 60.30it/s]


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mlightsource-[0m ([33mlightsource-unk[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.18.3
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20241127_123234-lp6xo82n[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33m./results[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/lightsource-unk/AST%20AIRI%20228[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/lightsource-unk/AST%20AIRI%20228/runs/lp6xo82n[0m


In [6]:
import os

In [None]:
test_dir = '/kaggle/input/safe-speak-2024-audio-spoof-detection-hackathon/wavs/'
test_audio_files = [os.path.join(test_dir, file) for file in os.listdir(test_dir) if file.endswith('.wav')]
idxs = [int(test_audio_files[i].split('/')[-1][:-4]) for i in range(len(test_audio_files))]
idxs_bad = [int(bad_name.split('/')[-1][:-4]) for bad_name in bad_names]

df = pd.DataFrame({
    'ID': [x for x in idxs if x not in idxs_bad],
    'score': class_0
})

df_bad = pd.DataFrame({
    'ID': idxs_bad,
    'score': [0 for _ in range(len(idxs_bad))] # скорее всего короткие это спуф поэтому тут нули)
})

df = pd.concat([df, df_bad]).sort_values(by=['ID'])

df.to_csv('submission_ast_class1.csv', index=False)