# HIFIGAN+ BWE COLAB FORK

Added features:
* phase accurate chunking
* batch processing
* crafted for audio input with a cutoff between 12khz and 20khz
* multiband ensembling with original file, keeping only "what's needed"
* smooth lowpass filter applied to extra high frequency (more realistic output)


<br/>
<br/>

*Original work : https://github.com/brentspell/hifi-gan-bwe*

*Tweaks by jarredou*

In [None]:
# @title Installation { display-mode: "form" }
from google.colab import drive
drive.mount('/content/drive')
!git clone -b fix_stream https://github.com/jarredou/hifi-gan-bwe
#!pip install -r hifi-gan-bwe/requirements.txt

In [None]:
# @title Inference { display-mode: "form" }
%cd /content/hifi-gan-bwe

import librosa
import google.colab.files
import numpy as np
import torch
import soundfile as sf
from IPython.display import Audio
from hifi_gan_bwe import BandwidthExtender

import gc
from tqdm import tqdm
from scipy import signal
from pathlib import Path
import glob
import warnings
#warnings.filterwarnings("ignore")


def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray):
    if array_1.shape[1] > array_2.shape[1]:
        array_1 = array_1[:,:array_2.shape[1]]
    elif array_1.shape[1] < array_2.shape[1]:
        padding = array_2.shape[1] - array_1.shape[1]
        array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0)
    return array_1

def lr_filter(audio, cutoff, filter_type, order=20, sr=48000):
    audio = audio.T
    nyquist = 0.5 * sr
    normal_cutoff = cutoff / nyquist
    b, a = signal.butter(order//2, normal_cutoff, btype=filter_type, analog=False)
    sos = signal.tf2sos(b, a)
    filtered_audio = signal.sosfiltfilt(sos, audio)
    return filtered_audio.T

def process_file(input_path,output_path):
    Path(output_path).mkdir(parents=True, exist_ok=True)
    filename =  Path(input_path).stem
    print(f"Processing {input_path}")

    input, sr = librosa.load(input_path, sr=None, mono=False)
    #print(f"input.shape={input.shape}")

    # resample to 24khz pre-processing
    input_rs = librosa.resample(input,orig_sr=sr,target_sr=sr_ft, res_type='soxr_hq')
    #sf.write("input2.wav", input_rs.T, sr_ft)
    # print(f"input resampled.shape={input_rs.shape}")

    # chunking parameters
    hop_length = chunk_size * sr_ft
    # print(f"hop_length={hop_length}")

    # process audio in chunks for each channel
    output_chunks = []
    num_chunks = input_rs.shape[1] // hop_length + 1
    remaining_samples = input_rs.shape[1] % hop_length

    with torch.no_grad():
        for i in tqdm(range(num_chunks), unit='chunk'):
            start = i * hop_length
            end = start + hop_length
            if end > input_rs.shape[1]:
                end = input_rs.shape[1]
            chunk = input_rs[:, start:end]

            # Extend bandwidth using the model for each channel
            chunk = np.stack([
                model(torch.from_numpy(x).cuda(), sr_ft).cpu()
                for x in chunk])

            output_chunks.append(chunk)

    # concatenate output chunks for each channel
    output = np.concatenate(output_chunks, axis=1)
    # print(f"output.shape={output.shape}")

    # get low from origin input resampled
    low = librosa.resample(input,orig_sr=sr,target_sr=48000, res_type='soxr_hq')
    # print(f"low.shape={low.shape}")

    # fix length issues
    #low = match_array_shapes(low, output)
    output = match_array_shapes(output, low)
    # print(f"low.shape={low.shape}")

    # linkwitz riley crossover
    low = lr_filter(low.T, crossover_freq, 'lowpass')
    high = lr_filter(output.T, crossover_freq, 'highpass')
    # print(f"high.shape={high.shape}")

    # sf.write(f"{output_path}/low_upsampled.wav", low, 48000, subtype='PCM_16')
    # sf.write(f"{output_path}/high_upsampled.wav", high, 48000, subtype='PCM_16')


    # add smoothing filter to high frequencies (more realistic)
    high = lr_filter(high, 18000, 'lowpass', order=2)

    # multiband ensemble
    output = low + high
    # print(f"output.shape={output.shape}")

    # Resample output to the original sample rate
    #output = librosa.resample(output, orig_sr=48000, target_sr=44100, res_type='soxr_hq')
    #print(f"output resampled.shape={output.shape}")
    sf.write(f"{output_path}/{filename}_upsampled.wav", output, 48000, subtype='PCM_16')
    print(f"Processing done !\nFile exported to : {output_path}/{filename}_upsampled.wav")


input_path = "/content/drive/MyDrive/audio.wav" #@param {type:"string"}
output_path = "/content/drive/MyDrive/output" #@param {type:"string"}

# chunksize in seconds
chunk_size = "90" #@param [120, 90, 60, 30, 10, 1]
chunk_size = int(chunk_size)

#hifi-gan-bwe-10-42890e3-vctk-48kHz
model = BandwidthExtender.from_pretrained("hifi-gan-bwe-13-59f00ca-vctk-24kHz-48kHz").cuda()

input_cutoff = "14000" #@param [20000, 19000, 18000, 17000, 16000, 14000, 13000, 13000, 12000]
input_cutoff = int(input_cutoff)

crossover_freq = input_cutoff - 500
sr_ft = 24000


if Path(input_path).is_file():
  process_file(input_path,output_path)
else:
  for file_path in sorted(glob.glob(input_path+"/*.*"))[:]:
    process_file(file_path,output_path)
model = model.cpu()
del model
torch.cuda.empty_cache()
#gc.collect()
