In [None]:
import os
import torch
import torchaudio
import IPython.display as ipd
# import matplotlib
# matplotlib.use('qt5agg')
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import soundfile as sf
import pandas as pd
from torchaudio.compliance.kaldi import spectrogram

In [None]:
plt.close('all')

# CNN with images of spectrograms

# Explore noWhale

In [None]:
# NoWhale
waveform_noWhale, sr_noWhale = torchaudio.load(r'.\train\train\0.wav')

print(type(waveform_noWhale))
print(type(sr_noWhale))
print(waveform_noWhale.shape)
# 2000 data points per second, in total 4000 data points
print(sr_noWhale)


In [None]:
plt.figure()
plt.plot(waveform_noWhale.t().numpy())

In [None]:
# Spectrogram
spectrogram_noWhale = torchaudio.transforms.Spectrogram()(waveform_noWhale)
# plot the spectrogram
plt.figure()
# scale the values with log2 and then select the first channel
plt.imshow(spectrogram_noWhale.log2()[0,:,:].numpy(), cmap='viridis')

## Explore RightWhale

In [None]:
# audio RightWhale
waveform_rightWhale, sr_rightWhale = torchaudio.load(r'.\train\train\1.wav')

print(type(waveform_rightWhale))
print(type(sr_rightWhale))
print(waveform_rightWhale.shape)
# 2000 data points per second, in total 4000 data points
print(sr_rightWhale)



In [None]:
plt.figure()
plt.plot(waveform_rightWhale.t().numpy())


In [None]:
# Spectrogram
spectrogram_rightWhale = torchaudio.transforms.Spectrogram()(waveform_rightWhale)
# plot the spectrogram
plt.figure()
# scale the values with log2 and then select the first channel
plt.imshow(spectrogram_rightWhale.log2()[0,:,:].numpy(), cmap='viridis')


## Load data

In [None]:
# collect all the paths in the train folder
sound_files = os.listdir(r'.\train\train')
df_labels = pd.read_csv(r'.\train.csv')
print(len(sound_files))

In [None]:
# get labels
train_labels = df_labels['class'].to_numpy()
print(train_labels[:10])
print(len(train_labels))

In [None]:
# remove the audios that are not valid
idx_labeled_audio_files = df_labels['idx'].to_numpy()
available_audio_files = []
for file in sound_files:
    try:
        idx = int(file.split('.')[0])
        available_audio_files.append(idx)
    except:
        print("Invalid file: ", file)

In [None]:
available_audio_files = sorted(available_audio_files)
print(available_audio_files[:5])

In [None]:
# split the valid audio files into noWhale and RightWhale
noWhale_paths = []
rightWhale_paths = []
# assign the paths to the corresponding label
for idx in available_audio_files:
    if train_labels[idx] == 'NoWhale':
        noWhale_paths.append(idx)
    elif train_labels[idx] == 'RightWhale':
        rightWhale_paths.append(idx)
    else:
        print("Invalid label: ", train_labels[idx])

print(len(noWhale_paths))
print(len(rightWhale_paths))

In [None]:
# get the complete path of the valid audio files
noWhale_paths = [os.path.join(r'.\train\train', str(file) + '.wav') for file in noWhale_paths] 
rightWhale_paths = [os.path.join(r'.\train\train', str(file) + '.wav') for file in rightWhale_paths]
print(noWhale_paths[:5])
print(rightWhale_paths[:5])

## Process data : Get images

In [None]:
if not os.path.exists(r'.\train_images_spectrogram'):
    os.makedirs(r'.\train_images_spectrogram\noWhale')
    os.makedirs(r'.\train_images_spectrogram\rightWhale')

In [None]:
# save the spectrogram images of the noWhale

for j, path in enumerate(noWhale_paths):
    if j % 1000 == 0:
        print(j)
    idx = int(path.split('\\')[-1].split('.')[0])
    waveform, sr = torchaudio.load(path)
    
    spectrogram_noWhale = torchaudio.transforms.Spectrogram()(waveform)
    spectrogram_path = os.path.join(r'.\train_images_spectrogram\noWhale', str(idx) + '.png')
    # scale the values with log2 and then select the first channel
   
    plt.imsave(spectrogram_path, spectrogram_noWhale.log2()[0,:,:].numpy(), cmap='viridis')
    j = j + 1


In [None]:

# save the spectrogram images of the rightWhale

for j, path in enumerate(rightWhale_paths):
    if j % 1000 == 0:
        print(j)
    idx = int(path.split('\\')[-1].split('.')[0])
    waveform, sr = torchaudio.load(path)

    spectrogram_rightWhale = torchaudio.transforms.Spectrogram()(waveform)
    spectrogram_path = os.path.join(r'.\train_images_spectrogram\rightWhale', str(idx) + '.png')
    # scale the values with log2 and then select the first channel

    plt.imsave(spectrogram_path, spectrogram_rightWhale.log2()[0,:,:].numpy(), cmap='viridis')
    j = j + 1



In [None]:
# save in labeled directories to apply ImageFolder