In [None]:
from src.requirements import *
from src.audio_handler import ASRDataset, collate_padding_asr, load_text
from src.tokenizer import Tokenizer
from src.models import FeatureEncoder, ContextModule, ContrastivePredictor, SSLModel, ASRModel, compute_mask_indices, flatten_targets

In [None]:
text_path = os.path.join("data", "corpus.txt")
if not os.path.exists(text_path):
    path = os.path.join("data", "text")
    filename = "corpus.txt"
    text = load_text(path)
    with open(os.path.join("data", filename), "w", encoding="utf-8") as f:
        f.write(text)

In [None]:
data_path = os.path.join("data", "metadata.tsv")
token_path = os.path.join("data", "tokenizer.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 4
epochs = 10
learning_rate = 1e-4
weight_decay = 1e-6

# refer to /models/ directory for update version
update_ver = 50_000

In [None]:
checkpoint_dict = torch.load(os.path.join('models', 'ssl_model', f'ssl_model_prototype_{update_ver}.pth'))
ssl_state_dict = checkpoint_dict['model_state_dict']
ssl_model = SSLModel().to(device)
ssl_model.load_state_dict(ssl_state_dict, strict=True)

In [None]:
if not os.path.exists(token_path):
    tokenizer = Tokenizer(text_path)
    tokenizer.save(token_path)
else:
    tokenizer = Tokenizer.load(token_path)
    
num_classes = len(tokenizer.vocab)

In [None]:
asr_model = ASRModel(ssl_model, vocab_size=num_classes-1).to(device)
asr_optimizer = torch.optim.Adam(asr_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
scheduler = CosineAnnealingLR(asr_optimizer, epochs, learning_rate)

In [None]:
asr_dataset = ASRDataset(metadata_path=data_path, tokenizer=tokenizer)
asr_dl = DataLoader(
    dataset = asr_dataset,
    batch_size = batch_size,
    pin_memory = True,
    collate_fn = collate_padding_asr,
    shuffle=True
)

In [None]:
def save_checkpoint(model, optimizer, scheduler, num_updates, path):
    checkpoint = {
        'num_updates' : num_updates,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'scheduler_state_dict' : scheduler.state_dict()
    }
    
    torch.save(checkpoint, path)

In [None]:
def train_asr(asr_model, asr_dl, optimizer, loss_fn, epochs, device):
    # DOWNSAMPLING_FACTOR = 5 * 4 * 4 * 4
    max_updates = 150_000
    num_updates = 0
    asr_model.train()
    
    for epoch in range(epochs):
        total_loss = 0.0
        print(f"Epoch [{epoch+1}/{epochs}]")
        
        for batch in tqdm(asr_dl):
            waveforms, targets, input_lengths, target_lengths = batch
            waveforms = waveforms.to(device)
            targets = targets.to(device)
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)

            optimizer.zero_grad()

            logits = asr_model(waveforms)
            logits = logits.transpose(0, 1)

            flat_targets = flatten_targets(targets, target_lengths).to(device)
                
            loss = loss_fn(logits, flat_targets, input_lengths, target_lengths)
            loss.backward()
            total_loss += loss.item()
            
            torch.nn.utils.clip_grad_norm_(asr_model.parameters(), max_norm=5.0)
            optimizer.step()
            num_updates += 1

            if num_updates % 10_000 == 0:
                save_path = os.path.join('models', 'asr_model', f'asr_model_prototype_{num_updates}.pth')
                save_checkpoint(asr_model, optimizer, scheduler, num_updates, save_path)

            if num_updates >= max_updates:
                break

        scheduler.step()
        torch.cuda.empty_cache()
        avg_loss = total_loss / len(asr_dl)
        print(f"Avg Loss: {avg_loss:.4f}")

        if num_updates >= max_updates:
            break

In [None]:
train_asr(asr_model, asr_dl, asr_optimizer, ctc_loss, epochs, device)