In [None]:
from src.requirements import *
from src.ssl_model import *
from src.asr_model import *
from src.tokenizer import *
from src.audio_handler import *

In [None]:
class InferenceModel(nn.Module):
    def __init__(self, asr_model, tokenizer, device='cpu'):
        super().__init__()
        self.asr = asr_model
        self.vocab = tokenizer.get_vocab()
        self.device = device
        
        self.asr.to(device)
        self.asr.eval()
        
        self.decoder = ctc_decoder(
            lexicon = None,
            tokens = self.vocab,
            blank_token = '<blank>',
            sil_token = 'ред',
            unk_word = None,
            nbest = 1,
            beam_size = 50
        )
    
    def forward(self, waveform, sr):
        if not isinstance(waveform, torch.Tensor):
            waveform = torch.tensor(waveform, dtype=torch.float32)
        
        if waveform.ndim == 2:
            if waveform.shape[0] == 2:
                waveform = waveform.mean(dim=0, keepdim=True)
            elif waveform.shape[1] == 2:
                waveform = waveform.T
                waveform = waveform.mean(dim=0, keepdim=True)
        elif waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)
        
        wave_np = waveform.squeeze(0).numpy()
        trimmed, _ = librosa.effects.trim(wave_np, top_db=TOP_DB)
        waveform = torch.tensor(trimmed, dtype=torch.float32).unsqueeze(0)
        
        max_val = waveform.abs().max()
        if max_val > 0:
            waveform = waveform / max_val
        
        if sr != 16_000:
            waveform = torchaudio.functional.resample(waveform, sr, 16_000)
        
        waveform = waveform.unsqueeze(0)
        waveform = waveform.to(self.device)
        
        DOWNSAMPLING_FACTOR = 320
        raw_length = waveform.shape[-1]
        input_length = torch.div(
            torch.tensor([raw_length]), 
            DOWNSAMPLING_FACTOR, 
            rounding_mode='floor'
        ).to(self.device)
        input_length[input_length == 0] = 1
        
        with torch.no_grad():
            log_probs = self.asr(waveform, input_length)
            log_probs_cpu = log_probs.transpose(0, 1).contiguous().cpu()
            results = self.decoder(log_probs_cpu)
            
            tokens = results[0][0].tokens
            text = ''.join([self.vocab[idx] for idx in tokens])
        
        return text
    
    def transcribe_file(self, audio_path):
        waveform, sr = torchaudio.load(audio_path)
        return self.forward(waveform, sr)
    
    def transcribe_batch(self, waveforms, sample_rates):
        texts = []
        for waveform, sr in zip(waveforms, sample_rates):
            text = self.forward(waveform, sr)
            texts.append(text)
        return texts

In [None]:
def calculate_metrics(references, hypotheses):
    wer = jiwer.wer(references, hypotheses)
    cer = jiwer.cer(references, hypotheses)
    
    print(f"WER: {wer*100:.2f}%")
    print(f"CER: {cer*100:.2f}%")
    
    print("\n--- Sample Predictions ---")
    for i in range(min(5, len(references))):
        print(f"\nRef: {references[i]}")
        print(f"Hyp: {hypotheses[i]}")
    
    return wer, cer

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
token_path = os.path.join("data", "tokenizer.json")
tokenizer = Tokenizer.load(token_path)
ssl_model = SSLModel()
asr_model = ASRModel(ssl_model, tokenizer.vocab_size)
update_ver = 5_000

checkpoint_dict = torch.load(os.path.join('models', 'asr_model', f'asr_model_prototype_{update_ver}.pth'))
asr_state_dict = checkpoint_dict['model_state_dict']
asr_model.load_state_dict(asr_state_dict, strict=True)

inf_model = InferenceModel(asr_model, tokenizer, device)

In [None]:
# Transcribe from file
# text = inf_model.transcribe_file('test.flac')
# print(f"Transcription: {text}")

# # Transcribe from waveform
waveform, sr = sf.read('test.flac', always_2d=True)
waveform = torch.tensor(waveform.T, dtype=torch.float32)
text = inf_model(waveform, sr)
print(f"Transcription: {text}")

# Batch transcription
# audio_files = ['audio1.wav', 'audio2.wav', 'audio3.wav']
# waveforms = []
# sample_rates = []

# for file in audio_files:
#     wav, sr = torchaudio.load(file)
#     waveforms.append(wav)
#     sample_rates.append(sr)

# texts = inference_model.transcribe_batch(waveforms, sample_rates)
# for file, text in zip(audio_files, texts):
#     print(f"{file}: {text}")

In [None]:
# Calculate metrics
wer, cer = calculate_metrics(references, hypotheses)