In [1]:
import datetime
import functorch
import scipy.signal
import kymatio
from kymatio.torch import TimeFrequencyScattering1D
import numpy as np
import os
import pandas as pd
from pnp_synth.physical import ftm
from pnp_synth.perceptual import jtfs
import scipy.signal
import sklearn
from sklearn.preprocessing import MinMaxScaler
import sys
import time
import torch


csv_path = os.path.expanduser("~/perceptual_neural_physical/data")

for module in [functorch, kymatio, np, pd, sklearn, torch]:
    print("{} version: {:s}".format(module.__name__, module.__version__))
print("")

functorch version: 0.2.1
kymatio version: 0.3.dev0
numpy version: 1.21.5
pandas version: 1.3.5
sklearn version: 1.1.2
torch version: 1.12.1



In [2]:
folds = ["train", "test", "val"]
fold_dfs = {}
for fold in folds:
    csv_name = fold + "_param_log_v2.csv"
    csv_path = os.path.join("..", "data", csv_name)
    fold_df = pd.read_csv(csv_path)
    fold_dfs[fold] = fold_df

full_df = pd.concat(fold_dfs.values()).sort_values(
    by="ID", ignore_index=False)
assert len(set(full_df["ID"])) == len(full_df)

def preprocess_gt(full_df):
    # takes dataframe, scale values in dataframe, output dataframe and scaler
    train_df = full_df.loc[full_df["set"] == "train"]
    # normalize
    scaler = MinMaxScaler()
    scaler.fit(train_df.values[:, 3:-1])
    full_df_norm = scaler.transform(
        full_df.values[:, 3:-1]
    )  # just a tensor, not dataframe
    return full_df_norm, scaler

full_df_norm, scaler = preprocess_gt(full_df)

def pnp_forward(Phi, g, scaler, rescaled_param):
    # Inverse parameter scaling
    sc_max = torch.tensor(scaler.data_max_)
    sc_min = torch.tensor(scaler.data_min_)
    theta = rescaled_param * (sc_max - sc_min) + sc_min

    # Synthesis
    x = g(theta)

    # Spectral analysis
    S = Phi(x)
    return S

def icassp23_synth(rescaled_param):
    return torch.nn.functional.pad(
        ftm.rectangular_drum(rescaled_param, **ftm.constants),
        (2**16, 0),
        mode='constant',
        value=0
    )

jtfs_params = dict(
    J = 14, #scale
    shape = (2**17, ), 
    Q = 12, #filters per octave, frequency resolution
    T = 2**14, # time averaging in samples
    F = 2, # frequency averaging in octaves
    max_pad_factor=1,
    max_pad_factor_fr=1,
    average = True,
    average_fr = True,
)
jtfs_operator = TimeFrequencyScattering1D(**jtfs_params)

def icassp23_pnp_forward(rescaled_param):    
    S = pnp_forward(
        Phi=jtfs_operator,
        g=icassp23_synth,
        scaler=scaler,
        rescaled_param=rescaled_param)
    unpadded_S = S[:, S.shape[1]//2:, :]
    return unpadded_S.flatten()

icassp23_pnp_jacobian = functorch.jacfwd(icassp23_pnp_forward)

In [3]:
start_time = int(time.time())

S = icassp23_pnp_forward(torch.ones((5,))*0.5)

elapsed_time = time.time() - int(start_time)
elapsed_hours = int(elapsed_time / (60 * 60))
elapsed_minutes = int((elapsed_time % (60 * 60)) / 60)
elapsed_seconds = elapsed_time % 60.
elapsed_str = "{:>02}:{:>02}:{:>05.2f}".format(elapsed_hours,
                                               elapsed_minutes,
                                               elapsed_seconds)
print("Total elapsed time: " + elapsed_str + ".")

Total elapsed time: 00:00:36.93.


In [4]:
start_time = int(time.time())

grads = icassp23_pnp_jacobian(torch.ones((5,))*0.5)
JTJ = torch.matmul(grads.T, grads)

elapsed_time = time.time() - int(start_time)
elapsed_hours = int(elapsed_time / (60 * 60))
elapsed_minutes = int((elapsed_time % (60 * 60)) / 60)
elapsed_seconds = elapsed_time % 60.
elapsed_str = "{:>02}:{:>02}:{:>05.2f}".format(elapsed_hours,
                                               elapsed_minutes,
                                               elapsed_seconds)
print("Total elapsed time: " + elapsed_str + ".")



Total elapsed time: 00:03:28.28.


In [10]:
torch.svd(JTJ)

torch.return_types.svd(
U=tensor([[-1.0000e+00, -1.7002e-03,  2.3553e-04,  2.1577e-04,  1.1748e-16],
        [ 1.6808e-03, -9.9849e-01, -4.9252e-02, -2.4251e-02,  1.6975e-16],
        [-3.0662e-04,  4.7945e-02, -9.9093e-01,  3.8417e-02,  1.1950e-01],
        [-2.6269e-04,  2.5268e-02, -2.3423e-02, -9.9279e-01,  1.1480e-01],
        [ 6.7733e-05, -8.7510e-03,  1.2280e-01,  1.1092e-01,  9.8618e-01]],
       dtype=torch.float64),
S=tensor([1.3816e+04, 6.8258e+01, 2.8593e+00, 2.1038e-02, 4.4271e-17],
       dtype=torch.float64),
V=tensor([[-1.0000e+00, -1.7002e-03,  2.3553e-04,  2.1577e-04, -1.1734e-16],
        [ 1.6808e-03, -9.9849e-01, -4.9252e-02, -2.4251e-02, -2.1224e-16],
        [-3.0662e-04,  4.7945e-02, -9.9093e-01,  3.8417e-02, -1.1950e-01],
        [-2.6269e-04,  2.5268e-02, -2.3423e-02, -9.9279e-01, -1.1480e-01],
        [ 6.7733e-05, -8.7510e-03,  1.2280e-01,  1.1092e-01, -9.8618e-01]],
       dtype=torch.float64))

In [11]:
torch.svd(JTJ).U @ torch.diag(torch.svd(JTJ).S) @ torch.svd(JTJ).V.T

tensor([[ 1.3816e+04, -2.3106e+01,  4.2300e+00,  3.6264e+00, -9.3470e-01],
        [-2.3106e+01,  6.8098e+01, -3.1353e+00, -1.7244e+00,  5.8064e-01],
        [ 4.2300e+00, -3.1353e+00,  2.9659e+00,  1.4937e-01, -3.7677e-01],
        [ 3.6264e+00, -1.7244e+00,  1.4937e-01,  6.6837e-02, -2.5879e-02],
        [-9.3470e-01,  5.8064e-01, -3.7677e-01, -2.5879e-02,  4.8666e-02]],
       dtype=torch.float64)

In [12]:
JTJ

tensor([[ 1.3816e+04, -2.3106e+01,  4.2300e+00,  3.6264e+00, -9.3470e-01],
        [-2.3106e+01,  6.8098e+01, -3.1353e+00, -1.7244e+00,  5.8064e-01],
        [ 4.2300e+00, -3.1353e+00,  2.9659e+00,  1.4937e-01, -3.7677e-01],
        [ 3.6264e+00, -1.7244e+00,  1.4937e-01,  6.6837e-02, -2.5879e-02],
        [-9.3470e-01,  5.8064e-01, -3.7677e-01, -2.5879e-02,  4.8666e-02]],
       dtype=torch.float64)

In [26]:
exclude = 2
sub = torch.cat((JTJ[:exclude, :], JTJ[exclude+1:, :]), axis=0)
sub = torch.cat((sub[:, :exclude], sub[:, exclude+1:]), axis=1)

torch.svd(sub)

torch.return_types.svd(
U=tensor([[-1.0000e+00, -1.6873e-03,  1.3530e-04, -1.8170e-04],
        [ 1.6807e-03, -9.9964e-01, -1.7719e-02,  1.9901e-02],
        [-2.6269e-04,  2.5248e-02, -3.9122e-01,  9.1995e-01],
        [ 6.7725e-05, -8.5149e-03,  9.2013e-01,  3.9152e-01]],
       dtype=torch.float64),
S=tensor([1.3816e+04, 6.8107e+01, 4.8350e-02, 1.7804e-02], dtype=torch.float64),
V=tensor([[-1.0000e+00, -1.6873e-03,  1.3530e-04, -1.8170e-04],
        [ 1.6807e-03, -9.9964e-01, -1.7719e-02,  1.9901e-02],
        [-2.6269e-04,  2.5248e-02, -3.9122e-01,  9.1995e-01],
        [ 6.7725e-05, -8.5149e-03,  9.2013e-01,  3.9152e-01]],
       dtype=torch.float64))

##### import matplotlib
%matplotlib inline
from matplotlib import pyplot as plt

plt.plot(np.max(S[0,:,:].numpy(), axis=0), '-o')

In [None]:
from matplotlib import pyplot as plt
plt.plot(np.cumsum(np.sort(torch.max(np.abs(grads), axis=-1).values))[::-1])

In [None]:
torch.svd(JTJ)