## Compute Laplacian

In [None]:
## IMPORT LIBRARIES 
import os
import sys
import igl
import time
import torch
import numpy as np
import meshplot as mp
from typing import Tuple
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

## IMPORT SOURCE
from smpl_torch_batch import SMPLModel

In [None]:
## LOAD ANIMATION DATA (ORIGINAL SCANS OF SINGLE INSTANCE)

training_data = torch.load('../data/50004_dataset.pt')
data_loader = torch.utils.data.DataLoader(training_data, batch_size=1, shuffle=False)

device = "cpu"
smpl_model = SMPLModel(device=device, model_path='../body_models/smpl/female/model.pkl')

for data in data_loader:
    target_verts = data[1].squeeze()
    V = np.array(target_verts, dtype=float)
    F = np.array(smpl_model.faces, dtype=int)
    break
    

In [None]:
## Get Smoothed Animation

smoothing_steps = 10
verbose = False

v_smooth_arr = [] #v_smoothed = np.empty_like(V) (Somehow meshplot doesnt work with np)
anim_length = V.shape[0]
for i in range(anim_length):
    
    v_smooth_arr.append(V[i]) # v_smoothed[i] = V[i]
    for j in range(smoothing_steps):
        v_smooth_arr[i] = igl.per_vertex_attribute_smoothing(v_smooth_arr[i], F)
             
    if verbose:
        if (i+1) % 10 == 0: print(">> Step ", i+1, "/", anim_length)



In [None]:
# PLOT THE SMOOTHED ANIMATION
"""
v_start = v_smooth_arr[0]
p = mp.plot(v_start, F)

for k in range(2):
    for i in range(anim_length):

        v_new = v_smooth_arr[i]
        p.update_object(vertices=v_new)
        v = v_new
        time.sleep(0.1)
"""

## Experiments

In [None]:
def generate_sine_wave(
    amplitude: float,
    frequency: float,
    phase: float,
    sampling_rate: int,
    duration: float) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """Generates a sine wave

    Args:
        amplitude (float): Amplitude
        frequency (float): Frequnecy [Hz]
        phase (float): Phase [rad]
        sampling_rate (int): Sampling rate[Hz]
        duration (float): Duration [s]

    Returns:
        Tuple[torch.FloatTensor, torch.FloatTensor]: t, y(t)
    """
    total_steps = int(sampling_rate * duration)
    t = torch.linspace(0, duration, total_steps)
    y = amplitude * torch.sin(2 * np.pi * frequency * t + phase)

    return t, y

In [None]:
def get_fft_components(signal, fourier, sampling_rate = 1_000, plot=True):
    absolutes = fourier.abs()
    freq = torch.fft.rfftfreq(len(signal), 1/sampling_rate)
    
    if plot:
        fig, ax = plt.subplots()
        ax.set_xlabel('Frequencies')
        ax.set_ylabel('$|F(s)|$')
        ax.grid()
        ax.scatter(x=freq, y=absolutes, s=8)
        plt.savefig('./plots/fourier_transform_sampling_{}_{}.png'.format(sampling_rate, time.time()))
    
    return absolutes, freq

In [None]:
V_smooth_tensor = torch.from_numpy(np.array(v_smooth_arr))
V_tensor = torch.from_numpy(V)

print(V_smooth_tensor.shape, V_tensor.shape)


In [None]:
###### FFT Computation (Naive)

period = V_tensor.shape[0]
num_verts = V_tensor.shape[1]
dims = V_tensor.shape[2]


V_diff_anim = torch.empty_like(V_tensor)

for vert in range(num_verts):
    for dim in range(dims):
        V_signal = V_tensor[:, vert, dim]
        V_smooth_signal = V_smooth_tensor[:, vert, dim]
        V_diff_signal = V_signal - V_smooth_signal
     
        V_diff_fourier = torch.fft.rfft(V_diff_signal, norm='forward') 
        absolutes, freqs = get_fft_components(V_diff_signal, V_diff_fourier, plot=False)
        
        #### RECONSTRUCT SIGNAL #########################################################
        signal = V_diff_signal
        sampling_rate = 1000
        num_HPF = 100
        
        reconst_signal = torch.zeros(len(signal) * sampling_rate)
        T = len(signal)
        N = len(freqs) # N = 135

        for i in range(10, 80): #int(N-num_HPF), N):#!!!!!!!!!!!!!!!!!!!!!!
            if i > len(freqs) or i > len(absolutes):
                break
            f = freqs[i]
            a = absolutes[i]
            t, y = generate_sine_wave(amplitude=a, frequency=f, phase=0, sampling_rate=sampling_rate, duration=T)
            reconst_signal += y
        
        sampled_idx = (torch.linspace(0.25, 0.75, period) * sampling_rate).long()
        V_diff_reconst = reconst_signal[sampled_idx]
        #################################################################################
        V_diff_anim[:, vert, dim] = V_diff_reconst
        
    if (vert+1) % 100 == 0:
        print(">> Step ", vert+1 , "/", num_verts)

        

In [None]:
## ATTEMPT TO COMPUTE FFT AS A BATCH -not working-
signal = V_tensor[:, 0, 0] # a sample signal
sampling_rate = 1000
freq = torch.fft.rfftfreq(len(signal), 1/sampling_rate) # freqs spectrum remain the same for every vertex

V_diff_signal = (V_tensor - V_smooth_tensor)
V_diff_fourier = torch.fft.fftn(V_diff_signal, dim=2) 

absolutes = V_diff_fourier.abs()

print(freq.shape)
print(V_diff_fourier.shape, absolutes.shape)
 

In [None]:
np.savez("V_diff_anim.npz", np.array(V_diff_anim))

In [None]:
# PLOT THE ADDED JIGGLING ANIMATION
#V_diff_anim[:,300:,:] = 0.0
V_diff_anim = np.array(V_diff_anim)

v_start = v_smooth_arr[0]
p = mp.plot(v_start, F)

for k in range(3):
    for i in range(anim_length):
        
        v_new = v_smooth_arr[i] + V_diff_anim[i]
        p.update_object(vertices=v_new)
        v = v_new
        time.sleep(0.1)

# Visualize SMPL

In [None]:
with np.load("V_diff_anim.npz") as file:
    V_diff_anim = file['arr_0']
    print(">> File loaded.")

In [None]:

for data in data_loader:
    beta_pose_trans_seq = data[0].squeeze().type(torch.float64)
    betas = beta_pose_trans_seq[:,:10]
    pose = beta_pose_trans_seq[:,10:82]
    trans = beta_pose_trans_seq[:,82:] 
   
    target_verts = data[1].squeeze()
    smpl_verts, joints = smpl_model(betas, pose, trans)
    break

In [None]:
v_start = np.array(smpl_verts[0])
p = mp.plot(v_start, F)

for k in range(3):
    for i in range(anim_length):
        
        v_new = [np.array(smpl_verts[i])][0] + 2*V_diff_anim[i] #[][0] workaround for meshplot
        p.update_object(vertices=v_new)
        v = v_new
        time.sleep(0.1)