In [None]:
##### MODEL EVALUATION AND PREDICTION

import os
import json
import torch
from transformers import ASTFeatureExtractor,ASTForAudioClassification

# Define the checkpoint path
checkpoint_path = "runs/ast_classifier/checkpoint-95220"

# Load the model and feature extractor
model = ASTForAudioClassification.from_pretrained(checkpoint_path)
feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint_path)

# Look for training history
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.json")
if os.path.exists(trainer_state_path):
    with open(trainer_state_path, "r") as f:
        trainer_state = json.load(f)
    print("\nTraining metrics from trainer_state.json:")
    print(json.dumps(trainer_state, indent=2))

# Set model to evaluation mode
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")

import librosa
import matplotlib.pyplot as plt
def predict_audio(file_path, model, feature_extractor, device="cuda"):
    # Load audio file
    audio, sr = librosa.load(file_path, sr=feature_extractor.sampling_rate)

    spec = librosa.feature.melspectrogram(y=audio, sr=feature_extractor.sampling_rate)
    spec_db = librosa.power_to_db(spec, ref=np.max)
    fig, ax = plt.subplots(nrows = 1, ncols = 1)
    img = librosa.display.specshow(spec_db, x_axis='time', y_axis='mel', ax = ax)
    fig.colorbar(img, ax = ax, format='%+2.0f dB')
    ax.set_title('Spectrogram')
    fig.show()

    # Preprocess the audio
    inputs = feature_extractor(
        audio,
        sampling_rate=feature_extractor.sampling_rate,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    )

    # Move inputs to the same device as the model
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)

    print(f"Raw logits: {logits}")
    print(f"Raw probabilities: {probabilities}")

    # Get predicted class (0 for fake, 1 for real)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0][predicted_class].item()

    print(f"P(fake): {probabilities[0][0].item():.4f}")
    print(f"P(real): {probabilities[0][1].item():.4f}")

    # Map class index to label
    label = "fake" if predicted_class == 0 else "real"

    return {
        "label": label,
        "confidence": confidence,
        "probabilities": {
            "fake": probabilities[0][0].item(),
            "real": probabilities[0][1].item()
        }
    }

audio_file_path = r"LibriSeVoc\train\fake\diffwave_26_495_000028_000000_gen.wav" # Adjust the path as needed
result = predict_audio(audio_file_path, model, feature_extractor, device)

print(f"Prediction: {result['label']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Probabilities - Fake: {result['probabilities']['fake']:.4f}, Real: {result['probabilities']['real']:.4f}")

In [None]:
#### AUDIO CUTTING SCRIPT

import os
import sys
import librosa
import soundfile as sf

def cut_audio(input_path, output_dir="./cut", clip_duration=5):
     # Ensure output directory exists
     os.makedirs(output_dir, exist_ok=True)

     # Get file name and extension
     base_name = os.path.basename(input_path)
     name, ext = os.path.splitext(base_name)

     # Load audio
     audio, sr = librosa.load(input_path, sr=None)
     total_duration = librosa.get_duration(y=audio, sr=sr)
     clip_samples = int(clip_duration * sr)
     num_clips = int(total_duration // clip_duration) + (1 if total_duration % clip_duration > 0 else 0)

     for i in range(num_clips):
          start_sample = i * clip_samples
          end_sample = min((i + 1) * clip_samples, len(audio))
          clip_audio = audio[start_sample:end_sample]
          out_path = os.path.join(output_dir, f"{name}({i+1}){ext}")
          sf.write(out_path, clip_audio, sr)
          print(f"Saved: {out_path}")

if __name__ == "__main__":
     if len(sys.argv) < 2:
          print("Usage: python cut_audio.py <path_to_audio_file>")
          sys.exit(1)
     audio_path = ""  # Adjust the path as needed
     if not os.path.exists(audio_path):
          print(f"File not found: {audio_path}")
          sys.exit(1)
     cut_audio(audio_path)

In [None]:
#### SEGMENTED AUDIO PREDICTION SCRIPT

import os
from collections import defaultdict

# Directory containing segmented audio clips
segmented_dir = "./cut"

# Collect all audio files and group by base name (without segment index)
audio_groups = defaultdict(list)
for fname in os.listdir(segmented_dir):
     if fname.lower().endswith(('.wav', '.mp3', '.flac', '.ogg')):
          # Extract base name (e.g., "audio(1).wav" -> "audio")
          base = fname.split('(')[0]
          audio_groups[base].append(os.path.join(segmented_dir, fname))

# Function to predict for a single audio file
def predict_single(file_path):
     return predict_audio(file_path, model, feature_extractor, device)

# Aggregate predictions for each group
results = {}
for base, files in audio_groups.items():
     fake_count = 0
     real_count = 0
     for fpath in files:
          pred = predict_single(fpath)
          if pred["label"] == "fake":
               fake_count += 1
          else:
               real_count += 1
     total = fake_count + real_count
     results[base] = {
          "fake_ratio": fake_count / total if total > 0 else 0,
          "real_ratio": real_count / total if total > 0 else 0,
          "fake_count": fake_count,
          "real_count": real_count,
          "total": total
     }

# Print summary
for base, stats in results.items():
     print(f"{base}: Fake {stats['fake_count']}/{stats['total']} ({stats['fake_ratio']:.2f}), "
            f"Real {stats['real_count']}/{stats['total']} ({stats['real_ratio']:.2f})")

In [1]:
import os
import numpy as np
from transformers import ASTFeatureExtractor
from transformers.utils import is_speech_available
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function

if is_speech_available():
    import torchaudio.compliance.kaldi as ta_kaldi

# based on the following literature that uses the Hamming window:
    # https://arxiv.org/pdf/2505.15136
    # https://arxiv.org/pdf/2409.05924
    # https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=11007653
class ASTFeatureExtractorHamming(ASTFeatureExtractor):
    """
    Custom AST Feature Extractor that uses Hamming window instead of Hann/Hanning.
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Override the window for numpy-based processing (when torchaudio is not available)
        if not is_speech_available():
            # Recalculate mel filters and window with hamming
            mel_filters = mel_filter_bank(
                num_frequency_bins=257,
                num_mel_filters=self.num_mel_bins,
                min_frequency=20,
                max_frequency=self.sampling_rate // 2,
                sampling_rate=self.sampling_rate,
                norm=None,
                mel_scale="kaldi",
                triangularize_in_mel_space=True,
            )
            self.mel_filters = mel_filters
            # Use hamming window instead of hann
            self.window = window_function(400, "hamming", periodic=False)
    
    def _extract_fbank_features(self, waveform: np.ndarray, max_length: int) -> np.ndarray:
        """
        Override to use hamming window type in torchaudio.compliance.kaldi.fbank
        """
        if is_speech_available():
            waveform = torch.from_numpy(waveform).unsqueeze(0)
            fbank = ta_kaldi.fbank(
                waveform,
                sample_frequency=self.sampling_rate,
                window_type="hamming",  # Changed from "hanning" to "hamming"
                num_mel_bins=self.num_mel_bins,
            )
        else:
            # Use numpy implementation with hamming window
            waveform = np.squeeze(waveform)
            fbank = spectrogram(
                waveform,
                self.window,  # This is now hamming window from __init__
                frame_length=400, # this follows the 25 ms frame length used in the paper (16000mhz * 0.025 = 400)
                hop_length=160, # this follows the hop length used in the paper (16000mhz * 0.01 = 160)
                fft_length=512,
                power=2.0,
                center=False,
                preemphasis=0.97,
                mel_filters=self.mel_filters,
                log_mel="log",
                mel_floor=1.192092955078125e-07,
                remove_dc_offset=True,
            ).T
            fbank = torch.from_numpy(fbank)

        n_frames = fbank.shape[0]
        difference = max_length - n_frames

        # pad or truncate, depending on difference
        if difference > 0:
            pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
            fbank = pad_module(fbank)
        elif difference < 0:
            fbank = fbank[0:max_length, :]

        fbank = fbank.numpy()
        return fbank

In [2]:
from datasets import load_dataset
from evaluate import evaluator, combine
import os
import json
import torch
from transformers import ASTFeatureExtractor,ASTForAudioClassification

NUM_PROC = os.cpu_count() - 1
dataset = load_dataset("audiofolder", data_dir="./LibriSeVoc", num_proc=NUM_PROC)

Resolving data files:   0%|          | 0/92407 [00:00<?, ?it/s]

In [None]:
print(dataset)
print('dataset train features:', dataset['train'].features)


In [3]:
from transformers import ASTForAudioClassification

NUM_MEL_BINS = 128 # based on https://arxiv.org/pdf/2409.05924
MAX_SEQUENCE_LENGTH = 507

checkpoints = [
    # # 1st iteration - test run (one dataset only)
    # "./runs2/ast_classifier/checkpoint-158700",
    
    # # 2nd iteration - first official run
    # "./runs2/ast_classifier/checkpoint-174570", # epoch 11

    # # 3rd iteration / checking if epoch 13 is really peak (test patchout)
    # "./runs2/ast_classifier/checkpoint-63480", # epoch 9
    # "./runs2/ast_classifier/checkpoint-71415", # epoch 10
    # "./runs2/ast_classifier/checkpoint-79350", # epoch 11
    # "./runs2/ast_classifier/checkpoint-87285", # epoch 12
    # "./runs2/ast_classifier/checkpoint-95220", # epoch 13

    # 4th iteration - cleaner dataset + using librespeech. Tuned on test set so use val set for testing
    "./runs2/ast_classifier/checkpoint-31740", # epoch 1
    "./runs2/ast_classifier/checkpoint-63480", # epoch 2
    "./runs2/ast_classifier/checkpoint-95220", # epoch 3
    "./runs2/ast_classifier/checkpoint-126960", # epoch 4
    "./runs2/ast_classifier/checkpoint-158700", # epoch 5
    "./runs2/ast_classifier/checkpoint-190440", # epoch 6
    "./runs2/ast_classifier/checkpoint-222180", # epoch 7    
]

def load_models(checkpoints:list):
    models={}
    for i, checkpoint in enumerate(checkpoints):
        print(checkpoint)
        model = ASTForAudioClassification.from_pretrained(checkpoint, local_files_only=True)
        # if i == 0: # 1st iteration uses a different config
        #     feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint, num_mel_bins=64, max_sequence_length=MAX_SEQUENCE_LENGTH, local_files_only=True)
        # else:
        feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint, num_mel_bins=NUM_MEL_BINS, max_sequence_length=MAX_SEQUENCE_LENGTH, local_files_only=True)
        feature_extractor.do_normalize = True
        
        models[i] = {
            "model": model,
            "feature_extractor": feature_extractor
        }
    return models

task_evaluator = evaluator(task="audio-classification")
results = []
models = load_models(checkpoints)

./runs2/ast_classifier/checkpoint-31740
./runs2/ast_classifier/checkpoint-63480
./runs2/ast_classifier/checkpoint-95220
./runs2/ast_classifier/checkpoint-126960
./runs2/ast_classifier/checkpoint-158700
./runs2/ast_classifier/checkpoint-190440
./runs2/ast_classifier/checkpoint-222180


In [None]:
num_labels = (np.unique(dataset["train"]["label"]))
print(num_labels)

In [None]:
import os

base_dir = "LibriSeVoc"  # Adjust path if needed

for subdir in os.listdir(base_dir):
    subdir_path = os.path.join(base_dir, subdir)
    if os.path.isdir(subdir_path):
        for filename in os.listdir(subdir_path):
            old_path = os.path.join(subdir_path, filename)
            if os.path.isfile(old_path):
                new_filename = f"{subdir}_{filename}"
                new_path = os.path.join(subdir_path, new_filename)
                
                # Avoid overwriting if file already exists
                if not os.path.exists(new_path):
                    os.rename(old_path, new_path)
                else:
                    print(f"Skipping {old_path}: {new_path} already exists.")

In [None]:
real = 0
fake = 0
for i in range(len(dataset["train"])):
    if dataset["train"][i]["label"] == 0:
        real += 1
    else:
        fake += 1

print("number of real datasets:", real)
print("number of fake datasets:", fake)


In [None]:
import torch
import pandas as pd
import evaluate
from evaluate import combine
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm

# Configure which metrics to compute via Evaluator
standard_metrics = ["accuracy", "recall", "precision", "f1"]

# Reuse the same mapping you already had in Evaluator
label_mapping = {"fake": 0, "real": 1}
input_column = "audio"
label_column = "label"
# Identify the positive class name for ROC AUC (the class mapped to 1)
positive_label_name = next(k for k, v in label_mapping.items() if v == 1)

results = []

for i in range(len(models)):  
    # 1) Compute prediction-based metrics with Evaluator (unchanged)
    eval_result = task_evaluator.compute(
        model_or_pipeline=models[i]["model"],
        feature_extractor=models[i]["feature_extractor"],
        data=dataset["test"],
        input_column=input_column,
        label_column=label_column,
        label_mapping=label_mapping,
        metric=combine(standard_metrics),
        device="cpu"
    )

    # 2) Compute ROC AUC separately with batched pipeline + Dataset
    # Resolve or build an audio-classification pipeline
    pipe = models[i].get("pipeline")
    if pipe is None:
        device_index = 0 if torch.cuda.is_available() else -1
        pipe = pipeline(
            "audio-classification",
            model=models[i]["model"],
            feature_extractor=models[i]["feature_extractor"],
            device="cpu",           # GPU=0, CPU=-1
        )

    # Collect references (mapped to ints) in dataset order
    references = [
        label_mapping[x] if x in label_mapping else int(x)
        for x in dataset["test"][label_column]
    ]

    # Stream audio from the Dataset and get per-sample probs for the positive class
    scores = []
    for pred in tqdm(
        pipe(
            KeyDataset(dataset["test"], input_column),
            top_k=None,                   # return all class scores
        ),
        total=len(dataset["test"]),
        desc=f"ROC AUC pass (model {i})",
    ):
        # pred is a list of {"label": str, "score": float}
        real_score = next(p["score"] for p in pred if p["label"] == positive_label_name)
        scores.append(real_score)

    # Compute roc_auc
    roc_auc_metric = evaluate.load("roc_auc")
    roc_auc_result = roc_auc_metric.compute(
        prediction_scores=scores,
        references=references,
    )

    # Merge and store
    eval_result.update(roc_auc_result)
    results.append(eval_result)

# Final DataFrame
df2 = pd.DataFrame(results, index=checkpoints)
df2[standard_metrics + ["roc_auc"]].round(4)

In [None]:
# Diagnostics: confusion matrix, classification report, ROC and PR curves, threshold sweep
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    classification_report,
    roc_curve,
    auc,
    precision_recall_curve,
    average_precision_score,
    f1_score,
    precision_score,
    recall_score,
)
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm

# 1) Pick the best checkpoint by ROC AUC (from df2 computed above)
best_i = int(np.argmax(df2["roc_auc"].to_numpy()))
print(f"Selected checkpoint for diagnostics: {checkpoints[best_i]} (ROC AUC={df2.loc[checkpoints[best_i], 'roc_auc']:.4f})")

# 2) Optional: limit to a subset of test samples for speed (set to None to use all)
N_SAMPLES = 10_000  # e.g., 5000 for a quick pass; set None for full test set
test_ds = dataset["test"]
if N_SAMPLES is not None and N_SAMPLES < len(test_ds):
    rng = np.random.default_rng(42)
    idx = rng.choice(len(test_ds), size=N_SAMPLES, replace=False)
    idx = np.sort(idx)
    test_ds = test_ds.select(idx)
    print(f"Using a subset of {len(test_ds)} samples for plotting.")

# 3) Build pipeline for the selected model
device_index = 0 if torch.cuda.is_available() else -1
pipe = pipeline(
    "audio-classification",
    model=models[best_i]["model"],
    feature_extractor=models[best_i]["feature_extractor"],
    device=device_index,
)

# 4) Prepare references (0=fake, 1=real) and run predictions
references = [
    label_mapping[x] if x in label_mapping else int(x)
    for x in test_ds["label"]
]

scores_pos = []  # probability/score for the positive class (real==1)
y_pred = []      # hard predictions via argmax

for pred in tqdm(
    pipe(KeyDataset(test_ds, input_column), top_k=None),
    total=len(test_ds),
    desc=f"Inference for diagnostics (model {best_i})",
):
    # pred is a list of {"label": str, "score": float}
    score_map = {p["label"]: p["score"] for p in pred}
    real_score = score_map.get("real", 0.0)
    fake_score = score_map.get("fake", 0.0)

    scores_pos.append(real_score)
    y_pred.append(1 if real_score >= fake_score else 0)

# 5) Confusion Matrix
cm = confusion_matrix(references, y_pred, labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["fake (0)", "real (1)"])
fig, ax = plt.subplots(figsize=(5, 5))
disp.plot(cmap="Blues", ax=ax, values_format="d", colorbar=False)
ax.set_title(f"Confusion Matrix\n{checkpoints[best_i]}")
plt.tight_layout()
plt.show()

# 6) Classification report
print(classification_report(references, y_pred, target_names=["fake", "real"], digits=4))

# 7) ROC Curve
fpr, tpr, thr_roc = roc_curve(references, scores_pos)
roc_auc_val = auc(fpr, tpr)
plt.figure(figsize=(5, 5))
plt.plot(fpr, tpr, label=f"AUC = {roc_auc_val:.4f}")
plt.plot([0, 1], [0, 1], "--", color="gray", alpha=0.7)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.tight_layout()
plt.show()

# 8) Precision–Recall Curve
prec, rec, thr_pr = precision_recall_curve(references, scores_pos)
ap = average_precision_score(references, scores_pos)
plt.figure(figsize=(5, 5))
plt.plot(rec, prec, label=f"AP = {ap:.4f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall Curve")
plt.legend(loc="lower left")
plt.tight_layout()
plt.show()

# 9) Threshold sweep (optional) to see the best F1 on this test set
thr_grid = np.linspace(0, 1, 101)
best_f1 = -1.0
best_thr = 0.5
for t in thr_grid:
    y_hat = [1 if s >= t else 0 for s in scores_pos]
    f1 = f1_score(references, y_hat)
    if f1 > best_f1:
        best_f1 = f1
        best_thr = t

y_hat_best = [1 if s >= best_thr else 0 for s in scores_pos]
p_best = precision_score(references, y_hat_best, zero_division=0)
r_best = recall_score(references, y_hat_best, zero_division=0)

print(f"Best F1 on this set via threshold: F1={best_f1:.4f} at threshold={best_thr:.2f}")
print(f"Precision={p_best:.4f}, Recall={r_best:.4f} at that threshold")

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def plot_training_curves(log_dir, metrics=('loss',), smooth=0.9):
    """Plot training curves from TensorBoard logs.
    
    Args:
        log_dir: Path to the directory containing TensorBoard logs
        metrics: Tuple of metric names to plot (e.g., ('loss', 'accuracy'))
        smooth: Smoothing factor for exponential moving average (0-1)
    """
    # Find the latest run directory
    run_dirs = [d for d in os.listdir(log_dir) if d.startswith('runs')]
    if not run_dirs:
        raise ValueError(f"No run directories found in {log_dir}")
    
    # Get the most recent run
    latest_run = sorted(run_dirs)[-1]
    log_path = os.path.join(log_dir, latest_run)
    
    # Load the event file
    event_acc = EventAccumulator(log_path)
    event_acc.Reload()
    
    # Plot each metric
    plt.figure(figsize=(10, 6))
    
    for metric in metrics:
        # Get training data
        train_metric = f'train/{metric}'
        if train_metric in event_acc.Tags()['scalars']:
            train_events = event_acc.Scalars(train_metric)
            train_steps = [e.step for e in train_events]
            train_values = [e.value for e in train_events]
            
            # Apply smoothing
            if smooth > 0:
                smooth_values = [train_values[0]]
                for val in train_values[1:]:
                    smooth_values.append(smooth * smooth_values[-1] + (1 - smooth) * val)
                train_values = smooth_values
            
            plt.plot(train_steps, train_values, label=f'Train {metric}', alpha=0.7)
        
        # Get validation data if it exists
        eval_metric = f'eval/{metric}'
        if eval_metric in event_acc.Tags()['scalars']:
            eval_events = event_acc.Scalars(eval_metric)
            eval_steps = [e.step for e in eval_events]
            eval_values = [e.value for e in eval_events]
            
            # For eval, we might have fewer points, so we'll plot them directly
            plt.scatter(eval_steps, eval_values, label=f'Val {metric}', alpha=0.7, marker='o')
    
    plt.title('Training and Validation Metrics')
    plt.xlabel('Steps')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Example usage - update the path to your logs directory
log_dir = "./runs/ast_classifier"  # Update this to your actual log directory
try:
    plot_training_curves(log_dir, metrics=('loss',))
except Exception as e:
    print(f"Error plotting: {e}")
    print("Make sure tensorboard is installed and the log directory is correct.")