In [1]:
from src.requirements import *
from src.models import *
from src.tokenizer import Tokenizer

In [2]:
ssl_ver = 50_000
asr_ver = 80_000

device = "cuda" if torch.cuda.is_available() else "cpu"

ssl_path = os.path.join("models", "ssl_model", f"ssl_model_prototype_{ssl_ver}.pth")
asr_path = os.path.join("models", "asr_model", f"asr_model_prototype_{asr_ver}.pth")
tokenizer_path = os.path.join("data", "tokenizer.json")

In [3]:
infer_model = InferenceModel(ssl_path, asr_path, tokenizer_path, Tokenizer, device)

In [4]:
path = "test.flac"
waveform, sr = sf.read(path)
waveform = torch.tensor(waveform, dtype=torch.float32)

# 1. Correctly handle sf.read output and target [C, T]
if waveform.ndim == 2:
    # If [T_raw, Channels], transpose to [Channels, T_raw]
    waveform = waveform.T
    # If multichannel, reduce to mono by averaging channels (dim=0)
    waveform = waveform.mean(dim=0, keepdim=True)
elif waveform.ndim == 1:
    # If [T_raw], add channel dimension to get [1, T_raw]
    waveform = waveform.unsqueeze(0)
    
# 2. Normalization
max_val = waveform.abs().max()
if max_val > 0:
    waveform = waveform / max_val # Shape: [1, T_raw]

# 3. Add Batch Dimension: [1, 1, T_raw] (Batch=1, Channel=1, Time=T)
waveform = waveform.unsqueeze(0) 

In [5]:
text = infer_model(waveform, sr)

In [6]:
text

'à¤…  '

In [None]:
tokenizer = Tokenizer.load(tokenizer_path)
vocab_size = len(tokenizer.vocab)

encoder = SSLModel()
encoder_checkpoint = torch.load(ssl_path)
encoder_state = encoder_checkpoint['model_state_dict']
encoder.load_state_dict(encoder_state, strict=False)
encoder.eval()

decoder = ASRModel(encoder, vocab_size-1)
decoder_checkpoint = torch.load(asr_path)
decoder_state = decoder_checkpoint['model_state_dict']
decoder.load_state_dict(decoder_state)
decoder.eval()

In [None]:
logits = decoder(waveform)
probs = logits.softmax(dim=-1)

max_probs, predicted_indices = probs.max(dim=-1)

print(f"Average Max Probability: {max_probs.mean().item():.3f}")
print(f"Top 5 most frequent indices: {torch.bincount(predicted_indices.flatten()).topk(5)}")