In [None]:
import torch
from pathlib import Path
from accelerate import Accelerator

In [None]:
DATA_PATH = Path('./data/cv_8_0_en')

In [None]:
accelerator = Accelerator(fp16=True)
print(f'Using {accelerator.device}.')

# Encoder Output Extraction

In [None]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model, HubertModel
from wav2vec_feature_extraction import extract_features_to_files

import sys
sys.path.append('..')
from cv8_en import prepare

In [None]:
OUTPUT_PATH = DATA_PATH / "encoder_outputs"
SAMPLING_RATE = 16_000
SEED = 419
USE_TRAIN_PCT = 0.35
USE_VAL_PCT = 0.25
ENCODER_MDL = 'facebook/hubert-xlarge-ls960-ft'
BATCH_SIZE = 6
MAX_AUDIO_LENGTH = 300_000

In [None]:
if not OUTPUT_PATH.exists():
    train, uncommon_chars = prepare('train', USE_TRAIN_PCT, SAMPLING_RATE, SEED)
    val, _ = prepare('validation', USE_VAL_PCT, SAMPLING_RATE, SEED, uncommon_chars)

    # Load model and feature extractor.
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ENCODER_MDL)
    mdl = HubertModel.from_pretrained(ENCODER_MDL) # Wav2Vec2Model.from_pretrained(ENCODER_MDL)
    mdl.eval().to(accelerator.device)

    # Write model outputs to files.
    extract_features_to_files(mdl, feature_extractor, train, BATCH_SIZE, OUTPUT_PATH / 'train', 
                              MAX_AUDIO_LENGTH, SAMPLING_RATE)

    extract_features_to_files(mdl, feature_extractor, val, BATCH_SIZE, OUTPUT_PATH / 'val', 
                              MAX_AUDIO_LENGTH, SAMPLING_RATE)
    
    # Clear GPU.
    mdl.cpu()
    torch.cuda.empty_cache()

# Training

In [None]:
# Necessary to disable warnings.
%env TOKENIZERS_PARALLELISM=False

In [None]:
import torch
from transformers import AutoTokenizer
from data_loading import Wav2VecFeaturesDataset, make_collate_fn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW, lr_scheduler
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2BaseModelOutput

import sys
sys.path.append('..')
from model.wav2vec_gpt2 import Wav2VecGPT2Model
from wer import calculate_wer

In [None]:
OUTPUT_PATH = Path('./results/1')
LOG_PATH = OUTPUT_PATH / 'logs'

ENCODER_ID = 'facebook/hubert-xlarge-ls960-ft'
DECODER_ID = 'gpt2'
PROMPT = 'Transcription:'
PAD_TOKEN = '_'
BATCH_SIZE = 16
LEARNING_RATES = {
    'default': 1e-6,
    ('cross_attn', 'crossattention', 'enc_to_dec_proj', 'encoder_outputs_pos_emb'): 6e-4
}
LR_SCHEDULER = lr_scheduler.CosineAnnealingLR
MAX_EPOCHS = 3
ACCUMULATE_GRAD = 2
MAX_LEN = 39

def LR_SCHEDULER(optimizer):
    num_steps = MAX_EPOCHS * (len(train_ds) // (BATCH_SIZE * ACCUMULATE_GRAD)) * 1.5
    return lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=1e-6)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(DECODER_ID)
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})

In [None]:
collate_fn = make_collate_fn(tokenizer)

train_ds = Wav2VecFeaturesDataset(DATA_PATH / 'encoder_outputs/train', PROMPT)
train_dl = DataLoader(train_ds, BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4)

val_ds = Wav2VecFeaturesDataset(DATA_PATH / 'encoder_outputs/val', PROMPT)
val_dl = DataLoader(val_ds, BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=4)

In [None]:
model = Wav2VecGPT2Model.from_encoder_decoder_pretrained(ENCODER_ID, DECODER_ID)
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
optimizer_groups = []

assigned_modules = []
for modules, lr in LEARNING_RATES.items():
    if isinstance(modules, tuple):
        module_names, module_params = zip(
            *[(n,p) for n,p in model.named_parameters() if any(m in n for m in modules)]
        )
        assigned_modules += module_names
        optimizer_groups.append({'params': module_params, 'lr': lr})
        
optimizer_groups.append({
    'params': [p for n,p in model.named_parameters() if not n in assigned_modules],
    'lr': LEARNING_RATES['default']
})

optimizer = AdamW(optimizer_groups, weight_decay=0.0)
lr_scheduler = LR_SCHEDULER(optimizer)

In [None]:
model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)
model.encoder.cpu()  # Does not need to be on GPU.

In [None]:
writer = SummaryWriter(LOG_PATH)
val_golds = [eg['sentence'][len(PROMPT) + 1:] for eg in val_ds]
global_train_step, val_count = 0, 0
prompt_token_count = len(tokenizer(PROMPT).input_ids)
best_wer = 10.

def evaluate():
    global val_count, best_wer
    
    model.eval()
    val_preds = []
    for encoder_hidden_states, _, input_ids in val_dl:
        with torch.no_grad():
            generated = model.generate(
                decoder_input_ids=input_ids[:, :prompt_token_count],
                encoder_outputs=Wav2Vec2BaseModelOutput(encoder_hidden_states), 
                max_length=MAX_LEN
            )
        val_preds += tokenizer.batch_decode(generated)
    val_preds = [pred[len(PROMPT) + 1:].rstrip(PAD_TOKEN) for pred in val_preds]
    wer = calculate_wer(val_preds, val_golds)
    writer.add_scalar('val_wer', wer, val_count)
    print('WER: ', wer)

    if wer < best_wer:
        best_wer = wer
        model.save_pretrained(OUTPUT_PATH)
        print('Saved new best model.')
    val_count += 1
    return val_preds


for epoch in range(MAX_EPOCHS):
    model.train()
    for encoder_hidden_states, _, input_ids in train_dl:
        global_train_step += 1
        out = model(decoder_input_ids=input_ids[:, :-1], 
                    labels=input_ids[:, 1:].contiguous(), 
                    encoder_outputs=Wav2Vec2BaseModelOutput(encoder_hidden_states))
        accelerator.backward(out.loss)
        writer.add_scalar('train_loss', out.loss.item(), global_train_step)
        [writer.add_scalar(f'learning_rate_group{i}', group['lr'], global_train_step) 
         for i, group in enumerate(optimizer.param_groups)]
        
        if (global_train_step + 1) % ACCUMULATE_GRAD == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        if global_train_step % 20 == 0:
            print(out.loss.item())
        
        if global_train_step % 5000 == 0:
            val_preds = evaluate()
            model.train()
            
# Final evaluation.
val_preds = evaluate()

In [None]:
val_golds[:10]

In [None]:
val_preds[:10]

In [None]:
!cp ./train.ipynb {LOG_PATH}