In [None]:
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
import filterbanks as fb
from IPython.display import Audio, display, HTML

## **1. Compute subband decompositions**

In [None]:
# Load audio to filter ---------------
audio_path    = "your_sound.wav"
sr = 44100
audio, _ = librosa.load(audio_path, sr=sr)

# Chop audio ------------------------
seconds = 2
size = sr * seconds
audio = audio[:size]
print("og audio")
display(Audio(audio, rate=sr))
audio = torch.tensor(audio).float()

# Create filterbank -----------------      
N_filter_bank = 16
erb_bank = fb.EqualRectangularBandwidth(size, sr, N_filter_bank, 20, sr // 2) # you may as well use Linear or Logarithmic filterbanks

# Apply filterbank ------------------
subbands_signal = erb_bank.generate_subbands(audio)[1:-1, :]

def plot_signals(matrix):
    num_signals = matrix.shape[0]
    fig, axes = plt.subplots(4, 4, figsize=(15, 10))
    
    for i in range(num_signals):
        row = i // 4
        col = i % 4
        size = len(matrix[i])
        axes[row, col].plot(np.arange(0,size)*(1/sr), matrix[i].detach().cpu().numpy())
        axes[row, col].set_title(f'Subband {i+1}')
        axes[row, col].set_xlabel('Time (s)')
        axes[row, col].set_ylabel('Amplitude')

    # Hide any unused subplots
    for i in range(num_signals, 16):
        row = i // 4
        col = i % 4
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()

plot_signals(subbands_signal)

# Generate HTML for a 4x4 matrix of playable audios
html_code = "<table style='width:100%; border-spacing:10px;'>"

for i in range(4):
    html_code += "<tr>"
    for j in range(4):
        index = i * 4 + j
        if index < len(subbands_signal):
            # Embed each audio widget in a cell
            audio_html = Audio(subbands_signal[index].detach().cpu().numpy(), rate=sr)._repr_html_()
            html_code += f"<td style='text-align:center; padding:10px;'>{audio_html}<br>Subband {index+1}</td>"
        else:
            # Empty cell if no audio is left
            html_code += "<td></td>"
    html_code += "</tr>"

html_code += "</table>"

# Display the matrix
display(HTML(html_code))

## **2. Visualize the Filterbanks**

In [None]:
# Create filterbank -----------------      
sr = 44100
size = 2 ** 15
N_filter_bank = 16
erb_bank = fb.EqualRectangularBandwidth(size, sr, N_filter_bank, 20, sr // 2) 
linear_bank = fb.Linear(size, sr, N_filter_bank, 20, sr // 2)  
logarithmic_bank = fb.Logarithmic(size, sr, N_filter_bank, 20, sr // 2)  

def plot_filter(filter_bank, title):
    plt.figure(figsize=(10, 6))
    for i in range(filter_bank.filters.shape[1]):  # Access the shape of the filters
        plt.plot(filter_bank.freqs, filter_bank.filters[:, i].detach().cpu().numpy(), label=f'Filter {i+1}')

    plt.title(title)
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Amplitude")
    plt.grid(True)
    plt.legend()
    plt.show()

plot_filter(linear_bank, "Linear Filterbank")
plot_filter(erb_bank, "ERB Filterbank")
plot_filter(logarithmic_bank, "Logarithmic Filterbank")
