In [35]:
import csv
import os
import torch
import torchaudio
import librosa
from cnn_model import CNNModel

from datapreprocessing import AudioProcessor
from torch.utils.data import Dataset



In [36]:
# audio_file_path = "/Users/zainhazzouri/projects/Bachelor_Thesis/Data/Kaggle/music_wav/bartok.wav"
audio_file_path = "/Users/zainhazzouri/Desktop/egp1.mp3"
# audio_file_path = "/Users/zainhazzouri/projects/Bachelor_Thesis/Data/Kaggle/music_wav/bagpipe.wav"

SAMPLE_RATE = 22050 # sample rate of the audio file
bit_depth = 16 # bit depth of the audio file
hop_length = 512
n_mfcc = 20 # number of MFCCs features
n_fft=1024, # window size
n_mels = 256 # number of mel bands to generate
win_length = None # window length



In [37]:
# Set device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_built():  # if you have apple silicon mac
    device = "mps"  # if it doesn't work try device = torch.device('mps')
else:
    device = "cpu"
print(f"Using {device}")


Using mps


In [38]:
model = CNNModel().to(device)
model.load_state_dict(torch.load("CNNModel_speech_music_discrimination.pth"))
model.eval()

CNNModel(
  (encoder): Sequential(
    (0): Conv2d(1, 40, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): ReLU()
    (2): Conv2d(40, 80, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (3): ReLU()
    (4): Conv2d(80, 160, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(160, 80, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(80, 40, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(40, 40, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
    (5): ReLU()
  )
  (avg_pool): AdaptiveAvgPool2d(output_size=1)
  (fc): Linear(in_features=40, out_features=4, bias=True)
)

In [39]:
def split_waveform(waveform, sample_rate):
    segment_length = sample_rate
    num_segments = waveform.shape[-1] // segment_length
    segments = []

    for i in range(num_segments):
        start = i * segment_length
        end = start + segment_length
        segments.append(waveform[:, start:end])

    return segments

In [40]:
audio_processor = AudioProcessor(audio_file_path)


In [41]:

# TODO maybe only use Librosa for all types of files ,, this functions is causing problems
def classify_audio_file_segments(audio_file_path, audio_processor):
    file_ext = os.path.splitext(audio_file_path)[1].lower()

    if file_ext == '.mp3':
        waveform, sample_rate = librosa.load(audio_file_path, sr=SAMPLE_RATE)
        waveform = torch.from_numpy(waveform).unsqueeze(0)
    else:
        waveform, sample_rate = torchaudio.load(audio_file_path)

    segments = split_waveform(waveform, sample_rate)

    segment_classifications = []

    for segment in segments:
        padded_segment = audio_processor.pad_waveform(segment, desired_length=sample_rate)
        # Apply the MFCC transformation directly
        mfcc = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=audio_processor.n_mfcc)(padded_segment)
        mfcc = mfcc.to(device).unsqueeze(0)
        output = model(mfcc)
        _, predicted_class = torch.max(output, 1)
        segment_classifications.append(predicted_class.item())

    return segment_classifications


In [42]:

def format_time(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"


def generate_classification_table(classification_results, segment_duration=1):
    table = []
    current_label = classification_results[0]
    start_time = 0

    for i, label in enumerate(classification_results[1:], 1):
        if label != current_label:
            table.append([format_time(start_time), format_time(i * segment_duration), current_label])
            start_time = i * segment_duration
            current_label = label

    table.append([format_time(start_time), format_time(len(classification_results) * segment_duration), current_label])

    return table


def save_classification_table_to_csv(table, output_file):
    with open(output_file, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(['Start Time', 'End Time', 'Class'])
        for row in table:
            csvwriter.writerow(row)

# print(audio_file_path)
classification_results = classify_audio_file_segments(audio_file_path, audio_processor)
table = generate_classification_table(classification_results)

# Save the table to a CSV file
save_classification_table_to_csv(table, "classification_table.csv")

# Print the table
print("Start Time | End Time | Class")
print("-" * 28)
for row in table:
    print(f"{row[0]} | {row[1]} | {row[2]}")




Start Time | End Time | Class
----------------------------
00:00:00 | 00:00:02 | 0
00:00:02 | 00:00:03 | 1
00:00:03 | 00:00:04 | 0
00:00:04 | 00:00:07 | 1
00:00:07 | 00:00:47 | 0
00:00:47 | 00:00:49 | 1
00:00:49 | 00:00:50 | 0
00:00:50 | 00:00:52 | 1
00:00:52 | 00:00:53 | 0
00:00:53 | 00:01:03 | 1
00:01:03 | 00:01:04 | 0
00:01:04 | 00:01:06 | 1
00:01:06 | 00:01:07 | 0
00:01:07 | 00:01:08 | 1
00:01:08 | 00:01:13 | 0
00:01:13 | 00:01:14 | 1
00:01:14 | 00:01:15 | 0
00:01:15 | 00:01:23 | 1
00:01:23 | 00:01:25 | 0
00:01:25 | 00:01:27 | 1
00:01:27 | 00:01:30 | 0
00:01:30 | 00:01:34 | 1
00:01:34 | 00:01:36 | 0
00:01:36 | 00:01:37 | 1
00:01:37 | 00:01:38 | 0
00:01:38 | 00:01:40 | 1
00:01:40 | 00:01:47 | 0
00:01:47 | 00:01:48 | 1
00:01:48 | 00:01:49 | 0
00:01:49 | 00:01:50 | 1
00:01:50 | 00:01:51 | 0
00:01:51 | 00:01:52 | 1
00:01:52 | 00:01:53 | 0
00:01:53 | 00:01:54 | 1
00:01:54 | 00:01:56 | 0
00:01:56 | 00:01:59 | 1
00:01:59 | 00:02:01 | 0
00:02:01 | 00:02:08 | 1
00:02:08 | 00:02:11 | 0
00:02

In [43]:
classification_results = classify_audio_file_segments(audio_file_path, audio_processor)
table = generate_classification_table(classification_results)
print(table)



[['00:00:00', '00:00:02', 0], ['00:00:02', '00:00:03', 1], ['00:00:03', '00:00:04', 0], ['00:00:04', '00:00:07', 1], ['00:00:07', '00:00:47', 0], ['00:00:47', '00:00:49', 1], ['00:00:49', '00:00:50', 0], ['00:00:50', '00:00:52', 1], ['00:00:52', '00:00:53', 0], ['00:00:53', '00:01:03', 1], ['00:01:03', '00:01:04', 0], ['00:01:04', '00:01:06', 1], ['00:01:06', '00:01:07', 0], ['00:01:07', '00:01:08', 1], ['00:01:08', '00:01:13', 0], ['00:01:13', '00:01:14', 1], ['00:01:14', '00:01:15', 0], ['00:01:15', '00:01:23', 1], ['00:01:23', '00:01:25', 0], ['00:01:25', '00:01:27', 1], ['00:01:27', '00:01:30', 0], ['00:01:30', '00:01:34', 1], ['00:01:34', '00:01:36', 0], ['00:01:36', '00:01:37', 1], ['00:01:37', '00:01:38', 0], ['00:01:38', '00:01:40', 1], ['00:01:40', '00:01:47', 0], ['00:01:47', '00:01:48', 1], ['00:01:48', '00:01:49', 0], ['00:01:49', '00:01:50', 1], ['00:01:50', '00:01:51', 0], ['00:01:51', '00:01:52', 1], ['00:01:52', '00:01:53', 0], ['00:01:53', '00:01:54', 1], ['00:01:54', 