In [None]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

torch.random.manual_seed(0)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(device)

In [None]:
# read audio file
filepath = "../datasets/audio/208-192.mp3"
from pydub import AudioSegment

# convert mp3 to wav
def convert_mp3_to_wav(filepath):
    audio = AudioSegment.from_mp3(filepath)
    wav_filepath = filepath.replace(".mp3", ".wav")
    audio.export(wav_filepath, format='wav')
    return wav_filepath

wav_filepath = convert_mp3_to_wav(filepath)

In [None]:
torchaudio.set_audio_backend("soundfile")
# Load and resample audio
waveform, sample_rate = torchaudio.load(wav_filepath)

In [None]:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

In [None]:
waveform

In [None]:
from datasets import load_dataset
parentNaId = "653144"
filepath = f'../datasets/{parentNaId}_transcriptions_with_audio.parquet'
# Load dataset from CSV
dataset = load_dataset('parquet', data_files=filepath)

In [None]:
dataset["train"][0]['audio_filepaths'][0]
# create new column to store the filepath after conversion
# audio_filepaths are lists, only take the first element for conversion
dataset = dataset.map(lambda x: {'audio_filepath_1st': [fp[0].replace("./", "../datasets/") for fp in x['audio_filepaths']]}, batched=True)
dataset = dataset.map(lambda x: {'transcription_str': [next(iter(t.values()))['transcription'] for t in x['transcription']]}, batched=True)
dataset = dataset.map(lambda x: {'audio_filepath_1st': [convert_mp3_to_wav(fp) for fp in x['audio_filepath_1st']]}, batched=True)

In [None]:
dataset["train"][0]['audio_filepath_1st']

In [None]:
dataset["train"][0]['transcription_str']

In [None]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-large-960h")

def preprocess_text(batch):
    batch["input_ids"] = tokenizer(batch["transcription_str"], padding=True, truncation=True).input_ids
    return batch

    # Convert tokenized output to a numpy array to ensure consistent dtype
    tokenized = tokenizer(texts, padding=True, truncation=True, return_tensors='np')
    batch["input_ids"] = tokenized["input_ids"].tolist()
    batch["attention_mask"] = tokenized["attention_mask"].tolist()
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess_text, batched=True)

In [None]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer

# Load tokenizer and model
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")

In [None]:
# Update model configuration for new vocabulary size
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = len(tokenizer)

In [None]:
train_dataset = dataset['train']
eval_dataset = dataset['train']


In [None]:
from torch.utils.data import DataLoader

# Define a custom collator
def data_collator(batch):
    audio_features = [item["input_values"] for item in batch]
    labels = [item["labels"] for item in batch]
    return {"input_values": audio_features, "labels": labels}

# Create DataLoader
train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator
)
eval_loader = DataLoader(
    eval_dataset, batch_size=16, shuffle=False, collate_fn=data_collator
)

In [None]:
for i in train_loader:
    print(i)
    break

In [None]:
dataset['train']

In [None]:
import torch
from torch.optim import AdamW

# Define optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

# Training loop
for epoch in range(5):
    model.train()
    for batch in train_loader:
        # Use "input_ids" as both input and label, to avoid KeyError
        input_ids = batch["input_ids"]
        labels = batch["input_ids"]

        # Forward pass
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1} completed with loss {loss.item():.4f}")