In [1]:
import torch
from torch import nn

from jamo import hangul_to_jamo

from transformer_torch import *
from preprocess import *
import configs as cf

In [2]:
MODEL_PATH = '../model/'
MODEL_NAME = 'single_speaker_tts'

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
transformer = torch.load(MODEL_PATH+MODEL_NAME+'.pt')

In [5]:
ds = get_single_speaker_dataset(
    cf.SPEAKER, cf.WAV_PATH, cf.SCRIPT_FILE_NAME, cf.SR, cf.N_MELS, cf.N_FFT, cf.HOP_LENGTH, cf.WIN_LENGTH)
dl = DataLoader(ds, batch_size=cf.BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

Loading ['여1_소설1', '여1_자기계발2', '여1_동화1', '여1_자기계발1'] ...
소설1 Done!
자기계발2 Done!
동화1 Done!
자기계발1 Done!


In [6]:
# 2번째 axis의 len이 1이면 np.apply_along_axis가 작동하지 않는다.
t = torch.rand((2, 1, 80))

t_mask = transformer.create_padding_mask(t, True)
t_mask.shape

ValueError: Cannot apply_along_axis when any iteration dimensions are 0

In [7]:
class Chatbot():

    def __init__(self, transformer, device):
        self.transformer = transformer.to(device)
        self.transformer.eval()
        self.device = device

    def qna(self, question):

        speech_seq_len = self.transformer.speech_seq_len
        
        question = '@'+question+'|'
        question = normalize_text(question)
        question = list(hangul_to_jamo(question))
        question_tokens = [cf.char_to_id[_] for _ in question]
        question_tokens = torch.LongTensor(question_tokens).unsqueeze(0).to(self.device)
        question_mask = self.transformer.create_padding_mask(question_tokens)
        with torch.no_grad():
            question_encd = self.transformer.encoder(question_tokens, question_mask)

        output_tokens = torch.zeros([1, cf.N_MELS]).unsqueeze(0).to(self.device)

        for i in range(speech_seq_len):
            target_tokens = output_tokens

            ##################################################################
            if len(target_tokens) == 1:
                target_mask = self.transformer.create_padding_mask(target_tokens.tile(2,1,1), True)[:1]
            else:
                target_mask = self.transformer.create_padding_mask(target_tokens, True)
#             target_mask = self.transformer.create_padding_mask(target_tokens, True)
            ##################################################################
            with torch.no_grad():
                output, _, attention = self.transformer.decoder(target_tokens, question_encd, target_mask, question_mask)

            output_tokens = torch.concat((
                    torch.zeros([1, cf.N_MELS]).unsqueeze(0).to(self.device),
                    output[:, :i+1, :]
            ), axis=1)
                
        answer = output_tokens
        
        self.question = question
        self.answer = answer
        self.attention = attention
        self.call_qna = True
        
        return answer, attention

#     def plot_attention_weights(self, draw_mean=False):
#         if not self.call_qna:
#             raise Exception('There is no `question`, `answer` and `attention`. Call `qna` first')
#         question_token = to_tokens(self.question, self.tokenizer, to_ids=False)
#         question_token = ['<sos>']+question_token+['<eos>']

#         answer_token = to_tokens(self.answer, self.tokenizer, to_ids=False)
#         answer_token = answer_token+['<eos>']

#         attention = self.attention.squeeze(0)
#         if draw_mean:
#             attention = torch.mean(attention, dim=0, keepdim=True)
#         attention = attention.cpu().detach().numpy()

#         n_col = 4
#         n_row = (attention.shape[0]-1)//n_col + 1
#         fig = plt.figure(figsize = (n_col*6, n_row*6))
#         for i in range(attention.shape[0]):
#             plt.subplot(n_row, n_col, i+1)
#             plt.matshow(attention[i], fignum=False)
#             plt.xticks(range(len(question_token)), question_token, rotation=45)
#             plt.yticks(range(len(answer_token)), answer_token)
#         plt.show()

In [8]:
tts = Chatbot(transformer, device)

# mel to audio sample

In [9]:
AUDIO_SAVE_PATH = '../audio_sample/'

In [10]:
import soundfile as sf

fpath = '../data/wav/여1_동화1/1.wav'

origin, _ = librosa.load(fpath, sr=cf.SR)
mel = get_mel(fpath, cf.SR, cf.N_MELS, cf.N_FFT, cf.HOP_LENGTH, cf.WIN_LENGTH)

inversed = librosa.feature.inverse.mel_to_audio(mel.T, sr=cf.SR, hop_length=cf.HOP_LENGTH, win_length=cf.WIN_LENGTH)

sf.write(AUDIO_SAVE_PATH+'test_origin.wav', origin, cf.SR)
sf.write(AUDIO_SAVE_PATH+'test_inversed.wav', inversed, cf.SR)

In [11]:
sample = iter(dl).next()

sample_text, sample_speech = sample[0].to(device), sample[1].to(device)

outputs, _, _ = transformer(sample_text, sample_speech)

mel_pred = outputs[0].detach().cpu().numpy()

pred_speech = librosa.feature.inverse.mel_to_audio(mel_pred.T, sr=cf.SR, hop_length=cf.HOP_LENGTH, win_length=cf.WIN_LENGTH)

origin_text = ''.join([cf.id_to_char[_] for _ in sample_text[0].detach().cpu().numpy()])

sf.write(AUDIO_SAVE_PATH+'model_output_test.wav', pred_speech, cf.SR)

In [12]:
origin_text.split('|')[0][1:]

"이야기하다 보면 '왜 저 두 사람이 친구인지'를 이해하고 고개를 끄덕이게 되곤 한다."