Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

[STC] STC loss ascends while training #21

Closed
LEECHOONGHO opened this issue Mar 21, 2022 · 3 comments
Closed

[STC] STC loss ascends while training #21

LEECHOONGHO opened this issue Mar 21, 2022 · 3 comments

Comments

@LEECHOONGHO
Copy link

LEECHOONGHO commented Mar 21, 2022

Hello, I'm training ASR model with STC Loss and letter-to-word encoder like below.
But when I progress training, STC Loss ascended and became 'Inf' after 12000 step.

Is there any miss in my implementation?
Any help would be appreciated.
Thank you.

training args:

  • num_gpu : 4
  • audio_length_per_gpu : 160s
  • lr : 0.0001~0.001
  • use FullyShardedDataParallel, mix_precision=off
#   max_word_length : 10
#   n_letter_symbols : 69 (blank + pad + korean)
#   n_word_symbols : 12158 (blank + korean_morph 97% in corpus)

#   blank_idx=0, p0=0.05, plast=0.15, thalf=16000

self.criterion = STC(
    blank_idx=self.cfg.blank_idx, 
    p0=self.cfg.p0, 
    plast=self.cfg.plast, 
    thalf=self.cfg.thalf, reduction="mean"
)

#   model_output : Tensor[batch_size, max_frame_length, n_letter_symbols*max_word_length]
#   self.l2w_matrix : Tensor[n_letter_symbols*max_word_length, n_word_symbols]

word_level_output = model_output @ self.l2w_matrix

#   word_level_output : Tensor[batch_size, max_frame_length, n_word_symbols]

word_level_output =  F.log_softmax(word_level_output.transpose(1, 0), dim=-1)

loss = self.criterion(word_level_output, word_labels)

stcloss

@vineelpratap
Copy link
Contributor

vineelpratap commented Mar 21, 2022

Hi,
Could you give a few details about the dataset you are using and how partial labels are created. Also, Have you tried CTC with the same training setup?

As a sanity check, you can output Tensor[batch_size, max_frame_length, n_word_symbols] from the model directly. This will make sure there are no errors with l2w_matrix creation.

Just to make sure have you made sure word_labels are 1-indexed to account for blank symbol ?

@LEECHOONGHO
Copy link
Author

LEECHOONGHO commented Mar 21, 2022

Hello. @vineelpratap
I'm using mixed korean asr dataset(num_audio=3180000) from aihub.
How I get l2w_matrix is like below

Now I'm trying to train with ctc(gtn_application). Thank you for your advice.

def get_l2w_matrix(self):
    # letters = {'BLANK':0, 'PAD':1, 'letter1':2, .....}
    # morphs = {'BLANK':0, 'word1':1, .....}

    with open(self.config.letter_dict_path, 'r', encoding='utf8') as j:
        letters = json.load(j)
    with open(self.config.word_dict_path, 'r', encoding='utf8') as j:
        words= json.load(j)
    
    E_matrix = torch.zeros(self.config.max_word_length, len(letters ), len(words)).bool()

    # set one hot vector
    for word, word_idx in words.items():
        if word == 'BLANK':
            E_matrix[0, letters['BLANK'], word_idx] = True
        else:
            for letter_idx, letter in enumerate(word2letter(word)):
                E_matrix[letter_idx, letters[letter], word_idx] = True

    # padding
    padding_location = ~torch.any(E_matrix, 1)
    E_matrix[:, letters['PAD'], :][padding_location] = True
    
    torch.save(E_matrix, self.config.e_matrix_path)
    
    return E_matrix.half().view(-1, E_matrix.shape[-1])

@LEECHOONGHO
Copy link
Author

LEECHOONGHO commented Mar 22, 2022

When I changed thalf 16000 -> 8000, and model to output word level directly, STC loss increases slightly after 8k step.

More details about by model is,

  1. I input normalized wav directly to wav2vec2.0 like model(WavLM)
  2. change total stride of wav feature extractor 320 -> 1280 (80ms per output frame)
  3. number of word per second in audio data is 0.5~6.5 (much unvoiced segment in audio)

loss2

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants