In [None]:
from src.requirements import *
from src.ssl_model import *
from src.asr_model import *
from src.audio_handler import *
from src.tokenizer import *

In [None]:
data_path = os.path.join("data", "metadata_normal.tsv")
cache_path = os.path.join("data", "cache_mmap", "asr")
text_path = os.path.join("data", "text")
token_path = os.path.join("data", "tokenizer.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 4
learning_rate = 1e-4

update_ver = 50_000

In [None]:
if not os.path.exists(token_path):
    text = load_text(text_path)
    tokenizer = Tokenizer()
    tokenizer.build_vocab(text)
    tokenizer.save(token_path)
else:
    tokenizer = Tokenizer.load(token_path)

vocab_size = len(tokenizer)
vocab = tokenizer.get_vocab()
print("Vocab size:", vocab_size)

In [None]:
ssl_model = SSLModel().to(device)
asr_model = ASRModel(
    ssl_model=ssl_model,
    vocab_size=len(tokenizer),
    hidden_dim=256,
    num_layers=4,
    dropout=0.2
).to(device)

In [None]:
checkpoint_dict = torch.load(os.path.join('models', 'ssl_model', f'ssl_model_prototype_{update_ver}.pth'))
ssl_state_dict = checkpoint_dict['model_state_dict']
ssl_model.load_state_dict(ssl_state_dict, strict=True)

In [None]:
asr_dataset = ASRDataset(
    metadata_path=data_path,
    tokenizer=tokenizer,
    cache_dir=cache_path,
    top_db=TOP_DB
)

asr_dl = DataLoader(
    dataset=asr_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_padding_asr
)

In [None]:
dataset_size = len(asr_dataset)
accum = 8
epochs = 5
steps_per_epoch = dataset_size // (batch_size * accum)
T_max = epochs * steps_per_epoch
warmup = int(0.05 * T_max)
print("Dataset size: ", dataset_size)
print("Batch size: ", batch_size)
print("Steps per epoch: ", steps_per_epoch)
print("Tmax: ", T_max)
print("Warmup steps: ", warmup)

In [None]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, asr_model.parameters()),
    lr=3e-4,
    weight_decay=0.01
)
# Check optimizer has params
print(f"\nOptimizer managing {sum(p.numel() for group in optimizer.param_groups for p in group['params']):,} parameters")

loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    total_steps=20_000,
    pct_start=0.1,
    anneal_strategy='cos'
)
scaler = torch.GradScaler(device)

In [None]:
beam_decoder = torchaudio.models.decoder.ctc_decoder(
    lexicon=None,
    tokens=vocab,
    blank_token='<blank>',
    sil_token='ред',
    unk_word=None,
    nbest=1,
    beam_size=50,
)

In [None]:
def save_checkpoint(model, optimizer, scheduler, num_updates, path):
    checkpoint = {
        'num_updates' : num_updates,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'scheduler_state_dict' : scheduler.state_dict()
    }
    
    torch.save(checkpoint, path)

In [None]:
def test_forward_pass_features_lstm(asr_model, val_dl, device):
    """Test feature rank through Bi-LSTM model."""
    
    asr_model.eval()
    batch = next(iter(val_dl))
    waveforms, targets, input_lengths, target_lengths = batch
    waveforms = waveforms.to(device)
    input_lengths = input_lengths.to(device)
    
    print(f"\n{'='*60}")
    print(f"BI-LSTM FEATURE ANALYSIS")
    print(f"{'='*60}")
    
    with torch.no_grad():
        # SSL features
        z = asr_model.encoder(waveforms).transpose(1, 2)
        c = asr_model.context(z)
        print(f"\n1. SSL Context output:")
        print(f"   Shape: {c.shape}")
        rank = check_rank_single(c)
        print(f"   Effective rank (95%): {rank}/128")
        
        # After multi-scale
        z_multi = asr_model.multiscale(c)
        z_div = z_multi + asr_model.diversify(z_multi)
        print(f"\n2. After diversify:")
        print(f"   Shape: {z_div.shape}")
        rank = check_rank_single(z_div)
        print(f"   Effective rank (95%): {rank}/128")
        
        # Pack sequences
        z_packed = nn.utils.rnn.pack_padded_sequence(
            z_div, input_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # Bi-LSTM
        lstm_out, _ = asr_model.lstm(z_packed)
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        print(f"\n3. After Bi-LSTM:")
        print(f"   Shape: {lstm_out.shape}")
        rank = check_rank_single(lstm_out)
        print(f"   Effective rank (95%): {rank}/512")  # 256*2
        
        # After projection
        z_proj = asr_model.projection(lstm_out)
        print(f"\n4. After projection:")
        print(f"   Shape: {z_proj.shape}")
        rank = check_rank_single(z_proj)
        print(f"   Effective rank (95%): {rank}/128")
        
        # Final output
        logits = asr_model.symbol(z_proj)
        print(f"\n5. Final logits:")
        print(f"   Shape: {logits.shape}")

def check_rank_single(features):
    """Helper to check rank of single batch."""
    features_flat = features.reshape(-1, features.size(-1)).cpu()
    U, S, V = torch.svd(features_flat)
    explained = (S ** 2) / (S ** 2).sum()
    rank_95 = (explained.cumsum(0) < 0.95).sum().item() + 1
    return rank_95

In [None]:
# Freeze SSL
asr_model.freeze_ssl()
asr_model.to(device)

# Check parameters
params = asr_model.get_num_params()
print(f"Total params: {params['total']:,}")
print(f"Trainable params: {params['trainable']:,}")

# Test forward pass
test_batch = next(iter(asr_dl))
waveforms, targets, input_lengths, target_lengths = test_batch
waveforms = waveforms.to(device)
input_lengths = input_lengths.to(device)

with torch.no_grad():
    log_probs = asr_model(waveforms, input_lengths)
    print(f"Output shape: {log_probs.shape}")  # (seq_len, batch, vocab)

# Check feature rank
test_forward_pass_features_lstm(asr_model, asr_dl, device)

In [None]:
def train_asr(asr_model, asr_dl, optimizer, scaler, scheduler, loss_fn, epochs, device):
    max_updates = 20_000
    num_updates = 0
    
    for epoch in range(epochs):
        print(f"Epoch [{epoch+1}/{epochs}]")
        total_loss = 0.0

        asr_model.train()
        
        for i, batch in enumerate(tqdm(asr_dl)):
            waveforms, targets, input_lengths, target_lengths = batch
            waveforms = waveforms.to(device)
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)
            
            with torch.autocast(device_type=device, dtype=torch.float16):
                log_probs = asr_model(waveforms, input_lengths) / accum
                flat_targets = torch.cat(targets).to(device)
                loss = loss_fn(log_probs, flat_targets, input_lengths, target_lengths)
            
            scaler.scale(loss).backward()
            total_loss += loss.item() * accum
            
            if (i+1) % accum == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(asr_model.parameters(), max_norm=2.0)
                scaler.step(optimizer)
                scaler.update()
                
                scheduler.step()
                optimizer.zero_grad()
                num_updates += 1
            
            if num_updates % 1_000 == 0:
                save_path = os.path.join('models', 'asr_model', f'asr_model_prototype_{num_updates}.pth')
                save_checkpoint(asr_model, optimizer, scheduler, num_updates, save_path)
            
            if num_updates >= max_updates:
                break
        
        torch.cuda.empty_cache()
        avg_loss = total_loss / len(asr_dl)
        print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}")
        
        if num_updates >= max_updates:
            print(f"Reached max updates ({max_updates})")
            break
    
    return asr_model

In [None]:
# asr_model = train_asr(asr_model, asr_dl, asr_optimizer, scaler, scheduler, loss_fn, epochs, device)