In [None]:
from src.requirements import *
from src.audio_handler import ASRDataset, collate_padding_asr, load_text, flatten_targets
from src.tokenizer import Tokenizer
from src.ssl_model import *
from src.asr_model import *

In [None]:
text_path = os.path.join("data", "corpus.txt")
if not os.path.exists(text_path):
    path = os.path.join("data", "text")
    filename = "corpus.txt"
    text = load_text(path)
    with open(os.path.join("data", filename), "w", encoding="utf-8") as f:
        f.write(text)

In [None]:
data_path = os.path.join("data", "metadata.tsv")
token_path = os.path.join("data", "tokenizer.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
epochs = 999
learning_rate = 1e-4

# refer to /models/ directory for update version
update_ver = 160_000

In [None]:
ssl_model = SSLModel().to(device)
ssl_model.load_state_dict(torch.load(os.path.join("models", "ssl_model", f"ssl_model_prototype_{update_ver}.pth")))

In [None]:
if not os.path.exists(token_path):
    tokenizer = Tokenizer(text_path)
    tokenizer.save(token_path)
else:
    tokenizer = Tokenizer.load(token_path)
    
num_classes = len(tokenizer.vocab)

In [None]:
asr_model = ASRModel(ssl_model, vocab_size=num_classes-1).to(device)
asr_optimizer = torch.optim.Adam(asr_model.parameters(), lr=learning_rate)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

In [None]:
asr_dataset = ASRDataset(metadata_path=data_path, tokenizer=tokenizer)
asr_dl = DataLoader(
    dataset = asr_dataset,
    batch_size = batch_size,
    pin_memory = True,
    collate_fn = collate_padding_asr,
    shuffle=True
)

In [None]:
tokenizer.vocab[20]

In [None]:
def train_asr(asr_model, asr_dl, optimizer, loss_fn, epochs, device):
    DOWNSAMPLING_FACTOR = 5 * 4 * 4 * 4
    max_updates = 150_000
    num_updates = 0
    asr_model.train()
    
    for epoch in range(epochs):
        total_loss = 0.0
        print(f"Epoch [{epoch+1}/{epochs}]")
        
        for batch in tqdm(asr_dl):
            waveforms, targets, input_lengths, target_lengths = batch
            waveforms = waveforms.to(device)

            targets = targets.to(device)
            input_lengths = input_lengths // DOWNSAMPLING_FACTOR
            input_lengths = input_lengths.to(device)
            target_lengths = target_lengths.to(device)

            optimizer.zero_grad()

            log_probs = asr_model(waveforms)
            log_probs = log_probs.transpose(0, 1)

            flat_targets = flatten_targets(targets, target_lengths).to(device)

            if (flat_targets < 0).any() or (flat_targets >= log_probs.size(2)).any():
                raise ValueError(f"Target IDs out of range=> min={flat_targets.min()}, max={flat_targets.max()}, num_classes={log_probs.size(2)}")
            if (input_lengths < target_lengths).any():
                print("Skipping batch: input_lengths < target_lengths")
                continue
                
            loss = loss_fn(log_probs, flat_targets, input_lengths, target_lengths)

            loss.backward()
            # gradient clipping to prevent gradient explosion (i.e. getting loss = nan)
            torch.nn.utils.clip_grad_norm_(asr_model.parameters(), max_norm=2.0)
            optimizer.step()
            num_updates += 1
            total_loss += loss.item()

            if num_updates % 10_000 == 0:
                torch.save(asr_model.state_dict(), os.path.join("models", "asr_model", f"asr_model_prototype_{num_updates}.pth"))

            if num_updates >= max_updates:
                break
                
        torch.cuda.empty_cache()
        avg_loss = total_loss / len(asr_dl)
        print(f"Avg Loss: {avg_loss:.4f}")

        if num_updates >= max_updates:
            break

In [None]:
train_asr(asr_model, asr_dl, asr_optimizer, ctc_loss, epochs, device)