In [53]:
## IMPORT LIBRARIES 
import os
import sys
import igl
import time
import torch
import numpy as np
import meshplot as mp
from typing import Tuple
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

## IMPORT SOURCE
from smpl_torch_batch import SMPLModel

In [54]:
## LOAD ANIMATION DATA (ORIGINAL SCANS OF SINGLE INSTANCE)

training_data = torch.load('../data/50004_dataset.pt')
data_loader = torch.utils.data.DataLoader(training_data, batch_size=1, shuffle=False)

device = "cpu"
smpl_model = SMPLModel(device=device, model_path='../body_models/smpl/female/model.pkl')

for data in data_loader:
    
    beta_pose_trans_seq = data[0].squeeze().type(torch.float64)
    betas = beta_pose_trans_seq[:,:10]
    pose = beta_pose_trans_seq[:,10:82]
    trans = beta_pose_trans_seq[:,82:] 
    
    target_verts = data[1].squeeze()
    smpl_verts, joints = smpl_model(betas, pose, trans)
    
    V_smpl = np.array(smpl_verts, dtype=float)
    V_dfaust = np.array(target_verts, dtype=float)
    F = np.array(smpl_model.faces, dtype=int)

    break

In [55]:
def generate_sine_wave(
    amplitude: float,
    frequency: float,
    phase: float,
    sampling_rate: int,
    duration: float) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """Generates a sine wave

    Args:
        amplitude (float): Amplitude
        frequency (float): Frequnecy [Hz]
        phase (float): Phase [rad]
        sampling_rate (int): Sampling rate[Hz]
        duration (float): Duration [s]

    Returns:
        Tuple[torch.FloatTensor, torch.FloatTensor]: t, y(t)
    """
    total_steps = int(sampling_rate * duration)
    t = torch.linspace(0, duration, total_steps)
    y = amplitude * torch.sin(2 * np.pi * frequency * t + phase)

    return t, y

In [56]:
def get_fft_components(signal, fourier, sampling_rate = 1_000, plot=True):
    absolutes = fourier.abs()
    freq = torch.fft.rfftfreq(len(signal), 1/sampling_rate)
    
    if plot:
        fig, ax = plt.subplots()
        ax.set_xlabel('Frequencies')
        ax.set_ylabel('$|F(s)|$')
        ax.grid()
        ax.scatter(x=freq, y=absolutes, s=8)
        plt.savefig('./plots/fourier_transform_sampling_{}_{}.png'.format(sampling_rate, time.time()))
    
    return absolutes, freq

In [57]:
def get_fft_of_mesh_animation(signal_batch, verbose_idx=-1):
    """
    Signal_batch: nd.array has shape (num_frames, num_verts, dim=3)
    
    """
    
    time_dim, verts_dim, space_dim = signal_batch.shape
    
    fft_batch = []
    for i in range(verts_dim):
        xyz_fft = []
        for j in range(space_dim):
            signal = signal_batch[:, i, j]
            signal_fourier = torch.fft.rfft(signal, norm='forward') 
            absolutes, freqs = get_fft_components(signal, signal_fourier, plot=False)
            xyz_fft.append((absolutes, freqs))
            
        fft_batch.append(xyz_fft)     
        if verbose_idx > 0:
            if (i+1) % verbose_idx == 0:
                print(">> Step ", i+1 , "/", verts_dim)
    
    return fft_batch

In [58]:
smpl_fft = get_fft_of_mesh_animation(smpl_verts, verbose_idx=1000)
dfaust_fft = get_fft_of_mesh_animation(target_verts, verbose_idx=1000)

>> Step  1000 / 6890
>> Step  2000 / 6890
>> Step  3000 / 6890
>> Step  4000 / 6890
>> Step  5000 / 6890
>> Step  6000 / 6890
>> Step  1000 / 6890
>> Step  2000 / 6890
>> Step  3000 / 6890
>> Step  4000 / 6890
>> Step  5000 / 6890
>> Step  6000 / 6890


In [59]:
fft_diff = []
for smpl_vertex_fft, dfaust_vertex_fft in zip(smpl_fft, dfaust_fft):
    fft_diff_xyz = []
    for smpl_xyz_fft, dfaust_xyz_fft in zip(smpl_vertex_fft, dfaust_vertex_fft):
        
        smpl_abs, smpl_freqs  = smpl_xyz_fft
        dfaust_abs, dfaust_freqs = dfaust_xyz_fft
        
        assert np.all([smpl_freqs == dfaust_freqs]) # sanity check if the frequencies match
        
        fft_diff_xyz.append(dfaust_abs - smpl_abs)
    fft_diff.append(fft_diff_xyz)

In [67]:
def reconstruct_signal_from_fft(fft_tuple, duration, sampling_rate = 1000, 
                                reconst_all=True, lower_idx=0, num_freqs=20):
    
    absolutes, freqs = fft_tuple
    assert len(absolutes) == len(freqs)

    if reconst_all:
        num_freqs = len(freqs) # N=135
        lower_idx = 0

    reconst_signal = torch.zeros(duration * sampling_rate)
    ############# SUM OF SINUSOIDAL #########################################################################
    for i in range(lower_idx, lower_idx+num_freqs): 

        f = freqs[i]
        a = absolutes[i]
        t, y = generate_sine_wave(amplitude=a, frequency=f, phase=0, sampling_rate=sampling_rate, duration=duration)
        reconst_signal += y
    #########################################################################################################
    
    sampled_idx = (torch.linspace(0.25, 0.75, duration) * sampling_rate).long()
    return reconst_signal[sampled_idx]

In [76]:
duration, num_verts, space_dim = smpl_verts.shape
freqs = smpl_fft[0][0][1] # todo: frequencies stay the same, you don't even have to store them for each vertex!

reconst_diff = []
for vert_idx in range(num_verts):
    reconst_diff_xyz = []
    for dim in range(space_dim):
        
        fft_tuple = (fft_diff[vert_idx][dim], freqs) 
        reconst_diff_xyz.append(reconstruct_signal_from_fft(fft_tuple, duration))

    reconst_diff.append(reconst_diff_xyz)

torch.Size([268])


In [71]:
np.savez("FFT_diff_anim.npz", np.array(reconst_diff))

In [72]:
with np.load("FFT_diff_anim.npz") as file:
    V_diff_anim = file['arr_0']
    print(">> File loaded.")

>> File loaded.


In [82]:
print(V_diff_anim.shape)
V_diff_anim = np.swapaxes(V_diff_anim, 0,2)
V_diff_anim = np.swapaxes(V_diff_anim, 1,2)
print(V_diff_anim.shape)

(6890, 3, 268)
(268, 6890, 3)


In [None]:
v_start = np.array(smpl_verts[0])
p = mp.plot(v_start, F)

for k in range(3):
    for i in range(duration):
        
        v_new = [np.array(smpl_verts[i])][0] + V_diff_anim[i] #[][0] workaround for meshplot
        p.update_object(vertices=v_new)
        v = v_new
        time.sleep(0.1)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0253415…