In [1]:
%cd ../challenge

/media/pips/Data/Projects/thesis-masters/verma-pytorch/challenge


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import numpy as np
import scipy.signal
import scipy.linalg
import warnings
from numba import njit


@njit
def apply_prediction_filter(past: np.ndarray, coeff: np.ndarray, steps: int) -> np.ndarray:
    prediction = np.zeros(steps)

    for i in range(steps):
        pred = np.dot(past, coeff)

        prediction[i] = pred

        past = np.roll(past, -1)
        past[-1] = pred

    return prediction


class ARModel:
    """
    AR model of order p.
    It finds the model parameters via the autocorrelation method and Levinson-Durbin recursion.
    It uses Numba jit complier to accelerate sample-by-sample inference.
    """

    def __init__(self, p: int, diagonal_load: float = 0.0):
        self.p = p
        self.diagonal_load = diagonal_load

        # Pre-compile Numba decorated function to expedite future calls
        apply_prediction_filter(past=np.zeros(self.p), coeff=np.ones(self.p), steps=1)

    def autocorrelation_method(self, valid: np.ndarray) -> np.ndarray:
        # Compute the sample autocorrelation function
        acf = scipy.signal.correlate(valid, valid, mode='full', method='auto')

        # Find the zeroth lag index
        zero_lag = len(acf) // 2

        # First column of the autocorrelation matrix
        c = acf[zero_lag:zero_lag + self.p]

        # Diagonal loading to improve conditioning
        c[0] += self.diagonal_load

        # Autocorrelation vector
        b = acf[zero_lag + 1:zero_lag + self.p + 1]

        # Solve the Toeplitz system of equations using the efficient Levinson-Durbin recursion
        ar_coeff = scipy.linalg.solve_toeplitz(c, b, check_finite=False)

        return ar_coeff

    def predict(self, valid: np.ndarray, steps: int) -> np.ndarray:
        # Find AR model parameters
        ar_coeff = self.autocorrelation_method(valid)

        # Apply linear prediction
        pred = apply_prediction_filter(
            past=valid[-self.p:],
            coeff=np.ascontiguousarray(ar_coeff[::-1], dtype=np.float32),  # needed for njit
            steps=steps
        )

        # Raise warning; helpful in case the AR model becomes numerically unstable.
        if np.any(np.abs(pred) > 1.0):
            warnings.warn(f'AR prediction exceeded the audio range [-1, 1]: found [{np.min(pred)}, {np.max(pred)}]',
                          RuntimeWarning)

        return pred

In [157]:
import torch
import numpy as np
from tqdm import tqdm
from copy import deepcopy


class PARCnet:

    def __init__(self,
                 packet_dim: int,
                 extra_dim: int,
                 ar_order: int,
                 ar_diagonal_load: float,
                 num_valid_ar_packets: int,
                 num_valid_nn_packets: int,
                 model_checkpoint: str,
                 xfade_len_in: int,
                 device: str = 'cpu',
                 ):

        self.packet_dim = packet_dim

        # Define the prediction length, including the extra length
        self.pred_dim = packet_dim + extra_dim

        # Define the AR and neural network contexts in sample
        self.ar_context_len = num_valid_ar_packets * packet_dim
        self.nn_context_len = num_valid_nn_packets * packet_dim

        # Define fade-in modulation vector (neural network contribution only)
        self.fade_in = np.ones(self.pred_dim)
        self.fade_in[:xfade_len_in] = np.linspace(0, 1, xfade_len_in)

        # Define fade-out modulation vector
        self.fade_out = np.ones(self.pred_dim)
        self.fade_out[-extra_dim:] = np.linspace(1, 0, extra_dim)

        # Instantiate the linear predictor
        self.ar_model = ARModel(ar_order, ar_diagonal_load)

        # Load the pretrained neural network
        self.neural_net = torch.jit.load(model_checkpoint, map_location=device)

    def __call__(self, input_signal: np.ndarray, trace: np.ndarray, **kwargs) -> np.ndarray:
        self.neural_net.eval()
        output_signal = deepcopy(input_signal)

        for i, loss in tqdm(enumerate(trace), total=len(trace)):
            if loss:
                # Start index of the ith packet
                idx = i * self.packet_dim

                # AR model context
                valid_ar_packets = output_signal[idx - self.ar_context_len:idx]
                valid_ar_packets = np.pad(valid_ar_packets, (self.ar_context_len - valid_ar_packets.shape[0], 0))

                # AR model inference
                ar_pred = self.ar_model.predict(valid=valid_ar_packets, steps=self.pred_dim)

                # NN model context
                nn_context = output_signal[idx - self.nn_context_len: idx]
                nn_context = np.pad(nn_context, (self.nn_context_len - nn_context.shape[0], 0))
                nn_context = np.pad(nn_context, (0, self.pred_dim))
                nn_context = torch.tensor(nn_context[None, None, ...])

                with torch.no_grad():
                    # NN model inference
                    nn_pred = self.neural_net(nn_context)
                    nn_pred = nn_pred[..., -self.pred_dim:]
                    nn_pred = nn_pred.squeeze().cpu().numpy()

                # Apply fade-in to the neural network contribution (inbound fade-in)
                nn_pred *= self.fade_in

                # Combine the two predictions
                prediction = ar_pred + nn_pred

                prediction_length = self.packet_dim

                # cross-fade if following packet is valid
                if  i + 1 < len(trace) and not trace[i + 1]: 
                    # Cross-fade the compound prediction (outbound fade-out)
                    prediction *= self.fade_out

                    # Cross-fade the output signal (outbound fade-in)
                    output_signal[idx:idx + self.pred_dim] *= 1 - self.fade_out

                    prediction_length = self.pred_dim

                # Conceal lost packet
                output_signal[idx: idx + prediction_length] += prediction[:prediction_length]

        return output_signal

In [5]:
import numpy as np
import librosa
import os
import soundfile as sf
from pathlib import Path

In [6]:
lossy_path = Path('lossy/')
traces_path = Path('traces/')
prediction_folder = Path('processed/')

In [7]:
with open("meta.txt") as f:
    filenames = [l.strip('\n') for l in f.readlines()]

In [69]:
traces = {}
for filename in filenames:
    with open(traces_path.joinpath(f"{filename}.txt")) as f:
        traces[filename] = np.asarray([int(l.strip()) for l in f.readlines()])

In [161]:
model = PARCnet(
    packet_dim=512,
    extra_dim=88,
    ar_order=128,
    ar_diagonal_load=0.001,
    num_valid_ar_packets=10,
    num_valid_nn_packets=7,
    model_checkpoint='parcnet_tloss-version_2.pth',
    xfade_len_in=10,
    device='cpu',
)

In [162]:
import shutil
shutil.rmtree(prediction_folder, ignore_errors=True)
pred = []
for filename, trace in traces.items():
    audio, _ = librosa.load(lossy_path.joinpath(f"{filename}.wav"), sr=44100)
    processed = model(audio, trace)
    pred.append((audio, processed))
    if not os.path.exists(prediction_folder):
        os.makedirs(prediction_folder)

    sf.write(prediction_folder.joinpath(f"{filename}.wav"), processed.T, 44100)

100%|██████████| 1000/1000 [00:06<00:00, 160.70it/s]
100%|██████████| 1000/1000 [00:01<00:00, 911.48it/s]
100%|██████████| 1000/1000 [00:03<00:00, 302.88it/s]
100%|██████████| 1000/1000 [00:02<00:00, 393.81it/s]
100%|██████████| 1000/1000 [00:00<00:00, 7018.84it/s]
100%|██████████| 1000/1000 [00:04<00:00, 226.81it/s]
100%|██████████| 1000/1000 [00:03<00:00, 301.88it/s]
100%|██████████| 1000/1000 [00:02<00:00, 463.64it/s]
100%|██████████| 1000/1000 [00:02<00:00, 341.71it/s]
100%|██████████| 1000/1000 [00:02<00:00, 356.72it/s]
100%|██████████| 1000/1000 [00:02<00:00, 446.56it/s]
100%|██████████| 1000/1000 [00:01<00:00, 944.57it/s]
100%|██████████| 1000/1000 [00:01<00:00, 753.57it/s]
100%|██████████| 1000/1000 [00:03<00:00, 284.05it/s]
100%|██████████| 1000/1000 [00:02<00:00, 398.89it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1501.05it/s]
100%|██████████| 1000/1000 [00:01<00:00, 689.90it/s]
100%|██████████| 1000/1000 [00:02<00:00, 367.36it/s]
100%|██████████| 1000/1000 [00:01<00:00, 942