In [1]:
from collections import ChainMap

import yaml
import torch
import fairseq_mod

import sys
sys.path.append("../..")

from wav2vec2_inference_pipeline import inference_pipeline
from data_loader import LibriSpeechDataLoader
from knowledge_distillation.kd_training import KnowledgeDistillationTraining
from fairseq_mod.models.wav2vec.teacher_wav2vec2 import TeacherWav2Vec2Model
from fairseq_mod.models.wav2vec.student_wav2vec2 import StudentWav2Vec2Model

### Load configurations and create letter dictionary

In [2]:
config = yaml.load(open('demo_config.yaml','r'), Loader=yaml.FullLoader)
target_dict = fairseq_mod.data.Dictionary.load('ltr_dict.txt')

### Create data loaders for training and validation

In [3]:
libriSpeech_data_loader = LibriSpeechDataLoader(**config["data_loader"])
train_data_loader = libriSpeech_data_loader.get_train_data_loader()
val_data_loaders = libriSpeech_data_loader.get_val_data_loaders()

### Create inference pipeline for validating the student model

In [4]:
inference_pipeline_example = inference_pipeline(target_dict, use_cuda=True, input_half=False)

### Create student and teacher model

In [5]:
student_model = StudentWav2Vec2Model.create_student_model(target_dict=target_dict,
                                                          fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"],
                                                          **config["knowledge_distillation"]["student_model"])
teacher_model = TeacherWav2Vec2Model.create_teacher_model(target_dict=target_dict,
                                                          fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"])

student_model.proj_to_decoder.weight is not in student model state_dict
student_model.proj_to_decoder.bias is not in student model state_dict
Finished loading weights into the student model
w2v_encoder.proj.weight is not in teacher model state_dict
w2v_encoder.proj.bias is not in teacher model state_dict
Finished loading weights into the teacher model


### Set the projection layer (which outputs probability distributions over tokens) for student and teacher model

In [6]:
def get_proj_layer(fairseq_pretrained_model_path):
    """
    Get projection layer's weights and biases of wav2vec 2.0 pre-trained model
    """
    w2v = torch.load(fairseq_pretrained_model_path)
    return w2v["model"]["w2v_encoder.proj.weight"], w2v["model"]["w2v_encoder.proj.bias"]

In [7]:
proj_layer_weight, proj_layer_bias = get_proj_layer(fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"])
student_model.init_proj_layer_to_decoder(proj_layer_weight, proj_layer_bias)
teacher_model.init_proj_layer_to_decoder(proj_layer_weight, proj_layer_bias)

### Train a student model with knowledge distillation and get its performance on dev set

In [8]:
KD_wav2vec2 = KnowledgeDistillationTraining(train_data_loader = train_data_loader,
                                            val_data_loaders = val_data_loaders,
                                            inference_pipeline = inference_pipeline_example,
                                            student_model = student_model,
                                            teacher_model = teacher_model,
                                            num_gpu_used = config["knowledge_distillation"]["general"]["num_gpu_used"],
                                            temperature = config["knowledge_distillation"]["general"]["temperature"],
                                            final_loss_coeff_dict = config["knowledge_distillation"]["final_loss_coeff"],
                                            logging_param = ChainMap(config["knowledge_distillation"]["general"], config["knowledge_distillation"]["optimization"],
                                                                     config["knowledge_distillation"]["final_loss_coeff"], config["knowledge_distillation"]["student_model"],
                                                                     config["knowledge_distillation"]["pytorch_lightning_trainer"]),
                                            **ChainMap(config["knowledge_distillation"]["optimization"],
                                                       config["knowledge_distillation"]["pytorch_lightning_trainer"],
                                                       config["knowledge_distillation"]["comet_info"])
                                            )
KD_wav2vec2.start_kd_training()

Global seed set to 32
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.

  | Name          | Type                 | Params
-------------------------------------------------------
0 | student_model | StudentWav2Vec2Model | 65.5 M
1 | teacher_model | TeacherWav2Vec2Model | 317 M 
-------------------------------------------------------
382 M     Trainable params
0         Non-trainable params
382 M     Total params
1,531.611 Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


dev_clean :0.9922171602126044

GPU 0 current active MB: 1567.2637439999999
GPU 0 current reserved MB: 1577.058304


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]


dev_clean :0.6119969627942293

GPU 0 current active MB: 2369.9584
GPU 0 current reserved MB: 2642.4115199999997


Validating: 0it [00:00, ?it/s]


dev_clean :0.5672930903568717

GPU 0 current active MB: 2369.9594239999997
GPU 0 current reserved MB: 2646.6058239999998


Validating: 0it [00:00, ?it/s]


dev_clean :0.5494495064540622

GPU 0 current active MB: 2369.9594239999997
GPU 0 current reserved MB: 2648.702976


Validating: 0it [00:00, ?it/s]


dev_clean :0.5127182991647684

GPU 0 current active MB: 2369.9594239999997
GPU 0 current reserved MB: 2646.6058239999998


Validating: 0it [00:00, ?it/s]


dev_clean :0.48395975702353833

GPU 0 current active MB: 2369.9594239999997
GPU 0 current reserved MB: 2648.702976


In [9]:
student_model = KD_wav2vec2.get_student_model()
val_result = inference_pipeline_example.run_inference_pipeline(student_model.cuda(), val_data_loaders["dev_clean"])

In [10]:
print("final WER is {:.2f}".format(val_result["inference_result"]*100))

final WER is 48.40


#### As the output above shows, WER has decreased from 99 to 48 after 5 epochs of training.