In [97]:
import numpy as np
import scipy as sp
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
import torch
import fb_utils as fb
import time

In [87]:
def tight(w, ver='poly'):
    if ver == 'poly':
        M, N = w.shape
        w_freqz = np.fft.fft(w,axis=1).T
        w_tight = np.zeros((M, N), dtype=np.complex64)
        for k in range(N):
            H = w_freqz[k, :]
            U = H / np.linalg.norm(H)
            w_tight[:,k] = np.conj(U)
        w_tight = np.fft.ifft(w_tight.T, axis=0).T
    else:
        W = np.concatenate([sp.linalg.circulant(w[k, :]) for k in range(w.shape[0])])
        S = np.matmul(W.T,W)
        S_sq = np.linalg.inv(sp.linalg.sqrtm(S))
        w_tight = np.matmul(S_sq,w.T).T
    return w_tight

In [109]:
def fir_tightener3000(w, supp, eps=1.1, print_kappa=False):
    A,B = fb.frame_bounds_lp(w)
    w_tight = w.copy()
    while B/A > eps:
        w_tight = tight(w_tight, ver='poly')
        w_tight[:,supp:] = 0
        w_tight = np.real(w_tight)
        A,B = fb.frame_bounds_lp(w_tight)
        if print_kappa:
            print('kappa:', B/A, 'error:', np.linalg.norm(w-w_tight))
    return w_tight

In [110]:
T = 1024
w = fb.random_filterbank(4096, 96, T, tight=False, support_only=False, to_torch=False)

w_tight = fir_tightener3000(w, T, eps=1.001)
