In [1]:
import yaml
import torch
import torchaudio
import fairseq_mod
from IPython.display import Audio, display

import sys
sys.path.append("../..")

from wav2vec2_inference_pipeline import W2lViterbiDecoder, process_batch_element
from wav2vec2_compression_demo import get_proj_layer
from data_loader import LibriSpeechDataLoader
from knowledge_distillation.kd_training import KnowledgeDistillModel
from fairseq_mod.models.wav2vec.student_wav2vec2 import StudentWav2Vec2Model
from fairseq_mod.models.wav2vec.teacher_wav2vec2 import TeacherWav2Vec2Model

# Load dataset

In [2]:
DATA_PATH = "/home/Knowledge-Distillation-Toolkit/examples/wav2vec2_compression_demo/datasets"
dev_clean_dataset = torchaudio.datasets.LIBRISPEECH(DATA_PATH, url='dev-clean', download=False)

# Load one sample audio

In [3]:
def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

In [60]:
sample = dev_clean_dataset[1]
# sample = dev_clean_dataset[10]

In [61]:
play_audio(sample[0], sample[1])

# Load student and teacher model

In [6]:
MODEL_LOAD_PATH = "/home/Knowledge-Distillation-Toolkit/examples/wav2vec2_compression_demo/speech-processing/retrain_exp9/checkpoints/student-epoch=042-train_final_loss=0.08857.ckpt"
config = yaml.load(open('demo_config.yaml','r'), Loader=yaml.FullLoader)
target_dict = fairseq_mod.data.Dictionary.load('ltr_dict.txt')

In [7]:
student_model = StudentWav2Vec2Model.create_student_model(target_dict=target_dict,
                                                              fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"],
                                                              **config["knowledge_distillation"]["student_model"])
teacher_model = TeacherWav2Vec2Model.create_teacher_model(target_dict=target_dict,
                                                              fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"])
proj_layer_weight, proj_layer_bias = get_proj_layer(fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"])
teacher_model.init_proj_layer_to_decoder(proj_layer_weight, proj_layer_bias)
student_model.init_proj_layer_to_decoder(torch.nn.Parameter(torch.rand(proj_layer_weight.shape)), torch.nn.Parameter(torch.rand(proj_layer_bias.shape)))

KD_module = KnowledgeDistillModel.load_from_checkpoint(
        MODEL_LOAD_PATH,
        num_gpu_used = None,
        max_epoch = 0,
        temperature = 0,
        optimize_method = None,
        scheduler_method = None,
        learning_rate = 0,
        num_lr_warm_up_epoch = 0,
        final_loss_coeff_dict = None,
        train_data_loader = None,
        val_data_loaders = None,
        inference_pipeline = None,
        student_model = student_model,
        teacher_model = teacher_model)

student_model.proj_to_decoder.weight is not in student model state_dict
student_model.proj_to_decoder.bias is not in student model state_dict
Finished loading weights into the student model
w2v_encoder.proj.weight is not in teacher model state_dict
w2v_encoder.proj.bias is not in teacher model state_dict
Finished loading weights into the teacher model


# Student and Teacher model's output

In [62]:
KD_module.student_model.cuda()
KD_module.student_model.eval()
generator = W2lViterbiDecoder(target_dict)
prediction, _, _, _, _, _, _, _, _, _ = process_batch_element(element=(torch.unsqueeze(sample[0],dim=0), sample[1]), 
                                                              model=KD_module.student_model, 
                                                              generator=generator, 
                                                              target_dict=target_dict, 
                                                              use_cuda=True, 
                                                              input_half=False)


In [63]:
KD_module.teacher_model.cuda()
KD_module.teacher_model.eval()
generator = W2lViterbiDecoder(target_dict)
t_prediction, _, _, _, _, _, _, _, _, _ = process_batch_element(element=(torch.unsqueeze(sample[0],dim=0), sample[1]), 
                                                              model=KD_module.teacher_model, 
                                                              generator=generator, 
                                                              target_dict=target_dict, 
                                                              use_cuda=True, 
                                                              input_half=False)


In [64]:
sample[2]

"TONY'S FOUND THE MARTIANS"

In [65]:
prediction

'TOYS FOUND THE MARTNS'

In [66]:
t_prediction

"TOADY'S FOUND THE MARTIANS"