In [None]:
import torch
# import librosa
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline  
# import librosa.display
import IPython.display as ipd
from scipy.io import wavfile
from scipy.io import wavfile
# from audio2numpy import open_audio
# import soundfile as sf
import sys
sys.path.append('../')
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import IPython.display as ipd

In [None]:
def print_metadata(metadata, src=None):
    if src:
        print("-" * 10)
        print("Source:", src)
        print("-" * 10)
        print(" - sample_rate:", metadata.sample_rate)
        print(" - num_channels:", metadata.num_channels)
        print(" - num_frames:", metadata.num_frames)
        print(" - bits_per_sample:", metadata.bits_per_sample)
        print(" - encoding:", metadata.encoding)
        print()

def print_stats(waveform, sample_rate=None, src=None):
    if src:
        print("-" * 10)
        print("Source:", src)
        print("-" * 10)
    if sample_rate:
        print("Sample Rate:", sample_rate)
    print("Shape:", tuple(waveform.shape))
    print("Dtype:", waveform.dtype)
    print(f" - Max:     {waveform.max().item():6.3f}")
    print(f" - Min:     {waveform.min().item():6.3f}")
    print(f" - Mean:    {waveform.mean().item():6.3f}")
    print(f" - Std Dev: {waveform.std().item():6.3f}")
    print()
    print(waveform)
    print()

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
    if num_channels > 1:
        axes[c].set_ylabel(f'Channel {c+1}')
    if xlim:
        axes[c].set_xlim(xlim)
    if ylim:
        axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
    if num_channels > 1:
        axes[c].set_ylabel(f'Channel {c+1}')
    if xlim:
        axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)

def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(ipd.Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(ipd.Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")


In [None]:
file_path = '/gpfs/data1/cmongp1/mpaliyam/raganet/data/audio/AbhEri/006a-Abheri-Raga_Alapanai.mp3'
# file_path = '/gpfs/data1/cmongp1/mpaliyam/raganet/data/audio/AbhEri/01-bhajarE_rE_mAnasa_shrI-AbhEri-mysore_vAsudevAcAr.mp3'


In [None]:
%%time
waveform, sample_rate = torchaudio.load(file_path)

In [None]:

if len(waveform.shape) > 1: 
    waveform = waveform.mean(axis = 0).reshape((1,-1))
# print_stats(waveform, sample_rate=sample_rate)
# plot_waveform(waveform, sample_rate)
# plot_specgram(waveform, sample_rate)
play_audio(waveform, sample_rate)

In [None]:
effects = [["pitch", "-q", "0"], [ "rate", "44100"]]
waveform[:,400:30000].shape

In [None]:
%%time
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(
    waveform, sample_rate, effects)

In [None]:
play_audio(waveform2, sample_rate2)

In [None]:
sample_rate2

In [None]:
%%time
effects = [[ "rate", "44100"]]

waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(
    waveform, sample_rate, effects)
print(sample_rate2, waveform2.shape)

In [None]:
play_audio(waveform2, sample_rate2)

In [None]:
model = torch.nn.Sequential(
                T.MelSpectrogram(sample_rate=sample_rate,n_mels = 244, n_fft = 1024))

# m = T.MelSpectrogram(sample_rate = sample_rate, n_mels = 24)

model.to('cuda')

print(sample_rate)
model

In [None]:
%%time
m = model(waveform2.to('cuda'))

In [None]:
print(m.shape)
plt.imshow(m[0,0:128,0:400].cpu().squeeze())

In [None]:
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(20,4))
cax = ax.matshow(m[0,0:128,0:400].cpu(), interpolation='nearest', aspect='auto', cmap=plt.cm.afmhot, origin='lower')
fig.colorbar(cax)
plt.title('Original Spectrogram')


In [None]:
t = (m.type(torch.FloatTensor))
t.dtype