In [2]:
# =================================================
# LIBRARIES
# =================================================
from glob import glob

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
%matplotlib inline


In [3]:
# =================================================
# DATA INDEXING --> DATASET
# =================================================

# ------- CONFIGS -------
DATASET_FOLDER_DIR = "/Users/pedropaiva/Documents/Dev/Research/CoBasE-Energy/cobas/Acoustic_Dataset_Collection/BeaconDataset/Test4/ffts"

label_map = {
    "control": 0,
    "0p": 1,
    "50p": 2,
    "100p": 3
}

# ------- IMPLEMENTATION -------
# function to get all sample labels given sample paths
def get_sample_labels(file_paths: list[str]) -> list[int]:
    labels = []

    for path in file_paths:
        if "control" in path:
            labels.append(label_map["control"])
        elif "0p" in path:
            labels.append(label_map["0p"])
        elif "50p" in path:
            labels.append(label_map["50p"])
        elif "100p" in path:
            labels.append(label_map["100p"])
        else:
            raise ValueError(f"[ERROR]: Could not infer label from path: {path}")

    return labels


# Dataset abstraction class
class SpectrogramsDataset(Dataset):
    TARGET_TIME_FRAMES = 400

    def __init__(self, file_paths: list[str], labels: list[int]):
        self.file_paths = file_paths
        self.labels = labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        # loads spec from disk
        spec = np.load(self.file_paths[idx])

        # ------- tensor verification -------
        # crops or pads tensor time_frames
        if spec.shape[1] > SpectrogramsDataset.TARGET_TIME_FRAMES:
            spec = spec[:, :SpectrogramsDataset.TARGET_TIME_FRAMES]
        elif spec.shape[1] < SpectrogramsDataset.TARGET_TIME_FRAMES:
            pad = SpectrogramsDataset.TARGET_TIME_FRAMES - spec.shape[1]
            spec = np.pad(spec, ((0, 0), (0, pad)), mode="constant")

        # turns numpy array into tensor
        spec = torch.from_numpy(spec).to(torch.float32)
        # adds dimension: (freq_bins, time_frames) --> (channels=1, freq_bins, time_frames)
        spec = spec.unsqueeze(0)

        # gets corresponding label
        label = self.labels[idx]

        return spec, label


In [4]:
# =================================================
# DATA SPLITTING & DATASET DEFINITION
# =================================================

# finds all file paths given dir
file_paths = sorted(glob(DATASET_FOLDER_DIR + "/*.npy"))

# gets all labels given file paths
labels = get_sample_labels(file_paths=file_paths)

# gets all file indices for data splitting
indices = list(range(len(file_paths)))

# train and temporary splits --> stratify=labels keeps the same proportion of indices across all labels
train_idxs, temp_idxs = train_test_split(
    indices, 
    train_size=0.7, 
    stratify=labels, 
    random_state=777)

# test and val splits --> stratification is based only on the label indices above selected at temp_idxs
val_idxs, test_idxs = train_test_split(
    temp_idxs,
    train_size=0.7,
    stratify=[labels[i] for i in temp_idxs],
    random_state=777
)

# defines train Dataset
train_dataset = SpectrogramsDataset(
    file_paths=[file_paths[i] for i in train_idxs],
    labels=[labels[i] for i in train_idxs]
)

# defines val Dataset
val_dataset = SpectrogramsDataset(
    file_paths=[file_paths[i] for i in val_idxs],
    labels=[labels[i] for i in val_idxs]
)

# defines test Dataset
test_dataset = SpectrogramsDataset(
    file_paths=[file_paths[i] for i in test_idxs],
    labels=[labels[i] for i in test_idxs]
)


In [5]:
# =================================================
# DATA LOADER
# =================================================

# ------- CONFIGS -------
BATCH_SIZE = 32
NUM_WORKERS = 0

# defines train data loader
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# defines val data loader
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# defines test data loader
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


In [6]:
# =================================================
# CNN MODEL ARCHITECTURE
# =================================================

class SpectrogramCNN(nn.Module):

    def __init__(self, num_classes):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.GroupNorm(                                                               ,
            nn.ReLU(),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2)
        )

        # computes the average over the entire spatial field, independently for each channel
        # e.g.: (batch_size=32, channels=128, freq_bins=16, time_frames=8) --> 
        # (batch_size=32, channels=128, freq_bins=1, time_frames=1)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        x = self.global_pool(x)

        # reshaping, no computation:
        # (batch_size, channels, freq_bins, time_frames) 4D tensor --> 
        # (batch_size, channels) 2D tensor
        x = x.view(x.size(0), -1)

        x = self.classifier(x)

        return x

In [7]:
# =================================================
# MODEL INSTANTIATION, LOSS FUNCTION, AND OPTIMIZER
# =================================================

NUM_OF_CLASSES = len(set(labels))

model = SpectrogramCNN(
    num_classes=NUM_OF_CLASSES
)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3
)

In [8]:
# =================================================
# RESULTS PLOTTING
# =================================================

# plt.ion()  # interactive mode

# fig, ax1 = plt.subplots()
# ax2 = ax1.twinx()

# train_line, = ax1.plot([], [], label="Train Loss")
# val_line, = ax1.plot([], [], label="Val Loss")
# acc_line, = ax2.plot([], [], linestyle="--", label="Val Accuracy")

# ax1.set_xlabel("Epoch")
# ax1.set_ylabel("Loss")
# ax2.set_ylabel("Accuracy")

# ax1.legend(loc="upper left")
# ax2.legend(loc="upper right")

# plt.display(fig)

# def update_plot(epoch, train_losses, val_losses, val_accuracies):
#     epochs = range(1, epoch + 1)

#     train_line.set_data(epochs, train_losses)
#     val_line.set_data(epochs, val_losses)
#     acc_line.set_data(epochs, val_accuracies)

#     ax1.relim()
#     ax1.autoscale_view()
#     ax2.relim()
#     ax2.autoscale_view()

#     fig.canvas.draw()
#     fig.canvas.flush_events()


# =================================================
# TRAINING
# =================================================

# ------- CONFIGS -------
NUM_EPOCHS = 30

best_val_loss = float("inf")
patience = 5
patience_counter = 0
min_delta = 1e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ------- plotting params -------
train_loss_list = []
val_loss_list = []
val_acc_list = []

# ------- IMPLEMENTATION -------
for epoch in range(NUM_EPOCHS):

    # ------- training -------
    model.train()
    train_loss = 0.0

    for specs, labels in train_loader:
        specs = specs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(specs)
        loss = criterion(outputs, labels)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    train_loss_list.append(train_loss)

    # ------- validation -------
    model.eval()
    val_loss = 0.0
    correct = 0

    with torch.no_grad():
        for specs, labels in val_loader:
            specs = specs.to(device)
            labels = labels.to(device)

            outputs = model(specs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()

    val_loss /= len(val_loader)
    val_loss_list.append(val_loss)
    val_acc = (correct / (len(val_dataset)))
    val_acc_list.append(val_acc)

    # update_plot(epoch+1, train_loss_list, val_loss_list, val_acc_list)

    print(
        f"Epoch {epoch+1}/{NUM_EPOCHS} | "
        f"Train Loss: {train_loss:.3f} | "
        f"Val Loss: {val_loss:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

    # ------- early stopping -------
    if val_loss < (best_val_loss - min_delta):
        best_val_loss = val_loss
        patience_counter = 0

        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch
        }, "best_model.pt")

    else:

        if patience_counter >= patience:
            print(f"[WARNING]: Early Stopping Triggere d!")
            break
            
        patience_counter += 1
        print(f"Early Stopping Counter: {patience_counter}/{patience}")


Epoch 1/30 | Train Loss: 0.491 | Val Loss: 0.322 | Val Acc: 0.965
Epoch 2/30 | Train Loss: 0.169 | Val Loss: 3.944 | Val Acc: 0.331
Early Stopping Counter: 1/5


KeyboardInterrupt: 