In [2]:
from data.libre_speech import LibreSpeechDataset

ds_train = LibreSpeechDataset(data_dir="hf-internal-testing/librispeech_asr_dummy", split="validation", streaming=False)
ds_test = ds_train

ds = ds_train.get_dataset()

print(ds_train)
print(ds_test)
print(ds)

<data.libre_speech.LibreSpeechDataset object at 0x7074dc1577d0>
<data.libre_speech.LibreSpeechDataset object at 0x7074dc1577d0>
Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})


In [3]:
import torch
from torch.utils.data import DataLoader

data_loader_train = DataLoader(ds_train, batch_size=16, collate_fn=ds_train.collate, num_workers=1)
data_loader_test = DataLoader(ds_test, batch_size=16, collate_fn=ds_test.collate, num_workers=1)

for data, mel, text in data_loader_train:
  print('data', data)
  print('mel', mel)
  print('text', text)
  break


data tensor([[ 2.3804e-03,  2.0752e-03,  1.9836e-03,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.5259e-04, -9.1553e-05, -1.8311e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-6.7139e-04,  6.1035e-05,  5.1880e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 7.0190e-04,  5.7983e-04,  3.3569e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.5259e-04, -9.1553e-05, -7.0190e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-1.8311e-04, -3.3569e-04, -2.1362e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])
mel tensor([[[ 1.1933e-01, -9.4576e-02, -1.0978e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         [ 4.9347e-04, -8.9271e-02, -6.7290e-02,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         [-1.5326e-01, -2.0804e-01, -2.2227e-01,  ..., -8.0603e-01,
          -8.0603e-01, -8.0603e-01],
         ...,
         [-8.0603e-01, -8.0603

In [4]:
import whisper

# Load the model
model = whisper.load_model("tiny.en")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [18]:
def evaluate_model(model, data_loader_test):
  model.eval()

  tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
  print(tokenizer)
  start_token = torch.tensor([[tokenizer.sot]], dtype=torch.long, device=device)
  pad_token = tokenizer.eot

  accuracy = []

  for i, (_, mel, text) in enumerate(data_loader_test):
    # Encode all texts in batch
    target_ids = [tokenizer.encode(t) for t in text]
    
    # Convert to padded tensor in using pad_sequence
    target_ids = [torch.tensor(ids, dtype=torch.long, device=device) for ids in target_ids]
    target_ids = torch.nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=pad_token)
    
    input_tks = torch.cat([start_token.repeat(len(text), 1), target_ids], dim=1)

    #Forward pass
    predictions = model(tokens=input_tks, mel=mel)
    remove_sot = input_tks[:, 1:]
    predictions = predictions[:, :-1, :]

    # Calculate accuracy
    pred_tokens = predictions.argmax(dim=-1)
    correct = (pred_tokens == remove_sot).sum().item()
    total = remove_sot.numel()
    accuracy.append(correct / total)
    if i > 128:
      break
  average_accuracy = sum(accuracy) / len(accuracy)
  print(f"Average accuracy: {average_accuracy * 100:.2f}%")

evaluate_model(model, data_loader_test)


Tokenizer(encoding=<Encoding 'gpt2.tiktoken'>, num_languages=99, language=None, task=None, sot_sequence=(50257,), special_tokens={'<|16.72|>': 51199, '<|0.54|>': 50390, '<|8.86|>': 50806, '<|14.62|>': 51094, '<|8.68|>': 50797, '<|5.46|>': 50636, '<|21.46|>': 51436, '<|1.84|>': 50455, '<|15.38|>': 51132, '<|0.36|>': 50381, '<|12.32|>': 50979, '<|14.96|>': 51111, '<|4.40|>': 50583, '<|2.48|>': 50487, '<|12.80|>': 51003, '<|18.66|>': 51296, '<|13.10|>': 51018, '<|23.82|>': 51554, '<|24.92|>': 51609, '<|2.38|>': 50482, '<|2.92|>': 50509, '<|22.82|>': 51504, '<|28.94|>': 51810, '<|15.74|>': 51150, '<|26.78|>': 51702, '<|15.32|>': 51129, '<|5.06|>': 50616, '<|19.86|>': 51356, '<|1.50|>': 50438, '<|12.72|>': 50999, '<|5.60|>': 50643, '<|nn|>': 50341, '<|6.98|>': 50712, '<|29.56|>': 51841, '<|16.00|>': 51163, '<|27.66|>': 51746, '<|21.26|>': 51426, '<|26.98|>': 51712, '<|12.66|>': 50996, '<|fa|>': 50299, '<|27.96|>': 51761, '<|21.38|>': 51432, '<|14.36|>': 51081, '<|ru|>': 50262, '<|9.94|>': 5

In [14]:
def train_model(model, data_loader_train, epoch=0):
  model.train()

  tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
  start_token = tokenizer.sot
  pad_token = tokenizer.eot

  # Define the optimizer and criterion
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
  criterion = torch.nn.CrossEntropyLoss()

  for i, (_, mel, text) in enumerate(data_loader_train):
    target_ids = [tokenizer.encode(t) for t in text]
    target_ids = [torch.tensor(ids, dtype=torch.long, device=device) for ids in target_ids]
    target_ids = torch.nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=pad_token)
    start_token_tensor = torch.tensor([start_token], dtype=torch.long, device=device).repeat(len(text), 1)
    input_tks = torch.cat([start_token_tensor, target_ids], dim=1)

    #Forward pass
    predictions = model(tokens=input_tks, mel=mel)
    remove_sot = input_tks[:, 1:]
    predictions = predictions[:, :-1, :]
    loss = criterion(predictions.transpose(1, 2), remove_sot)

    #Backward pass
    optimizer.zero_grad()
    loss = criterion(predictions.transpose(1, 2), target_ids)
    loss.backward()
    optimizer.step()
    
    print(f"{i}/{len(data_loader_train)}")
    if i > 128:
      break
  print(f"Epoch {epoch+1}, Loss: {loss.item()}")

train_model(model, data_loader_train)


0/5
1/5
2/5
3/5
4/5
Epoch 1, Loss: 1.387174367904663


In [15]:
evaluate_model(model, data_loader_test)

Tokenizer(encoding=<Encoding 'gpt2.tiktoken'>, num_languages=99, language=None, task=None, sot_sequence=(50257,), special_tokens={'<|16.72|>': 51199, '<|0.54|>': 50390, '<|8.86|>': 50806, '<|14.62|>': 51094, '<|8.68|>': 50797, '<|5.46|>': 50636, '<|21.46|>': 51436, '<|1.84|>': 50455, '<|15.38|>': 51132, '<|0.36|>': 50381, '<|12.32|>': 50979, '<|14.96|>': 51111, '<|4.40|>': 50583, '<|2.48|>': 50487, '<|12.80|>': 51003, '<|18.66|>': 51296, '<|13.10|>': 51018, '<|23.82|>': 51554, '<|24.92|>': 51609, '<|2.38|>': 50482, '<|2.92|>': 50509, '<|22.82|>': 51504, '<|28.94|>': 51810, '<|15.74|>': 51150, '<|26.78|>': 51702, '<|15.32|>': 51129, '<|5.06|>': 50616, '<|19.86|>': 51356, '<|1.50|>': 50438, '<|12.72|>': 50999, '<|5.60|>': 50643, '<|nn|>': 50341, '<|6.98|>': 50712, '<|29.56|>': 51841, '<|16.00|>': 51163, '<|27.66|>': 51746, '<|21.26|>': 51426, '<|26.98|>': 51712, '<|12.66|>': 50996, '<|fa|>': 50299, '<|27.96|>': 51761, '<|21.38|>': 51432, '<|14.36|>': 51081, '<|ru|>': 50262, '<|9.94|>': 5