In [1]:
import os
os.environ["PATH"] += os.pathsep + "D:\\ProgramFiles\\Microsoft Visual Studio\\2022\\Community\\VC\\Tools\\MSVC\\14.44.35207\\bin\\Hostx64\\x64"
os.environ["CUDA_HOME"] = "D:\\ProgramFiles\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.9"
from datasets import load_dataset
from transformers import ASTFeatureExtractor
import torch
from torch.utils.data import DataLoader
import torch.profiler
import numpy as np
from datasets import Audio

Make sure to install Pytorch and CUDA Toolkit to make this work

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

In [2]:
import platform
print(f"Python version: {platform.python_version()}")
print(f"System: {platform.system()} {platform.release()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA device count: {torch.cuda.device_count()}")

Python version: 3.12.0
System: Windows 11
PyTorch version: 2.7.1+cu128
CUDA available: True
CUDA version: 12.8
GPU: NVIDIA GeForce RTX 2060
CUDA device count: 1


In [3]:
from transformers.utils import is_speech_available
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function

if is_speech_available():
    import torchaudio.compliance.kaldi as ta_kaldi

# based on the following literature that uses the Hamming window:
    # https://arxiv.org/pdf/2505.15136
    # https://arxiv.org/pdf/2409.05924
    # https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=11007653
class ASTFeatureExtractorHamming(ASTFeatureExtractor):
    """
    Custom AST Feature Extractor that uses Hamming window instead of Hann/Hanning.
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Override the window for numpy-based processing (when torchaudio is not available)
        if not is_speech_available():
            # Recalculate mel filters and window with hamming
            mel_filters = mel_filter_bank(
                num_frequency_bins=257,
                num_mel_filters=self.num_mel_bins,
                min_frequency=20,
                max_frequency=self.sampling_rate // 2,
                sampling_rate=self.sampling_rate,
                norm=None,
                mel_scale="kaldi",
                triangularize_in_mel_space=True,
            )
            self.mel_filters = mel_filters
            # Use hamming window instead of hann
            self.window = window_function(400, "hamming", periodic=False)
    
    def _extract_fbank_features(self, waveform: np.ndarray, max_length: int) -> np.ndarray:
        """
        Override to use hamming window type in torchaudio.compliance.kaldi.fbank
        """
        if is_speech_available():
            waveform = torch.from_numpy(waveform).unsqueeze(0)
            fbank = ta_kaldi.fbank(
                waveform,
                sample_frequency=self.sampling_rate,
                window_type="hamming",  # Changed from "hanning" to "hamming"
                num_mel_bins=self.num_mel_bins,
            )
        else:
            # Use numpy implementation with hamming window
            waveform = np.squeeze(waveform)
            fbank = spectrogram(
                waveform,
                self.window,  # This is now hamming window from __init__
                frame_length=400, # this follows the 25 ms frame length used in the paper (16000mhz * 0.025 = 400)
                hop_length=160, # this follows the hop length used in the paper (16000mhz * 0.01 = 160)
                fft_length=512,
                power=2.0,
                center=False,
                preemphasis=0.97,
                mel_filters=self.mel_filters,
                log_mel="log",
                mel_floor=1.192092955078125e-07,
                remove_dc_offset=True,
            ).T
            fbank = torch.from_numpy(fbank)

        n_frames = fbank.shape[0]
        difference = max_length - n_frames

        # pad or truncate, depending on difference
        if difference > 0:
            pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
            fbank = pad_module(fbank)
        elif difference < 0:
            fbank = fbank[0:max_length, :]

        fbank = fbank.numpy()
        return fbank

In [4]:
from torch.nn import Dropout
from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTEmbeddings

class ASTEmbeddingsWithPatchout(ASTEmbeddings):
    def __init__(self, config):
        super().__init__(config)
        self.patchout_prob = config.patchout_prob
        self.patchout_strategy = config.patchout_strategy
        self.patchout_dropout = Dropout(self.patchout_prob)

    def forward(self, input_values: torch.Tensor) -> torch.Tensor:
        batch_size = input_values.shape[0]
        embeddings = self.patch_embeddings(input_values)  # Shape: (batch_size, num_patches, hidden_size)

        # Apply Patchout during training
        if self.training:
            if self.patchout_strategy == "unstructured":
                embeddings = self.patchout_dropout(embeddings)
            elif self.patchout_strategy == "structured":
                embeddings = self.structured_patchout(embeddings)

        # Add CLS and distillation tokens + positional embeddings
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
        embeddings = embeddings + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

    def structured_patchout(self, embeddings: torch.Tensor) -> torch.Tensor:
        """
        Apply structured Patchout by dropping entire rows/columns of patches.
        """
        batch_size, num_patches, hidden_size = embeddings.shape
        freq_out_dim, time_out_dim = self.get_shape(self.config)
        
        # Reshape to 2D grid: (batch_size, freq_out_dim, time_out_dim, hidden_size)
        embeddings = embeddings.view(batch_size, freq_out_dim, time_out_dim, hidden_size)

        # Randomly drop frequency columns or time rows
        if self.patchout_prob > 0:
            # Drop frequency columns
            freq_mask = torch.rand(freq_out_dim, device=embeddings.device) > self.patchout_prob
            embeddings = embeddings * freq_mask.view(1, freq_out_dim, 1, 1)

            # Drop time rows
            time_mask = torch.rand(time_out_dim, device=embeddings.device) > self.patchout_prob
            embeddings = embeddings * time_mask.view(1, 1, time_out_dim, 1)

        # Flatten back to (batch_size, num_patches, hidden_size)
        embeddings = embeddings.view(batch_size, num_patches, hidden_size)
        return embeddings

In [5]:
import os
NUM_PROC = (os.cpu_count() - 1) * 4
dataset = load_dataset("audiofolder", data_dir="./dataset", num_proc=NUM_PROC)

Resolving data files:   0%|          | 0/253914 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/31738 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/31742 [00:00<?, ?it/s]

In [6]:
print(type(dataset))
print(type(dataset['train']))

<class 'datasets.dataset_dict.DatasetDict'>
<class 'datasets.arrow_dataset.Dataset'>


In [7]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"

NUM_MEL_BINS = 128 # based on https://arxiv.org/pdf/2409.05924
MAX_SEQUENCE_LENGTH = 507

global feature_extractor

feature_extractor = ASTFeatureExtractorHamming.from_pretrained(
        pretrained_model, 
        num_mel_bins=NUM_MEL_BINS, 
        max_length=MAX_SEQUENCE_LENGTH
    )

In [8]:
model_input_name = feature_extractor.model_input_names[0]
SAMPLING_RATE = feature_extractor.sampling_rate
print("Custom AST Feature Extractor with Hamming window created successfully!")
print(f"Sampling rate: {feature_extractor.sampling_rate}")
print(f"Mel bins: {feature_extractor.num_mel_bins}")

num_labels = len(np.unique(dataset["train"]["label"]))
# num_labels = 2


Custom AST Feature Extractor with Hamming window created successfully!
Sampling rate: 16000
Mel bins: 128


In [9]:
print('dataset:', dataset)
print('dataset train features:', dataset['train'].features)
print('dataset test features:', dataset['test'].features)
print('dataset validation features:', dataset['validation'].features)
print('model_input_name:', model_input_name)
print('SAMPLING_RATE:', SAMPLING_RATE)
print('dataset["train"][0]:', dataset['train'][0])

# For when using IterableDataset
# for example in dataset['train']:
#     print('dataset["train"][0]:', example)
#     break

print('num_labels:', num_labels)
print('dataset columns:', dataset['train'].column_names)


dataset: DatasetDict({
    train: Dataset({
        features: ['audio', 'label'],
        num_rows: 253914
    })
    validation: Dataset({
        features: ['audio', 'label'],
        num_rows: 31738
    })
    test: Dataset({
        features: ['audio', 'label'],
        num_rows: 31742
    })
})
dataset train features: {'audio': Audio(sampling_rate=None, mono=True, decode=True, id=None), 'label': ClassLabel(names=['fake', 'real'], id=None)}
dataset test features: {'audio': Audio(sampling_rate=None, mono=True, decode=True, id=None), 'label': ClassLabel(names=['fake', 'real'], id=None)}
dataset validation features: {'audio': Audio(sampling_rate=None, mono=True, decode=True, id=None), 'label': ClassLabel(names=['fake', 'real'], id=None)}
model_input_name: input_values
SAMPLING_RATE: 16000
dataset["train"][0]: {'audio': {'path': 'D:\\ProgramFiles\\dev\\repos\\thesis-testing\\dataset\\train\\fake\\0.wav', 'array': array([ 8.55924794e-04,  5.84704467e-05,  7.75483320e-04, ...,
       -1.

In [10]:
# calculate values for normalization
feature_extractor.do_normalize = False

def preprocess_audio(batch):
    wavs = [audio["array"] for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt", return_attention_mask=True, max_length=507)
    
    # print("wavs:", wavs)
    # print("inputs:", inputs)
    
    return {
        model_input_name: inputs.get(model_input_name),
        "labels": torch.tensor(batch["label"])
    }

dataset = dataset.rename_column("audio", "input_values")
# this can't be if we're going to use an iterable dataset:
dataset["train"].set_transform(preprocess_audio, output_all_columns=False)

# dataset["train"] = dataset["train"].map(
#     preprocess_audio,
#     batched=True,
# )
print(dataset)


DatasetDict({
    train: Dataset({
        features: ['input_values', 'label'],
        num_rows: 253914
    })
    validation: Dataset({
        features: ['input_values', 'label'],
        num_rows: 31738
    })
    test: Dataset({
        features: ['input_values', 'label'],
        num_rows: 31742
    })
})


In [11]:
import torch
import os
from tqdm import tqdm
from torch.utils.data import DataLoader
import gc

def compute_dataset_statistics_optimized(dataset, model_input_name, batch_size=None, device=None):
    """
    Optimized computation of dataset mean and standard deviation.
    
    Args:
        dataset: The dataset to compute statistics for
        model_input_name: The key name for model input in batch
        batch_size: Batch size (auto-optimized if None)
        device: Device to use ('cuda', 'cpu', or None for auto-detection)
    
    Returns:
        tuple: (mean, std)
    """
    
    # 1. OPTIMIZATION: Auto-detect optimal batch size and device
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    if batch_size is None:
        # Start with a larger batch size for statistics computation
        if device == 'cuda':
            # Try to use more GPU memory for faster processing
            batch_size = min(512, len(dataset) // 10)  # Use 10% of dataset or 512, whichever is smaller
        else:
            batch_size = min(256, len(dataset) // 20)  # Conservative for CPU
    
    print(f"Using device: {device}, batch_size: {batch_size}")
    
    # 2. OPTIMIZATION: Configure DataLoader for maximum performance
    dataloader_kwargs = {
        'batch_size': batch_size,
        'shuffle': False,  # No need to shuffle for statistics
        'drop_last': False,  # Process all data
        'pin_memory': device == 'cuda',  # Only pin memory if using GPU
        'num_workers': 0,  # Start with 0 to avoid multiprocessing issues
    }
    
    # Try to enable multiprocessing if it works
    try:
        test_loader = DataLoader(dataset, batch_size=2, num_workers=2)
        next(iter(test_loader))  # Test if multiprocessing works
        dataloader_kwargs['num_workers'] = min(4, os.cpu_count() - 1)  # Conservative worker count
        print(f"Multiprocessing enabled with {dataloader_kwargs['num_workers']} workers")
        del test_loader
    except:
        print("Multiprocessing failed, using single process")
    
    dataloader = DataLoader(dataset, **dataloader_kwargs)
    
    # 3. OPTIMIZATION: Use appropriate tensor dtypes and device placement
    if device == 'cuda':
        sum_dtype = torch.float64  # Higher precision for accumulation
        working_dtype = torch.float32  # Working precision
    else:
        sum_dtype = torch.float64
        working_dtype = torch.float32
    
    # Initialize accumulators on the target device with appropriate dtype
    total_sum = torch.tensor(0.0, dtype=sum_dtype, device=device)
    total_squared_sum = torch.tensor(0.0, dtype=sum_dtype, device=device)
    total_count = 0
    
    # 4. OPTIMIZATION: Process with memory-efficient operations
    progress_bar = tqdm(dataloader, desc="Computing normalization stats", unit="batch")
    
    with torch.no_grad():  # Disable gradient computation
        for batch_idx, batch in enumerate(progress_bar):
            # Get batch data
            batch_data = batch[model_input_name]
            
            # 5. OPTIMIZATION: Efficient tensor conversion and device transfer
            if not isinstance(batch_data, torch.Tensor):
                batch_data = torch.tensor(batch_data, dtype=working_dtype)
            else:
                batch_data = batch_data.to(dtype=working_dtype)
            
            # Move to target device if needed
            if batch_data.device != torch.device(device):
                batch_data = batch_data.to(device, non_blocking=True)
            
            # 6. OPTIMIZATION: Efficient flattening and computation
            # Use view instead of flatten when possible (more memory efficient)
            flat_batch = batch_data.view(-1)
            
            # Compute statistics using vectorized operations
            batch_sum = flat_batch.sum(dtype=sum_dtype)
            batch_squared_sum = flat_batch.pow(2).sum(dtype=sum_dtype)
            batch_count = flat_batch.numel()
            
            # Update accumulators
            total_sum += batch_sum
            total_squared_sum += batch_squared_sum
            total_count += batch_count
            
            # 7. OPTIMIZATION: Memory cleanup for large datasets
            if batch_idx % 100 == 0:  # Periodic cleanup
                if device == 'cuda':
                    torch.cuda.empty_cache()
                gc.collect()
            
            # Update progress bar with current estimates
            if batch_idx % 10 == 0:
                current_mean = (total_sum / total_count).item()
                progress_bar.set_postfix({
                    'current_mean': f'{current_mean:.4f}',
                    'processed': f'{total_count:,}'
                })
    
    # 8. OPTIMIZATION: Numerically stable standard deviation calculation
    # Use the corrected formula to avoid numerical instability
    mean = total_sum / total_count
    variance = (total_squared_sum / total_count) - (mean * mean)
    
    # Clamp variance to avoid negative values due to floating point errors
    variance = torch.clamp(variance, min=0.0)
    std = torch.sqrt(variance)
    
    # Convert back to Python scalars
    final_mean = mean.item()
    final_std = std.item()
    
    # Final cleanup
    if device == 'cuda':
        torch.cuda.empty_cache()
    
    return final_mean, final_std

def compute_dataset_statistics_chunked(dataset, model_input_name, chunk_size=1000):
    """
    Alternative approach: Process dataset in chunks without DataLoader for maximum memory efficiency.
    Use this if the DataLoader approach still has issues.
    """
    print(f"Processing dataset in chunks of {chunk_size}")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    sum_dtype = torch.float64
    
    total_sum = torch.tensor(0.0, dtype=sum_dtype, device=device)
    total_squared_sum = torch.tensor(0.0, dtype=sum_dtype, device=device)
    total_count = 0
    
    dataset_size = len(dataset)
    num_chunks = (dataset_size + chunk_size - 1) // chunk_size
    
    with torch.no_grad():
        for chunk_idx in tqdm(range(num_chunks), desc="Processing chunks"):
            start_idx = chunk_idx * chunk_size
            end_idx = min(start_idx + chunk_size, dataset_size)
            
            # Process items in current chunk
            chunk_tensors = []
            for idx in range(start_idx, end_idx):
                item = dataset[idx][model_input_name]
                if not isinstance(item, torch.Tensor):
                    item = torch.tensor(item, dtype=torch.float32)
                chunk_tensors.append(item.flatten())
            
            # Concatenate chunk data
            if chunk_tensors:
                chunk_data = torch.cat(chunk_tensors).to(device)
                
                # Compute statistics
                chunk_sum = chunk_data.sum(dtype=sum_dtype)
                chunk_squared_sum = chunk_data.pow(2).sum(dtype=sum_dtype)
                chunk_count = chunk_data.numel()
                
                total_sum += chunk_sum
                total_squared_sum += chunk_squared_sum
                total_count += chunk_count
                
                # Cleanup
                del chunk_data, chunk_tensors
                if device == 'cuda':
                    torch.cuda.empty_cache()
    
    # Calculate final statistics
    mean = total_sum / total_count
    variance = (total_squared_sum / total_count) - (mean * mean)
    variance = torch.clamp(variance, min=0.0)
    std = torch.sqrt(variance)
    
    return mean.item(), std.item()

In [12]:
# Method 1: Optimized DataLoader approach (recommended)
try:
    # mean, std = compute_dataset_statistics_optimized(
    #     dataset=dataset["train"],
    #     model_input_name=model_input_name,
    #     batch_size=256,  # Start with this, increase if you have more memory
    #     device='cuda' if torch.cuda.is_available() else 'cpu'
    # )
    # Testing purposes
    mean = 0.31605425901883943
    std = 0.45787811188377187
    
    print(f'Optimized - Mean: {mean}, Std: {std}')
    
    feature_extractor.mean = mean
    feature_extractor.std = std
except Exception as e:
    print(f"DataLoader approach failed: {e}")
    print("Trying chunked approach...")
    
    # Method 2: Fallback chunked approach
    mean, std = compute_dataset_statistics_chunked(
        dataset=dataset["train"],
        model_input_name=model_input_name,
        chunk_size=500
    )
    print(f'Chunked - Mean: {mean}, Std: {std}')

Optimized - Mean: 0.31605425901883943, Std: 0.45787811188377187


In [13]:
from audiomentations import Compose, AddGaussianNoise, BandPassFilter, Mp3Compression, RoomSimulator, Gain, ClippingDistortion
import math  # Needed for azimuth/elevation parameters

audio_augmentations = Compose([
    # 1. Noise injection (both papers)
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    
    # 2. Bandpass filtering (continuous-learning.pdf)
    BandPassFilter(min_center_freq=300, max_center_freq=3000, 
                   min_bandwidth_fraction=0.1, max_bandwidth_fraction=0.3, p=0.5),
    
    # 3. MP3 compression (hybrid-audio.pdf)
    Mp3Compression(min_bitrate=16, max_bitrate=32, p=0.5),
    
    # 4. Room simulation - CORRECTED (continuous-learning.pdf)
    RoomSimulator(
        min_size_x=5.0, max_size_x=15.0,
        min_size_y=5.0, max_size_y=15.0,
        min_size_z=2.4, max_size_z=4.0,
        min_absorption_value=0.1, max_absorption_value=0.8,
        min_mic_distance=1.0, max_mic_distance=5.0,  # Far-field simulation
        min_mic_azimuth=-math.pi, max_mic_azimuth=math.pi,
        min_mic_elevation=-math.pi/4, max_mic_elevation=math.pi/4,
        calculation_mode="absorption",
        p=0.5
    ),
    
    # 5. Gain variation (hybrid-audio.pdf)
    Gain(min_gain_db=-6, max_gain_db=6, p=0.3),
    
    # 6. Clipping distortion (both papers)
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=10, p=0.3),
], p=0.8, shuffle=True)

In [14]:
def preprocess_audio_with_transforms(batch):
    # we apply augmentations on each waveform
    wavs = [audio_augmentations(audio["array"], sample_rate=SAMPLING_RATE) for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt", return_attention_mask=True, max_length=507)
    return {
        model_input_name: inputs.get(model_input_name),
        "labels": torch.tensor(batch["label"])
    }

print('dataset train features:', dataset['train'].features)
print('dataset test features:', dataset['test'].features)
print('dataset validation features:', dataset['validation'].features)
# Cast the audio column to the appropriate feature type and rename it
dataset = dataset.cast_column("input_values", Audio(sampling_rate=feature_extractor.sampling_rate))



dataset train features: {'input_values': Audio(sampling_rate=None, mono=True, decode=True, id=None), 'label': ClassLabel(names=['fake', 'real'], id=None)}
dataset test features: {'input_values': Audio(sampling_rate=None, mono=True, decode=True, id=None), 'label': ClassLabel(names=['fake', 'real'], id=None)}
dataset validation features: {'input_values': Audio(sampling_rate=None, mono=True, decode=True, id=None), 'label': ClassLabel(names=['fake', 'real'], id=None)}


In [15]:
# with augmentations on the training set
dataset["train"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)
# w/o augmentations on the test set
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)
dataset["validation"].set_transform(preprocess_audio, output_all_columns=False)


In [16]:
temp_ds = dataset['train'].with_format(None)  # This removes any applied transforms
temp_ds_test = dataset['test'].with_format(None)  # This removes any applied transforms

print(temp_ds[0])
unique_labels = sorted(set(temp_ds["label"] + temp_ds_test["label"]))
print(unique_labels)

{'input_values': {'path': 'D:\\ProgramFiles\\dev\\repos\\thesis-testing\\dataset\\train\\fake\\0.wav', 'array': array([ 8.55924794e-04,  5.84704467e-05,  7.75483320e-04, ...,
       -1.02691360e-04, -2.82649242e-04,  0.00000000e+00], shape=(29105,)), 'sampling_rate': 16000}, 'label': 0}
[0, 1]


In [17]:
from transformers import ASTConfig, ASTForAudioClassification

# Load configuration from the pretrained model
config = ASTConfig.from_pretrained(pretrained_model)

# Update configuration with the number of labels in our dataset
config.num_mel_bins = NUM_MEL_BINS  # Make sure this matches your feature extractor
config.max_length = MAX_SEQUENCE_LENGTH   # Or whatever your sequence length is
config.num_labels = num_labels
config.label2id = {"fake": 0, "real": 1}
config.id2label = {0: "fake", 1: "real"}

# setting dropout to prevent overfitting. based on https://www.mdpi.com/2073-431X/13/10/256
config.hidden_dropout_prob = 0.10

# having a patch size of 16, and time and frequency stride of 16, allows us to have NO 
# overlaps. this prevents difficulties if the file has real and fake audio 
# (based on https://arxiv.org/pdf/2409.05924)
config.patch_size = 16
config.frequency_stride = 16
config.time_stride = 16

config.patchout_prob = 0.5  # Probability of dropping a patch
config.patchout_strategy = "structured"  # "unstructured" or "structured"

# Initialize the model with the updated configuration
model = ASTForAudioClassification.from_pretrained(
    pretrained_model,
    config=config,
    attn_implementation="sdpa",
    ignore_mismatched_sizes=True
)

print(model.audio_spectrogram_transformer.embeddings)  # See what this submodule is

model.audio_spectrogram_transformer.embeddings.embeddings = ASTEmbeddingsWithPatchout(config)

model.init_weights()

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- audio_spectrogram_transformer.embeddings.position_embeddings: found shape torch.Size([1, 1214, 768]) in the checkpoint and torch.Size([1, 250, 768]) in the model instantiated
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ASTEmbeddings(
  (patch_embeddings): ASTPatchEmbeddings(
    (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.1, inplace=False)
)


In [18]:
print(config.patchout_prob)
print(config.patchout_strategy)
sample_input = torch.randn(1, 128, 507)  # Batch of spectrograms (batch_size, num_mel_bins, max_length)
output = model(sample_input)
print("Output shape:", output.logits.shape)

0.5
structured
Output shape: torch.Size([1, 2])


In [19]:
from transformers import TrainingArguments, EarlyStoppingCallback
# weight decay formula based on https://arxiv.org/abs/1711.05101:
    # weight decay = ynorm(sqr(batch_size/(dataset_size * num_train_epochs)))
    # where ynorm is between 0.025 and 0.05 https://towardsdatascience.com/weight-decay-and-its-peculiar-effects-66e0aee3e7b8/

DATASET_SIZE = 317_394
YNORM = 0.025
BATCH_SIZE = 16
NUM_TRAIN_EPOCHS = 20
WEIGHT_DECAY = YNORM * np.sqrt(BATCH_SIZE / (DATASET_SIZE * NUM_TRAIN_EPOCHS))

print("Weight decay: ", WEIGHT_DECAY)

# Configure training run with TrainingArguments class
training_args = TrainingArguments(
    optim="adamw_torch_fused", # based on https://huggingface.co/docs/transformers/v4.35.2/en/perf_train_gpu_one#optimizer-choice
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    output_dir="./runs/ast_classifier",
    logging_dir="./logs/ast_classifier",
    report_to="tensorboard",
    learning_rate=2e-5, # based on https://arxiv.org/pdf/2505.15136
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,  # Start with 8, can increase if memory allows
    per_device_eval_batch_size=BATCH_SIZE,
    dataloader_num_workers=0, # based on https://huggingface.co/docs/transformers/v4.35.2/en/perf_train_gpu_one#data-preloading
    dataloader_pin_memory=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="loss", # based on https://www.mdpi.com/2073-431X/13/10/256 and https://arxiv.org/pdf/2505.15136
    logging_steps=20,
    remove_unused_columns=False,
    fp16=True,
    fp16_full_eval=True,
    no_cuda=not torch.cuda.is_available(),
    lr_scheduler_type="cosine", # based on https://arxiv.org/pdf/2505.15136
    weight_decay=WEIGHT_DECAY, # based on https://arxiv.org/abs/1711.05101 and https://towardsdatascience.com/weight-decay-and-its-peculiar-effects-66e0aee3e7b8/
    torch_compile=True,
    torch_compile_backend="aot_eager",
    torch_empty_cache_steps=4,
)



The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.


Weight decay:  3.9690415546402594e-05


In [20]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    return metrics

In [21]:
from transformers import Trainer, DataCollatorWithPadding
from tqdm.auto import tqdm

# Initialize the data collator
data_collator = DataCollatorWithPadding(
    tokenizer=feature_extractor,  # Your feature extractor acts as the tokenizer
    padding=True,
    max_length=config.max_length,
    return_tensors="pt"
)

def init_model():
    return ASTForAudioClassification.from_pretrained(
    pretrained_model,
    config=config,
    attn_implementation="sdpa",
    ignore_mismatched_sizes=True
)

trainer = Trainer(
    model_init=init_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
    data_collator=data_collator,
    # if no improvements happened to validation loss within 5 epochs, stop training.
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.00)]
)

  trainer = Trainer(
Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- audio_spectrogram_transformer.embeddings.position_embeddings: found shape torch.Size([1, 1214, 768]) in the checkpoint and torch.Size([1, 250, 768]) in the model instantiated
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
# Add this before training to verify shapes
sample = next(iter(dataset["train"]))
print(f"Sample input shape: {sample['input_values'].shape}")

# Forward pass to check for errors
with torch.no_grad():
    output = model(sample["input_values"].unsqueeze(0))
print("Forward pass successful!")



Sample input shape: torch.Size([507, 128])
Forward pass successful!


In [23]:
print(f"Model config: {model.config}")
print(f"Sample input shape: {sample['input_values'].shape}")
print(f"Model's expected input shape: {model.config.max_length}")
print(f"Feature extractor config: {feature_extractor}")

Model config: ASTConfig {
  "architectures": [
    "ASTForAudioClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "frequency_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "fake",
    "1": "real"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "fake": 0,
    "real": 1
  },
  "layer_norm_eps": 1e-12,
  "max_length": 507,
  "model_type": "audio-spectrogram-transformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_mel_bins": 128,
  "patch_size": 16,
  "patchout_prob": 0.5,
  "patchout_strategy": "structured",
  "qkv_bias": true,
  "time_stride": 16,
  "torch_dtype": "float32",
  "transformers_version": "4.54.1"
}

Sample input shape: torch.Size([507, 128])
Model's expected input shape: 507
Feature extractor config: ASTFeatureExtractorHamming {
  "do_normalize": false,
  "feature_extractor_type": "ASTFeatureExtractorHamming",
  "feature_size": 1,
  "max_length

In [24]:
# Check GPU availability
print(f"Using device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
     print(f"GPU Memory Total: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
     print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
     print(f"GPU Memory Cached: {torch.cuda.memory_reserved(0)/1e9:.2f} GB")


Using device: NVIDIA GeForce RTX 2060
GPU Memory Total: 6.44 GB
GPU Memory Allocated: 0.00 GB
GPU Memory Cached: 0.00 GB


In [25]:
import torch
from torch.serialization import safe_globals
import numpy as np
from accelerate.utils import get_grad_scaler

# Get the correct scaler type based on your hardware
scaler = get_grad_scaler()

CHECKPOINT_NUM = 63480
# Load the scaler state dict
with safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType]):
    try:
        scaler_state = torch.load(fr"runs\ast_classifier\checkpoint-{CHECKPOINT_NUM}\scaler.pt", weights_only=True)
        print("Scaler state type:", type(scaler_state))
        print("Scaler state contents:", scaler_state)
        print("Trainer accelerator scaler: ", trainer.accelerator.scaler)
        if not trainer.accelerator.scaler:
            print("Trainer accelerator scaler is None. Setting up the scaler")
            scaler.load_state_dict(scaler_state)
            trainer.accelerator.scaler = scaler
    except Exception as e:

        print(f"Error loading scaler: {e}")

Error loading scaler: [Errno 2] No such file or directory: 'runs\\ast_classifier\\checkpoint-63480\\scaler.pt'


In [None]:
import optuna
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"  # For internal stack traces
os.environ["TORCH_LOGS"] = "+dynamo"     # Additional context

print(torch._dynamo.list_backends())

def compute_objective(metrics: dict[str, float]) -> list[float]:
    return metrics["eval_loss"], metrics["eval_accuracy"]

def hp_space(trial):
    return {
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16]),
        "per_device_eval_batch_size": trial.suggest_categorical("per_device_eval_batch_size", [16, 32]),
        "num_train_epochs": trial.suggest_categorical("num_train_epochs", [5, 10, 15, 20]),
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 5e-5),
        "weight_decay": trial.suggest_float("weight_decay", 0.0, 0.3),
        "warmup_steps": trial.suggest_categorical("warmup_steps", [0, 500, 1000]),
        "lr_scheduler_type": trial.suggest_categorical("lr_scheduler_type", ["linear", "cosine", "cosine_with_restarts"]), 
    }

best_trials = trainer.hyperparameter_search(
    direction=["minimize", "maximize"],
    backend="optuna",
    hp_space=hp_space,
    n_trials=20,
    compute_objective=compute_objective,
)

[I 2025-08-03 12:45:52,791] A new study created in memory with name: no-name-fa005678-f308-4232-9234-a4621dc7a5cc
Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- audio_spectrogram_transformer.embeddings.position_embeddings: found shape torch.Size([1, 1214, 768]) in the checkpoint and torch.Size([1, 250, 768]) in the model instantiated
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm']


Epoch,Training Loss,Validation Loss




In [None]:
%tensorboard --logdir=~/ray_results/audio_classification

In [None]:
with torch.serialization.safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType]):
    trainer.train(resume_from_checkpoint=False)

In [None]:
if torch.cuda.is_available():
     print(f"GPU Memory Total: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
     print(f"GPU Memory Allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
     print(f"GPU Memory Cached: {torch.cuda.memory_reserved(0)/1e9:.2f} GB")
else:
    print("No GPU available")

In [None]:
##### MODEL EVALUATION AND PREDICTION

import os
import json
from transformers import ASTForAudioClassification

# Define the checkpoint path
checkpoint_path = "runs/ast_classifier/checkpoint-158700"

# Load the model and feature extractor
model = ASTForAudioClassification.from_pretrained(checkpoint_path)
feature_extractor = ASTFeatureExtractor.from_pretrained(checkpoint_path)

# Look for training history
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.json")
if os.path.exists(trainer_state_path):
    with open(trainer_state_path, "r") as f:
        trainer_state = json.load(f)
    print("\nTraining metrics from trainer_state.json:")
    print(json.dumps(trainer_state, indent=2))

# Set model to evaluation mode
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")

import librosa
def predict_audio(file_path, model, feature_extractor, device="cuda"):
    # Load audio file
    audio, sr = librosa.load(file_path, sr=feature_extractor.sampling_rate)

    # Preprocess the audio
    inputs = feature_extractor(
        audio,
        sampling_rate=feature_extractor.sampling_rate,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    )

    # Move inputs to the same device as the model
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)

    print(f"Raw logits: {logits}")
    print(f"Raw probabilities: {probabilities}")

    # Get predicted class (0 for fake, 1 for real)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0][predicted_class].item()

    print(f"P(fake): {probabilities[0][0].item():.4f}")
    print(f"P(real): {probabilities[0][1].item():.4f}")

    # Map class index to label
    label = "fake" if predicted_class == 0 else "real"

    return {
        "label": label,
        "confidence": confidence,
        "probabilities": {
            "fake": probabilities[0][0].item(),
            "real": probabilities[0][1].item()
        }
    }

audio_file_path = "" # Adjust the path as needed
result = predict_audio(audio_file_path, model, feature_extractor, device)

print(f"Prediction: {result['label']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Probabilities - Fake: {result['probabilities']['fake']:.4f}, Real: {result['probabilities']['real']:.4f}")

In [None]:
#### AUDIO CUTTING SCRIPT

import os
import sys
import librosa
import soundfile as sf

def cut_audio(input_path, output_dir="./cut", clip_duration=5):
     # Ensure output directory exists
     os.makedirs(output_dir, exist_ok=True)

     # Get file name and extension
     base_name = os.path.basename(input_path)
     name, ext = os.path.splitext(base_name)

     # Load audio
     audio, sr = librosa.load(input_path, sr=None)
     total_duration = librosa.get_duration(y=audio, sr=sr)
     clip_samples = int(clip_duration * sr)
     num_clips = int(total_duration // clip_duration) + (1 if total_duration % clip_duration > 0 else 0)

     for i in range(num_clips):
          start_sample = i * clip_samples
          end_sample = min((i + 1) * clip_samples, len(audio))
          clip_audio = audio[start_sample:end_sample]
          out_path = os.path.join(output_dir, f"{name}({i+1}){ext}")
          sf.write(out_path, clip_audio, sr)
          print(f"Saved: {out_path}")

if __name__ == "__main__":
     if len(sys.argv) < 2:
          print("Usage: python cut_audio.py <path_to_audio_file>")
          sys.exit(1)
     audio_path = ""  # Adjust the path as needed
     if not os.path.exists(audio_path):
          print(f"File not found: {audio_path}")
          sys.exit(1)
     cut_audio(audio_path)

In [None]:
#### SEGMENTED AUDIO PREDICTION SCRIPT

import os
from collections import defaultdict

# Directory containing segmented audio clips
segmented_dir = "./cut"

# Collect all audio files and group by base name (without segment index)
audio_groups = defaultdict(list)
for fname in os.listdir(segmented_dir):
     if fname.lower().endswith(('.wav', '.mp3', '.flac', '.ogg')):
          # Extract base name (e.g., "audio(1).wav" -> "audio")
          base = fname.split('(')[0]
          audio_groups[base].append(os.path.join(segmented_dir, fname))

# Function to predict for a single audio file
def predict_single(file_path):
     return predict_audio(file_path, model, feature_extractor, device)

# Aggregate predictions for each group
results = {}
for base, files in audio_groups.items():
     fake_count = 0
     real_count = 0
     for fpath in files:
          pred = predict_single(fpath)
          if pred["label"] == "fake":
               fake_count += 1
          else:
               real_count += 1
     total = fake_count + real_count
     results[base] = {
          "fake_ratio": fake_count / total if total > 0 else 0,
          "real_ratio": real_count / total if total > 0 else 0,
          "fake_count": fake_count,
          "real_count": real_count,
          "total": total
     }

# Print summary
for base, stats in results.items():
     print(f"{base}: Fake {stats['fake_count']}/{stats['total']} ({stats['fake_ratio']:.2f}), "
            f"Real {stats['real_count']}/{stats['total']} ({stats['real_ratio']:.2f})")