In [7]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import librosa as lb
from sklearn.preprocessing import LabelEncoder

# Load the uploaded data
annotated_data = pd.read_csv('/Users/rachelwang/Downloads/notes/models/csv/quality_labeled_empty.csv')

# Dummy label encoder (for illustration purposes)
dummy_labels = ['1', '2', '3', '4', '5']
le = LabelEncoder()
le.fit(dummy_labels)

# Define the model class (same as before)
class MFCCNet(nn.Module):
    def __init__(self):
        super(MFCCNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 3), padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 2), padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(64, 96, kernel_size=(2, 2), padding=1)
        self.bn3 = nn.BatchNorm2d(96)
        self.pool3 = nn.MaxPool2d(2)
        
        self.conv4 = nn.Conv2d(96, 128, kernel_size=(2, 2), padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.gmp = nn.AdaptiveMaxPool2d((1, 1))
        
        self.fc1 = nn.Linear(128, 50)
        self.fc2 = nn.Linear(50, 25)
        self.fc3 = nn.Linear(25, 5)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = self.pool1(nn.ReLU()(self.bn1(self.conv1(x))))
        x = self.pool2(nn.ReLU()(self.bn2(self.conv2(x))))
        x = self.pool3(nn.ReLU()(self.bn3(self.conv3(x))))
        x = self.gmp(nn.ReLU()(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)
        x = self.dropout(nn.ReLU()(self.fc1(x)))
        x = self.dropout(nn.ReLU()(self.fc2(x)))
        x = self.fc3(x)
        return x

# Define the feature extraction and preprocessing functions
def pad_or_truncate(feature, max_len):
    if feature.shape[1] < max_len:
        pad_width = max_len - feature.shape[1]
        feature = np.pad(feature, ((0, 0), (0, pad_width)), mode='constant')
    else:
        feature = feature[:, :max_len]
    return feature

def getFeatures(path, max_len=259):
    soundArr, sample_rate = lb.load(path)
    mfcc = lb.feature.mfcc(y=soundArr, sr=sample_rate)
    mfcc = pad_or_truncate(mfcc, max_len)
    return mfcc

def preprocess_audio(path, max_len=259):
    mfcc = getFeatures(path, max_len)
    mfcc = np.expand_dims(mfcc, axis=0)  
    mfcc = np.expand_dims(mfcc, axis=1)  
    return torch.tensor(mfcc).float()

# Load the saved model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MFCCNet().to(device)
model.load_state_dict(torch.load('/Users/rachelwang/Downloads/notes/models/best_model_updated.pth'))
model.eval()

def predict_quality(model, audio_path):
    # Preprocess the audio file
    mfcc = preprocess_audio(audio_path).to(device)
    
    # Make prediction
    with torch.no_grad():
        output = model(mfcc)
        _, predicted = torch.max(output, 1)
    
    # Get the predicted class label
    predicted_label = le.inverse_transform(predicted.cpu().numpy())[0]
    return predicted_label

# Predict quality for the new data
predicted_qualities = []
for idx, row in annotated_data.iterrows():
    audio_path = row['file']
    predicted_quality = predict_quality(model, audio_path)
    predicted_qualities.append(predicted_quality)

# Add the predicted quality to the dataframe
annotated_data['quality'] = predicted_qualities

# Save the updated dataframe
output_path = '/Users/rachelwang/Downloads/notes/models/csv/predicted_quality_label_empty.csv'
annotated_data.to_csv(output_path, index=False)
output_path

'/Users/rachelwang/Downloads/notes/models/csv/predicted_quality_label_empty.csv'

In [8]:
import pandas as pd

# Load the CSV file with predictions
data = pd.read_csv(output_path)

# Count the occurrences of each adherence label
label_counts = data['quality'].value_counts().sort_index()

# Print the counts for each label
print("Quality label counts:")
print(label_counts)

Quality label counts:
quality
1     11
2     27
4    100
5    661
Name: count, dtype: int64
