In [72]:
import os
from collections import Counter, defaultdict
import numpy as np
import torch
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from sklearn.metrics import confusion_matrix
import argparse

In [73]:
# Check if GPU
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"Using device: {device}")

# Training hyperparameters (matching MATLAB NN_function.m)
MAX_EPOCHS = 20
INITIAL_LEARNING_RATE = 0.001
MOMENTUM = 0.9
LEARN_RATE_DROP_PERIOD = 5
LEARN_RATE_DROP_FACTOR = 10 ** (-0.5)

# Set a default image size for ResNet50
IMAGE_SIZE = (224, 224)


class ImageDataset(Dataset):
    """PyTorch Dataset for loading images from file paths."""

    def __init__(self, file_paths, labels=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transforms.Compose(
            [transforms.Resize(IMAGE_SIZE), transforms.ToTensor()]
        )

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        if self.labels is not None:
            label = self.labels[idx]
            return image, label
        else:
            return image


def balance_dataset(file_paths, labels):
    """Balance dataset by downsampling to the minimum class count."""
    # Count each label
    counter = Counter(labels)
    min_count = min(counter.values())
    balanced_files = []
    balanced_labels = []
    label_to_files = {}
    for fp, lab in zip(file_paths, labels):
        label_to_files.setdefault(lab, []).append(fp)
    for lab, fps in label_to_files.items():
        fps = np.array(fps)
        # Randomly choose min_count files for this label
        indices = np.random.choice(len(fps), min_count, replace=False)
        balanced_files.extend(fps[indices].tolist())
        balanced_labels.extend([lab] * min_count)
    return balanced_files, balanced_labels


class EarlyStopping:
    """Early stopping to stop training when validation loss doesn't improve."""

    def __init__(self, patience=7, verbose=False, delta=0, path="checkpoint.pt"):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model..."
            )
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


def nn_function(
    file_paths,
    labels,
    equalize_labels,
    minibatch_size,
    validation_patience,
    checkpoint_path,
    use_amp=False,
    skip_training=False
):
    """
    Equivalent to MATLAB NN_function.m using PyTorch
    Parameters:
        file_paths (list): List of image file paths (strings).
        labels (list): List of labels corresponding to each image.
        equalize_labels (bool): Whether to equalize label counts in the dataset.
        minibatch_size (int): Batch size for training.
        validation_patience (int): Patience for early stopping.
        checkpoint_path (str): Directory to save model checkpoints.
        use_amp (bool): Whether to use mixed precision training.

    Returns:
        model: The trained PyTorch model.
        train_loader: Training data loader.
        val_loader: Validation data loader.
        val_labels: List of validation labels.
        val_label_counts: Dictionary with counts per label in the validation set.
        val_loader: Augmented validation data loader (same as val_loader here).
    """
    # If equalizing labels, balance the dataset
    if equalize_labels:
        file_paths, labels = balance_dataset(file_paths, labels)

    # Split dataset into training (70%) and validation (30%) sets
    train_files, val_files, train_labels, val_labels = train_test_split(
        file_paths, labels, test_size=0.3, stratify=labels, random_state=42
    )

    # Create PyTorch datasets and data loaders
    train_dataset = ImageDataset(train_files, train_labels)
    val_dataset = ImageDataset(val_files, val_labels)

    train_loader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=minibatch_size, shuffle=False)

    # Determine number of classes
    classes = np.unique(labels)
    num_classes = len(classes)

    # Load the pretrained ResNet50 model
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

    # Freeze the base model
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final classification layer
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    # Move model to device (CPU/GPU/MPS)
    model = model.to(device)

    # Set up loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=INITIAL_LEARNING_RATE, momentum=MOMENTUM)

    # Set up learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LEARN_RATE_DROP_PERIOD, gamma=LEARN_RATE_DROP_FACTOR)

    # Prepare checkpoint directory
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)

    # Setup early stopping
    early_stopping = EarlyStopping(
        patience=validation_patience,
        verbose=True,
        path=os.path.join(checkpoint_path, "model_checkpoint_best.pt"),
    )

    # Initialize GradScaler for mixed precision
    scaler = torch.amp.GradScaler(enabled=use_amp)

    # Training loop
    n_epochs = MAX_EPOCHS
    # Determine validation frequency: floor(number of training iterations per epoch)
    val_frequency = max(1, len(train_dataset) // minibatch_size)
    global_iteration = 0
    for epoch in range(n_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        iteration = 0

        for inputs, target_labels in train_loader:
            iteration += 1
            global_iteration += 1
            inputs, target_labels = inputs.to(device), target_labels.to(device)
            optimizer.zero_grad()

            # Use mixed precision if enabled
            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                outputs = model(inputs)
                loss = criterion(outputs, target_labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += target_labels.size(0)
            correct += (predicted == target_labels).sum().item()

            # Mid-epoch validation
            if iteration % val_frequency == 0:
                model.eval()
                val_loss = 0.0
                val_correct = 0
                val_total = 0
                with torch.no_grad():
                    for val_inputs, val_labels in val_loader:
                        val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                        val_outputs = model(val_inputs)
                        v_loss = criterion(val_outputs, val_labels)
                        val_loss += v_loss.item() * val_inputs.size(0)
                        _, val_predicted = torch.max(val_outputs.data, 1)
                        val_total += val_labels.size(0)
                        val_correct += (val_predicted == val_labels).sum().item()
                current_val_loss = val_loss / val_total
                print(f"[Epoch {epoch+1}, Iteration {iteration}] Validation Loss: {current_val_loss:.4f}")
                early_stopping(current_val_loss, model)
                if early_stopping.early_stop:
                    print("Early stopping triggered during epoch validation")
                    break
                model.train()  # switch back to training mode
        else:
            # Only executed if the inner loop did NOT break due to early stopping
            pass

        # Compute epoch statistics
        epoch_loss = running_loss / total
        epoch_acc = correct / total

        print(
            f"Epoch {epoch+1}/{n_epochs}, "
            f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}"
        )

        # Save checkpoint for this epoch
        torch.save(
            model.state_dict(),
            os.path.join(checkpoint_path, f"model_checkpoint_{epoch+1:02d}.pt"),
        )

        # Step the learning rate scheduler
        scheduler.step()
        print(f"Learning rate after epoch {epoch+1}: {scheduler.get_last_lr()[0]:.6f}")

        # If early stopping was triggered, break out of epoch loop
        if early_stopping.early_stop:
            break

    # Load the best model
    model.load_state_dict(
        torch.load(os.path.join(checkpoint_path, "model_checkpoint_best.pt"))
    )

    return (
        model,
        train_loader,
        val_loader,
        val_labels,
        dict(Counter(val_labels)),
        val_loader,
    )

Using device: cpu


In [74]:
# Fine-tuning hyperparameters (matching MATLAB NN_prep_for_classify.m)
FINE_TUNE_MAX_EPOCHS = 1
FINE_TUNE_MINIBATCH_SIZE = 32
FINE_TUNE_INITIAL_LEARNING_RATE = 1e-6


def nn_prep_for_classify(model, val_loader, train_loader):
    """Fine-tune the given model for one epoch with a low learning rate using the training and validation datasets.

    Parameters:
        model: PyTorch model.
        val_loader: Validation DataLoader.
        train_loader: Training DataLoader.

    Returns:
        model: The fine-tuned PyTorch model.
    """

    # Setup optimizer for fine-tuning with very low learning rate
    optimizer = torch.optim.SGD(model.parameters(), lr=FINE_TUNE_INITIAL_LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    # Determine validation frequency: floor(number of training iterations per epoch)
    val_frequency = max(1, len(train_loader.dataset) // FINE_TUNE_MINIBATCH_SIZE)

    model.train()
    iteration = 0
    running_loss = 0.0

    for inputs, target_labels in train_loader:
        iteration += 1
        inputs, target_labels = inputs.to(device), target_labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, target_labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

        # Mid-epoch validation
        if iteration % val_frequency == 0:
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                    val_outputs = model(val_inputs)
                    v_loss = criterion(val_outputs, val_labels)
                    val_loss += v_loss.item() * val_inputs.size(0)
                    _, val_predicted = torch.max(val_outputs.data, 1)
                    val_total += val_labels.size(0)
                    val_correct += (val_predicted == val_labels).sum().item()
            current_val_loss = val_loss / val_total
            current_val_acc = val_correct / val_total
            print(f"[Iteration {iteration}] Fine-tuning - Mid-epoch Val Loss: {current_val_loss:.4f}, Val Acc: {current_val_acc:.4f}")
            model.train()  # Switch back to training mode

    # End of epoch: Evaluate on validation set
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, target_labels in val_loader:
            inputs, target_labels = inputs.to(device), target_labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, target_labels)
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            val_total += target_labels.size(0)
            val_correct += (predicted == target_labels).sum().item()
    val_epoch_loss = val_loss / val_total
    val_epoch_acc = val_correct / val_total

    print(f'Fine-tuning - Final Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}')

    return model

In [75]:
def precision(M):
    """Calculate the precision vector from a confusion matrix M.
    Precision is defined as diag(M) / sum(M, axis=1).
    Parameters:
        M (np.ndarray): Confusion matrix.
    Returns:
        np.ndarray: The per-class precision.
    """
    M = np.array(M, dtype=float)
    # Avoid division by zero by adding eps
    eps = np.finfo(float).eps
    p = np.diag(M) / (np.sum(M, axis=1) + eps)
    return p


def recall(M):
    """Calculate the recall vector from a confusion matrix M.
    Recall is defined as diag(M) / sum(M, axis=0).
    Parameters:
        M (np.ndarray): Confusion matrix.
    Returns:
        np.ndarray: The per-class recall.
    """
    M = np.array(M, dtype=float)
    eps = np.finfo(float).eps
    r = np.diag(M) / (np.sum(M, axis=0) + eps)
    return r


def f1_score(p, r):
    """Calculate the F1 score given precision and recall vectors.
    The F1 score is defined as 2*(p*r)/(p+r). If (p+r)==0, F1 is set to 0.
    Parameters:
        p (np.ndarray): Precision vector.
        r (np.ndarray): Recall vector.
    Returns:
        np.ndarray: The per-class F1 score.
    """
    eps = np.finfo(float).eps
    f = np.where((p + r) > 0, 2 * p * r / (p + r + eps), 0)
    return f


def f1_metrics(M, label_counts):
    """Compute per-class precision, recall, F1 score and weighted metrics.

    Parameters:
        M (np.ndarray): Confusion matrix.
        label_counts (array-like): Counts for each label (should sum to the total).

    Returns:
        tuple: (precision_vector, recall_vector, f1_vector, weighted_precision, weighted_recall, weighted_f1)
    """
    label_counts = np.array(label_counts, dtype=float)
    weights = label_counts / (np.sum(label_counts) + np.finfo(float).eps)
    p = precision(M)
    r = recall(M)
    f = f1_score(p, r)
    weighted_p = np.sum(p * weights)
    weighted_r = np.sum(r * weights)
    weighted_f = np.sum(f * weights)
    return p, r, f, weighted_p, weighted_r, weighted_f

In [76]:
# Helper function to classify a dataset using the PyTorch model
def classify_dataset(model, dataloader):
    """Return predicted labels for the dataset."""
    model.eval()  # Set the model to evaluation mode
    all_preds = []

    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            if device.type == 'cuda':
                with torch.cuda.amp.autocast(device_type='cuda'):
                    outputs = model(inputs)
            else:
                outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())

    return np.array(all_preds)


# Helper function to compute overall accuracy
def compute_accuracy(pred_labels, true_labels):
    import torch
    if torch.is_tensor(true_labels):
        true_labels = true_labels.cpu().numpy().astype(int)
    else:
        true_labels = np.array(true_labels, dtype=int)
    return np.sum(pred_labels == true_labels) / len(true_labels)


# Added helper function to run training, fine-tuning and evaluation for given label type
def run_evaluation(files, numeric_labels, equalize, minibatch_size, validation_patience, checkpoint_path, skip_training=False):
    model, train_loader, val_loader, lbls, lbl_counts, _ = nn_function(
        files, numeric_labels, equalize_labels=equalize, minibatch_size=minibatch_size,
        validation_patience=validation_patience, checkpoint_path=checkpoint_path, use_amp=(device.type=='cuda'), skip_training=skip_training
    )

    # Fine-tune for classification
    model = nn_prep_for_classify(model, val_loader, train_loader)

    # Get predictions
    preds = classify_dataset(model, val_loader)

    # Compute overall accuracy using true labels extracted from the validation dataset
    true_labels = np.array(val_loader.dataset.labels, dtype=int)
    acc = compute_accuracy(preds, true_labels)

    # Compute confusion matrix and metrics
    conf = confusion_matrix(true_labels, preds)
    p, r, f1, wp, wr, wf1 = f1_metrics(conf, list(lbl_counts.values()))

    metrics = {
        'accuracy': acc,
        'weighted_precision': wp,
        'weighted_recall': wr,
        'weighted_f1': wf1,
        'preds': preds,
        'labels': true_labels,
        'confusion': conf
    }
    return metrics


def train(label_type='valence'):
    # Read data from Excel, adjust file name and columns as needed
    # Assuming the Excel file is named 'SoundwelDatasetKey.xlsx' and is in the current directory
    excel_file = "/content/drive/MyDrive/Colab Notebooks/soundwel_sound_analysis/soundwel/SoundwelDatasetKey.xlsx"
    data = pd.read_excel(excel_file)

    # Extract columns (adjust column names/indexes based on actual data)
    # For this example, we assume the Excel has columns: 'File', 'Valence', 'Context', 'Site'
    Files = data['Spectrogram Filename'].tolist()

    # Append directory to Files
    Files = [os.path.join('/content/drive/MyDrive/Colab Notebooks/soundwel_sound_analysis/soundwel/Soundwel Dataset - Audio and Spectrograms', f) for f in Files]

    Valence = data['Valence'].tolist()
    Context = data['Context'].tolist()  # Uncommented to handle context classification
    Site = data['Recording Team'].tolist()

    # Convert labels to numeric categories
    # Create mappings for valence and context
    valence_categories = {cat: idx for idx, cat in enumerate(sorted(set(Valence)))}
    context_categories = {cat: idx for idx, cat in enumerate(sorted(set(Context)))}  # Added for context
    site_labels = sorted(set(Site))

    Valence_numeric = [valence_categories[x] for x in Valence]
    Context_numeric = [context_categories[x] for x in Context]  # Added for context

    # Define checkpoint directory bases
    checkpoint_base_val = os.path.join('checkpoints', 'Valence')
    checkpoint_base_con = os.path.join('checkpoints', 'Context')
    os.makedirs(checkpoint_base_val, exist_ok=True)
    os.makedirs(checkpoint_base_con, exist_ok=True)

    # Dictionaries to store metrics
    overall_metrics_val = {}
    overall_metrics_con = {}  # Added for context
    site_accuracy_val = defaultdict(dict)
    site_accuracy_con = defaultdict(dict)  # Added for context

    # Loop over 12 iterations (mimicking MATLAB for i = 1:12)
    for i in range(1, 2):
        print(f"Beginning Loop: {i}")

        # Define checkpoint directory for this iteration
        cp_val = os.path.join(checkpoint_base_val, f'Iter{i}')
        cp_con = os.path.join(checkpoint_base_con, f'Iter{i}')
        os.makedirs(cp_val, exist_ok=True)
        os.makedirs(cp_con, exist_ok=True)

        if label_type in ['both', 'valence']:
            # Run evaluation for valence
            val_metrics_dict = run_evaluation(Files, Valence_numeric, True, 32, 5, cp_val, skip_training=False)
            print(f"Iteration {i} - Valence Accuracy: {val_metrics_dict['accuracy']:.2f}")
            overall_metrics_val[i] = val_metrics_dict
            for site in site_labels:
                site_accuracy_val[site][i] = val_metrics_dict['accuracy']
        if label_type in ['both', 'context']:
            # Run evaluation for context
            con_metrics_dict = run_evaluation(Files, Context_numeric, False, 32, 5, cp_con, skip_training=False)
            print(f"Iteration {i} - Context Accuracy: {con_metrics_dict['accuracy']:.2f}")
            overall_metrics_con[i] = con_metrics_dict
            for site in site_labels:
                site_accuracy_con[site][i] = con_metrics_dict['accuracy']

    # Save metrics to file or print summary
    print("Overall Valence Metrics:", overall_metrics_val)
    print("Overall Context Metrics:", overall_metrics_con)
    print("Site-wise Valence Accuracies:", dict(site_accuracy_val))
    print("Site-wise Context Accuracies:", dict(site_accuracy_con))

    # New block to run site_validation_imds from the nn_function overall
    # try:
    #     from site_validation_imds import site_validation_imds
    #     # Define the directory containing the soundwel images
    #     image_dir = 'soundwel'
    #     # List image files in the directory with extensions png, jpg, jpeg
    #     image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    #     # Create dummy labels (e.g., all zeros)
    #     dummy_labels = [0] * len(image_files)

    #     print('Running site_validation_imds on image files:', image_files)
    #     dataset_site = site_validation_imds(image_files, dummy_labels, base_dir=image_dir)
    #     # Get the first batch to check
    #     images, labels = next(iter(dataset_site))
    #     print('Site validation batch images shape:', images.shape, 'labels:', labels)
    # except Exception as e:
    #     print('Error during site validation:', e)

In [None]:
train(label_type='valence')

Beginning Loop: 1
[Epoch 1, Iteration 101] Validation Loss: 0.5243
Validation loss decreased (inf --> 0.524257). Saving model...
Epoch 1/20, Train Loss: 0.5729, Train Acc: 0.7032
Learning rate after epoch 1: 0.001000
[Epoch 2, Iteration 101] Validation Loss: 0.4912
Validation loss decreased (0.524257 --> 0.491238). Saving model...
Epoch 2/20, Train Loss: 0.5020, Train Acc: 0.7689
Learning rate after epoch 2: 0.001000
