In [None]:
from src.requirements import *
from src.audio_handler import ASRDataset, Tokenizer, collate_padding_asr, load_text
from src.ssl_model import *
from src.asr_model import *

In [None]:
text_path = os.path.join("data", "corpus.txt")
if not os.path.exists(text_path):
    path = os.path.join("data", "text")
    filename = "corpus.txt"
    text = load_text(path)
    with open(os.path.join("data", filename), "w", encoding="utf-8") as f:
        f.write(text)

In [None]:
df = pd.read_csv(os.path.join("data", "metadata.tsv"), sep="\t")
transcripts = df["transcript"].tolist()
all_chars = set("".join(transcripts))
unique_vocabs = list(all_chars)
vocab_size = len(unique_vocabs)

In [None]:
data_path = os.path.join("data", "metadata.tsv")
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 2
epochs = 10
learning_rate = 1e-4

ssl_model = SSLAutoregressiveModel()
ssl_model.to(device)
print("SSl Model created")
ssl_model.load_state_dict(torch.load(os.path.join("models", "model_prototype_II.pth")))
print("Encoder parameters loaded")

tokenizer = Tokenizer(text_path)
print("Tokenizer created")
asr_model = ASRModel(ssl_model, vocab_size)
asr_model.to(device)
print("ASR model created")
asr_optimizer = torch.optim.AdamW(asr_model.parameters(), lr=learning_rate)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

asr_dataset = ASRDataset(metadata_path=data_path, tokenizer=tokenizer)
print("ASR dataset loaded")
asr_dl = DataLoader(
    dataset = asr_dataset,
    batch_size = batch_size,
    pin_memory = True,
    collate_fn = collate_padding_asr,
    shuffle=True
)
print("ASR dataloader created")

In [None]:
for batch in asr_dl:
    waveform, target, _, _ = batch
    print(target.shape)
    break

In [None]:
def train_asr(asr_model, asr_dl, optimizer, loss_fn, epochs, device):
    asr_model.train()
    
    for epoch in range(epochs):
        total_loss = 0.0
        print(f"Epoch [{epoch+1}/{epochs}]")

        for batch in tqdm(asr_dl):
            waveforms, targets, input_lengths, target_lengths = batch
            waveforms = waveforms.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()

            # Forward pass → (B, T, vocab_size)
            log_probs = asr_model(waveforms)

            # For CTC: expected shape (T, B, C)
            log_probs = log_probs.transpose(0, 1)

            loss = loss_fn(log_probs, targets, input_lengths, target_lengths)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(asr_dl)
        print(f"Avg Loss: {avg_loss:.4f}")

In [None]:
train_asr(asr_model, asr_dl, asr_optimizer, ctc_loss, epochs, device)