## Importando módulos

In [1]:
import torch
from torch import nn

## Disponibilidad de dispositivo

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Usando {}'.format(device))

Usando cuda


## Carga de datos
* Depende de los scripts de run_training, config_parser, data_handler y batcher

In [3]:
from run_training import get_train_elements

In [4]:
config, data_handler, batcher = get_train_elements()

Checkpoint dir already exists ./training
Use existing config ./training/config.pkl
Loading data
Loading face vertices
Loading templates
Loading raw audio
Process audio
Loading index maps
Initialize data splits
Initialize training, validation, and test indices
sequence data missing FaceTalk_170811_03274_TA - sentence01
sequence data missing FaceTalk_170811_03274_TA - sentence02
sequence data missing FaceTalk_170811_03274_TA - sentence24
sequence data missing FaceTalk_170913_03279_TA - sentence12
sequence data missing FaceTalk_170913_03279_TA - sentence38
sequence data missing FaceTalk_170912_03278_TA - sentence11
sequence data missing FaceTalk_170809_00138_TA - sentence32


* Conversión de datos a tensores de tipo float32 (originalmente float64 o double para pytorch)

In [5]:
processed_audio, face_vertices, face_templates, subject_idx = batcher.get_training_batch(config['batch_size'])
processed_audio = torch.from_numpy(processed_audio).type(torch.float32)
face_vertices = torch.from_numpy(face_vertices).type(torch.float32)
face_templates = torch.from_numpy(face_templates).type(torch.float32)
subject_idx = torch.from_numpy(subject_idx)
print("processed audio: ", processed_audio.shape, processed_audio.dtype)
print("face vertices: ", face_vertices.shape, face_vertices.dtype)
print("face templates: ", face_templates.shape, face_templates.dtype)
print("subject index: ", subject_idx.shape, subject_idx.dtype)

processed audio:  torch.Size([128, 16, 29]) torch.float32
face vertices:  torch.Size([128, 5023, 3]) torch.float32
face templates:  torch.Size([128, 5023, 3]) torch.float32
subject index:  torch.Size([128]) torch.int64


In [6]:
condition = nn.functional.one_hot(subject_idx, batcher.get_num_training_subjects())
print(condition.shape)

torch.Size([128, 8])


## Speech Encoder

In [7]:
speech_encoding_dim = config['expression_dim']
condition_speech_features = config['condition_speech_features']
speech_encoder_size_factor = config['speech_encoder_size_factor']

batch_norm = nn.BatchNorm1d(num_features=29, eps=1e-5, momentum=0.9)

time_convs = nn.Sequential(
            nn.Conv2d(in_channels=37, out_channels=32, kernel_size=(1,3), stride=(1,2), padding=(0,1)),
            nn.ReLU(), # [128, 32, 1, 8]
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,3), stride=(1,2), padding=(0,1)),
            nn.ReLU(), # [128, 32, 1, 4]
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,3), stride=(1,2), padding=(0,1)),
            nn.ReLU(), # [128, 64, 1, 2]
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1,3), stride=(1,2), padding=(0,1)),
            nn.ReLU(), # [128, 64, 1, 1]
        )

flatten = nn.Flatten()

fc_layers = nn.Sequential(
            nn.Linear(72, 128),
            nn.Tanh(),
            nn.Linear(128, speech_encoding_dim)
        )

* Debido a que el Batch Normalization espera un tensor de la forma $[128, 29, 16]$ es necesario cambiar las dimensiones del tensor original de $[128, 16, 29]$

In [8]:
processed_audio = processed_audio.permute(0,2,1)
print("processed audio: ", processed_audio.shape)

processed audio:  torch.Size([128, 29, 16])


In [9]:
features_norm = batch_norm(processed_audio)
print("features: ", features_norm.shape)
speech_features_reshaped = torch.reshape(features_norm, (-1, features_norm.shape[1], 1, features_norm.shape[2]))
print("features reshaped: ", speech_features_reshaped.shape)
condition_reshaped = torch.reshape(condition, (-1, condition.shape[1], 1, 1))
print("condition reshaped: ", condition_reshaped.shape)
# función equivalente en pytorch a tf.transpose en tensores de n-dimensiones
speech_feature_condition = condition_reshaped.permute(0, 1, 3, 2)
print("feature condition: ", speech_feature_condition.shape)
speech_feature_condition = torch.tile(speech_feature_condition, (1, 1, 1, features_norm.shape[2]))
print("feature condition: ", speech_feature_condition.shape)
speech_features_reshaped = torch.cat((speech_features_reshaped, speech_feature_condition), dim=1)
print("features reshaped: ", speech_features_reshaped.shape)

features:  torch.Size([128, 29, 16])
features reshaped:  torch.Size([128, 29, 1, 16])
condition reshaped:  torch.Size([128, 8, 1, 1])
feature condition:  torch.Size([128, 8, 1, 1])
feature condition:  torch.Size([128, 8, 1, 16])
features reshaped:  torch.Size([128, 37, 1, 16])


* Se utiliza padding a diferencia de la versión en Tensorflow debido a que los cálculos no resultarían como lo dicho en el paper:
    * Primera capa sin padding: $[128, 32, 1, 7]$
    * Segunda capa sin padding: $[128, 32, 1, 3]$
    * Tercera capa sin padding: $[128, 64, 1, 1]$
    * Cuarta capa sin padding: $indefinido$ (no se puede aplicar un kernel de $1\times3$)

In [10]:
features = time_convs(speech_features_reshaped)
print(features.shape)
features_flat = flatten(features)
print(features_flat.shape)

torch.Size([128, 64, 1, 1])
torch.Size([128, 64])


In [11]:
concatenated = torch.cat((features_flat, condition), dim=1)
print(concatenated.shape)
fc_result = fc_layers(concatenated)
print(fc_result.shape)

torch.Size([128, 72])
torch.Size([128, 50])


## Speech Decoder

In [12]:
expression_basis_fname = config['expression_basis_fname']
init_expression = config['init_expression']

num_vertices = config['num_vertices']
expression_dim = config['expression_dim']

decoder = nn.Linear(expression_dim, 3*num_vertices)

In [13]:
exp_offset = decoder(fc_result)
print(exp_offset.shape)
exp_offset = torch.reshape(exp_offset, (-1, num_vertices, 3))
print(exp_offset.shape)

torch.Size([128, 15069])
torch.Size([128, 5023, 3])


In [14]:
predicted = exp_offset + face_templates
print(predicted.shape)

torch.Size([128, 5023, 3])
