## Importando módulos

In [1]:
import torch
from torch import nn

## Carga de datos
* Depende de los scripts de run_training, config_parser, data_handler y batcher

In [2]:
from run_training import get_train_elements

In [3]:
config, batcher = get_train_elements()

Checkpoint dir already exists ./training
Use existing config ./training/config.pkl
Loading data
Loading face vertices
Loading templates
Loading raw audio
Process audio
Loading index maps
Initialize data splits
Initialize training, validation, and test indices
sequence data missing FaceTalk_170811_03274_TA - sentence01
sequence data missing FaceTalk_170811_03274_TA - sentence02
sequence data missing FaceTalk_170811_03274_TA - sentence24
sequence data missing FaceTalk_170913_03279_TA - sentence12
sequence data missing FaceTalk_170913_03279_TA - sentence38
sequence data missing FaceTalk_170912_03278_TA - sentence11
sequence data missing FaceTalk_170809_00138_TA - sentence32


* Conversión de datos a tensores de tipo float32 (originalmente float64 o double para pytorch)

In [4]:
import numpy as np

In [5]:
processed_audio, face_vertices, face_templates, subject_idx = batcher.get_training_batch(config['batch_size'])

processed_audio = np.expand_dims(processed_audio, -1)
face_vertices = np.expand_dims(face_vertices, -1)
face_templates = np.expand_dims(face_templates, -1)

processed_audio = torch.from_numpy(processed_audio).type(torch.float32)
face_vertices = torch.from_numpy(face_vertices).type(torch.float32)
face_templates = torch.from_numpy(face_templates).type(torch.float32)
subject_idx = torch.from_numpy(subject_idx)

print("processed audio: ", processed_audio.shape, processed_audio.dtype)
print("face vertices: ", face_vertices.shape, face_vertices.dtype)
print("face templates: ", face_templates.shape, face_templates.dtype)
print("subject index: ", subject_idx.shape, subject_idx.dtype)

processed audio:  torch.Size([128, 16, 29, 1]) torch.float32
face vertices:  torch.Size([128, 5023, 3, 1]) torch.float32
face templates:  torch.Size([128, 5023, 3, 1]) torch.float32
subject index:  torch.Size([128]) torch.int64


In [6]:
num_training_subjects = batcher.get_num_training_subjects()
val_processed_audio, val_face_vertices, val_face_templates, _ = batcher.get_validation_batch(config['batch_size'])

val_processed_audio = np.expand_dims(np.tile(val_processed_audio, (num_training_subjects, 1, 1)), -1)
val_face_vertices = np.expand_dims(np.tile(val_face_vertices, (num_training_subjects, 1, 1)), -1)
val_face_templates = np.expand_dims(np.tile(val_face_templates, (num_training_subjects, 1, 1)), -1)

val_processed_audio = torch.from_numpy(val_processed_audio).type(torch.float32)
val_face_vertices = torch.from_numpy(val_face_vertices).type(torch.float32)
val_face_templates = torch.from_numpy(val_face_templates).type(torch.float32)

print("processed audio: ", val_processed_audio.shape, val_processed_audio.dtype)
print("face vertices: ", val_face_vertices.shape, val_face_vertices.dtype)
print("face templates: ", val_face_templates.shape, val_face_templates.dtype)

processed audio:  torch.Size([1024, 16, 29, 1]) torch.float32
face vertices:  torch.Size([1024, 5023, 3, 1]) torch.float32
face templates:  torch.Size([1024, 5023, 3, 1]) torch.float32


In [7]:
condition = nn.functional.one_hot(subject_idx, batcher.get_num_training_subjects())
print(condition.shape)

torch.Size([128, 8])


In [8]:
val_condition = np.reshape(np.repeat(np.arange(num_training_subjects)[:,np.newaxis],
                repeats=config['num_consecutive_frames']*config['batch_size'], axis=-1), [-1,])
val_condition = torch.from_numpy(val_condition)
val_condition = nn.functional.one_hot(val_condition, batcher.get_num_training_subjects())
print(val_condition.shape)

torch.Size([1024, 8])


## Speech Encoder

In [9]:
class FCLayer(nn.Module):
    def __init__(self, in_units, out_units, init_weights=None, weightini=0.1, bias=0.0):
        super().__init__()
        self.layer = nn.Linear(in_units, out_units)

        # inicialización de pesos
        if init_weights is not None:
            self.layer.weight.data = init_weights
        elif weightini == 0.0:
            nn.init.constant_(self.layer.weight, weightini)
        else:
            nn.init.normal_(self.layer.weight, std=weightini)
        
        # inicialización de bias
        nn.init.constant_(self.layer.bias, bias)
    
    def forward(self, x):
        return self.layer(x)

In [10]:
class CustomConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, k_size, stride=(0,0), padding=(0,0), std_dev=0.02):
        super().__init__()
        self.conv_layer = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=k_size, stride=stride, padding=padding)

        # inicialización de pesos y bias
        nn.init.normal_(self.conv_layer.weight, std=std_dev)
        nn.init.normal_(self.conv_layer.bias, std=std_dev)
    
    def forward(self, x):
        return self.conv_layer(x)

In [11]:
speech_encoding_dim = config['expression_dim']
condition_speech_features = config['condition_speech_features']
speech_encoder_size_factor = config['speech_encoder_size_factor']

# En tensorflow, el batchnorm usa el formato [N, H, W, C] y los autores ingresan el tensor de [N, 16, 29, 1] tal cual
# en tf, se usa decay en vez de momentum, momentum = 1 - decay segun el foro
# https://discuss.pytorch.org/t/convering-a-batch-normalization-layer-from-tf-to-pytorch/20407
batch_norm = nn.BatchNorm2d(num_features=1, eps=1e-5, momentum=0.1)

time_convs = nn.Sequential(
            CustomConv2d(in_ch=37, out_ch=32, k_size=(3,1), stride=(2,1), padding=(1,0)),
            nn.ReLU(), # [128, 32, 8, 1]
            CustomConv2d(in_ch=32, out_ch=32, k_size=(3,1), stride=(2,1), padding=(1,0)),
            nn.ReLU(), # [128, 32, 4, 1]
            CustomConv2d(in_ch=32, out_ch=64, k_size=(3,1), stride=(2,1), padding=(1,0)),
            nn.ReLU(), # [128, 64, 2, 1]
            CustomConv2d(in_ch=64, out_ch=64, k_size=(3,1), stride=(2,1), padding=(1,0)),
            nn.ReLU() # [128, 64, 1, 1]
        )

flatten = nn.Flatten()

fc_layers = nn.Sequential(
            FCLayer(72, 128),
            nn.Tanh(),
            FCLayer(128, speech_encoding_dim)
        )

* Debido a que el Batch Normalization espera un tensor de la forma $[N, C, H, W]$ es necesario cambiar las dimensiones del tensor original de $[N, H, W, C]$

In [12]:
processed_audio = processed_audio.permute(0,3,1,2)
print("processed audio: ", processed_audio.shape)

processed audio:  torch.Size([128, 1, 16, 29])


In [13]:
val_processed_audio = val_processed_audio.permute(0,3,1,2)
print("processed audio: ", val_processed_audio.shape)

processed audio:  torch.Size([1024, 1, 16, 29])


In [14]:
features_norm = batch_norm(processed_audio)
print("features: ", features_norm.shape)
# Regresar a la forma  el dato
features_norm = features_norm.permute(0, 2, 3, 1)
print("features: ", features_norm.shape)

speech_features_reshaped = torch.reshape(features_norm, (-1, features_norm.shape[1], 1, features_norm.shape[2]))
print("features reshaped: ", speech_features_reshaped.shape)

# función equivalente en pytorch a tf.transpose en tensores de n-dimensiones
speech_feature_condition = torch.reshape(condition, (-1, condition.shape[1], 1, 1)).permute(0, 2, 3, 1)#(0, 1, 3, 2)
print("feature condition: ", speech_feature_condition.shape)

speech_feature_condition = torch.tile(speech_feature_condition, (1, features_norm.shape[1], 1, 1))
print("feature condition: ", speech_feature_condition.shape)

speech_features_reshaped = torch.cat((speech_features_reshaped, speech_feature_condition), dim=-1)
print("features reshaped: ", speech_features_reshaped.shape)

# transformar el tensor a la forma de pytorch [N, C, H, W]
speech_features_reshaped = speech_features_reshaped.permute(0, 3, 1, 2)
print("features reshaped: ", speech_features_reshaped.shape)

features:  torch.Size([128, 1, 16, 29])
features:  torch.Size([128, 16, 29, 1])
features reshaped:  torch.Size([128, 16, 1, 29])
feature condition:  torch.Size([128, 1, 1, 8])
feature condition:  torch.Size([128, 16, 1, 8])
features reshaped:  torch.Size([128, 16, 1, 37])
features reshaped:  torch.Size([128, 37, 16, 1])


In [15]:
val_features_norm = batch_norm(val_processed_audio)
print("features: ", val_features_norm.shape)

val_features_norm = val_features_norm.permute(0, 2, 3, 1)
print("features: ", val_features_norm.shape)

val_speech_features_reshaped = torch.reshape(val_features_norm, (-1, val_features_norm.shape[1], 1, val_features_norm.shape[2]))
print("features reshaped: ", val_speech_features_reshaped.shape)

# función equivalente en pytorch a tf.transpose en tensores de n-dimensiones
val_speech_feature_condition = torch.reshape(val_condition, (-1, val_condition.shape[1], 1, 1)).permute(0, 2, 3, 1)
print("feature condition: ", val_speech_feature_condition.shape)

val_speech_feature_condition = torch.tile(val_speech_feature_condition, (1, val_features_norm.shape[1], 1, 1))
print("feature condition: ", val_speech_feature_condition.shape)

val_speech_features_reshaped = torch.cat((val_speech_features_reshaped, val_speech_feature_condition), dim=-1)
print("features reshaped: ", val_speech_features_reshaped.shape)

val_speech_features_reshaped = val_speech_features_reshaped.permute(0, 3, 1, 2)
print("features reshaped: ", speech_features_reshaped.shape)

features:  torch.Size([1024, 1, 16, 29])
features:  torch.Size([1024, 16, 29, 1])
features reshaped:  torch.Size([1024, 16, 1, 29])
feature condition:  torch.Size([1024, 1, 1, 8])
feature condition:  torch.Size([1024, 16, 1, 8])
features reshaped:  torch.Size([1024, 16, 1, 37])
features reshaped:  torch.Size([128, 37, 16, 1])


* Se utiliza padding ya que a diferencia de Tensorflow donde el padding se calcula si se utiliza "SAME", en Pytorch se debe especificar.

In [16]:
features = time_convs(speech_features_reshaped)
print("after convs: ", features.shape)
features_flat = flatten(features)
print("flatten: ", features_flat.shape)

after convs:  torch.Size([128, 64, 1, 1])
flatten:  torch.Size([128, 64])


In [17]:
val_features = time_convs(val_speech_features_reshaped)
print("after convs: ", val_features.shape)
val_features_flat = flatten(val_features)
print("flatten: ", val_features_flat.shape)

after convs:  torch.Size([1024, 64, 1, 1])
flatten:  torch.Size([1024, 64])


In [18]:
concatenated = torch.cat((features_flat, condition), dim=1)
print("concat: ", concatenated.shape)
fc_result = fc_layers(concatenated)
print("fc result: ", fc_result.shape)

concat:  torch.Size([128, 72])
fc result:  torch.Size([128, 50])


In [19]:
val_concatenated = torch.cat((val_features_flat, val_condition), dim=1)
print("concat: ", val_concatenated.shape)
val_fc_result = fc_layers(val_concatenated)
print("fc result: ", val_fc_result.shape)

concat:  torch.Size([1024, 72])
fc result:  torch.Size([1024, 50])


## Speech Decoder

In [20]:
expression_basis_fname = config['expression_basis_fname']
init_expression = config['init_expression']

num_vertices = config['num_vertices']
expression_dim = config['expression_dim']

In [21]:
init_exp_basis = np.zeros((3*num_vertices, expression_dim))

if init_expression:
    init_exp_basis[:, :min(expression_dim, 100)] = np.load(expression_basis_fname)[:, :min(expression_dim, 100)]

init_exp_basis = torch.from_numpy(init_exp_basis).type(torch.float32)
print(init_exp_basis.shape)

torch.Size([15069, 50])


In [22]:
decoder = FCLayer(in_units=expression_dim, out_units=3*num_vertices, init_weights=init_exp_basis)

In [23]:
exp_offset = decoder(fc_result)
print(exp_offset.shape)
exp_offset = torch.reshape(exp_offset, (-1, num_vertices, 3, 1))
print(exp_offset.shape)

torch.Size([128, 15069])
torch.Size([128, 5023, 3, 1])


In [24]:
val_exp_offset = decoder(val_fc_result)
print(val_exp_offset.shape)
val_exp_offset = torch.reshape(val_exp_offset, (-1, num_vertices, 3, 1))
print(val_exp_offset.shape)

torch.Size([1024, 15069])
torch.Size([1024, 5023, 3, 1])


In [25]:
predicted = exp_offset + face_templates
print(predicted.shape)

torch.Size([128, 5023, 3, 1])


In [26]:
val_predicted = val_exp_offset + val_face_templates
print(val_predicted.shape)

torch.Size([1024, 5023, 3, 1])
