In [None]:
# Necessary to disable warnings.
%env TOKENIZERS_PARALLELISM=False

In [1]:
import torch
from pathlib import Path
from accelerate import Accelerator
from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW, lr_scheduler

import sys
sys.path.append('..')
from cv8_en import prepare
from model.wav2vec_gpt2 import Wav2VecGPT2Model
from wer import calculate_wer

In [None]:
accelerator = Accelerator(fp16=True)
print(f'Using {accelerator.device}.')

In [None]:
OUTPUT_PATH = Path('./results/0')
LOG_PATH = OUTPUT_PATH / 'logs'

SAMPLING_RATE = 16_000
SEED = 419
USE_TRAIN_PCT = 0.1
USE_VAL_PCT = 0.05

ENCODER_ID = 'facebook/wav2vec2-base-960h'
DECODER_ID = 'gpt2'
PROMPT = 'Transcription:'
PAD_TOKEN = '_'
MAX_AUDIO_LENGTH = 300_000
MAX_TOKEN_SEQ_LEN = 39

LEARNING_RATE = 3e-4
BATCH_SIZE = 1
MAX_EPOCHS=4
ACCUMULATE_GRAD=8

def LR_SCHEDULER(optimizer):
    num_steps = MAX_EPOCHS * (len(train_ds) // (BATCH_SIZE * ACCUMULATE_GRAD)) * 1.1
    return lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=1e-6)

In [None]:
train, uncommon_chars = prepare('train', USE_TRAIN_PCT, SAMPLING_RATE, SEED)
val, _ = prepare('validation', USE_VAL_PCT, SAMPLING_RATE, SEED, uncommon_chars)

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ENCODER_ID)

tokenizer = AutoTokenizer.from_pretrained(DECODER_ID)
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})

In [None]:
class AudioDataset(Dataset):
    
    def __init__(self, ds, prompt):
        self.ds = ds
        self.prompt = prompt + ' '
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        eg = self.ds[idx]
        
        # TODO: Do this somewhere else?
        eg['sentence'] = self.prompt + eg['sentence']
        return eg['audio']['array'], eg['sentence']
    
def collate_fn(examples):
    # Remove the longest examples, these may lead to OOM-Errors.
    examples = [eg for eg in examples if len(eg[0]) < MAX_AUDIO_LENGTH]
    
    audio_features = feature_extractor(
        [eg[0] for eg in examples], sampling_rate=16_000, return_tensors='pt', padding='longest'
    ).input_values
    
    input_ids = tokenizer(
        [eg[1] for eg in examples], return_tensors='pt', padding=True
    ).input_ids
    
    return audio_features, input_ids

In [None]:
train_ds = AudioDataset(train, PROMPT)
val_ds = AudioDataset(val, PROMPT)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=4)

In [None]:
model = Wav2VecGPT2Model.from_encoder_decoder_pretrained(ENCODER_ID, DECODER_ID)
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0)
lr_scheduler = LR_SCHEDULER(optimizer)

In [None]:
model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)

In [None]:
writer = SummaryWriter(LOG_PATH)
val_golds = [eg[1][len(PROMPT) + 1:] for eg in val_ds]
global_train_step, val_count = 0, 0
prompt_token_count = len(tokenizer(PROMPT).input_ids)
best_wer = 10.

def evaluate():
    global val_count, best_wer
    
    model.eval()
    val_preds = []
    for audio_features, input_ids in val_dl:
        with torch.no_grad():
            generated = model.generate(
                audio_features,
                decoder_input_ids=input_ids[:, :prompt_token_count],
                max_length=MAX_LEN
            )
        val_preds += tokenizer.batch_decode(generated)
    val_preds = [pred[len(PROMPT) + 1:].rstrip(PAD_TOKEN) for pred in val_preds]
    wer = calculate_wer(val_preds, val_golds)
    writer.add_scalar('val_wer', wer, val_count)
    print('WER: ', wer)

    if wer < best_wer:
        best_wer = wer
        model.save_pretrained(OUTPUT_PATH)
        print('Saved new best model.')
    val_count += 1
    return val_preds


for epoch in range(MAX_EPOCHS):
    model.train()
    for audio_features, input_ids in train_dl:
        print('i')
        global_train_step += 1
        out = model(audio_features,
                    decoder_input_ids=input_ids[:, :-1], 
                    labels=input_ids[:, 1:].contiguous())
        accelerator.backward(out.loss)
        writer.add_scalar('train_loss', out.loss.item(), global_train_step)
        [writer.add_scalar(f'learning_rate_group{i}', group['lr'], global_train_step) 
         for i, group in enumerate(optimizer.param_groups)]
        
        if (global_train_step + 1) % ACCUMULATE_GRAD == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        if global_train_step % 20 == 0:
            print(out.loss.item())
        
        if global_train_step % 5000 == 0:
            val_preds = evaluate()
            model.train()
            
# Final evaluation.
val_preds = evaluate()

In [None]:
val_golds[:10]

In [None]:
val_preds[:10]

In [None]:
!cp ./train.ipynb {LOG_PATH}