In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from pathlib import Path

In [None]:
DATA_PATH = Path('./data/timit')

# Feature Extraction

In [None]:
import soundfile as sf
import shutil
from datasets import load_dataset
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
from gpt2_s2t.speech_feature_extraction import extract_features_to_files

In [None]:
if not DATA_PATH.exists():
    timit = load_dataset('timit_asr')
    
    # Load Model for feature extraction.
    wave2vec_name = 'facebook/wav2vec2-large-960h-lv60-self'
    wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wave2vec_name)
    wave2vec = Wav2Vec2Model.from_pretrained(wave2vec_name)
    wave2vec.eval().cuda()
    
    # Extract audio and transcriptions.
    def make_examples(ds_split):
        examples = []
        for eg in ds_split:
            eg['audio'] = np.array(sf.read(eg['file'])[0])

            # TODO: Temporary (?) Helper for generation since using empty `input_ids` lead to errors.
            eg['transcription'] = 'Transcription: ' + eg['text']
            
            eg['id'] = '_'.join([eg['dialect_region'], eg['speaker_id'], eg['id']])
            examples.append(eg)
        return examples
    
    train_examples = make_examples(timit['train'])
    test_examples = make_examples(timit['test'])
        
    # `max_len` is just the longest sample in the dataset (determined in advance).
    extract_features_to_files(wave2vec, wave2vec_extractor, train_examples, 
                              batch_size=8, max_len=124621, output_path=DATA_PATH, val_pct=0.1)
    extract_features_to_files(wave2vec, wave2vec_extractor, test_examples, 
                              batch_size=8, max_len=121140, output_path=DATA_PATH / 'test', val_pct=0)

In [None]:
wave2vec.cpu()
torch.cuda.empty_cache()

# Training

In [None]:
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from gpt2_s2t.model import S2TModel
from gpt2_s2t.data_loading import S2TDataset, make_collate_fn

In [None]:
accelerator = Accelerator(fp16=True)
print(f'Using {accelerator.device}.')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '_'})
gpt2_model = AutoModelForCausalLM.from_pretrained('gpt2', add_cross_attention=True)

In [None]:
model = S2TModel(gpt2_model)
best_val_loss = 10

# # Sometimes loss suddenly increases by a lot. Then the best saved model can be loaded here and trained further.
# model.load_state_dict(torch.load('test.pt'))
# best_val_loss = 0.71

In [None]:
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if ('crossattention' in n or 'projection' in n)],
        "lr": 5e-4,
    },
    {
        "params": [p for n, p in model.named_parameters() if ('crossattention' not in n and 'projection' not in n)],
        "lr": 1e-6,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, weight_decay=0.)

In [None]:
collate_fn = make_collate_fn(tokenizer)

train_ds = S2TDataset(DATA_PATH / 'train')
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=8)

val_ds = S2TDataset(DATA_PATH / 'val')
val_dl = DataLoader(val_ds, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=8)

In [None]:
model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)

In [None]:
epochs = 40
accumulate_gradients = 2

for i in range(epochs):
    
    model.train()
    for step, (encoder_hidden_states, input_ids) in enumerate(train_dl):
        out = model(encoder_hidden_states, input_ids)
        accelerator.backward(out.loss)
        if step % 100 == 0:
            print(out.loss.item())
        if (step + 1) % accumulate_gradients == 0:
            optimizer.step()
            optimizer.zero_grad()
            
    model.eval()
    val_losses = []
    for step, (encoder_hidden_states, input_ids) in enumerate(val_dl):
        with torch.no_grad():
            out = model(encoder_hidden_states, input_ids)
        val_losses.append(out.loss.item())
    val_loss = np.array(val_losses).mean()
    print('VAL: ', val_loss)
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), 'test_timit.pt')
        best_val_loss = val_loss

# Test

In [None]:
model.load_state_dict(torch.load('test_timit.pt'))
model.eval()

In [None]:
test_ds = S2TDataset(DATA_PATH / 'test')
test_dl = DataLoader(test_ds, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=8)
test_dl = accelerator.prepare(test_dl)

In [None]:
# Test loss.
test_losses = []
for step, (encoder_hidden_states, input_ids) in enumerate(test_dl):
    with torch.no_grad():
        out = model(encoder_hidden_states, input_ids)
    test_losses.append(out.loss.item())
test_loss = np.array(test_losses).mean()
print('TEST: ', test_loss)

In [None]:
def get_prediction(example):
    with torch.no_grad():
        input_ids = tokenizer('Transcription:', return_tensors='pt')['input_ids'].cuda()
        encoder_states = example['wave2vec_features'][None, ...].cuda()

        for k in range(40):
            out = model(encoder_states, input_ids)
            input_ids = torch.cat([input_ids, out.logits.argmax(-1)[0, -1][None, None]], dim=-1)

    return tokenizer.decode(input_ids[0]).replace('_', '')

In [None]:
# Prediction on random validation example.
eg = test_ds[np.random.randint(len(test_ds))]
prefix_len = len('Transcription: ')
predicted_text = get_prediction(eg)

print('PRED:', predicted_text[prefix_len:])
print('GOLD:', eg['transcription'][prefix_len:])

In [None]:
def get_prediction_batched(encoder_hidden_states, input_ids, prefix_len):
    with torch.no_grad():
        for k in range(35):
            out = model(encoder_hidden_states, input_ids)
            input_ids = torch.cat([input_ids, out.logits.argmax(-1)[:, -1].unsqueeze(-1)], dim=-1)
    return [tokenizer.decode(ids).replace('_', '')[prefix_len:] for ids in input_ids]

In [None]:
test_preds = []
for batch in test_dl:
    encoder_hidden_states, _ = batch
    input_ids = tokenizer(
        ['Transcription:'] * encoder_hidden_states.shape[0], return_tensors='pt', padding=True
    ).input_ids.cuda()
    
    test_preds += get_prediction_batched(encoder_hidden_states, input_ids, prefix_len)

In [None]:
test_golds = [eg['transcription'][prefix_len:] for eg in test_ds]

In [None]:
import jiwer

transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.Strip(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])

jiwer.wer(
    test_golds, 
    test_preds, 
    truth_transform=transformation, 
    hypothesis_transform=transformation
)