In [13]:
import torch
import torch.nn as nn
import torch.optim as optim

In [14]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

In [15]:
from interface.TimitInterface import TimitInterface
from phoneme import initialize_model, get_needed_data

In [16]:
#processor, model, device = initialize_model()
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You sho

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder)

In [17]:
dataset = TimitInterface()

0it [00:00, ?it/s]

Map:   0%|          | 0/2688 [00:00<?, ? examples/s]

Map:   0%|          | 0/336 [00:00<?, ? examples/s]

Map:   0%|          | 0/336 [00:00<?, ? examples/s]

In [18]:
val_data = dataset.valid
train_data = dataset.train

In [19]:
data = {}
#load data

In [20]:
train_output = list(map(lambda v: (v['audio_file'], v['word_file'], v['phonetic_file']), train_data.values()))
val_output = list(map(lambda v: (v['audio_file'], v['word_file'], v['phonetic_file']), val_data.values()))

In [21]:
print(len(train_output))

2688


In [22]:
print(len(val_output))

336


In [23]:
# Second define your dataloader to give inputs, targets, input_length and target_length
train_dataloader = []
val_dataloader = []
for wav_path, txt_path, phn_path in train_output:
    train_dataloader.append(get_needed_data(wav_path, txt_path, phn_path))

for wav_path, txt_path, phn_path in val_output:
    val_dataloader.append(get_needed_data(wav_path, txt_path, phn_path))

ctc_loss = nn.CTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 1 #3
max_grad_norm = 1.0

Audio shape: (79771,)
Token shape: torch.Size([1, 79771])
Phoneme abstract: [42, 13, 56, 1, 53, 21, 48, 9, 53, 21, 0, 40, 57, 32, 41, 18, 59, 11, 35, 55, 28, 40, 13, 49, 13, 46, 3, 42, 60, 46, 3, 23, 9, 39, 3, 33, 48, 11, 6]
Phoneme token: tensor([42, 13, 56,  1, 53, 21, 48,  9, 53, 21,  0, 40, 57, 32, 41, 18, 59, 11,
        35, 55, 28, 40, 13, 49, 13, 46,  3, 42, 60, 46,  3, 23,  9, 39,  3, 33,
        48, 11,  6], dtype=torch.int32)
Audio shape: (65537,)
Token shape: torch.Size([1, 65537])
Phoneme abstract: [40, 12, 49, 0, 33, 45, 22, 8, 59, 22, 13, 1, 37, 12, 41, 33, 10, 45, 40, 13, 41, 11, 37, 12, 34, 2, 35, 12, 49, 52, 19, 13, 57, 32, 2, 34, 12, 36, 40, 0, 58, 38, 27, 6, 40, 8, 49, 5, 33, 18, 23, 1, 57, 42, 26]
Phoneme token: tensor([40, 12, 49,  0, 33, 45, 22,  8, 59, 22, 13,  1, 37, 12, 41, 33, 10, 45,
        40, 13, 41, 11, 37, 12, 34,  2, 35, 12, 49, 52, 19, 13, 57, 32,  2, 34,
        12, 36, 40,  0, 58, 38, 27,  6, 40,  8, 49,  5, 33, 18, 23,  1, 57, 42,
        26], dtype

In [27]:
count = 0
for epoch in range(num_epochs):
    model.train()
    for _, _, inputs, targets, token_len, target_lengths in train_dataloader:
        count += 1
        optimizer.zero_grad()

        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs).logits
        outputs = outputs.transpose(0, 1)  # (seq_len, batch_size, vocab_size)

        if torch.isnan(outputs).any():
            print("NaN detected in outputs")
            continue

        # Calculate input_lengths based on the shape of outputs
        input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.int32)

        log_probs = torch.log_softmax(outputs, dim=-1)
        predicted_indices = torch.argmax(log_probs, dim=-1)

        print(f"Predicted indices: {predicted_indices}")

        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
        loss.backward()

        if count % 10 == 0:
            print(f"Loss at step {count}: {loss.item()}")

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        if count == 200:
            break

    # Validation
    model.eval()
    val_loss = 0
    count = 0
    if len(val_dataloader) == 0:
        print("Validation DataLoader is empty, skipping validation.")
    else:
        with torch.no_grad():
            for _, _, inputs, targets, token_len, target_lengths in val_dataloader:
                count += 1
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs).logits
                outputs = outputs.transpose(0, 1)
                
                input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.int32)
                target_lengths = target_lengths.to(torch.int32)
                log_probs = torch.log_softmax(outputs, dim=-1)

                loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
                val_loss += loss.item()

                if count == 50:
                    break

        val_loss /= len(val_dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {val_loss:.4f}")


Validation DataLoader is empty, skipping validation.
Validation DataLoader is empty, skipping validation.
Validation DataLoader is empty, skipping validation.
