In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchaudio.transforms as T
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from time import time
from tqdm import tqdm
from torchvision import transforms

In [2]:
class SpectrogramDataset(Dataset):
    def __init__(self, csv_file, data_dir):

        self.data = pd.read_csv(csv_file)
        self.data_dir = data_dir


    def __len__(self):
    
        return len(self.data)
    

    def __getitem__(self, idx):
    
        row = self.data.iloc[idx]
        tensor_path = os.path.join(self.data_dir, row['spectrogram'])
        spectrogram = torch.load(tensor_path).unsqueeze(0).float();  # shape: [1, H, W]
        label = int(row['speaker_count'])
        return spectrogram, label
    



In [3]:
# Constants
conv1_out = 16
conv2_out = 32
conv3_out = 64
dropout_prob = 0.3
fc_hidden = 128
num_classes = 4

# Model
class SpeakerCountCNN(nn.Module):
    def __init__(self, input_height, input_width):
        super(SpeakerCountCNN, self).__init__()

        self.conv1 = nn.Conv2d(1, conv1_out, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(conv1_out)

        self.conv2 = nn.Conv2d(conv1_out, conv2_out, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(conv2_out)

        self.conv3 = nn.Conv2d(conv2_out, conv3_out, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(conv3_out)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(dropout_prob)

        flattened_size = conv3_out * (input_height // 8) * (input_width // 8)
        self.fc1 = nn.Linear(flattened_size, fc_hidden)
        self.fc2 = nn.Linear(fc_hidden, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))

        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [4]:
dataset = SpectrogramDataset(csv_file=r"data/spectrogram_labels.csv", data_dir=r"data/spectrograms")
input_height, input_width = dataset[0][0].shape[1:]

model = SpeakerCountCNN(input_height=input_height, input_width=input_width)


In [5]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [7]:
for epoch in range(10):
    model.train()
    total_loss = 0
    start_time = time()

    # Show a progress bar across all batches
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    end_time = time()
    epoch_duration = end_time - start_time
    avg_loss = total_loss / len(train_loader)

    print(f"Epoch {epoch+1} — Loss: {avg_loss:.4f} — Time: {epoch_duration:.2f}s")


                                                          

Epoch 1 — Loss: 0.5061 — Time: 818.33s


                                                          

Epoch 2 — Loss: 0.4417 — Time: 851.07s


                                                          

Epoch 3 — Loss: 0.4139 — Time: 892.35s


                                                          

Epoch 4 — Loss: 0.3952 — Time: 957.60s


                                                          

Epoch 5 — Loss: 0.3830 — Time: 155.00s


                                                          

Epoch 6 — Loss: 0.3711 — Time: 140.42s


                                                          

Epoch 7 — Loss: 0.3607 — Time: 95.77s


                                                          

Epoch 8 — Loss: 0.3506 — Time: 96.87s


                                                          

Epoch 9 — Loss: 0.3407 — Time: 98.85s


                                                           

Epoch 10 — Loss: 0.3284 — Time: 100.08s




In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in val_loader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        preds = output.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

print(f"Validation Accuracy: {correct / total:.2%}")


Validation Accuracy: 82.39%
