In [None]:
import os
import numpy as np
import pandas as pd
import librosa
import soundfile as sf
from pydub import AudioSegment
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, confusion_matrix
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift
import wandb
from torch.cuda.amp import GradScaler, autocast
import gc

# 1. Setup and Install Dependencies

In [None]:
!pip install librosa soundfile pydub transformers datasets matplotlib seaborn wandb
print(f"GPU Available: {torch.cuda.is_available()}")
print(f"GPU Count: {torch.cuda.device_count()}")
print(f"GPU Name: {torch.cuda.get_device_name(0)}")

## Initialize W&B

In [None]:
wandb.init(project="deepfake-audio-detection")

# 2. Define Paths

In [None]:
FOR_PATH = '/kaggle/input/fake-or-real-dataset'
DFADD_PATH = '/kaggle/input/dfadd-dataset'
OUTPUT_DIR = '/kaggle/working/preprocessed'
MEL_DIR = '/kaggle/working/mel_spectrograms'

# Create directories
os.makedirs(os.path.join(OUTPUT_DIR, 'real'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, 'fake'), exist_ok=True)
os.makedirs(os.path.join(MEL_DIR, 'real'), exist_ok=True)
os.makedirs(os.path.join(MEL_DIR, 'fake'), exist_ok=True)

# 3. EDA (Exploratory Data Analysis)

In [None]:
def count_samples(dataset_path, real_folder='real', fake_folder='fake'):
    real_files = len(os.listdir(os.path.join(dataset_path, real_folder)))
    fake_files = len(os.listdir(os.path.join(dataset_path, fake_folder)))
    return real_files, fake_files

for_real, for_fake = count_samples(FOR_PATH)
dfadd_real, dfadd_fake = count_samples(DFADD_PATH)
print(f"FoR: Real={for_real}, Fake={for_fake}")
print(f"DFADD: Real={dfadd_real}, Fake={dfadd_fake}")

# Visualize class distribution
data = {'Dataset': ['FoR', 'FoR', 'DFADD', 'DFADD'],
        'Class': ['Real', 'Fake', 'Real', 'Fake'],
        'Count': [for_real, for_fake, dfadd_real, dfadd_fake]}
df = pd.DataFrame(data)
sns.barplot(x='Dataset', y='Count', hue='Class', data=df)
plt.title("Class Distribution in FoR and DFADD")
plt.savefig('/kaggle/working/class_distribution.png')
plt.close()

### Analyze and visualize sample audio

In [None]:
def analyze_audio(file_path):
    audio, sr = librosa.load(file_path, sr=None)
    duration = librosa.get_duration(y=audio, sr=sr)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    librosa.display.waveshow(audio, sr=sr)
    plt.title("Waveform")
    plt.subplot(1, 2, 2)
    S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
    S_dB = librosa.power_to_db(S, ref=np.max)
    librosa.display.specshow(S_dB, sr=sr, x_axis='time', y_axis='mel')
    plt.colorbar(format='%+2.0f dB')
    plt.title("Mel-Spectrogram")
    plt.tight_layout()
    plt.savefig('/kaggle/working/sample_audio.png')
    plt.close()
    return sr, duration

sample_file = os.path.join(FOR_PATH, 'real', os.listdir(os.path.join(FOR_PATH, 'real'))[0])
sr, duration = analyze_audio(sample_file)
print(f"Sample Rate: {sr}, Duration: {duration}s")

# 4. Preprocess Audio

In [None]:
def preprocess_audio(input_path, output_path, target_sr=16000, target_duration=3.0):
    try:
        audio = AudioSegment.from_file(input_path)
        audio = audio.set_frame_rate(target_sr).set_channels(1)
        target_length = int(target_duration * 1000)
        if len(audio) > target_length:
            audio = audio[:target_length]
        else:
            audio = audio + AudioSegment.silent(duration=target_length - len(audio))
        audio.export(output_path, format='wav')
    except Exception as e:
        print(f"Error processing {input_path}: {e}")

for dataset_path, dataset_name in [(FOR_PATH, 'FoR'), (DFADD_PATH, 'DFADD')]:
    for class_name in ['real', 'fake']:
        input_dir = os.path.join(dataset_path, class_name)
        output_dir = os.path.join(OUTPUT_DIR, class_name)
        for file in tqdm(os.listdir(input_dir), desc=f"Processing {dataset_name} {class_name}"):
            preprocess_audio(os.path.join(input_dir, file), os.path.join(output_dir, file))

# 5. Handle Class Imbalance (Oversampling real class)

In [None]:
augment = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.2, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.5)
])

def augment_audio(input_path, output_path, augmentations):
    audio, sr = librosa.load(input_path, sr=16000)
    augmented_audio = augmentations(audio, sample_rate=sr)
    sf.write(output_path, augmented_audio, sr) # Corrected sf.write arguments

real_files = os.listdir(os.path.join(OUTPUT_DIR, 'real'))
target_count = len(os.listdir(os.path.join(OUTPUT_DIR, 'fake')))

if len(real_files) < target_count: # Only oversample if needed
    print(f"Oversampling real class from {len(real_files)} to target count {target_count}")
    current_real_count = len(real_files)
    # Iterate through real files and augment until target_count is reached
    # To avoid infinite loop or excessive augmentation, let's make sure we augment a specific number of times
    # based on the difference needed.
    num_augmentations_needed = target_count - current_real_count
    
    # Distribute augmentations among existing real files
    # For simplicity, let's loop through real_files and augment each one until the target is met
    # A more sophisticated approach would be to calculate how many times each file needs to be augmented
    # or to randomly pick files to augment.
    aug_idx = 0
    while len(os.listdir(os.path.join(OUTPUT_DIR, 'real'))) < target_count:
        for file in real_files:
            if len(os.listdir(os.path.join(OUTPUT_DIR, 'real'))) >= target_count:
                break
            input_path = os.path.join(OUTPUT_DIR, 'real', file)
            output_path = os.path.join(OUTPUT_DIR, 'real', f"aug_{aug_idx}_{file}") # Unique name for augmented file
            try:
                augment_audio(input_path, output_path, augment)
                aug_idx += 1
            except Exception as e:
                print(f"Error during augmentation of {input_path}: {e}")
else:
    print("Real class count is already greater than or equal to fake class count. No oversampling needed.")
print(f"Final Real samples after oversampling: {len(os.listdir(os.path.join(OUTPUT_DIR, 'real')))}")

# 6. Feature Extraction (Mel-Spectrograms)

In [None]:
def audio_to_melspectrogram(input_path, output_path, sr=16000, n_mels=128, n_fft=2048, hop_length=512):
    audio, sr = librosa.load(input_path, sr=sr)
    S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
    S_dB = librosa.power_to_db(S, ref=np.max)
    np.save(output_path, S_dB)

for class_name in ['real', 'fake']:
    input_dir = os.path.join(OUTPUT_DIR, class_name)
    output_dir = os.path.join(MEL_DIR, class_name)
    for file in tqdm(os.listdir(input_dir), desc=f"Converting {class_name} to Mel"):
        # Ensure the output file name is unique and ends with .npy
        base_name, ext = os.path.splitext(file)
        if ext.lower() in ['.wav', '.mp3', '.flac']:
            audio_to_melspectrogram(os.path.join(input_dir, file), os.path.join(output_dir, f'{base_name}.npy'))
        else:
            print(f"Skipping non-audio file: {file}")

# 7. Dataset Splitting

In [None]:
real_files = [os.path.join(MEL_DIR, 'real', f) for f in os.listdir(os.path.join(MEL_DIR, 'real')) if f.endswith('.npy')]
fake_files = [os.path.join(MEL_DIR, 'fake', f) for f in os.listdir(os.path.join(MEL_DIR, 'fake')) if f.endswith('.npy')]

all_files = real_files + fake_files
labels = [0] * len(real_files) + [1] * len(fake_files)

train_files, temp_files, train_labels, temp_labels = train_test_split(
    all_files, labels, test_size=0.3, stratify=labels, random_state=42
)
val_files, test_files, val_labels, test_labels = train_test_split(
    temp_files, temp_labels, test_size=0.5, stratify=temp_labels, random_state=42
)
print(f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")

# 8. Custom Dataset

In [None]:
class AudioDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mel = np.load(self.file_paths[idx])
        # Ensure mel-spectrogram has 3 channels for ViT, even if it's a single-channel input conceptually.
        # ViT expects 3 channels, so we'll stack the single channel 3 times.
        if mel.ndim == 2: # Check if it's (n_mels, time_steps)
            mel = np.stack([mel, mel, mel], axis=0) # Make it (3, n_mels, time_steps)
        elif mel.ndim == 3 and mel.shape[0] == 1: # If it's (1, n_mels, time_steps)
            mel = np.concatenate([mel, mel, mel], axis=0) # Make it (3, n_mels, time_steps)
            
        if self.transform:
            mel = self.transform(mel)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return mel, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=True) # ViT base expects 224x224 input
])

train_dataset = AudioDataset(train_files, train_labels, transform)
val_dataset = AudioDataset(val_files, val_labels, transform)
test_dataset = AudioDataset(test_files, test_labels, transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# 9. Define Models

In [None]:
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleResNet, self).__init__()
        # CNN input channel should be 3 if we stack 3 times, otherwise 1
        # Assuming we keep the 1-channel mel for CNN and convert to 3 for ViT
        # Or, we can make this CNN also accept 3 channels to be consistent with the dataset's transform.
        # Let's adjust for 3 channels since the transform is set for 3 channels for ViT.
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Calculate output size after conv and pooling layers for the FC layer
        # For 128x128 input (from Resize), MaxPool2d reduces to 64x64
        # If transforms.Resize((224, 224)) is used, then MaxPool2d(2,2) will make it 112x112
        # Let's dynamically calculate the size for robustness
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling
        self.fc = nn.Linear(128, num_classes) # Corrected based on adaptive pooling

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x) # Apply global average pooling
        x = torch.flatten(x, 1) # Flatten for the fully connected layer
        x = self.fc(x)
        return x

# Initialize models
cnn_model = SimpleResNet().cuda()
cnn_model = nn.DataParallel(cnn_model)

# For ViT, ensure the feature extractor is initialized correctly
# The ViTForImageClassification automatically handles the feature extraction as part of its forward pass
# when you pass raw pixel values (or in our case, preprocessed mel-spectrograms that resemble images).
# The transforms.Resize((224, 224)) makes the mel-spectrogram suitable for ViT's input expectation.
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=2, ignore_mismatched_sizes=True)
vit_model = nn.DataParallel(vit_model.cuda())

# 10. Training and Evaluation Functions

In [None]:
criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0]).cuda()) # Example: weighting fake class higher
scaler = GradScaler()

def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(loader, desc="Training"):
        inputs, labels = inputs.cuda(), labels.cuda()
        optimizer.zero_grad()
        with autocast():
            outputs = model(inputs).logits if hasattr(model, 'module') and hasattr(model.module, 'logits') else model(inputs) # Handle ViT output
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
    return running_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    preds, true_labels = [], []
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Evaluating"):
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs).logits if hasattr(model, 'module') and hasattr(model.module, 'logits') else model(inputs) # Handle ViT output
            _, predicted = torch.max(outputs, 1)
            preds.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    # Handle case where true_labels or preds might be empty (e.g., if loader is empty)
    if len(true_labels) == 0:
        return 0.0, 0.0, 0.0, 0.0

    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average='binary', zero_division=0)
    # ROC AUC requires probability scores, but here we have binary predictions.
    # If you want true AUC, you'd need the raw output scores before softmax/argmax.
    # For now, using binary predictions for AUC will give a degenerate result (0 or 1).
    # If you have raw scores, change `preds` to `torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()`
    auc = roc_auc_score(true_labels, preds)
    return precision, recall, f1, auc

# 11. Train Models

In [None]:
models = {
    # 'CNN': (cnn_model, torch.optim.Adam(cnn_model.parameters(), lr=1e-4)),
    'ViT': (vit_model, torch.optim.Adam(vit_model.parameters(), lr=1e-4))
}

num_epochs = 10 # Define number of epochs

for model_name, (model, optimizer) in models.items():
    print(f"Training {model_name}...")
    best_f1 = 0
    for epoch in range(num_epochs):
        train_loss = train_epoch(model, train_loader, criterion, optimizer)
        precision, recall, f1, auc = evaluate(model, val_loader)
        wandb.log({
            f"{model_name}_epoch": epoch + 1,
            f"{model_name}_train_loss": train_loss,
            f"{model_name}_val_precision": precision,
            f"{model_name}_val_recall": recall,
            f"{model_name}_val_f1": f1,
            f"{model_name}_val_auc": auc
        })
        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), f'/kaggle/working/best_{model_name}.pth')
            print(f"Saved best {model_name} model with F1: {f1:.4f}")
        print(f"{model_name} Epoch {epoch+1}/{num_epochs}: Loss={train_loss:.4f}, Val F1={f1:.4f}")
    
    # Evaluate on test set
    print(f"\nEvaluating {model_name} on Test Set...")
    # Load the best model weights before testing
    model.load_state_dict(torch.load(f'/kaggle/working/best_{model_name}.pth'))
    precision, recall, f1, auc = evaluate(model, test_loader)
    print(f"{model_name} Test: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, AUC={auc:.4f}")
    wandb.log({
        f"{model_name}_test_precision": precision,
        f"{model_name}_test_recall": recall,
        f"{model_name}_test_f1": f1,
        f"{model_name}_test_auc": auc
    })

    # Confusion Matrix
    preds, true_labels = [], []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs).logits if hasattr(model, 'module') and hasattr(model.module, 'logits') else model(inputs)
            _, predicted = torch.max(outputs, 1)
            preds.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            
    if len(true_labels) > 0: # Only plot if there's data
        cm = confusion_matrix(true_labels, preds)
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'{model_name} Confusion Matrix')
        plt.savefig(f'/kaggle/working/{model_name}_confusion_matrix.png')
        plt.close()
        wandb.log({f"{model_name}_confusion_matrix": wandb.Image(f'/kaggle/working/{model_name}_confusion_matrix.png')})
    else:
        print(f"No data to generate confusion matrix for {model_name} test set.")

# 12. Cleanup

In [None]:
torch.cuda.empty_cache()
gc.collect()

wandb.finish()