In [1]:
import sys

sys.path.insert(0, '../')

from modules.spec_dataset import *
from modules.train_prep import *
from modules.plot_results import *

In [2]:
# Define a Convolutional Neural Network class using PyTorch's nn.Module as the base class.
class SpectrogramCNN(nn.Module):
    def __init__(self, num_classes=4):
        # Initialize the parent class (nn.Module) and specify the current subclass name.
        super(SpectrogramCNN, self).__init__()

        # First convolutional layer: 
        # - Input channels = 1 (assuming grayscale or single-channel input),
        # - Output channels = 32,
        # - Kernel size = (3, 10) which determines the filter size,
        # - Padding = (1, 5) to keep the spatial dimensions constant after this layer.
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 10), padding=(1, 5))
        
        # Batch normalization layer to stabilize learning by normalizing the input layer by re-centering and re-scaling.
        self.bn1 = nn.BatchNorm2d(32)
        
        # Max pooling layer to reduce the spatial dimensions of the output from the convolutional layer.
        # Pool size = (2, 4) reduces height by 2 and width by 4.
        self.pool = nn.MaxPool2d((2, 4))

        # Dropout layer to prevent overfitting by randomly zeroing some of the elements of the input tensor
        # with probability 0.25 at each update during training time.
        self.dropout = nn.Dropout(0.25)
        
        # Fully connected layer that maps the reshaped output to the number of classes.
        # Note: The input features (430080) need to be calculated based on the output size of the last pooling layer,
        # which depends on the input size of the network.
        self.fc1 = nn.Linear(430080, num_classes)

    def forward(self, x):
        # Forward pass definition:
        # - Add an extra channel dimension.
        x = x.unsqueeze(1)
        
        # Apply the first convolutional layer, followed by batch normalization, a Leaky ReLU activation function,
        # and then a max pooling layer.
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        
        # Apply dropout.
        x = self.dropout(x)
        
        # Flatten the output of the last layer to make it suitable for input to the fully connected layer.
        x = torch.flatten(x, 1)
        
        # Output layer where the final classification is computed.
        x = self.fc1(x)
        
        return x

In [None]:
# Check for CUDA GPU availability and set the device accordingly. If CUDA is not available, check for MPS availability and set the MPS device.
if torch.cuda.is_available():
    torch.cuda.set_device(cuda_device)
elif torch.backends.mps.is_available():
    mps_device = torch.device("mps")

# Define transformations for training and validation datasets. Here, images are converted to PyTorch tensors.
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),  # Convert images to tensor format
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),  # Convert images to tensor format
    ]),
}

# Setup spectrogram datasets for training and validation using specified directories and behavior files.
# Apply the previously defined transformations to each dataset.
dsets = {
    'train': SpectroDataset(train_dir, train_behav_file, data_transforms['train']),
    'val': SpectroDataset(val_dir, val_behav_file, data_transforms['val'])
}

# Calculate dataset sizes for train and validation sets.
dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

# Initialize dictionary to store data loaders, configured with class balancing via weighted sampling.
dset_loaders = {}
for split in ['train', 'val']:
    # Get array of labels for each dataset.
    targets = np.array([dsets[split].get_label(i) for i in range(len(dsets[split]))])
    # Get class counts to determine imbalance.
    class_counts = dsets[split].get_class_counts()
    # Calculate weights inversely proportional to class frequencies to address class imbalance.
    class_weights = np.array([1.0 / class_counts[label] if class_counts[label] > 0 else 0 for label in targets])
    # Create a weighted sampler using these weights.
    sampler = WeightedRandomSampler(class_weights, num_samples=len(class_weights), replacement=True)
    # Create the DataLoader for each split using the sampler.
    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader:', split)

# Initialize the model, loss function, and optimizer.
model_ft = SpectrogramCNN(num_classes=4)  # Initialize the CNN model defined earlier.
criterion = nn.CrossEntropyLoss()  # Loss function for classification.
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001, weight_decay=1e-5)  # Adam optimizer with learning rate and weight decay settings.

# Set up the device (CUDA, MPS, or CPU) based on availability and move the model and loss function to that device.
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model_ft.to(device)
criterion.to(device)
print(f"Training on {device}")

# Train the model using the provided datasets, loaders, and training configuration.
# This function is assumed to handle the training loop, including forward and backward passes, optimization, and learning rate scheduling.
model_ft, accuracies, losses, preds, labels = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dset_loaders, dset_sizes, num_epochs=n_epochs)


In [None]:
# Output the accuracies and losses for each training epoch for both training and validation datasets.
for split in ['train', 'val']:
    # Print the accuracies recorded during each epoch of the training and validation phases.
    print(split, 'accuracies by epoch:', accuracies[split])
    # Print the losses recorded during each epoch of the training and validation phases.
    print(split, 'losses by epoch:', losses[split])

# Convert any Tensors in the accuracies dictionary to NumPy arrays for easier handling in non-PyTorch contexts.
for key, value in accuracies.items():
    # Use a list comprehension to iterate over all accuracy values for each epoch,
    # move them to CPU if they are on a GPU, and convert them to NumPy arrays.
    new_value = [v.cpu().numpy() for v in value]
    # Update the dictionary with the converted values.
    accuracies[key] = new_value

# Save the trained model's weights to a file for later use, typically for deployment or further evaluation.
torch.save(model_ft.state_dict(), '../../models/CNN_spec_best_model.pt')

# Plot the training history, which typically includes plots of accuracy and loss over training epochs.
# This function is assumed to handle plotting and visualization tasks, enhancing understanding of training dynamics.
plot_training_history(accuracies, losses, 'CNN_spec')

# Generate and plot a confusion matrix to visualize the performance of the trained model in terms of
# its ability to correctly classify examples from the validation set.
# This is a crucial step for understanding model performance on different classes.
plot_confusion_matrix(labels, preds, 'CNN_spec')
