<a href="https://colab.research.google.com/github/dbstj1231/2023_AI_Academy_ASR/blob/main/7_WFST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Nemo+K2 ASR Demo

materials

logits, indexed words, indexed tokens and lexicon from Nemo
Arpa format from n-gram models


## k2-fsa(k2)
-  OpenFST를 개량한 python 기반의 WFST 패키지
- End-to-end 모델과 WFST를 결합할 목적으로 제작됨
- FSA, FST를 tensor 형태로 구현하여 GPU에서 WFST 연산을 가능하게 함
- https://github.com/k2-fsa/k2

## K2 설치 및 기본 환경 세팅

### Prerequisities (pip 기준)
- Python >= 3.6
- CUDA >= 10.1
- PyTorch == 1.7.1 (conda 설치의 경우 >=1.7.1)

### PyPI (PIP)를 이용한 설치 방법 (권장, but colab에선 비권장)
- pip install k2 (Prerequisities 자동으로 설치됨)

### Colab에서 돌아가는 pip 설치 방법
- ! pip install torch==1.7.1
- ! pip install k2==1.17.dev20220710+cuda10.2.torch1.7.1 -f https://k2-fsa.org/nightly/

### Source code (github)을 이용한 설치 방법
- git clone https://github.com/k2-fsa/k2.git
- cd k2
- python3 setup.py install

In [None]:
!pip install torch==1.13.1 torchaudio==0.13.1
!pip install k2==1.24.3.dev20230629+cpu.torch1.13.1 -f https://k2-fsa.org/nightly/

In [2]:
import k2, torch

In [None]:
import k2.version
k2.version.version.main()

## K2를 이용한 WFST 구현

### Weighted Finite-state Acceptor(WFSA)
- StartState \t EndState \t Symbol(index) \t Weight
- Symbol은 기본적으로 index로 표현/계산됨
- 시각화할 일이 있을 때만 symbol을 표현
- Symbol table 정의가 필요함
- 일반적으로 k2.Fsa.from_str()를 사용하여 생성함
- draw() or to_dot() methods를 이용하여 시각화

## WFSA creation

In [4]:
s = '''
0 1 1 1
0 1 3 5
0 2 1 3
0 2 2 4
0 2 3 7
0 2 4 8
0 2 5 9
1 3 6 9
1 3 5 8
2 3 6 12
3 4 -1 0
4
'''

In [5]:
sym_str = '''
  <eps> 0
 a 1
 b 2
 c 3
 d 4
 e 5
 f 6
'''

###  Weighted Finite-state Transducer(WFST)
- WFST: StartState \t EndState \t InputSymbol \t OutputSymbol\t Weight
- 시각화된 표현은 arc당 InputSymbol:OutputSymbol/weight
- CTC, lexicon, language model의 최적화된 표현을 위해
- k2.Fsa.from_str의 acceptor=False option을 이용하여 생성

In [None]:
s = '''
0 1 1 1 1
0 1 3 3 5
0 2 1 1 3
0 2 2 2 4
0 2 3 3 7
0 2 4 4 8
0 2 5 5 9
1 3 6 6 9
1 3 5 5 8
2 3 6 6 12
3 4 -1 -1 0
4
'''

#a_fsa = k2.Fsa.from_str(s)


sym_str = '''
 <eps> 0
 a 1
 b 2
 c 3
 d 4
 e 5
 f 6
'''
#a_fsa.symbols = k2.SymbolTable.from_str(sym_str)
#a_fsa.labels_sym = k2.SymbolTable.from_str(sym_str)
a_fst.symbols = k2.SymbolTable.from_str(sym_str)
a_fst.labels_sym = k2.SymbolTable.from_str(sym_str)
a_fst.aux_labels_sym = k2.SymbolTable.from_str(sym_str)
a_fst = k2.arc_sort(a_fst)
a_fst.draw('a_fst.svg')


## Composition & Determinization

In [None]:
s = '''
0 1 1 2 0.1
1 2 2 2 0.3
1 3 2 2 0.4
2 3 1 2 0.5
3 3 1 1 0.6
3 4 -1 -1 0
4
'''


sym_str = '''
 <eps> 0
 a 1
 b 2
 c 3
 d 4
 e 5
 f 6
'''


In [None]:
s = '''
0 1 2 2 0.1
1 1 2 1 0.2
1 2 1 2 0.3
2 3 2 1 0.5
3 4 -1 -1 0.6
4
'''



sym_str = '''
 <eps> 0
 a 1
 b 2
 c 3
 d 4
 e 5
 f 6
'''

In [None]:
# compose


In [None]:
# determinization
a_fsa.draw('a_fsa.svg')

In [None]:
a_deter = k2.determinize(a_fsa)
a_deter = k2.arc_sort(a_deter)
a_deter.labels_sym = k2.SymbolTable.from_str(sym_str)
a_deter.draw('deter.svg')

# WFST

In [6]:
words_values = k2.SymbolTable.from_file('lang/lm/words.txt')
tokens_values = k2.SymbolTable.from_file('lang/lm/tokens.txt')

## Lexicon transducer

In [None]:
# C = k2.ctc_topo(max_token=129, modified=False)

In [9]:
from utils import read_lexicon

lexicon = read_lexicon("lang/lm/lexicon.txt")

In [10]:
from utils import add_disambig_symbols

lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)

In [None]:
max_disambig

1

In [11]:
tokens_values.add('#0', 128)
tokens_values.add('#1', 129)

print(tokens_values.get('#0'), tokens_values.get('#1'))

128 129


In [None]:
from utils import lexicon_to_fst

L = lexicon_to_fst(
    lexicon_disambig,
    token2id=tokens_values._sym2id,
    word2id=words_values._sym2id,
    need_self_loops=True
)

## Grammar transducer

In [None]:
!pip install kaldilm

In [None]:
!python3 -m kaldilm --read-symbol-table="lang/lm/words.txt" --disambig-symbol='#0' --max-order=3 lang/lm/libri_3_gram_1e-7.arpa > lang/G.fst.txt

In [None]:
import torch

with open("lang/G.fst.txt") as f:
  G = k2.Fsa.from_openfst(f.read(), acceptor=False)
  torch.save(G.as_dict(), "lang/G.pt")

## Composition & Determinization

### Composition

In [None]:
L = k2.arc_sort(L)
G = k2.arc_sort(G)

LG = k2.compose(L, G)
#L_inv = L.invert()
#L_inv = k2.arc_sort(L_inv)
#L_inv.rename_tensor_attribute_('aux_labels', 'left_labels')
#LG = k2.intersect(L_inv, G, treat_epsilons_specially=True)
#LG.rename_tensor_attribute_('left_labels', 'labels')
LG = k2.connect(LG)

print(LG.shape)

### Determinization

In [None]:
LG = k2.determinize(LG)
LG = k2.connect(LG)
print(LG.shape)

### Epsilon removal

In [None]:
LG.labels[LG.labels >= tokens_values["#0"]] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None

assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= words_values["#0"]] = 0

LG = k2.remove_epsilon(LG)
#logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)

In [None]:
print(LG.shape)

(3472138, None)


## CLG composition

In [None]:
torch.save(LG.as_dict(), "lang/LG.pt")

In [None]:
C = k2.arc_sort(C)
LG = k2.arc_sort(LG)

In [None]:
CLG = k2.compose(C, LG)

CLG = k2.connect(CLG)

print(CLG.shape)

(6944399, None)


In [None]:
torch.save(CLG.as_dict(), 'lang/CLG.pt')

In [None]:
import k2, torch
CLG = k2.Fsa.from_dict(torch.load('lang/CLG.pt', map_location="cpu"))

## Utterance transducer

In [None]:
import torch
nnet_outputs = torch.load('logits.pt')

In [None]:
print(len(nnet_outputs))
print(nnet_outputs[0].shape)

2620
torch.Size([303, 129])


In [None]:
def rearrange_blksym(nnet_outputs):
    nnet_t = nnet_outputs.T
    tmp = nnet_t[1:-1]
    tmp2 = nnet_t[-1:]
    logits = torch.cat([tmp2, tmp])
    logits = logits.T
    logits = torch.tensor([logits.numpy()])

    return logits


In [None]:
logits = rearrange_blksym(nnet_outputs[0])

In [None]:
logits.shape

torch.Size([1, 303, 128])

In [None]:
supervision_segments = torch.tensor([[0, 0, logits.shape[1]]], dtype=torch.int32)

In [None]:
dense_fsa_vec = k2.DenseFsaVec(
    logits,
    supervision_segments)

In [None]:
lattice = k2.intersect_dense_pruned(CLG, dense_fsa_vec, 30.0, 15, 30, 1000000)

In [None]:
best_path = k2.shortest_path(lattice, use_double_scores=True)

In [None]:
from utils import get_texts

token_ids = get_texts(best_path)
hyp = [[words_values[i] for i in ids] for ids in token_ids]
print(" ".join(hyp[0]))

In [None]:
token_ids