In [None]:
from transformers import ASTForAudioClassification, ASTFeatureExtractor
import torch
import numpy as np
from datasets import Audio
import torchaudio

# Load the feature extractor and model
model_path = "./runs/ast_classifier/checkpoint-best"  # Update this path if your best model is saved elsewhere
feature_extractor = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_mel_bins=64, max_length=507)
model = ASTForAudioClassification.from_pretrained(model_path)
model.eval()  # Set the model to evaluation mode

In [None]:
def predict_audio(file_path, threshold=0.5):
    # Load and preprocess the audio
    waveform, sample_rate = torchaudio.load(file_path)
    
    # Convert to mono if needed
    if len(waveform.shape) > 1 and waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Resample if needed (to 16kHz)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    
    # Convert to numpy array and normalize
    waveform = waveform.numpy().squeeze()
    
    # Extract features
    inputs = feature_extractor(
        waveform, 
        sampling_rate=16000,
        return_tensors="pt",
        return_attention_mask=True,
        max_length=507
    )
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1)
        confidence, predicted_class = torch.max(probabilities, 1)
        
    # Convert to human-readable output
    label = "Fake" if predicted_class.item() == 0 else "Real"
    confidence = confidence.item() * 100  # Convert to percentage
    
    return {
        "prediction": label,
        "confidence": f"{confidence:.2f}%",
        "probabilities": {
            "Fake": f"{probabilities[0][0].item()*100:.2f}%",
            "Real": f"{probabilities[0][1].item()*100:.2f}%"
        }
    }

In [None]:
# Example usage
audio_file = "path/to/your/audio_file.wav"  # Replace with your audio file path
result = predict_audio(audio_file)
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']}")
print("Probabilities:")
print(f"  - Fake: {result['probabilities']['Fake']}")
print(f"  - Real: {result['probabilities']['Real']}")