In [None]:
import torchaudio
import os
import pandas as pd
import matplotlib.pyplot as plt
import torch
import math
import random
from torchaudio import transforms
from IPython.display import Audio
import librosa

input_dir = "./data_music"
output_dir = "./data_split"

In [None]:
def split_audio(audio_file, output_dir, num_segments=5, original_name=""):
    try:
        waveform, sample_rate = torchaudio.load(audio_file)
    except RuntimeError:
        print(f"Failed to load: {audio_file}")
        return

    total_duration = waveform.size(1) / sample_rate
    segment_duration = total_duration / num_segments
    
    os.makedirs(output_dir, exist_ok=True)
    
    for i in range(num_segments):
        start = int(i * segment_duration * sample_rate)
        end = int((i + 1) * segment_duration * sample_rate)
        if end > waveform.size(1):
            end = waveform.size(1)
        
        segment_waveform = waveform[:, start:end]
        segment_name = f"{original_name}_segment_{i + 1}.wav"
        output_file = os.path.join(output_dir, segment_name)
        
        torchaudio.save(output_file, segment_waveform, sample_rate)

def split_audio_files(input_dir, output_dir, num_segments=5):
    for class_folder in os.listdir(input_dir):
        class_path = os.path.join(input_dir, class_folder)
        if os.path.isdir(class_path):
            output_class_path = os.path.join(output_dir, class_folder)
            os.makedirs(output_class_path, exist_ok=True)
            
            for audio_file in os.listdir(class_path):
                audio_file_path = os.path.join(class_path, audio_file)
                if os.path.isfile(audio_file_path):
                    split_audio(audio_file_path, output_class_path, num_segments, audio_file)

# Split audio files
split_audio_files(input_dir, output_dir, num_segments=60)

In [None]:
GTZAN = output_dir
gtzan_directory_list = os.listdir(GTZAN)

for root, dirs, files in os.walk("./data_split"):
    for file in files:
        if file == ".DS_Store":
            os.remove(os.path.join(root, file))

file_genre = []
file_path = []

for folder in gtzan_directory_list:
    files_path = os.path.join(GTZAN, folder)
    for audio in os.listdir(files_path):
        file_genre.append(folder)
        file_path.append(files_path + "/" + audio)

In [None]:
genre_df = pd.DataFrame(file_genre, columns=["Genre"])
path_df = pd.DataFrame(file_path, columns=["Path"])
gtzan_df = pd.concat([genre_df, path_df], axis=1)
gtzan_df.sample(n=10, random_state = 42)

In [None]:
class AudioUtil():
    @staticmethod
    def open(audio_file):
        sig, sr = torchaudio.load(audio_file)
        return (sig, sr)
    
    @staticmethod
    def rechannel(aud, new_channel):
        sig, sr = aud
        if (sig.shape[0] == new_channel):
            return aud
        if (new_channel == 1):
            resig = sig[:1, :]
        else:
            resig = torch.cat([sig, sig])
            return ((resig, sr))
    
    @staticmethod
    def resample(aud, newsr):
        sig, sr = aud
        if (sr == newsr):
            return aud
        
        num_channels = sig.shape[0]
        resig = torchaudio.transforms.Resample(sr, newsr)(sig[:1,:])
        if (num_channels > 1):
            retwo = torchaudio.transforms.Resample(sr, newsr)(sig[1:,:])
            resig = torch.cat([resig, retwo])
        return ((resig, newsr))
    
    @staticmethod
    def pad_trunc(aud, max_ms):
        sig, sr = aud
        num_rows, sig_len = sig.shape
        max_len = sr//1000 * max_ms

        if (sig_len > max_len):
            sig = sig[:,:max_len]
        elif (sig_len < max_len):
            pad_begin_len = random.randint(0, max_len - sig_len)
            pad_end_len = max_len - sig_len - pad_begin_len

            pad_begin = torch.zeros((num_rows, pad_begin_len))
            pad_end = torch.zeros((num_rows, pad_end_len))
            
            sig = torch.cat((pad_begin, sig, pad_end), 1)
            
        return (sig, sr)
    
    @staticmethod
    def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
        sig, sr = aud
        top_db = 80
        
        spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)
        spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
        
        return (spec)

In [None]:
if torch.cuda.is_available():
    device_count = torch.cuda.device_count()
    print("Number of available GPUs:", device_count)
    for i in range(device_count):
        print("GPU", i, ":", torch.cuda.get_device_name(i))
else:
    print("GPU is not available")

In [None]:
def plot_waveform(waveform, sr, title="Waveform", ax=None):
    waveform = waveform.numpy()
    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sr

    if ax is None:
        _, ax = plt.subplots()
    ax.plot(time_axis, waveform[0], linewidth=1)
    ax.grid(True)
    ax.set_xlim([0, time_axis[-1]])
    ax.set_title(title)
    
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.imshow(specgram, origin="lower", aspect="auto", interpolation="nearest")