In [None]:
##### MODEL EVALUATION AND PREDICTION

import os
import json
from transformers import ASTForAudioClassification

# Define the checkpoint path
checkpoint_path = "runs/ast_classifier/checkpoint-158700"

# Load the model and feature extractor
model = ASTForAudioClassification.from_pretrained(checkpoint_path)
feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint_path)

# Look for training history
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.json")
if os.path.exists(trainer_state_path):
    with open(trainer_state_path, "r") as f:
        trainer_state = json.load(f)
    print("\nTraining metrics from trainer_state.json:")
    print(json.dumps(trainer_state, indent=2))

# Set model to evaluation mode
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")

import librosa
def predict_audio(file_path, model, feature_extractor, device="cuda"):
    # Load audio file
    audio, sr = librosa.load(file_path, sr=feature_extractor.sampling_rate)

    # Preprocess the audio
    inputs = feature_extractor(
        audio,
        sampling_rate=feature_extractor.sampling_rate,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    )

    # Move inputs to the same device as the model
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)

    print(f"Raw logits: {logits}")
    print(f"Raw probabilities: {probabilities}")

    # Get predicted class (0 for fake, 1 for real)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0][predicted_class].item()

    print(f"P(fake): {probabilities[0][0].item():.4f}")
    print(f"P(real): {probabilities[0][1].item():.4f}")

    # Map class index to label
    label = "fake" if predicted_class == 0 else "real"

    return {
        "label": label,
        "confidence": confidence,
        "probabilities": {
            "fake": probabilities[0][0].item(),
            "real": probabilities[0][1].item()
        }
    }

audio_file_path = "" # Adjust the path as needed
result = predict_audio(audio_file_path, model, feature_extractor, device)

print(f"Prediction: {result['label']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Probabilities - Fake: {result['probabilities']['fake']:.4f}, Real: {result['probabilities']['real']:.4f}")