# IMPORT

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)


# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device

device(type='cuda')

# STORAGE

In [3]:
!df -h /kaggle/working

Filesystem      Size  Used Avail Use% Mounted on
/dev/loop1       20G  156K   20G   1% /kaggle/working


# DATA

## dataset memory size

In [None]:
from huggingface_hub import HfApi
from huggingface_hub import login

login("")

# print out total memory size
def print_dataset_file_sizes(repo_id):
    api = HfApi()
    dataset_info = api.dataset_info(repo_id=repo_id, files_metadata=True)

    total_size_bytes = 0  
    print(f"File sizes for dataset '{repo_id}/en':")  
    for sibling in dataset_info.siblings:  
        #filename = sibling.rfilename  
        #print(filename.split('/'))
        #filename_array = filename.split('/')
        #if len(filename_array) > 1 and filename_array[1] == 'en':
        size_in_bytes = sibling.size or 0  
        total_size_bytes += size_in_bytes  
        #size_mb = size_in_bytes / (1024 * 1024)  
        #print(f"  {filename}: {size_mb:.2f} MiB")
        """else:
            size_in_bytes = sibling.size or 0  
            total_size_bytes = size_in_bytes"""

    total_size_gb = total_size_bytes / (1024 ** 3)  
    print(f"\nTotal size: {total_size_gb:.2f} GiB")

In [5]:
print_dataset_file_sizes('mozilla-foundation/common_voice_17_0')

File sizes for dataset 'mozilla-foundation/common_voice_17_0/en':

Total size: 967.10 GiB


## dataset loading

In [6]:
from datasets import load_dataset

ds_train = load_dataset(
    "mozilla-foundation/common_voice_17_0", 'en', split="train", 
    streaming=True, trust_remote_code=True)

README.md:   0%|          | 0.00/12.7k [00:00<?, ?B/s]

common_voice_17_0.py:   0%|          | 0.00/8.19k [00:00<?, ?B/s]

languages.py:   0%|          | 0.00/3.92k [00:00<?, ?B/s]

release_stats.py:   0%|          | 0.00/132k [00:00<?, ?B/s]

## dataset preprocess

In [7]:
%%capture
import torchaudio.transforms as T
import librosa
from transformers import Wav2Vec2Processor
from transformers.utils import move_cache

move_cache()

TARGET_SR = 16000

mfcc_transform = T.MFCC(
    sample_rate=TARGET_SR, 
    n_mfcc=128,
    melkwargs={ 
        'n_fft':1024,
        'hop_length':160,
        'n_mels':128,
    },
)
facebook_proc = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-large-960h')

def preprocess_ds(example):
    audio_array = example['audio']['array']
    sample_rate = example['audio']['sampling_rate']
    if sample_rate != TARGET_SR:
        audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=TARGET_SR)
    audio_tensor = torch.tensor(audio_array).unsqueeze(0).float()
    audio_shape = audio_tensor.shape
    if len(audio_shape) != 2 or audio_shape[0] != 1:
        raise ValueError(
            f'Unexpected audio shape: {audio_shape} -> Expected audio shape: (1, num_samples)'
        )
    mfcc_features = mfcc_transform(audio_tensor).squeeze(0)
    mel_dim_shape = mfcc_features.shape[0]
    if mel_dim_shape != 128:
        raise ValueError(
            f'Unexpected mel shape: {mel_dim_shape} -> Expected mel shape: (128, time_steps)'
        )
    sent_tok = facebook_proc.tokenizer(example['sentence']).input_ids
    return {
        'mfcc_features': mfcc_features,
        'sent_tok': torch.tensor(sent_tok),
    }

In [8]:
ds_train_processed = ds_train.map(preprocess_ds)

## dataset vocabulary

In [9]:
vocab = facebook_proc.tokenizer.get_vocab()

print("Number of classes (vocabulary size): ", len(vocab))
print("Samples words (first 50 words): ", list(vocab.keys())[:50])

Number of classes (vocabulary size):  32
Samples words (first 50 words):  ['<pad>', '<s>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']


## dataloader settings

In [10]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from datasets.utils.logging import disable_progress_bar

disable_progress_bar()

def collate_fn(batch):
    
    mfcc_batch = []
    sent_tok_batch = []
    
    for item in batch:
        mfcc_batch.append((item['mfcc_features']).permute(1, 0).clone().detach())
        sent_tok_batch.append((item['sent_tok']).clone().detach())
        
    mfcc_length = torch.tensor([v.shape[0] for v in mfcc_batch], dtype=torch.long)
    tok_length = torch.tensor([len(v) for v in sent_tok_batch], dtype=torch.long)
    
    mfcc_padded = pad_sequence(mfcc_batch, batch_first=True, padding_value=0).permute(0, 2, 1)
    sent_tok_padded = pad_sequence(sent_tok_batch, batch_first=True, padding_value=0)
    
    return mfcc_padded, sent_tok_padded, mfcc_length, tok_length

In [11]:
data_loader = DataLoader(ds_train_processed, batch_size=8, collate_fn=collate_fn)

## dataloader check

In [12]:
%%capture
"""
# iterate over one batch
for batch_idx, (mfcc_padded, sent_tok_padded, mfcc_length, tok_length) in enumerate(data_loader):
    print(f"BATCH NUMBER {batch_idx + 1}.")
    
    # inspect the shape of the mel spectrogram and tokenized sentences
    print(f"-> mfcc_padded, {mfcc_padded.shape}")
    print(f"-> sent_tok_padded, {sent_tok_padded.shape}")

    # inspect each element shape
    print(f'-> mfcc_length, {mfcc_length}')
    print(f'-> tok_length, {tok_length}')

    # inspect the first sample (if batch_size > 1)
    print(f"-> mfcc first sample, {mfcc_padded[0]}")
    print(f"-> tokens first sample, {sent_tok_padded[0]}")

    break
"""

 # NETWORK

## model

In [13]:
import torch.nn as nn
import math

class ResidualBlock(nn.Module):
    """dilated residual block"""
    def __init__(self, in_channels, out_channels, kernel_size=7, dilation=1):
        super().__init__()
        self.conv_filter = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size=kernel_size, 
            padding=((kernel_size-1) * dilation // 2),
            dilation=dilation,
        )
        self.conv_gate = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=((kernel_size-1) * dilation // 2),
            dilation=dilation
        )
        self.conv_out = nn.Conv1d(
            out_channels, 
            out_channels, 
            kernel_size=1
        )
        self.bn_filter = nn.BatchNorm1d(out_channels)
        self.bn_gate = nn.BatchNorm1d(out_channels)
        self.bn_out = nn.BatchNorm1d(out_channels)
    def forward(self, x):
        filter_out = torch.tanh(self.bn_filter(self.conv_filter(x)))
        gate_out = torch.sigmoid(self.bn_gate(self.conv_gate(x)))
        out = filter_out * gate_out # element wise multiplication
        out = torch.tanh(self.bn_out(self.conv_out(out)))
        return x + out, out # residual and skip connection

class FeatureEncoder(nn.Module):
    """cnn to extract mfcc features and incorporated residual blocks"""
    def __init__(self, in_channels=128, out_channels=512, num_blocks=3):
        super().__init__()
        self.front_conv = nn.Sequential(
            nn.Conv1d(in_channels, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.Tanh(),
        )
        self.num_blocks = num_blocks
        self.res_blocks = nn.ModuleList()
        for i in range(num_blocks):
            for ratio in [1, 2, 4, 8, 16]:
                self.res_blocks.append(ResidualBlock(128, 128, kernel_size=7, dilation=ratio))
        self.final_conv = nn.Sequential(
            nn.Conv1d(128, out_channels, kernel_size=1),
            nn.BatchNorm1d(out_channels),
            nn.Tanh()
        )
    def forward(self, x):
        x = self.front_conv(x)
        skip_connection = 0
        for res_block in self.res_blocks:
            x, skip = res_block(x)
            skip_connection += skip
        return self.final_conv(skip_connection)

class PositionalEncoder(nn.Module):
    """sin and cos positional embedding"""
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, d_model, 2)) * (-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:x.size(0), :, :]
        return self.dropout(x)

class TransformerEncoder(nn.Module):
    """transformer encoder"""
    def __init__(self, input_dim=512, num_head=8, ff_dim=2048, n_layers=12, dropout=0.1):
        super().__init__()
        self.pos_encoder = PositionalEncoder(d_model=input_dim, dropout=dropout)
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=input_dim,
            nhead=num_head,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True,
            activation=nn.GELU(),
        )
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=n_layers)
    def forward(self, x, src_key_padding_mask=None):
        x = x.permute(0, 2, 1)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        return x

class SpeechTextNeural(nn.Module):
    """neural"""
    def __init__(
        self, 
        vocab_size, 
        input_dim=512, 
        ff_dim=2048, 
        transformer_layers=12, 
        transformer_heads=8):
        super().__init__()
        self.encoder = FeatureEncoder()
        self.transformer = TransformerEncoder(
            input_dim, 
            transformer_heads, 
            ff_dim, 
            transformer_layers
        )
        self.fc1 = nn.Linear(input_dim, vocab_size)
    def forward(self, x):
        #print(f'before encoder: {x.shape}')
        x = self.encoder(x)
        #print(f'after encoder: {x.shape}')
        #print(f'before transformer: {x.shape}')
        x = self.transformer(x)
        #print(f'after transformer: {x.shape}')
        #print(f'before full connector: {x.shape}')
        x = self.fc1(x)
        #print(f'after full connector: {x.shape}')
        return x.permute(1, 0, 2)

In [14]:
vocab_length = len(facebook_proc.tokenizer.get_vocab())
model = SpeechTextNeural(vocab_length)

In [15]:
%%capture
"""# test one batch and one prediction
mfcc, tokens, mfcc_lengths, token_lengths = next(iter(data_loader))
print(f"TOTAL MFCC: {mfcc.shape}, TOTAL TOKENS: {tokens.shape}")
print(f"LIST OF MFCC: {mfcc_lengths}, LIST OF TOKENS: {token_lengths}")

logits = model(mfcc)
print(f"PREDICTION: {logits.shape}")  # should be [16, time, vocab_size]"""

In [16]:
%%capture
"""# test training
from torchinfo import summary

batch_size=16
mfcc_features = 128
samples = 16_000

summary(model, input_size=(batch_size, mfcc_features, samples))"""

# TRAIN

## save model

In [17]:
from datetime import datetime

def save_model(path_v, epoch, b_idx, model, optimizer, training_losses):
    
    os.makedirs(path_v, exist_ok=True)
    existing_files = [
        f for f in os.listdir(path_v) if f.endswith('.pth') and f.startswith('model')]
    versions = []
    
    for file in existing_files:
        filename = file.replace('.pth', '').split('_')
        if len(filename) > 1 and filename[-1].isdigit():
            versions.append(int(filename[-1]))
    new_version = max(versions)+1 if versions else 1

    checkpoint = {
        'epoch': epoch,
        'batch_index': b_idx,
        'time': datetime.now(),
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'training_losses': training_losses,
    }
    
    new_path = os.path.join(path_v, f'model_100percent_{new_version}.pth')
    torch.save(checkpoint, new_path)
    print(f'MODEL SAVED AT {new_path}')

## training setting

In [18]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

## start training

In [19]:
from itertools import islice
import time
from datetime import datetime

def train(model, dataloader, epochs=10):
    
    model.train().to(device)
    training_losses = []
    
    for epoch in range(epochs):

        es_time = time.perf_counter()
        
        total_loss = 0.
        total_accuracies = 0.
        total_pred_label = 0.
        b_idx = 0
        
        # calculate the number of batches to process (5% of the dataset)
        num_batches = int(
            0.05 * len(dataloader.dataset)) if hasattr(dataloader.dataset, '__len__') else 6000
        
        for idx, (mfcc, tok, input_length, label_length) in enumerate(islice(dataloader, num_batches)):
            
            b_idx = idx
            bs_time = time.perf_counter()
            
            optimizer.zero_grad()
            outputs = model(mfcc.to(device))
            log_probs = outputs.log_softmax(2)
            loss = criterion(
                log_probs, 
                tok.to(device), 
                input_length.to(device), 
                label_length.to(device)
            )
            loss.backward()
            optimizer.step()
            batch_loss = loss.item()
            total_loss += batch_loss
            training_losses.append(batch_loss)
            
            be_time = time.perf_counter() - bs_time
            
            if idx % 1000 == 0:
                print(f'date time {datetime.now()} --------------------------')
                print(f'epoch: {epoch}, batch -> {idx}, loss: {batch_loss:.4f}, batch_time: {be_time:.4f}s')
            elif idx % 2999 == 0:
                save_model('model', epoch, b_idx, model, optimizer, training_losses)
                
        ee_time = time.perf_counter() - es_time
        print(f'EPOCH {epoch+1}, LOSS: {total_loss:.4f}, EPOCH_TIME: {ee_time}.')
        
        save_model('model', epoch, b_idx, model, optimizer, training_losses)

In [20]:
train(model, data_loader, 5)

Reading metadata...: 1101170it [00:26, 41168.39it/s]


date time 2025-03-31 00:37:42.753729 --------------------------
epoch: 0, batch -> 0, loss: 49.0111, batch_time: 1.7898s
date time 2025-03-31 00:51:08.279279 --------------------------
epoch: 0, batch -> 1000, loss: 0.7722, batch_time: 0.6048s
date time 2025-03-31 01:04:46.167335 --------------------------
epoch: 0, batch -> 2000, loss: 0.8102, batch_time: 0.7257s
MODEL SAVED AT model/model_100percent_1.pth
date time 2025-03-31 01:18:20.530048 --------------------------
epoch: 0, batch -> 3000, loss: 0.7547, batch_time: 0.6620s
date time 2025-03-31 01:31:44.099249 --------------------------
epoch: 0, batch -> 4000, loss: 0.9027, batch_time: 0.8022s


Reading metadata...: 1101170it [00:25, 43870.39it/s]


date time 2025-03-31 01:45:30.431434 --------------------------
epoch: 0, batch -> 5000, loss: 0.8703, batch_time: 0.6552s


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 7570d8a6-cea0-4d85-8256-0daf036a2481)')' thrown while requesting GET https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0/resolve/main/audio/en/train/en_train_1.tar
Retrying in 1s [Retry 1/5].


MODEL SAVED AT model/model_100percent_2.pth
EPOCH 1, LOSS: 5376.7031, EPOCH_TIME: 4921.242153649.
MODEL SAVED AT model/model_100percent_3.pth


Reading metadata...: 1101170it [00:25, 42603.10it/s]


date time 2025-03-31 01:59:29.933775 --------------------------
epoch: 1, batch -> 0, loss: 1.1387, batch_time: 0.5985s
date time 2025-03-31 02:13:00.559405 --------------------------
epoch: 1, batch -> 1000, loss: 0.7765, batch_time: 0.6078s
date time 2025-03-31 02:26:31.334407 --------------------------
epoch: 1, batch -> 2000, loss: 0.7918, batch_time: 0.7281s
MODEL SAVED AT model/model_100percent_4.pth
date time 2025-03-31 02:40:03.374124 --------------------------
epoch: 1, batch -> 3000, loss: 0.7872, batch_time: 0.6640s
date time 2025-03-31 02:53:27.121006 --------------------------
epoch: 1, batch -> 4000, loss: 0.8914, batch_time: 0.8041s


Reading metadata...: 1101170it [00:26, 40784.14it/s]


date time 2025-03-31 03:07:14.184570 --------------------------
epoch: 1, batch -> 5000, loss: 0.8443, batch_time: 0.6637s
MODEL SAVED AT model/model_100percent_5.pth
EPOCH 2, LOSS: 5197.1465, EPOCH_TIME: 4888.579419648.
MODEL SAVED AT model/model_100percent_6.pth


Reading metadata...: 1101170it [00:24, 44803.88it/s]


date time 2025-03-31 03:20:57.943917 --------------------------
epoch: 2, batch -> 0, loss: 1.1711, batch_time: 0.5944s
date time 2025-03-31 03:34:21.751482 --------------------------
epoch: 2, batch -> 1000, loss: 0.7758, batch_time: 0.6077s
date time 2025-03-31 03:47:46.016159 --------------------------
epoch: 2, batch -> 2000, loss: 0.7791, batch_time: 0.7286s
MODEL SAVED AT model/model_100percent_7.pth
date time 2025-03-31 04:01:08.586308 --------------------------
epoch: 2, batch -> 3000, loss: 0.7702, batch_time: 0.6638s
date time 2025-03-31 04:14:21.267208 --------------------------
epoch: 2, batch -> 4000, loss: 0.8848, batch_time: 0.8042s


Reading metadata...: 1101170it [00:25, 42870.35it/s]


date time 2025-03-31 04:28:06.110165 --------------------------
epoch: 2, batch -> 5000, loss: 0.8438, batch_time: 0.6589s
MODEL SAVED AT model/model_100percent_8.pth
EPOCH 3, LOSS: 5191.3586, EPOCH_TIME: 4845.234505643.
MODEL SAVED AT model/model_100percent_9.pth


Reading metadata...: 1101170it [00:25, 43029.21it/s]


date time 2025-03-31 04:41:45.172276 --------------------------
epoch: 3, batch -> 0, loss: 1.1699, batch_time: 0.5921s
date time 2025-03-31 04:55:08.776427 --------------------------
epoch: 3, batch -> 1000, loss: 0.7770, batch_time: 0.6078s
date time 2025-03-31 05:08:40.208184 --------------------------
epoch: 3, batch -> 2000, loss: 0.7653, batch_time: 0.7278s
MODEL SAVED AT model/model_100percent_10.pth
date time 2025-03-31 05:22:10.965845 --------------------------
epoch: 3, batch -> 3000, loss: 0.7482, batch_time: 0.6638s
date time 2025-03-31 05:35:29.573890 --------------------------
epoch: 3, batch -> 4000, loss: 0.8831, batch_time: 0.8039s


Reading metadata...: 1101170it [00:25, 42583.72it/s]


date time 2025-03-31 05:49:15.550962 --------------------------
epoch: 3, batch -> 5000, loss: 0.8374, batch_time: 0.6577s
MODEL SAVED AT model/model_100percent_11.pth
EPOCH 4, LOSS: 5188.2343, EPOCH_TIME: 4868.969827072002.
MODEL SAVED AT model/model_100percent_12.pth


Reading metadata...: 1101170it [00:29, 37415.30it/s]


date time 2025-03-31 06:02:58.710008 --------------------------
epoch: 4, batch -> 0, loss: 1.1704, batch_time: 0.5905s
date time 2025-03-31 06:16:22.459485 --------------------------
epoch: 4, batch -> 1000, loss: 0.7766, batch_time: 0.6079s
date time 2025-03-31 06:29:55.186472 --------------------------
epoch: 4, batch -> 2000, loss: 0.7613, batch_time: 0.7280s
MODEL SAVED AT model/model_100percent_13.pth
date time 2025-03-31 06:43:22.338727 --------------------------
epoch: 4, batch -> 3000, loss: 0.7403, batch_time: 0.6649s
date time 2025-03-31 06:56:33.087211 --------------------------
epoch: 4, batch -> 4000, loss: 0.8841, batch_time: 0.8045s


Reading metadata...: 1101170it [00:26, 42188.09it/s]


date time 2025-03-31 07:10:12.661855 --------------------------
epoch: 4, batch -> 5000, loss: 0.8375, batch_time: 0.6582s
MODEL SAVED AT model/model_100percent_14.pth
EPOCH 5, LOSS: 5186.7219, EPOCH_TIME: 4847.760792031.
MODEL SAVED AT model/model_100percent_15.pth
