In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from pathlib import Path

In [None]:
DATA_PATH = Path('./data/lj_speech')

# Feature Extraction

In [None]:
import soundfile as sf
from datasets import load_dataset
from torchaudio.transforms import Resample
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
from gpt2_asr.speech_feature_extraction import extract_features_to_files

In [None]:
if not DATA_PATH.exists():
    
    lj_speech = load_dataset('lj_speech', split='train')  # Is only train.
    
    # Load Model for feature extraction.
    wave2vec_name = 'facebook/wav2vec2-large-960h-lv60-self'
    wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wave2vec_name)
    wave2vec = Wav2Vec2Model.from_pretrained(wave2vec_name)
    wave2vec.eval().cuda()
    resampler = Resample(22050, 16_000)
    
    # Extract audio and transcriptions.
    examples = []
    for i, eg in enumerate(lj_speech):
        audio = torch.tensor(sf.read(eg['file'])[0])
        eg['audio'] = resampler(audio.float()).numpy()

        # TODO: Temporary (?) Helper for generation since using empty `input_ids` lead to errors.
        eg['transcription'] = 'Transcription: ' + eg['normalized_text']

        examples.append(eg)
        
    # `max_len` is just the longest sample in the dataset (determined in advance).
    extract_features_to_files(wave2vec, wave2vec_extractor, examples, 
                              batch_size=8, max_len=161540, output_path=DATA_PATH, val_pct=0.1)

In [None]:
wave2vec.cpu()
torch.cuda.empty_cache()

# Training

In [None]:
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from gpt2_s2t.model import S2TModel
from gpt2_s2t.data_loading import S2TDataset, make_collate_fn
from gpt2_s2t.evaluation_utils import get_predictions, calculate_mean_loss, calculate_wer

In [None]:
accelerator = Accelerator(fp16=True)
print(f'Using {accelerator.device}.')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '_'})
gpt2_model = AutoModelForCausalLM.from_pretrained('gpt2', add_cross_attention=True)

In [None]:
model = S2TModel(gpt2_model)
best_val_loss = 10

# # Sometimes loss suddenly increases by a lot. Then the best saved model can be loaded here and trained further.
# # However, training also seems to converge again, so may not be necessary.
# model.load_state_dict(torch.load('test.pt'))
# best_val_loss = 1.56

In [None]:
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if ('crossattention' in n or 'projection' in n)],
        "lr": 5e-4,
    },
    {
        "params": [p for n, p in model.named_parameters() if ('crossattention' not in n and 'projection' not in n)],
        "lr": 1e-6,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, weight_decay=0.)

In [None]:
collate_fn = make_collate_fn(tokenizer)

train_ds = S2TDataset(DATA_PATH / 'train')
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=8)

val_ds = S2TDataset(DATA_PATH / 'val')
val_dl = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=8)

In [None]:
model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)

## Training Loop

In [None]:
epochs = 40
accumulate_gradients = 2

for i in range(epochs):
    
    model.train()
    for step, (encoder_hidden_states, input_ids) in enumerate(train_dl):
        out = model(encoder_hidden_states, input_ids)
        accelerator.backward(out.loss)
        if step % 100 == 0:
            print(out.loss.item())
        if (step + 1) % accumulate_gradients == 0:
            optimizer.step()
            optimizer.zero_grad()
            
    model.eval()
    val_loss = calculate_mean_loss(model, val_dl)
    print('VAL: ', val_loss)
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), 'test.pt')
        best_val_loss = val_loss

## Test

In [None]:
model.load_state_dict(torch.load('test.pt'))
model.eval()

In [None]:
print('Validation loss:', calculate_mean_loss(model, val_dl))

In [None]:
# Prediction on random validation example.
eg = val_ds[np.random.randint(len(val_ds))]
predicted_text = get_predictions(
    eg['wave2vec_features'][None, ...].cuda(), model, tokenizer
)[0]

prefix_len = len('Transcription: ')
print('PRED:', predicted_text[prefix_len:])
print('GOLD:', eg['transcription'][prefix_len:])

In [None]:
# Word error rate.
val_preds = []
for batch in val_dl:
    encoder_hidden_states, _ = batch
    val_preds += get_predictions(encoder_hidden_states, model, tokenizer)
val_preds = [pred[prefix_len:] for pred in val_preds]
val_golds = [eg['transcription'][prefix_len:] for eg in val_ds]

print('WER:', calculate_wer(val_preds, val_golds))