In [1]:
import sys
sys.path.append('../')
import pandas as pd
import torch
from dataset import ASRDataset
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor
) 

In [2]:
class Config:
    df_fn = '../data/metadata.csv'
    vocab_fn = "../vocab/vocab.json"
    weight_fn = '/home/hyunseoki_rtx3090/ssd1/02_src/speech_recognition/K-wav2vec_finetune_v2/wav2vec2_baseline.pt'
    wav_dir = '/home/hyunseoki_rtx3090/ssd1/01_dataset/aihub/KsponSpeech/wav'

    device = 'cpu'


args = Config()

In [3]:
tokenizer = Wav2Vec2CTCTokenizer(
    args.vocab_fn, 
    unk_token="[UNK]",
    pad_token="[PAD]",
    word_delimiter_token="__"
)
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, 
    sampling_rate=16000, 
    padding_value=0.0, 
    do_normalize=True, 
    return_attention_mask=False
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained(
    'facebook/wav2vec2-base',
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size = len(tokenizer),
)
model.load_state_dict(torch.load(args.weight_fn, map_location='cpu'))
model.to(args.device)
model.eval();

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['quantizer.weight_proj.bias', 'quantizer.weight_proj.weight', 'project_q.bias', 'quantizer.codevectors', 'project_hid.bias', 'project_hid.weight', 'project_q.weight']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti

In [8]:
df = pd.read_csv(args.df_fn)
valid_df = df[3000:4000]
valid_df.reset_index(inplace=True)
valid_dataset = ASRDataset(wav_dir=args.wav_dir, df=valid_df, processor=processor)

def ctc_data_collator(batch):
    """
    Custom data collator function to dynamically pad the data
    """
    input_features = [{"input_values": sample["audio"]} for sample in batch]
    label_features = [{"input_ids": sample["label"]} for sample in batch]
    batch = processor.pad(
        input_features,
        padding=True,
        return_tensors="pt",
    )
    with processor.as_target_processor():
        labels_batch = processor.pad(
            label_features,
            padding=True,
            return_tensors="pt",
        )
        
    labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
    batch["labels"] = labels
    return batch

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=1,
    collate_fn=ctc_data_collator,
    num_workers=2,
    pin_memory=False if args.device == 'cpu' else True,
    persistent_workers=True,
)
iterator = iter(valid_loader)

In [9]:
data = next(iterator)

In [10]:
logits = model(data['input_values'].to(args.device)).logits

In [11]:
pred_ids = torch.argmax(logits, dim=-1)
pred_ids.shape

torch.Size([1, 1428])

In [19]:
len(pred_ids.squeeze()[50])

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc831807920>
Traceback (most recent call last):
  File "/home/hyunseoki_rtx3090/mambaforge/envs/speech_recognition/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/hyunseoki_rtx3090/mambaforge/envs/speech_recognition/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/hyunseoki_rtx3090/mambaforge/envs/speech_recognition/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hyunseoki_rtx3090/mambaforge/envs/speech_recognition/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hyunseoki_rtx3090/mambaforge/envs/speech_recognitio

TypeError: len() of a 0-d tensor

In [20]:
processor.batch_decode(pred_ids)

['']

In [21]:
processor.batch_decode(data['labels'])

['어[UNK]나한테[UNK]배워야[UNK]되겠네[UNK]난[UNK]스키를[UNK]군대[UNK]가서[UNK]처음[UNK]타가지고[UNK]어[UNK]한[UNK]세[UNK]번[UNK]타니까[UNK]그런[UNK]십[UNK]일[UNK]자[UNK]이런[UNK]거[UNK]다[UNK]하고[UNK]재미가[UNK]없더라고[UNK]그래서[UNK]이제[UNK]보드로[UNK]넘어왔지[UNK]보드[UNK]타면[UNK]확실히[UNK]젊은[UNK]사람들한테는[UNK]보드가[UNK]스릴[UNK]있고[UNK]재밌는[UNK]거[UNK]같애[UNK]그서[UNK]내가[UNK]보드[UNK]많은[UNK]사람들한테[UNK]쫌[UNK]알렸지[UNK]가르쳐[UNK]준[UNK]사람도[UNK]많고[UNK]그리고']

In [22]:
from IPython.display import Audio

print(processor.batch_decode(data['labels'])[0].replace('[UNK]', ' '))
Audio(data['input_values'].numpy().squeeze(), rate=16000)

어 나한테 배워야 되겠네 난 스키를 군대 가서 처음 타가지고 어 한 세 번 타니까 그런 십 일 자 이런 거 다 하고 재미가 없더라고 그래서 이제 보드로 넘어왔지 보드 타면 확실히 젊은 사람들한테는 보드가 스릴 있고 재밌는 거 같애 그서 내가 보드 많은 사람들한테 쫌 알렸지 가르쳐 준 사람도 많고 그리고
