In [19]:
# Imports 
from datasets import load_dataset
import os
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, MBart50Tokenizer, MBartForConditionalGeneration



In [20]:
# Load the europarl dataset training split
dataset = load_dataset("tj-solergibert/Europarl-ST", split="train")


In [21]:
# Filter the dataset where the speech is english, there is a french transcription, and the audio file is present 
dataset = dataset.filter(lambda example: example['original_language'] == 'en' and 'fr' in example['transcriptions'].keys() and example["transcriptions"]['fr'] is not None and os.path.exists(example["audio_path"]))
print(dataset)

# Select a subset of the filtered dataset for training
dataset = dataset.select(range(20))
print(dataset)

# Display some examples from the filtered dataset
for example in dataset.select(range(5)):
    print(example)

Dataset({
    features: ['original_speech', 'original_language', 'audio_path', 'segment_start', 'segment_end', 'transcriptions'],
    num_rows: 31777
})
Dataset({
    features: ['original_speech', 'original_language', 'audio_path', 'segment_start', 'segment_end', 'transcriptions'],
    num_rows: 20
})
{'original_speech': 'Mr President, I know that I will not be popular for making a long speech at this time, but my two fellow-rapporteurs, with whom I have worked very closely as a team, have made short statements so I want to keep the team spirit together.', 'original_language': 'en', 'audio_path': 'en/audios/en.20080924.23.3-123.m4a', 'segment_start': 0.0, 'segment_end': 14.470000267028809, 'transcriptions': {'de': 'Herr Präsident! Ich weiß, dass ich mir keine Freunde mache, wenn ich um diese Uhrzeit eine lange Rede halte, doch meine beiden Mitberichterstatter, mit denen ich sehr eng im Team zusammengearbeitet habe, haben kurze Stellungnahmen abgegeben, sodass ich den Teamgeist zusammen

In [22]:
# Load models and processors/tokenizers for wav2vec and marian
wav2vec2_processor_name = "facebook/wav2vec2-base-960h"
wav2vec2_processor = Wav2Vec2Processor.from_pretrained(wav2vec2_processor_name)
print("Loaded wav2vec2 processor")

# Load wav2vec2.0 model
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained(wav2vec2_processor_name)
print("Loaded wav2vec2 model")

# Load mBART tokenizer
mbart_tokenizer_name = "facebook/mbart-large-50"
tokenizer = MBart50Tokenizer.from_pretrained(mbart_tokenizer_name, src_lang="en_XX", tgt_lang="fr_XX")
print("Loaded mBART tokenizer")

# Load mBART model
mbart_model = MBartForConditionalGeneration.from_pretrained(mbart_tokenizer_name)
print("Loaded mBART model")


Loaded wav2vec2 processor


Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You sho

Loaded wav2vec2 model
Loaded mBART tokenizer
Loaded mBART model


In [23]:
# device config 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec2_model.to(device)
mbart_model.to(device)

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): Embedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (fi

In [26]:
import librosa 

audio_directory = "."

# preprocesses the dataset for training 
def preprocess(batch):
    # gets audio meta data
    audio_path = batch["audio_path"]
    segment_start = batch["segment_start"]
    segment_end = batch["segment_end"]
    
    # load the audio
    audio, sr = librosa.load(audio_path, sr=16000, offset=segment_start, duration=segment_end-segment_start)
    
    # process the audio using wav2vec
    inputs = wav2vec2_processor(audio, sampling_rate=sr, return_tensors="pt")
    
    # pad the input to be length 3000 for wav2vec 
    input_features = inputs.input_features.squeeze(0)
    # if input_features.shape[-1] < 3000:
    #     padding = torch.zeros((input_features.shape[0], 3000 - input_features.shape[-1]))
    #     input_features = torch.cat([input_features, padding], dim=-1)
    
    batch["input_features"] = input_features
    
    # tokenize the transcription of the original speech for wav2vec labels 
    original_transcription = batch["original_speech"]
    tokenized_originals = wav2vec2_processor.tokenizer(
        original_transcription,
        return_tensors="pt",
        padding="longest"
    ).input_ids.squeeze(0)
    
    batch["english_text"] = tokenized_originals
    
    french_transcription = batch["transcriptions"]["fr"]
    
    # Tokenize the target French text translation for marian
    tokenized_labels = tokenizer(
        text_target=french_transcription,
        return_tensors="pt",
        padding="longest"
    ).input_ids.squeeze(0)
    
    batch["labels"] = tokenized_labels
    
    return batch
print(dataset)
# Apply the preprocessing function to the dataset
dataset = dataset.map(preprocess, remove_columns=["audio_path", "original_speech", "original_language", "segment_start", "segment_end", "transcriptions"])


Dataset({
    features: ['original_speech', 'original_language', 'audio_path', 'segment_start', 'segment_end', 'transcriptions'],
    num_rows: 20
})


  audio, sr = librosa.load(audio_path, sr=16000, offset=segment_start, duration=segment_end-segment_start)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
Map:   0%|          | 0/20 [00:00<?, ? examples/s]


AttributeError: 

In [None]:
# collate function for tensor and padding before dataloader
def collate_fn(batch):
    input_features = [item['input_features'] for item in batch]
    labels = [item['labels'] for item in batch]
    english_text = [item['english_text'] for item in batch]
    
    # Convert all inputs to tensors if they are not already
    input_features = [torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in input_features]
    labels = [torch.tensor(l) if not isinstance(l, torch.Tensor) else l for l in labels]
    english_text = [torch.tensor(l) if not isinstance(l, torch.Tensor) else l for l in english_text]
    
    
    # pad sequences
    input_features = pad_sequence(input_features, batch_first=True)
    labels = pad_sequence(labels, batch_first=True, padding_value=mt_tokenizer.pad_token_id)
    english_text = pad_sequence(english_text, batch_first=True)
    
    
    return {
        'input_features': input_features,
        'labels': labels, 
        'english_text': english_text
    }

# create data loader using collate function
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
print(dataloader)
# optimizers
wav2vec_optimizer = AdamW(wav2vec2_model.parameters(), lr=5e-5)
mt_optimizer = AdamW(mbart_model.parameters(), lr=5e-5)

# training loop
num_epochs = 3
output_dir = "./fine_tuned_models"

        
for epoch in range(3):
    wav2vec2_model.train()
    mbart_model.train()
    
    epoch_loss = 0.0
    for batch in tqdm(dataloader):
        
    
        input_features = (batch["input_features"])
        transcription = batch["english_text"]
        target_ids = batch["labels"]
    
        # forward pass through wav2vec model with english transcription labels
        wav2vec_outputs = wav2vec2_model(input_features, labels=transcription)
        
        predicted_ids = wav2vec_outputs.logits.argmax(dim=-1)
        
        # decode the predicted ids to text
        predicted_texts = wav2vec2_processor.batch_decode(predicted_ids, skip_special_tokens=True)
        print(predicted_texts)

        # tokenize the decoded text for the translation model
        translation_inputs = mbart_tokenizer(predicted_texts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

        # forward pass through Translation model with wav2vec's transcription as input
        translation_outputs = mbart_model(input_ids=translation_inputs, labels=target_ids)

        # combine the losses for end-to-end training
        # combined_loss = wav2vec_outputs.loss + translation_outputs.loss
        # ombined_loss.backward()
       
        wav2vec_outputs.loss.backward()
        translation_outputs.loss.backward()
        
        wav2vec_optimizer.step()
        mt_optimizer.step()
        
        wav2vec_optimizer.zero_grad()
        mt_optimizer.zero_grad()

        epoch_loss += wav2vec_outputs.loss.item() + translation_outputs.loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss}")

    # save the models
    wav2vec2_model.save_pretrained(f"./fine_tuned_models/wav2vec_epoch_{epoch + 1}")
    wav2vec2_processor.save_pretrained(f"./fine_tuned_models/wav2vec_epoch_{epoch + 1}")
    mbart_model.save_pretrained(f"./fine_tuned_models/mt_epoch_{epoch + 1}")
    mbart_tokenizer.save_pretrained(f"./fine_tuned_models/mt_epoch_{epoch + 1}")
    

In [None]:
test_dataset = load_dataset("tj-solergibert/Europarl-ST", split="test")
test_dataset = test_dataset.filter(lambda example: example['original_language'] == 'en' and 'fr' in example['transcriptions'].keys() and example["transcriptions"]['fr'] is not None)
print(len(test_dataset))
print(test_dataset)
test_dataset = test_dataset.select(range(10))
test_dataset = test_dataset.map(preprocess, remove_columns=["audio_path", "original_speech", "original_language", "segment_start", "segment_end", "transcriptions"])

# Create DataLoader with custom collate_fn
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)


# Paths to the saved models
wav2vec_model_path = "./fine_tuned_models/wav2vec_epoch_3"  # Update with your final epoch
mt_model_path = "./fine_tuned_models/mt_epoch_3"  # Update with your final epoch

# Load the models
# wav2vec_processor = wav2vecProcessor.from_pretrained(wav2vec_model_path)
# wav2vec_model = wav2vecForConditionalGeneration.from_pretrained(wav2vec_model_path)
# mt_tokenizer = MarianTokenizer.from_pretrained(mt_model_path)
# mt_model = MarianMTModel.from_pretrained(mt_model_path)

wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
wav2vec_model =  Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
mt_tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50")
mt_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wav2vec_model.to(device)
mt_model.to(device)

wav2vec_model.eval()
mt_model.eval()
    
epoch_loss = 0.0
for batch in tqdm(test_dataloader):
    input_features = batch["input_features"].to(device)

    original_speech = wav2vec_processor.batch_decode(batch["english_text"].to(device), skip_special_tokens=True)[0]
    print("Original Speech", original_speech)
    # Generate transcription using wav2vec model
    with torch.no_grad():
        wav2vec_outputs = wav2vec_model.generate(
            input_features,
            num_beams=5, 
            repetition_penalty=1.2, 
            no_repeat_ngram_size=2, 
            temperature=0.7,  
            top_k=50,  
            top_p=0.95  
        )
        
        transcription = wav2vec_processor.batch_decode(wav2vec_outputs, skip_special_tokens=True)[0]

    print("Generated Transcription:", transcription)

    # Translate the transcription using the translation model
    tokenized_transcription = mt_tokenizer(transcription, return_tensors="pt", padding="longest", truncation=True)
    tokenized_transcription = tokenized_transcription.input_ids.to(device)

    with torch.no_grad():
        translated_tokens = mt_model.generate(tokenized_transcription)
        translation = mt_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

    print("Translation to French:", translation)

avg_loss = epoch_loss / len(dataloader)
print(f"Loss: {avg_loss}")
    


In [None]:
# Calculate BLEU score of the original model and the trained model 
import sacrebleu
import jiwer

bleu_scores = []
wer_scores = []

for batch in tqdm(test_dataloader):
    input_features = batch["input_features"].to(device)
    target_ids = batch["labels"].to(device)

    original_speech = wav2vec_processor.batch_decode(batch["english_text"].to(device), skip_special_tokens=True)[0]
    
    print("Original Speech", original_speech)
    # Generate transcription using wav2vec model
    with torch.no_grad():
        wav2vec_outputs = wav2vec_model.generate(
            input_features,
            num_beams=5, 
            repetition_penalty=1.2, 
            no_repeat_ngram_size=2, 
            temperature=0.7,  
            top_k=50,  
            top_p=0.95  
        )
        
        transcription = wav2vec_processor.batch_decode(wav2vec_outputs, skip_special_tokens=True)[0]

    # print("Generated Transcription:", transcription)

    # Translate the transcription using the translation model
    tokenized_transcription = mt_tokenizer(transcription, return_tensors="pt", padding="longest", truncation=True)
    tokenized_transcription = tokenized_transcription.input_ids.to(device)

    with torch.no_grad():
        translated_tokens = mt_model.generate(tokenized_transcription)
        translation = mt_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

    # print("Translation to French:", translation)

    # Calculate BLEU score
    reference = mt_tokenizer.decode(target_ids[0], skip_special_tokens=True)
    hypothesis = translation 
    bleu = sacrebleu.corpus_bleu([hypothesis], [[reference]])
    bleu_score = bleu.score
    bleu_scores.append(bleu_score)
    wer = jiwer.wer(reference, hypothesis)
    wer_scores.append(wer)

    print(f"Transcription: {transcription}")
    print(f"Generated Translation: {translation}")
    print(f"Reference Translation: {reference}")
    print(f"BLEU Score: {bleu_score}\n")
    print(f"WER: {wer}\n")

avg_bleu_score = sum(bleu_scores) / len(bleu_scores)
avg_wer_score = sum(wer_scores) / len(wer_scores)

print(f"Average BLEU Score: {avg_bleu_score}")
print(f"Average WER: {avg_wer_score}")