# Distilling convolutions in Tacotron 2

Firstly, download Tacotron 2 checkpoint: https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view.

Then install requirements from the project root folder: `pip install -r requirements.txt`

In [1]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.insert(0, "../../")
sys.path.insert(0, "../../tacotron2/")

import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from IPython.display import Audio
from IPython.display import display

from audio.vocoders import griffin_lim
from tacotron2.model import Tacotron2
from tacotron2.text import text_to_sequence, sequence_to_text
from module import ConvModule

___
## **1 Loading ground truth model**
Tacotron 2:

In [2]:
TACOTRON_CONFIG=json.load(open('./../../tacotron2/config.json', 'r'))
TACOTRON_CHECKPT='./../../checkpoints/tacotron2_statedict.pt'

In [3]:
tacotron2 = Tacotron2(TACOTRON_CONFIG)
checkpt_state_dict = torch.load(TACOTRON_CHECKPT,
                                map_location=lambda storage, loc: storage)['state_dict']
tacotron2.load_state_dict(checkpt_state_dict)
_ = tacotron2.cpu().eval()

print('Number of parameters:', tacotron2.nparams())

Number of parameters: 28193153


Ground truth convolutional module:

In [4]:
SAVE_MODULE=False

conv_module = ConvModule(TACOTRON_CONFIG)
conv_module.embedding = tacotron2.embedding
conv_module.convolutions = tacotron2.encoder.convolutions

if SAVE_MODULE:
    torch.save(conv_module.state_dict(), 'conv_module.pt')

What shape has output?

In [5]:
conv_module.forward(torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] * 2)).shape

torch.Size([2, 512, 10])

___
## **2 Defining compressed version**

In [6]:
dist_module = ...

___
## **3 Distilling**

In [7]:
from data import TextDataset, TextCollate
from torch.utils.data import DataLoader

N_EPOCHS=100
BACTH_SIZE=64
ITER_PER_VERBOSE_TRAIN=5
ITER_PER_VERBOSE_TEST=50

device = torch.device('cpu:0')

In [8]:
train_dataset = TextDataset('./filelists/ljs_audio_text_train_filelist.txt', TACOTRON_CONFIG)
test_dataset = TextDataset('./filelists/ljs_audio_text_test_filelist.txt', TACOTRON_CONFIG)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BACTH_SIZE,
    num_workers=1,
    collate_fn=TextCollate(),
    drop_last=True
)
tets_loader = DataLoader(
    dataset=test_dataset,
    batch_size=BACTH_SIZE,
    num_workers=1,
    collate_fn=TextCollate(),
    drop_last=True
)

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(dist_module.parameters())

In [None]:
scores = []
for epoch in range(N_EPOCHS):
    epoch_scores = []
    for iteration, (x, y) for enumerate(train_dataloader):
        dist_module.zero_grad()
        if ON_GPU:
            x = x.to(device)
            y = y.to(device)

        y_hat = dist_module(x)

        loss = criterion(y, y_hat)
        epoch_scores.append(loss.item())
        loss.backward()

        #grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)

        optimizer.step()
        dist_module.zero_grad()
        
        if i % ITER_PER_VERBOSE_TRAIN == 0:
            print('Epoch: {} | Step: {} | Loss: {}'.format(
                epoch + 1,
                iteration,
                loss
            ))
        if i % ITER_PER_VERBOSE_TEST == 0:
            print('Epoch: {} | Step: {} | Loss: {}'.format(
                epoch + 1,
                iteration,
                loss
            ))
    
    print('######################')
    print('Best score for epoch #{} = '.format(epoch + 1, max(epoch_scores)))
    print('######################')
    scores += epoch_scores

___
## **4 Quick synthesis check**

Replace ground truth convolutions with distilled ones somehow.

In [None]:
from audio import MelTransformer

In [None]:
tacotron2.encoder.convolutions = ...

In [None]:
texts = ["Implicit learning of the likelihood makes normalizing flows very strong generative tool."]
assert len(texts) > 0
texts = [text.strip() for text in texts]
sequences = [np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
             for text in texts]
sequences = [torch.autograd.Variable(torch.from_numpy(sequence)).long()
             for sequence in sequences]

In [None]:
TEXT_IDX = 0

with torch.no_grad():
    mel_outputs, mel_outputs_postnet, gate_outputs, alignments = tacotron2.inference(sequences[TEXT_IDX])

In [None]:
mel_transform = MelTransformer(
    TACOTRON_CONFIG['filter_length'],
    TACOTRON_CONFIG['hop_length'],
    TACOTRON_CONFIG['win_length'], 
    sampling_rate=22050
).cpu()
mel_decompress = mel_transform._spectral_de_normalize(mel_outputs_postnet.cpu())
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], mel_transform.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
waveform = griffin_lim(torch.autograd.Variable(spec_from_mel[:, :, :-1]), 
                       mel_transform.stft_fn, 50)

In [None]:
Audio(waveform[0].data.cpu().numpy(), rate=22050)