In [None]:
from facenet_pytorch import MTCNN, InceptionResnetV1
from torchvision import transforms
from scipy.io.wavfile import read
import IPython.display as ipd
from text import cleaners
from PIL import Image
from tqdm import tqdm
import random
import argparse
import torch
import os

parser = argparse.ArgumentParser()
parser.add_argument('--filter_length', default=1024, type=int, help= 'filter_length')
parser.add_argument('--segment_size', default=8192, type=int, help= 'segment_size')
parser.add_argument('--hop_length', default=256, type=int, help= 'hop_length')
parser.add_argument('--inter_channels', default=192, type=int, help= 'inter_channels')
parser.add_argument('--hidden_channels', default=192, type=int, help= 'hidden_channels')
parser.add_argument('--filter_channels', default=768, type=int, help= 'filter_channels')
parser.add_argument('--n_heads', default=2, type=int, help= 'n_heads')
parser.add_argument('--n_layers', default=6, type=int, help= 'n_layers')
parser.add_argument('--kernel_size', default=3, type=int, help= 'kernel_size')
parser.add_argument('--p_dropout', default=0.1, type=float, help= 'p_dropout')
parser.add_argument('--resblock', default="1", type=str, help= 'resblock')
parser.add_argument('--resblock_kernel_sizes', default=[3,7,11], type=list, help= 'resblock_kernel_sizes')
parser.add_argument('--resblock_dilation_sizes', default=[[1,3,5], [1,3,5], [1,3,5]], type=list, help= 'resblock_dilation_sizes')
parser.add_argument('--upsample_rates', default=[8,8,2,2], type=list, help= 'upsample_rates')
parser.add_argument('--upsample_initial_channel', default=512, type=int, help= 'upsample_initial_channel')
parser.add_argument('--upsample_kernel_sizes', default=[16,16,4,4], type=list, help= 'upsample_kernel_sizes')
parser.add_argument('--gin_channels', default=128, type=int, help= 'gin_channels')
parser.add_argument('--model_name', default='FVTTS', help= 'model_name')
parser.add_argument('--GPU', default=0, type=str, help= 'GPU')

args = parser.parse_args(args= [])
args.model_dir = f'{args.save_path}/{args.model_name}'
device = torch.device(f'cuda:{args.GPU}' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
_pad        = '_'
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
_symbol_to_id = {s: i for i, s in enumerate(symbols)}


def load_checkpoint(checkpoint_path, model, optimizer=None):
  assert os.path.isfile(checkpoint_path)
  checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')

  iteration = checkpoint_dict['iteration']
  learning_rate = checkpoint_dict['learning_rate']
  print(f'ITER {iteration} - LR : {learning_rate}')
      
  saved_state_dict = checkpoint_dict['model']
  if hasattr(model, 'module'):
      state_dict = model.state_dict()
  new_state_dict= {}
  for k, v in state_dict.items():
      try:
          new_state_dict[k] = saved_state_dict[k]
      except:
          print("%s is not in the checkpoint" % k)
          new_state_dict[k] = v
  model.load_state_dict(new_state_dict)

In [None]:
import face_model
import torch.nn as nn
net_g_face = face_model.SynthesizerTrn(
    len(symbols),
    args.filter_length // 2 + 1,
    args.segment_size // args.hop_length, 
    args.inter_channels, 
    args.hidden_channels, 
    args.filter_channels, 
    args.n_heads, 
    args.n_layers, 
    args.kernel_size, 
    args.p_dropout, 
    args.resblock, 
    args.resblock_kernel_sizes, 
    args.resblock_dilation_sizes, 
    args.upsample_rates, 
    args.upsample_initial_channel, 
    args.upsample_kernel_sizes, 
    
    args.gin_channels).to(device)

net_g_face = nn.DataParallel(net_g_face)
c_path = f'{args.model_dir}/G_bestDloss.pth'
net_g_face.eval()
load_checkpoint(f'{c_path}', net_g_face, None)

In [None]:
def train_transform():
    transform_list = [
        transforms.Resize(size=(128, 128)),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)

def _clean_text(text, cleaner_names):
  for name in cleaner_names:
    cleaner = getattr(cleaners, name)
    if not cleaner:
      raise Exception('Unknown cleaner: %s' % name)
    text = cleaner(text)
  return text
def cleaned_text_to_sequence(cleaned_text):
  sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
  return sequence

def intersperse(lst, item):
  result = [item] * (len(lst) * 2 + 1)
  result[1::2] = lst
  return result

img_tf = train_transform()
resnet = InceptionResnetV1(pretrained='vggface2', classify=False).eval()

In [None]:
i_path = 'image/'
text = 'AND HE WAS TALKING ABOUT THE IMPORTANCE OF COACHING BOYS INTO MEN AND CHANGING THE CULTURE OF THE LOCKER ROOM AND GIVING'
save_path = 'voice.wav'

img = Image.open(i_path)
image_emb = img_tf(img)
content = resnet(image_emb.unsqueeze(0)).T

clean_text = _clean_text(text, ["english_cleaners2"])
text_norm = cleaned_text_to_sequence(clean_text)
text_norm = intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
x_tst_ = text_norm.unsqueeze(0).to(device)
x_tst_lengths_ = torch.LongTensor([text_norm.size(0)]).to(device)

face_audio = net_g_face.module.infer(x_tst_.to(device), x_tst_lengths_.to(device), img=image_emb.to(device), content = content.to(device), noise_scale=0.667, noise_scale_w=0.8, length_scale=1.2)[0][0,0].data.cpu().float().numpy()
    
with open(save_path,'wb') as f:
    f.write(ipd.Audio(face_audio*2, rate=args.sampling_rate, normalize=False).data)