<a href="https://colab.research.google.com/github/grzegorzj/drone-detection/blob/main/drone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/saraalemadi/DroneAudioDataset.git

Cloning into 'DroneAudioDataset'...
remote: Enumerating objects: 10649, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Total 10649 (delta 5), reused 5 (delta 5), pack-reused 10643 (from 1)[K
Receiving objects: 100% (10649/10649), 274.31 MiB | 15.28 MiB/s, done.
Resolving deltas: 100% (181/181), done.
Updating files: 100% (23409/23409), done.


# Attempt number one
Unconstrained, simple classifier. At the moment, not paying too much attention to size of the model or inference time.

In [None]:
pip install torch torchaudio numpy scikit-learn

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import os
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split

# Constants
SAMPLE_RATE = 44100  # Standard audio sample rate
MAX_DURATION = 5  # Maximum duration in seconds to consider
N_MELS = 64  # Number of mel bands
BATCH_SIZE = 32
NUM_EPOCHS = 30
LEARNING_RATE = 0.001

class AudioDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_mels=N_MELS,
            normalized=True
        )

        # Get all file paths and labels
        self.files = []
        self.labels = []

        # Positive examples (yes_drone)
        pos_dir = self.root_dir / 'yes_drone'
        for file in pos_dir.glob('*.wav'):
            self.files.append(file)
            self.labels.append(1)

        # Negative examples (unknown)
        neg_dir = self.root_dir / 'unknown'
        for file in neg_dir.glob('*.wav'):
            self.files.append(file)
            self.labels.append(0)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        label = self.labels[idx]

        # Load audio
        waveform, sample_rate = torchaudio.load(audio_path)

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # Resample if necessary
        if sample_rate != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)
            waveform = resampler(waveform)

        # Pad or trim to MAX_DURATION
        target_length = MAX_DURATION * SAMPLE_RATE
        if waveform.shape[1] < target_length:
            waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform.shape[1]))
        else:
            waveform = waveform[:, :target_length]

        # Convert to mel spectrogram
        mel_spec = self.mel_spectrogram(waveform)

        # Log scale mel spectrogram
        mel_spec = torch.log(mel_spec + 1e-9)

        if self.transform:
            mel_spec = self.transform(mel_spec)

        return mel_spec, label

class DroneClassifier(nn.Module):
    def __init__(self):
        super(DroneClassifier, self).__init__()

        # Lightweight CNN architecture
        self.features = nn.Sequential(
            # First conv block
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Second conv block
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Third conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Dropout for regularization
            nn.Dropout(0.3)
        )

        # Adaptive pooling to handle variable input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

        # Fully connected layers
        self.classifier = nn.Sequential(
            nn.Linear(64 * 4 * 4, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().view(-1, 1)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predicted = (outputs > 0.5).float()
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float().view(-1, 1)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print('-' * 60)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_drone_classifier.pth')



In [None]:
def run_training():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create dataset
    dataset = AudioDataset('DroneAudioDataset/Binary_Drone_Audio')

    # Split dataset
    train_idx, val_idx = train_test_split(
        range(len(dataset)),
        test_size=0.2,
        stratify=dataset.labels,
        random_state=42
    )

    # Calculate class weights for imbalanced dataset
    labels = np.array(dataset.labels)
    class_weights = torch.tensor([
        1.0 / (np.sum(labels == 0) / len(labels)),
        1.0 / (np.sum(labels == 1) / len(labels))
    ], device=device)

    # Data loaders
    train_loader = DataLoader(
        torch.utils.data.Subset(dataset, train_idx),
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4
    )

    val_loader = DataLoader(
        torch.utils.data.Subset(dataset, val_idx),
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4
    )

    # Training
    model = DroneClassifier().to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, device)

In [None]:
run_training()

Using device: cuda




Epoch [1/30]
Train Loss: 0.7009, Train Acc: 84.59%
Val Loss: 0.4721, Val Acc: 96.33%
------------------------------------------------------------
Epoch [2/30]
Train Loss: 0.3569, Train Acc: 94.48%
Val Loss: 0.5458, Val Acc: 96.75%
------------------------------------------------------------
Epoch [3/30]
Train Loss: 0.2606, Train Acc: 96.49%
Val Loss: 0.1120, Val Acc: 98.80%
------------------------------------------------------------
Epoch [4/30]
Train Loss: 0.1686, Train Acc: 98.03%
Val Loss: 0.4785, Val Acc: 77.53%
------------------------------------------------------------
Epoch [5/30]
Train Loss: 0.1443, Train Acc: 98.62%
Val Loss: 0.1246, Val Acc: 97.35%
------------------------------------------------------------
Epoch [6/30]
Train Loss: 0.1049, Train Acc: 98.90%
Val Loss: 0.0356, Val Acc: 99.27%
------------------------------------------------------------
Epoch [7/30]
Train Loss: 0.0787, Train Acc: 99.21%
Val Loss: 0.2158, Val Acc: 96.24%
---------------------------------------

In [None]:
def print_model_params(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total trainable parameters: {total_params:,}')

    # Print size of each layer
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f'{name}: {param.numel():,} parameters')
            print(f'Shape: {list(param.shape)}')

In [None]:
model = DroneClassifier()
print_model_params(model)

Total trainable parameters: 89,185
features.0.weight: 144 parameters
Shape: [16, 1, 3, 3]
features.0.bias: 16 parameters
Shape: [16]
features.1.weight: 16 parameters
Shape: [16]
features.1.bias: 16 parameters
Shape: [16]
features.4.weight: 4,608 parameters
Shape: [32, 16, 3, 3]
features.4.bias: 32 parameters
Shape: [32]
features.5.weight: 32 parameters
Shape: [32]
features.5.bias: 32 parameters
Shape: [32]
features.8.weight: 18,432 parameters
Shape: [64, 32, 3, 3]
features.8.bias: 64 parameters
Shape: [64]
features.9.weight: 64 parameters
Shape: [64]
features.9.bias: 64 parameters
Shape: [64]
classifier.0.weight: 65,536 parameters
Shape: [64, 1024]
classifier.0.bias: 64 parameters
Shape: [64]
classifier.3.weight: 64 parameters
Shape: [1, 64]
classifier.3.bias: 1 parameters
Shape: [1]


# Detection under the noise level
The first attempt used a straightforward dataset with relatively easy to distinguish samples. Real-life scenario may be harder than that, especially given refractions, noise from other military vehicles, etc. Therefore, we augument the sounds with noise.

In [None]:
import os
import zipfile
import shutil
import pandas as pd
from pathlib import Path
import urllib.request
from tqdm import tqdm

class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

def download_esc50(target_dir='/content/datasets'):
    """
    Downloads and prepares the ESC-50 dataset
    Returns the path to the noise samples directory
    """
    target_dir = Path(target_dir)
    target_dir.mkdir(exist_ok=True)

    url = "https://github.com/karoldvl/ESC-50/archive/master.zip"
    zip_path = target_dir / "ESC-50-master.zip"

    if not zip_path.exists():
        print("Downloading ESC-50 dataset...")
        with DownloadProgressBar(unit='B', unit_scale=True,
                               miniters=1, desc="ESC-50") as t:
            urllib.request.urlretrieve(url, filename=zip_path,
                                     reporthook=t.update_to)

    extract_dir = target_dir / "ESC-50-master"
    if not extract_dir.exists():
        print("Extracting dataset...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(target_dir)

    noise_dir = target_dir / "noise_samples"
    noise_dir.mkdir(exist_ok=True)

    meta_file = extract_dir / "meta" / "esc50.csv"
    df = pd.read_csv(meta_file)

    noise_categories = [
        'rain', 'sea_waves', 'crackling_fire', 'crickets',  # nature
        'engine', 'train', 'airplane',  # transportation
        'wind', 'thunderstorm',  # weather
        'crowd', 'footsteps',  # human
        'helicopter', 'chainsaw', 'siren'  # mechanical/urban
    ]

    print("Preparing noise samples...")
    audio_dir = extract_dir / "audio"

    for category in noise_categories:
        category_files = df[df['category'] == category]['filename']
        for filename in category_files:
            src = audio_dir / filename
            dst = noise_dir / filename
            if src.exists() and not dst.exists():
                shutil.copy2(src, dst)

    print(f"\nNoise dataset prepared at: {noise_dir}")
    print(f"Total noise samples: {len(list(noise_dir.glob('*.wav')))}")

    noise_files = df[df['category'].isin(noise_categories)]
    category_stats = noise_files['category'].value_counts()
    print("\nSamples per category:")
    for category, count in category_stats.items():
        print(f"{category}: {count}")

    return str(noise_dir)

if __name__ == "__main__":
    # Download and prepare the dataset
    noise_dir = download_esc50()
    print(f"\nDataset preparation completed. Use this path in your training script:")
    print(f"noise_dir = '{noise_dir}'")

Downloading ESC-50 dataset...


ESC-50: 646MB [00:45, 14.3MB/s]


Extracting dataset...
Preparing noise samples...

Noise dataset prepared at: /content/datasets/noise_samples
Total noise samples: 520

Samples per category:
thunderstorm: 40
chainsaw: 40
airplane: 40
train: 40
wind: 40
footsteps: 40
crackling_fire: 40
helicopter: 40
rain: 40
engine: 40
sea_waves: 40
siren: 40
crickets: 40

Dataset preparation completed. Use this path in your training script:
noise_dir = '/content/datasets/noise_samples'


In [None]:
import torch
import torchaudio
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm
import random
import json
from datetime import datetime
import shutil

# Constants
SAMPLE_RATE = 44100
MAX_DURATION = 5
N_MELS = 64

class CachedAudioDataset(torch.utils.data.Dataset):
    def __init__(self, cache_dir, transform=None):
        """
        Dataset that loads pre-augmented and cached spectrograms
        """
        self.cache_dir = Path(cache_dir)
        self.transform = transform

        # Load metadata
        with open(self.cache_dir / 'metadata.json', 'r') as f:
            self.metadata = json.load(f)

        self.files = list(self.metadata.keys())

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filename = self.files[idx]
        # Load cached spectrogram
        spec_path = self.cache_dir / f"{filename}.pt"
        mel_spec = torch.load(spec_path)
        label = self.metadata[filename]['label']

        if self.transform:
            mel_spec = self.transform(mel_spec)

        return mel_spec, label

def prepare_augmented_dataset(
    root_dir,
    noise_dir,
    cache_dir,
    num_augmentations=3,
    min_snr_db=5,
    max_snr_db=20,
    force_rebuild=False
):
    """
    Prepare and cache augmented dataset

    Args:
        root_dir: Directory containing original audio files
        noise_dir: Directory containing noise samples
        cache_dir: Directory to store augmented spectrograms
        num_augmentations: Number of augmented versions per original file
        min_snr_db: Minimum signal-to-noise ratio
        max_snr_db: Maximum signal-to-noise ratio
        force_rebuild: If True, rebuild cache even if it exists
    """
    cache_dir = Path(cache_dir)

    # Check if cache exists and is complete
    if not force_rebuild and cache_dir.exists():
        if (cache_dir / 'metadata.json').exists():
            print("Found existing cached dataset")
            return cache_dir

    # Create cache directory
    cache_dir.mkdir(parents=True, exist_ok=True)

    # Initialize mel spectrogram transform
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_mels=N_MELS,
        normalized=True
    )

    # Load noise files
    print("Loading noise files...")
    noise_files = list(Path(noise_dir).glob('*.wav'))

    metadata = {}

    # Process original files
    print("Processing original files and creating augmentations...")

    # Process positive examples (yes_drone)
    pos_dir = Path(root_dir) / 'yes_drone'
    neg_dir = Path(root_dir) / 'unknown'

    all_files = list(pos_dir.glob('*.wav')) + list(neg_dir.glob('*.wav'))
    total_files = len(all_files) * (num_augmentations + 1)  # +1 for original

    with tqdm(total=total_files, desc="Processing audio files") as pbar:
        for audio_path in all_files:
            is_positive = audio_path.parent.name == 'yes_drone'

            # Load and process original file
            waveform, sample_rate = torchaudio.load(audio_path)

            # Process original version
            processed_waveform = process_waveform(waveform, sample_rate)
            mel_spec = create_spectrogram(processed_waveform, mel_spectrogram)

            # Save original
            filename = f"orig_{audio_path.stem}"
            torch.save(mel_spec, cache_dir / f"{filename}.pt")
            metadata[filename] = {
                'source': str(audio_path),
                'label': 1 if is_positive else 0,
                'augmented': False
            }
            pbar.update(1)

            # Create augmented versions
            for aug_idx in range(num_augmentations):
                # Select random noise file
                noise_path = random.choice(noise_files)
                noise_waveform, noise_sr = torchaudio.load(noise_path)

                # Process noise
                processed_noise = process_waveform(noise_waveform, noise_sr)

                # Apply noise
                snr_db = random.uniform(min_snr_db, max_snr_db)
                noisy_waveform = add_noise(processed_waveform, processed_noise, snr_db)

                # Create and save spectrogram
                mel_spec = create_spectrogram(noisy_waveform, mel_spectrogram)
                filename = f"aug_{audio_path.stem}_{aug_idx}"
                torch.save(mel_spec, cache_dir / f"{filename}.pt")

                metadata[filename] = {
                    'source': str(audio_path),
                    'label': 1 if is_positive else 0,
                    'augmented': True,
                    'noise_source': str(noise_path),
                    'snr_db': snr_db
                }
                pbar.update(1)

    # Save metadata
    with open(cache_dir / 'metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2)

    print(f"\nDataset cached at: {cache_dir}")
    print(f"Total files: {len(metadata)}")
    return cache_dir

def process_waveform(waveform, sample_rate):
    """Process waveform to standard format"""
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Resample if necessary
    if sample_rate != SAMPLE_RATE:
        resampler = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)
        waveform = resampler(waveform)

    # Pad or trim to MAX_DURATION
    target_length = MAX_DURATION * SAMPLE_RATE
    if waveform.shape[1] < target_length:
        waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform.shape[1]))
    else:
        waveform = waveform[:, :target_length]

    return waveform

def create_spectrogram(waveform, mel_spectrogram):
    """Create mel spectrogram from waveform"""
    mel_spec = mel_spectrogram(waveform)
    mel_spec = torch.log(mel_spec + 1e-9)
    return mel_spec

def add_noise(signal, noise, target_snr_db):
    """Add noise to signal at specified SNR"""
    signal_power = torch.mean(signal ** 2)
    noise_power = torch.mean(noise ** 2)

    snr = 10 ** (target_snr_db / 10)
    scale = torch.sqrt(signal_power / (noise_power * snr))

    noisy_signal = signal + scale * noise

    # Normalize to prevent clipping
    max_val = torch.max(torch.abs(noisy_signal))
    if max_val > 1:
        noisy_signal = noisy_signal / max_val

    return noisy_signal

if __name__ == "__main__":
    cache_dir = prepare_augmented_dataset(
        root_dir='DroneAudioDataset/Binary_Drone_Audio',
        noise_dir='/content/datasets/noise_samples',
        cache_dir='augmented_data',
        num_augmentations=3,  # Number of augmented versions per original file
        force_rebuild=False
    )

    # Create dataset from cache
    dataset = CachedAudioDataset(cache_dir)
    print(f"Dataset size: {len(dataset)}")



Loading noise files...
Processing original files and creating augmentations...


Processing audio files: 0it [00:00, ?it/s]


Dataset cached at: augmented_data
Total files: 0
Dataset size: 0





In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import os
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import random
import matplotlib.pyplot as plt
from datetime import datetime

# Constants
SAMPLE_RATE = 44100
MAX_DURATION = 5
N_MELS = 64
BATCH_SIZE = 32
NUM_EPOCHS = 30
LEARNING_RATE = 0.001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class NoiseAugment:
    def __init__(self, noise_dir, min_snr_db=5, max_snr_db=20):
        self.noise_dir = Path(noise_dir)
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db
        self.noise_files = list(self.noise_dir.glob('*.wav'))

        if not self.noise_files:
            raise ValueError(f"No .wav files found in {noise_dir}")

    def load_random_noise(self, target_length):
        noise_path = random.choice(self.noise_files)
        noise, sample_rate = torchaudio.load(noise_path)

        if noise.shape[0] > 1:
            noise = torch.mean(noise, dim=0, keepdim=True)

        if sample_rate != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)
            noise = resampler(noise)

        if noise.shape[1] < target_length:
            num_repeats = (target_length + noise.shape[1] - 1) // noise.shape[1]
            noise = noise.repeat(1, num_repeats)

        start = random.randint(0, noise.shape[1] - target_length)
        noise = noise[:, start:start + target_length]

        return noise

    def apply_noise(self, waveform):
        noise = self.load_random_noise(waveform.shape[1])

        signal_power = torch.mean(waveform ** 2)
        noise_power = torch.mean(noise ** 2)

        target_snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
        snr = 10 ** (target_snr_db / 10)
        scale = torch.sqrt(signal_power / (noise_power * snr))

        noisy_waveform = waveform + scale * noise

        max_val = torch.max(torch.abs(noisy_waveform))
        if max_val > 1:
            noisy_waveform = noisy_waveform / max_val

        return noisy_waveform

class AugmentedAudioDataset(Dataset):
    def __init__(self, root_dir, noise_dir, transform=None, aug_probability=0.5):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.aug_probability = aug_probability
        self.noise_augmenter = NoiseAugment(noise_dir)

        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_mels=N_MELS,
            normalized=True
        )

        # Get all file paths and labels
        self.files = []
        self.labels = []

        # Positive examples (yes_drone)
        pos_dir = self.root_dir / 'yes_drone'
        for file in pos_dir.glob('*.wav'):
            self.files.append(file)
            self.labels.append(1)

        # Negative examples (unknown)
        neg_dir = self.root_dir / 'unknown'
        for file in neg_dir.glob('*.wav'):
            self.files.append(file)
            self.labels.append(0)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        label = self.labels[idx]

        waveform, sample_rate = torchaudio.load(audio_path)

        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        if sample_rate != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)
            waveform = resampler(waveform)

        target_length = MAX_DURATION * SAMPLE_RATE
        if waveform.shape[1] < target_length:
            waveform = torch.nn.functional.pad(waveform, (0, target_length - waveform.shape[1]))
        else:
            waveform = waveform[:, :target_length]

        if random.random() < self.aug_probability:
            waveform = self.noise_augmenter.apply_noise(waveform)

        mel_spec = self.mel_spectrogram(waveform)
        mel_spec = torch.log(mel_spec + 1e-9)

        if self.transform:
            mel_spec = self.transform(mel_spec)

        return mel_spec, label

class DroneClassifier(nn.Module):
    def __init__(self):
        super(DroneClassifier, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Dropout(0.3)
        )

        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

        self.classifier = nn.Sequential(
            nn.Linear(64 * 4 * 4, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    best_val_loss = float('inf')

    # Initialize lists to store metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().view(-1, 1)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predicted = (outputs > 0.5).float()
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss = train_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total

        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float().view(-1, 1)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                predicted = (outputs > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total

        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print('-' * 60)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, f'best_drone_classifier_{timestamp}.pth')

    # Plot training history
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.savefig(f'training_history_{timestamp}.png')
    plt.close()

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(234)
    random.seed(542)
    np.random.seed(7463)

    # Create dataset
    dataset = AugmentedAudioDataset(
        root_dir='DroneAudioDataset/Binary_Drone_Audio',
        noise_dir='/content/datasets/noise_samples',
        aug_probability=0.5
    )

    # Split dataset
    train_indices, val_indices = train_test_split(
        range(len(dataset)),
        test_size=0.2,
        random_state=52,
        stratify=dataset.labels
    )

    # Create data loaders
    train_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=torch.utils.data.SubsetRandomSampler(train_indices)
    )

    val_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=torch.utils.data.SubsetRandomSampler(val_indices)
    )

    # Initialize model, criterion, and optimizer
    model = DroneClassifier().to(DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train model
    train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, DEVICE)

if __name__ == "__main__":
    main()



Epoch [1/30]
Train Loss: 0.2475, Train Acc: 88.61%
Val Loss: 0.1595, Val Acc: 88.64%
------------------------------------------------------------
Epoch [2/30]
Train Loss: 0.1633, Train Acc: 88.61%
Val Loss: 0.1801, Val Acc: 88.64%
------------------------------------------------------------
Epoch [3/30]
Train Loss: 0.1453, Train Acc: 88.61%
Val Loss: 0.1689, Val Acc: 88.64%
------------------------------------------------------------


KeyboardInterrupt: 