In [None]:
##### MODEL EVALUATION AND PREDICTION

import os
import json
import torch
from transformers import ASTFeatureExtractor,ASTForAudioClassification

# Define the checkpoint path
checkpoint_path = "runs/ast_classifier/checkpoint-95220"

# Load the model and feature extractor
model = ASTForAudioClassification.from_pretrained(checkpoint_path)
feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint_path)

# Look for training history
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.json")
if os.path.exists(trainer_state_path):
    with open(trainer_state_path, "r") as f:
        trainer_state = json.load(f)
    print("\nTraining metrics from trainer_state.json:")
    print(json.dumps(trainer_state, indent=2))

# Set model to evaluation mode
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")

import librosa
import matplotlib.pyplot as plt
def predict_audio(file_path, model, feature_extractor, device="cuda"):
    # Load audio file
    audio, sr = librosa.load(file_path, sr=feature_extractor.sampling_rate)

    spec = librosa.feature.melspectrogram(y=audio, sr=feature_extractor.sampling_rate)
    spec_db = librosa.power_to_db(spec, ref=np.max)
    fig, ax = plt.subplots(nrows = 1, ncols = 1)
    img = librosa.display.specshow(spec_db, x_axis='time', y_axis='mel', ax = ax)
    fig.colorbar(img, ax = ax, format='%+2.0f dB')
    ax.set_title('Spectrogram')
    fig.show()

    # Preprocess the audio
    inputs = feature_extractor(
        audio,
        sampling_rate=feature_extractor.sampling_rate,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    )

    # Move inputs to the same device as the model
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)

    print(f"Raw logits: {logits}")
    print(f"Raw probabilities: {probabilities}")

    # Get predicted class (0 for fake, 1 for real)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0][predicted_class].item()

    print(f"P(fake): {probabilities[0][0].item():.4f}")
    print(f"P(real): {probabilities[0][1].item():.4f}")

    # Map class index to label
    label = "fake" if predicted_class == 0 else "real"

    return {
        "label": label,
        "confidence": confidence,
        "probabilities": {
            "fake": probabilities[0][0].item(),
            "real": probabilities[0][1].item()
        }
    }

audio_file_path = r"LibriSeVoc\train\fake\diffwave_26_495_000028_000000_gen.wav" # Adjust the path as needed
result = predict_audio(audio_file_path, model, feature_extractor, device)

print(f"Prediction: {result['label']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Probabilities - Fake: {result['probabilities']['fake']:.4f}, Real: {result['probabilities']['real']:.4f}")

In [None]:
#### AUDIO CUTTING SCRIPT

import os
import sys
import librosa
import soundfile as sf

def cut_audio(input_path, output_dir="./cut", clip_duration=5):
     # Ensure output directory exists
     os.makedirs(output_dir, exist_ok=True)

     # Get file name and extension
     base_name = os.path.basename(input_path)
     name, ext = os.path.splitext(base_name)

     # Load audio
     audio, sr = librosa.load(input_path, sr=None)
     total_duration = librosa.get_duration(y=audio, sr=sr)
     clip_samples = int(clip_duration * sr)
     num_clips = int(total_duration // clip_duration) + (1 if total_duration % clip_duration > 0 else 0)

     for i in range(num_clips):
          start_sample = i * clip_samples
          end_sample = min((i + 1) * clip_samples, len(audio))
          clip_audio = audio[start_sample:end_sample]
          out_path = os.path.join(output_dir, f"{name}({i+1}){ext}")
          sf.write(out_path, clip_audio, sr)
          print(f"Saved: {out_path}")

if __name__ == "__main__":
     if len(sys.argv) < 2:
          print("Usage: python cut_audio.py <path_to_audio_file>")
          sys.exit(1)
     audio_path = ""  # Adjust the path as needed
     if not os.path.exists(audio_path):
          print(f"File not found: {audio_path}")
          sys.exit(1)
     cut_audio(audio_path)

In [None]:
#### SEGMENTED AUDIO PREDICTION SCRIPT

import os
from collections import defaultdict

# Directory containing segmented audio clips
segmented_dir = "./cut"

# Collect all audio files and group by base name (without segment index)
audio_groups = defaultdict(list)
for fname in os.listdir(segmented_dir):
     if fname.lower().endswith(('.wav', '.mp3', '.flac', '.ogg')):
          # Extract base name (e.g., "audio(1).wav" -> "audio")
          base = fname.split('(')[0]
          audio_groups[base].append(os.path.join(segmented_dir, fname))

# Function to predict for a single audio file
def predict_single(file_path):
     return predict_audio(file_path, model, feature_extractor, device)

# Aggregate predictions for each group
results = {}
for base, files in audio_groups.items():
     fake_count = 0
     real_count = 0
     for fpath in files:
          pred = predict_single(fpath)
          if pred["label"] == "fake":
               fake_count += 1
          else:
               real_count += 1
     total = fake_count + real_count
     results[base] = {
          "fake_ratio": fake_count / total if total > 0 else 0,
          "real_ratio": real_count / total if total > 0 else 0,
          "fake_count": fake_count,
          "real_count": real_count,
          "total": total
     }

# Print summary
for base, stats in results.items():
     print(f"{base}: Fake {stats['fake_count']}/{stats['total']} ({stats['fake_ratio']:.2f}), "
            f"Real {stats['real_count']}/{stats['total']} ({stats['real_ratio']:.2f})")

In [None]:
import os
import numpy as np
from transformers import ASTFeatureExtractor
from transformers.utils import is_speech_available
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function

if is_speech_available():
    import torchaudio.compliance.kaldi as ta_kaldi

# based on the following literature that uses the Hamming window:
    # https://arxiv.org/pdf/2505.15136
    # https://arxiv.org/pdf/2409.05924
    # https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=11007653
class ASTFeatureExtractorHamming(ASTFeatureExtractor):
    """
    Custom AST Feature Extractor that uses Hamming window instead of Hann/Hanning.
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Override the window for numpy-based processing (when torchaudio is not available)
        if not is_speech_available():
            # Recalculate mel filters and window with hamming
            mel_filters = mel_filter_bank(
                num_frequency_bins=257,
                num_mel_filters=self.num_mel_bins,
                min_frequency=20,
                max_frequency=self.sampling_rate // 2,
                sampling_rate=self.sampling_rate,
                norm=None,
                mel_scale="kaldi",
                triangularize_in_mel_space=True,
            )
            self.mel_filters = mel_filters
            # Use hamming window instead of hann
            self.window = window_function(400, "hamming", periodic=False)
    
    def _extract_fbank_features(self, waveform: np.ndarray, max_length: int) -> np.ndarray:
        """
        Override to use hamming window type in torchaudio.compliance.kaldi.fbank
        """
        if is_speech_available():
            waveform = torch.from_numpy(waveform).unsqueeze(0)
            fbank = ta_kaldi.fbank(
                waveform,
                sample_frequency=self.sampling_rate,
                window_type="hamming",  # Changed from "hanning" to "hamming"
                num_mel_bins=self.num_mel_bins,
            )
        else:
            # Use numpy implementation with hamming window
            waveform = np.squeeze(waveform)
            fbank = spectrogram(
                waveform,
                self.window,  # This is now hamming window from __init__
                frame_length=400, # this follows the 25 ms frame length used in the paper (16000mhz * 0.025 = 400)
                hop_length=160, # this follows the hop length used in the paper (16000mhz * 0.01 = 160)
                fft_length=512,
                power=2.0,
                center=False,
                preemphasis=0.97,
                mel_filters=self.mel_filters,
                log_mel="log",
                mel_floor=1.192092955078125e-07,
                remove_dc_offset=True,
            ).T
            fbank = torch.from_numpy(fbank)

        n_frames = fbank.shape[0]
        difference = max_length - n_frames

        # pad or truncate, depending on difference
        if difference > 0:
            pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
            fbank = pad_module(fbank)
        elif difference < 0:
            fbank = fbank[0:max_length, :]

        fbank = fbank.numpy()
        return fbank

In [None]:
from datasets import load_dataset
from evaluate import evaluator, combine
import os
import json
import torch
from transformers import ASTFeatureExtractor,ASTForAudioClassification

NUM_PROC = os.cpu_count() - 1
dataset = load_dataset("audiofolder", data_dir="./LibriSeVoc", num_proc=NUM_PROC)

In [None]:
print(dataset)
print('dataset train features:', dataset['train'].features)


In [None]:
from transformers import ASTForAudioClassification

NUM_MEL_BINS = 128 # based on https://arxiv.org/pdf/2409.05924
MAX_SEQUENCE_LENGTH = 507

checkpoints = [
    # 1st iteration - test run (one dataset only)
    "./runs/ast_classifier/checkpoint-158700",
    # 2nd iteration - first official run
    "./runs/ast_classifier/checkpoint-174570", # epoch 11
    # 3rd iteration / checking if epoch 13 is really peak (test patchout)
    "./runs/ast_classifier/checkpoint-63480", # epoch 9
    "./runs/ast_classifier/checkpoint-71415", # epoch 10
    "./runs/ast_classifier/checkpoint-79350", # epoch 11
    "./runs/ast_classifier/checkpoint-87285", # epoch 12
    "./runs/ast_classifier/checkpoint-95220", # epoch 13
]

mean = 0.31605425901883943
std = 0.45787811188377187

def load_models(checkpoints:list):
    models={}
    for i, checkpoint in enumerate(checkpoints):
        print(checkpoint)
        model = ASTForAudioClassification.from_pretrained(checkpoint, local_files_only=True)
        if i == 0: # 1st iteration uses a different config
            feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint, num_mel_bins=64, max_sequence_length=MAX_SEQUENCE_LENGTH, local_files_only=True)
        else:
            feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint, num_mel_bins=NUM_MEL_BINS, max_sequence_length=MAX_SEQUENCE_LENGTH, local_files_only=True)
        
        feature_extractor.mean = mean
        feature_extractor.std = std

        models[i] = {
            "model": model,
            "feature_extractor": feature_extractor
        }
    return models

task_evaluator = evaluator(task="audio-classification")
results = []
models = load_models(checkpoints)

In [None]:
num_labels = (np.unique(dataset["train"]["label"]))
print(num_labels)

In [None]:
import os

base_dir = "LibriSeVoc"  # Adjust path if needed

for subdir in os.listdir(base_dir):
    subdir_path = os.path.join(base_dir, subdir)
    if os.path.isdir(subdir_path):
        for filename in os.listdir(subdir_path):
            old_path = os.path.join(subdir_path, filename)
            if os.path.isfile(old_path):
                new_filename = f"{subdir}_{filename}"
                new_path = os.path.join(subdir_path, new_filename)
                
                # Avoid overwriting if file already exists
                if not os.path.exists(new_path):
                    os.rename(old_path, new_path)
                else:
                    print(f"Skipping {old_path}: {new_path} already exists.")

In [None]:
real = 0
fake = 0
for i in range(len(dataset["train"])):
    if dataset["train"][i]["label"] == 0:
        real += 1
    else:
        fake += 1

print("number of real datasets:", real)
print("number of fake datasets:", fake)


In [None]:
import pandas as pd

for i in range(len(models)):
    results.append(
        task_evaluator.compute(
        model_or_pipeline=models[i]["model"],
        feature_extractor=models[i]["feature_extractor"],
        data=dataset["train"],
        input_column="audio",
        label_column="label",
        label_mapping={"fake": 0, "real": 1},
        metric=combine(["accuracy", "recall", "precision", "f1"]),
        )
    )

df2 = pd.DataFrame(results, index=checkpoints)
df2[["accuracy", "recall", "precision", "f1"]].round(4)