# Distilling convolutions in Tacotron 2

Firstly, download Tacotron 2 checkpoint: https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view.

Then install requirements from the project root folder: `pip install -r requirements.txt`

In [1]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.insert(0, "../../")
sys.path.insert(0, "../../tacotron2/")

import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from IPython.display import Audio

from audio.vocoders import griffin_lim
from tacotron2.model import Tacotron2
from tacotron2.text import text_to_sequence, sequence_to_text
from module import ConvModule

___
## **1 Loading model**
Tacotron 2:

In [2]:
TACOTRON_CONFIG=json.load(open('./../../tacotron2/config.json', 'r'))
TACOTRON_CHECKPT='./../../checkpoints/tacotron2_statedict.pt'
ON_GPU=False

In [3]:
tacotron2 = Tacotron2(TACOTRON_CONFIG)
checkpt_state_dict = torch.load(TACOTRON_CHECKPT,
                                map_location=lambda storage, loc: storage)['state_dict']
tacotron2.load_state_dict(checkpt_state_dict)
_ = tacotron2.cuda().eval() if ON_GPU else tacotron2.cpu().eval()

print('Number of parameters:', tacotron2.nparams())

Number of parameters: 28193153


Ground truth convolutional module:

In [4]:
SAVE_MODULE=False

conv_module = ConvModule(TACOTRON_CONFIG)
conv_module.embedding = tacotron2.embedding
conv_module.convolutions = tacotron2.encoder.convolutions

if SAVE_MODULE:
    torch.save(conv_module.state_dict(), 'conv_module.pt')

What shape has output?

In [9]:
conv_module.forward(torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] * 2)).shape

torch.Size([2, 512, 11])

## **2 Distilling**

In [12]:
from data import TextDataset, TextCollate
from torch.utils.data import DataLoader

In [None]:
train_dataset = TextDataset('./filelists/ljs_audio_text_train_filelist.txt', config)
test_dataset = TextDataset('./filelists/ljs_audio_text_test_filelist.txt', config)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    num_workers=1,
    collate_fn=TextCollate(),
    drop_last=True
)
train_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=64,
    num_workers=1,
    collate_fn=TextCollate(),
    drop_last=True
)

## **3 Quick synthesis check**

In [14]:
([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] * 2)

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]