In [1]:
import os
import re
from tqdm import tqdm

import torch

In [3]:
ignored_columns = [
    "EyeBlinkLeft", 
    "EyeBlinkRight", 
    "EyeLookDownLeft", 
    "EyeLookDownRight", 
    "EyeLookInLeft", 
    "EyeLookInRight", 
    "EyeLookOutLeft", 
    "EyeLookOutRight", 
    "EyeBlinkLeft", 
    "EyeBlinkRight", 
    "EyeLookDownLeft", 
    "EyeLookDownRight", 
    "EyeLookInLeft", 
    "EyeLookInRight", 
    "EyeLookOutLeft", 
    "EyeLookOutRight", 
    "EyeLookUpLeft", 
    "EyeLookUpRight", 
    "EyeSquintLeft", 
    "EyeSquintRight", 
    "EyeWideLeft", 
    "EyeWideRight", 
    "BrowDownLeft", 
    "BrowDownRight", 
    "BrowInnerUp", 
    "BrowOuterUpLeft", 
    "BrowOuterUpRight", 
    "CheekSquintLeft", 
    "CheekSquintRight", 
    "JawLeft", 
    "JawRight", 
    "MouthLeft", 
    "MouthRight", 
    "MouthUpperUpLeft", 
    "MouthUpperUpRight", 
    "MouthLowerDownLeft", 
    "MouthLowerDownRight", 
    "MouthSmileLeft", 
    "MouthSmileRight", 
    "MouthFrownLeft", 
    "MouthFrownRight", 
    "NoseSneerLeft", 
    "NoseSneerRight", 
    "HeadYaw", 
    "HeadPitch", 
    "HeadRoll", 
    "TongueOut", 
    "LeftEyeYaw", 
    "LeftEyePitch", 
    "LeftEyeRoll", 
    "RightEyeYaw", 
    "RightEyePitch", 
    "RightEyeRoll"
]

In [4]:
# audio_ggongggong
indices = []
spectrograms = []
spectrogram_lengths = []

# shape_ggongggong
timecodes = []
blendshapes = []
blendshape_lengths = []
f_names = []

for essential in tqdm(os.scandir('/shared/air/shared/youngkim/mediazen/preprocessed/test/essentials/')):
    
    f_name = os.path.splitext(essential.name)[0]
    idx = int(f_name.split('_')[0])
    speaker = re.sub(r'[0-9]+', '', f_name.split('_')[2])

    # spectrogram: torch.Tensor (audio_frame, 161)
    # sample_rate: Int, 16000
    # blendshape: Dict (Timecode, BlendShapeCount, *(BlendShapeColumns))
    spectrogram, sample_rate, blendshape = torch.load(essential.path)
    spectrogram_length = len(spectrogram)
    # ignored_columns에 대문자 오타때문에 다시 걸러줌
    for column in ignored_columns:
        if column in blendshape.keys():
            del(blendshape[column])
    timecode = blendshape.pop('Timecode') # List (shape_frame)
    blendshape_count = blendshape.pop('BlendShapeCount')[0] # Int, 61 -> 필요없음
    blendshape_columns = list(blendshape.keys()) # List (num. of blendshape)
    try:
        blendshape_tensor = torch.Tensor(list(blendshape.values())).T # torch.Tensor (shape_frame, num. of blendshape)
    except TypeError:
        print('blendshape type error: ', essential.path)
        continue
    blendshape_length = len(blendshape_tensor)
    

    # error check
    if sample_rate != 16000:
        print('sample rate error: ', essential.path)
        continue

    if torch.sum(spectrogram.isnan()):
        print('spectrogram nan error: ', essential.path)
        continue
    
    if torch.sum(blendshape_tensor.isnan()):
        print('blendshape nan error: ', essential.path)
        break

    # audio_ggongggong
    indices.append(idx)
    spectrograms.append(spectrogram)
    spectrogram_lengths.append(spectrogram_length)
    # shape_ggongggong
    timecodes.append(timecode)
    blendshapes.append(blendshape_tensor)
    blendshape_lengths.append(blendshape_length)
    f_names.append(f_name)
    

indices_tensor = torch.IntTensor(indices)
padded_spectrogram_tensors = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
spectrogram_length_tensors = torch.IntTensor(spectrogram_lengths)

stripped_timecodes = [torch.LongTensor([int(time.replace(':', '').replace('.', '')) for time in timecode]) for timecode in timecodes]
padded_timecodes = torch.nn.utils.rnn.pad_sequence(stripped_timecodes, batch_first=True)
padded_blendshape_tensors = torch.nn.utils.rnn.pad_sequence(blendshapes, batch_first=True)
blendshape_length_tensors = torch.IntTensor(blendshape_lengths)

if len(padded_spectrogram_tensors) != len(padded_blendshape_tensors):
    print("length error, audio and shape length doesn't match.")
else:
    print(len(padded_spectrogram_tensors), ' sentences loaded.')

spectrogram_data = (sample_rate, indices_tensor, padded_spectrogram_tensors, spectrogram_length_tensors)
blendshape_data = (padded_timecodes, blendshape_count, blendshape_columns, padded_blendshape_tensors, blendshape_length_tensors, f_names)

7it [00:00, 66.50it/s]

7  sentences loaded.





In [5]:
target_dir = '/shared/air/shared/youngkim/mediazen/preprocessed/test/ggongggong'

target_spec_data = os.path.join(target_dir, 'audio_ggongggong.pt')
target_shape_data = os.path.join(target_dir, 'shape_ggongggong.pt')

torch.save(spectrogram_data, target_spec_data)
torch.save(blendshape_data, target_shape_data)

In [7]:
len(blendshape_data[2])

16

In [8]:
import pandas as pd

timecode = blendshape_data[0][0]

recovered_timecode =  [f'{(s := str(time.item()))[:-9]}:{s[-9:-7]}:{s[-7:-5]}:{s[-5:-3]}.{s[-3:]}' for time in timecode]
timecode_index = pd.Index(recovered_timecode, name='Timecode')

In [12]:
column = blendshape_data[2]
recovered_column = ['BlendShapeCount', *column]
len(recovered_column)

17

In [10]:
import numpy as np

idx = 0
length = blendshape_data[4][idx]
index = spectrogram_data[1][idx]
timecode = blendshape_data[0][idx]
column = blendshape_data[2]


# def save_to_csv(self, length, index, timecode, column, prediction):
recovered_timecode =  [f'{(s := str(time.item()))[:-9]}:{s[-9:-7]}:{s[-7:-5]}:{s[-5:-3]}.{s[-3:]}' for time in timecode]
timecode_index = pd.Index(recovered_timecode, name='Timecode')

blendshape_count = np.expand_dims(np.full(len(timecode), len(column)), axis=1)
filtered_prediction = np.apply_along_axis(self.column_filter, 0, prediction)
recovered_content = np.hstack([blendshape_count, filtered_prediction])

recovered_column = ['BlendShapeCount', *column]

df = pd.DataFrame(recovered_content, index=timecode_index, columns=recovered_column)
chopped_df = df[:length.item()]
chopped_df.to_csv(os.path.join(self.target_dir, f'{index}_prediction.csv'))

Index(['23:51:40:02.131', '23:51:40:03.131', '23:51:40:04.130',
       '23:51:40:05.130', '23:51:40:06.129', '23:51:40:08.129',
       '23:51:40:09.128', '23:51:40:10.128', '23:51:40:11.127',
       '23:51:40:12.127',
       ...
       ':::.0', ':::.0', ':::.0', ':::.0', ':::.0', ':::.0', ':::.0', ':::.0',
       ':::.0', ':::.0'],
      dtype='object', name='Timecode', length=7012)

In [10]:
def tensor_delete(tensor, remove_idx):
    return torch.cat((tensor[:remove_idx], tensor[remove_idx+1:])).detach().clone()

In [11]:
remove_idx = 1270

In [12]:
new_spectrogram_data = list(spectrogram_data)

for i in [1, 2, 3]:
    print(new_spectrogram_data[i].shape)
    new_spectrogram_data[i] = tensor_delete(new_spectrogram_data[i], remove_idx)
    print(new_spectrogram_data[i].shape)

new_spectrogram_data = tuple(new_spectrogram_data)

new_blendshape_data = list(blendshape_data)

for i in [0, 3, 4]:
    print(new_blendshape_data[i].shape)
    new_blendshape_data[i] = tensor_delete(new_blendshape_data[i], remove_idx)
    print(new_blendshape_data[i].shape)

new_blendshape_data = tuple(new_blendshape_data)

torch.Size([0])
torch.Size([0])
torch.Size([1443, 11682, 161])
torch.Size([1442, 11682, 161])
torch.Size([1443])
torch.Size([1442])
torch.Size([1443, 7012])
torch.Size([1442, 7012])
torch.Size([1443, 7012, 16])
torch.Size([1442, 7012, 16])
torch.Size([1443])
torch.Size([1442])
