In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

DATA_DIR = "../data/frog_clips"
OUTPUT_MODEL = "../model/vgg_frog_model.pth"
LABEL_MAP_FILE = "../model/label_mapping.json"

TARGET_SR = 22050
N_MELS = 128
DURATION = 5  # seconds
SAMPLES = TARGET_SR * DURATION

#### Dataset Class

In [2]:
class FrogAudioDataset(Dataset):
    def __init__(self):
        self.paths = []
        self.labels = []

        label_map = {"CONTROL": 0, "TOAD-WEST": 1}

        for cls in ["CONTROL", "TOAD-WEST"]:
            folder = os.path.join(DATA_DIR, cls)
            for f in os.listdir(folder):
                self.paths.append(os.path.join(folder, f))
                self.labels.append(label_map[cls])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        y, sr = librosa.load(path, sr=TARGET_SR)

        if len(y) < SAMPLES:
            y = np.pad(y, (0, SAMPLES - len(y)))
        else:
            y = y[:SAMPLES]

        # Convert audio to Mel spectrogram
        S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS)
        S_dB = librosa.power_to_db(S, ref=np.max)
        S_dB = np.expand_dims(S_dB, axis=0)

        return torch.tensor(S_dB, dtype=torch.float32), self.labels[idx]

#### Model Definition (VGG-like CNN)

In [7]:
class VGGSmall(nn.Module):
    def __init__(self):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )

        # NEW — fixes mat1/mat2 issue
        self.pool = nn.AdaptiveAvgPool2d((4, 4))

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 2)  # binary classification
        )

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)   # ALWAYS outputs [batch, 64, 4, 4]
        x = self.classifier(x)
        return x

#### Training

In [8]:
dataset = FrogAudioDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = VGGSmall().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

EPOCHS = 20

for epoch in range(EPOCHS):
    total_loss = 0
    correct = 0
    for x, y in tqdm(loader):
        x, y = x.to(device), torch.tensor(y).to(device)

        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()

    print(f"Epoch {epoch+1}/{EPOCHS}  Loss: {total_loss:.4f}  Acc: {correct/len(dataset):.4f}")

  x, y = x.to(device), torch.tensor(y).to(device)
100%|██████████| 59/59 [00:13<00:00,  4.31it/s]


Epoch 1/20  Loss: 20.2835  Acc: 0.8699


100%|██████████| 59/59 [00:08<00:00,  6.66it/s]


Epoch 2/20  Loss: 16.2645  Acc: 0.8955


100%|██████████| 59/59 [00:08<00:00,  6.80it/s]


Epoch 3/20  Loss: 14.5176  Acc: 0.9062


100%|██████████| 59/59 [00:08<00:00,  6.95it/s]


Epoch 4/20  Loss: 12.3516  Acc: 0.9211


100%|██████████| 59/59 [00:08<00:00,  6.96it/s]


Epoch 5/20  Loss: 11.2231  Acc: 0.9190


100%|██████████| 59/59 [00:08<00:00,  7.04it/s]


Epoch 6/20  Loss: 7.3791  Acc: 0.9467


100%|██████████| 59/59 [00:08<00:00,  6.79it/s]


Epoch 7/20  Loss: 4.9849  Acc: 0.9744


100%|██████████| 59/59 [00:07<00:00,  7.49it/s]


Epoch 8/20  Loss: 5.3380  Acc: 0.9638


100%|██████████| 59/59 [00:07<00:00,  7.68it/s]


Epoch 9/20  Loss: 4.4400  Acc: 0.9808


100%|██████████| 59/59 [00:08<00:00,  7.05it/s]


Epoch 10/20  Loss: 4.0568  Acc: 0.9723


100%|██████████| 59/59 [00:08<00:00,  7.11it/s]


Epoch 11/20  Loss: 4.1220  Acc: 0.9680


100%|██████████| 59/59 [00:07<00:00,  7.61it/s]


Epoch 12/20  Loss: 3.5596  Acc: 0.9808


100%|██████████| 59/59 [00:07<00:00,  7.52it/s]


Epoch 13/20  Loss: 3.7834  Acc: 0.9765


100%|██████████| 59/59 [00:08<00:00,  7.04it/s]


Epoch 14/20  Loss: 3.7711  Acc: 0.9808


100%|██████████| 59/59 [00:11<00:00,  4.98it/s]


Epoch 15/20  Loss: 4.1956  Acc: 0.9723


100%|██████████| 59/59 [00:09<00:00,  6.54it/s]


Epoch 16/20  Loss: 3.6753  Acc: 0.9723


100%|██████████| 59/59 [00:09<00:00,  6.08it/s]


Epoch 17/20  Loss: 3.0996  Acc: 0.9829


100%|██████████| 59/59 [00:09<00:00,  6.49it/s]


Epoch 18/20  Loss: 4.2239  Acc: 0.9701


100%|██████████| 59/59 [00:08<00:00,  6.96it/s]


Epoch 19/20  Loss: 3.9390  Acc: 0.9723


100%|██████████| 59/59 [00:09<00:00,  6.41it/s]

Epoch 20/20  Loss: 2.8898  Acc: 0.9829





#### Save the Model

In [11]:
torch.save(model.state_dict(), OUTPUT_MODEL)
print("Saved:", OUTPUT_MODEL)

Saved: ../model/vgg_frog_model.pth
