<a href="https://colab.research.google.com/github/drmarcoaron/Swahili-asr/blob/main/Untitled2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ULTRA-OPTIMIZED ASR TRAINING - ALL BUGS FIXED
print("🚀 Ultra-Optimized ASR Training (Fixed)...")

# CELL 1: MINIMAL SETUP
from google.colab import drive
drive.mount('/content/drive')

!apt update -qq && apt install -y ffmpeg > /dev/null 2>&1
!pip install -q datasets[audio]==3.5.1 transformers accelerate evaluate peft==0.5.0 librosa soundfile torch torchaudio bitsandbytes jiwer

🚀 Ultra-Optimized ASR Training (Fixed)...
Mounted at /content/drive
40 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mSkipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m42.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.6/85.6 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# CELL 2: OPTIMIZED IMPORTS & CONFIG
import os, gc, re, random, warnings, math
warnings.filterwarnings('ignore')
import torch, numpy as np, librosa
from tqdm.auto import tqdm
from datasets import load_dataset, Audio, IterableDataset
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
from peft import LoraConfig, get_peft_model
import evaluate
from torch.utils.data import DataLoader

# AGGRESSIVE MEMORY OPTIMIZATION
os.environ.update({
    "CUDA_VISIBLE_DEVICES": "0",
    "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:32,expandable_segments:True",
    "TOKENIZERS_PARALLELISM": "false",
    "OMP_NUM_THREADS": "1",
    "PYTORCH_JIT": "0"
})

from huggingface_hub import notebook_login
notebook_login()

# CONFIG
MODEL = "openai/whisper-small"
LANG, LANG_CODE = "Swahili", "sw"
DATASET = "mozilla-foundation/common_voice_17_0"

torch.cuda.empty_cache()
print(f"GPU: {torch.cuda.get_device_name()} | Memory: {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f}GB")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

GPU: Tesla T4 | Memory: 14.7GB


In [None]:
# CELL 3: ULTRA-FAST DATA LOADING
print("Loading datasets...")
train_dataset = load_dataset(DATASET, LANG_CODE, split="train", streaming=True)
val_dataset = load_dataset(DATASET, LANG_CODE, split="validation", streaming=True)

# Remove ALL unnecessary columns for speed
KEEP_COLS = ["sentence_id", "sentence", "audio"]
train_dataset = train_dataset.remove_columns([c for c in train_dataset.column_names if c not in KEEP_COLS])
val_dataset = val_dataset.remove_columns([c for c in val_dataset.column_names if c not in KEEP_COLS])

Loading datasets...


README.md:   0%|          | 0.00/12.7k [00:00<?, ?B/s]

common_voice_17_0.py:   0%|          | 0.00/8.19k [00:00<?, ?B/s]

languages.py:   0%|          | 0.00/3.92k [00:00<?, ?B/s]

release_stats.py:   0%|          | 0.00/132k [00:00<?, ?B/s]

The repository for mozilla-foundation/common_voice_17_0 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/mozilla-foundation/common_voice_17_0.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


In [None]:
# CELL 4: LIGHTNING-FAST PREPROCESSING
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL)
tokenizer = WhisperTokenizer.from_pretrained(MODEL, language=LANG, task="transcribe")
processor = WhisperProcessor.from_pretrained(MODEL, language=LANG, task="transcribe")

def ultra_fast_filter(ex):
    """Ultra-fast filtering - only essential checks"""
    try:
        audio_len = len(ex["audio"]["array"]) / ex["audio"]["sampling_rate"]
        text_len = len(ex["sentence"])
        return 3.0 <= audio_len <= 8.0 and 10 <= text_len <= 80
    except:
        return False

def speed_normalize_text(text):
    """Lightning-fast text normalization"""
    return re.sub(r'[^\w\s]', '', text.lower().strip())

def ultra_fast_augment(audio, prob=0.2):
    """Minimal, fast augmentation"""
    if random.random() > prob:
        return audio
    # Only speed perturbation - fastest augmentation
    if random.random() < 0.7:
        try:
            rate = random.uniform(0.97, 1.03)
            return librosa.effects.time_stretch(audio, rate=rate)
        except:
            pass
    return audio

def lightning_prepare(batch):
    """Ultra-optimized preprocessing"""
    audio_array = batch["audio"]["array"]
    sr = batch["audio"]["sampling_rate"]

    # Fast resample
    if sr != 16000:
        audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)

    # Minimal augmentation
    audio_array = ultra_fast_augment(audio_array)

    # Fast feature extraction
    batch["input_features"] = feature_extractor(
        audio_array, sampling_rate=16000, return_attention_mask=False
    ).input_features[0]

    # Fast tokenization with limits
    text = speed_normalize_text(batch["sentence"])
    batch["labels"] = tokenizer(text, max_length=128, truncation=True).input_ids

    return batch

class UltraFastCollator:
    """Optimized data collator - FIXED"""
    def __init__(self, processor):  # FIXED: __init__ instead of **init**
        self.processor = processor

    def __call__(self, features):  # FIXED: __call__ instead of **call**
        # Batch input features
        input_batch = self.processor.feature_extractor.pad(
            [{"input_features": f["input_features"]} for f in features],
            return_tensors="pt"
        )

        # Batch labels
        label_batch = self.processor.tokenizer.pad(
            [{"input_ids": f["labels"]} for f in features],
            return_tensors="pt"
        )

        labels = label_batch["input_ids"].masked_fill(
            label_batch["attention_mask"].ne(1), -100
        )

        # Remove BOS if present
        if labels.size(1) > 0 and (labels[:, 0] == self.processor.tokenizer.bos_token_id).all():
            labels = labels[:, 1:]

        return {
            "input_features": input_batch["input_features"].half(),  # Force FP16
            "labels": labels
        }

data_collator = UltraFastCollator(processor)

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
# CELL 6: SPEED-OPTIMIZED STREAMING - FIXED
class LightningDataset(IterableDataset):
    """Fixed IterableDataset implementation"""
    def __init__(self, hf_ds, prep_fn, filter_fn, max_samples):  # FIXED: __init__
        # Don't call super().__init__() for IterableDataset
        self.hf_ds = hf_ds
        self.prep_fn = prep_fn
        self.filter_fn = filter_fn
        self.max_samples = max_samples

    def __iter__(self):  # FIXED: __iter__
        count = 0
        for ex in self.hf_ds:
            if count >= self.max_samples:
                break
            try:
                if self.filter_fn(ex):
                    yield self.prep_fn(ex)
                    count += 1
            except:
                continue

def lightning_eval(model, eval_stream, collator, tokenizer, processor, max_samples=50):
    """Ultra-fast evaluation"""
    device = next(model.parameters()).device
    model.eval()

    wer_metric = evaluate.load("wer")
    eval_ds = LightningDataset(eval_stream, lightning_prepare, ultra_fast_filter, max_samples)
    eval_dl = DataLoader(eval_ds, batch_size=1, collate_fn=collator, num_workers=0)

    forced_ids = processor.get_decoder_prompt_ids(language=LANG, task="transcribe")

    with torch.no_grad():
        for batch in eval_dl:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            with torch.cuda.amp.autocast():
                tokens = model.generate(
                    input_features=batch["input_features"],
                    forced_decoder_ids=forced_ids,
                    max_new_tokens=64,
                    num_beams=1,
                    do_sample=False
                )

            preds = tokenizer.batch_decode(tokens.cpu(), skip_special_tokens=True)

            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            refs = tokenizer.batch_decode(labels, skip_special_tokens=True)

            preds = [speed_normalize_text(p) for p in preds]
            refs = [speed_normalize_text(r) for r in refs]

            wer_metric.add_batch(predictions=preds, references=refs)

            # Aggressive cleanup
            del batch, tokens, preds, labels, refs
            torch.cuda.empty_cache()

    return 100 * wer_metric.compute()

In [None]:
# CELL 7: LIGHTNING TRAINER - FIXED
class LightningTrainer:
    """Fixed Lightning Trainer"""
    def __init__(self, model, train_stream, eval_stream, collator, tokenizer, processor):  # FIXED: __init__
        self.model = model
        self.train_stream = train_stream
        self.eval_stream = eval_stream
        self.collator = collator
        self.tokenizer = tokenizer
        self.processor = processor

        # Optimized hyperparameters
        self.lr = 1e-4
        self.epochs = 5
        self.batch_size = 1
        self.grad_accum = 32
        self.warmup_steps = 500
        self.max_train_samples = 15000
        self.eval_steps = 800
        self.save_steps = 1600
        self.max_eval_samples = 50

        # Optimizer with 8-bit precision for memory
        try:
            from bitsandbytes.optim import AdamW8bit
            self.optimizer = AdamW8bit(
                model.parameters(),
                lr=self.lr,
                betas=(0.9, 0.999),
                eps=1e-8,
                weight_decay=0.01
            )
            print("Using AdamW8bit optimizer")
        except ImportError:
            print("bitsandbytes not available, using standard AdamW")
            self.optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=self.lr,
                betas=(0.9, 0.999),
                eps=1e-8,
                weight_decay=0.01
            )

        # Scheduler
        steps_per_epoch = math.ceil(self.max_train_samples / (self.batch_size * self.grad_accum))
        total_steps = steps_per_epoch * self.epochs

        from transformers import get_cosine_schedule_with_warmup
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer, self.warmup_steps, total_steps
        )

        self.scaler = torch.cuda.amp.GradScaler()
        print(f"Training config: {self.epochs} epochs, {steps_per_epoch} steps/epoch, {total_steps} total steps")

    def train(self):
        device = next(self.model.parameters()).device
        self.model.train()

        global_step = 0
        best_wer = float('inf')

        OUTPUT_DIR = "./lightning_checkpoints"
        os.makedirs(OUTPUT_DIR, exist_ok=True)

        for epoch in range(self.epochs):
            print(f"\n=== EPOCH {epoch+1}/{self.epochs} ===")

            # Create fresh dataset
            train_ds = LightningDataset(
                self.train_stream, lightning_prepare, ultra_fast_filter, self.max_train_samples
            )
            train_dl = DataLoader(train_ds, batch_size=self.batch_size, collate_fn=self.collator, num_workers=0)

            epoch_loss = 0.0
            num_updates = 0

            for batch_idx, batch in enumerate(tqdm(train_dl, desc=f"Training")):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

                with torch.cuda.amp.autocast():
                    outputs = self.model(**batch)
                    loss = outputs.loss / self.grad_accum

                self.scaler.scale(loss).backward()
                epoch_loss += loss.item() * self.grad_accum

                if (batch_idx + 1) % self.grad_accum == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
                    self.scheduler.step()

                    global_step += 1
                    num_updates += 1

                    if global_step % 50 == 0:
                        avg_loss = epoch_loss / max(1, num_updates)
                        lr = self.scheduler.get_last_lr()[0]
                        mem_gb = torch.cuda.memory_allocated() / 1024**3
                        print(f"Step {global_step} | Loss: {avg_loss:.3f} | LR: {lr:.1e} | VRAM: {mem_gb:.1f}GB")

                    # Fast evaluation
                    if global_step % self.eval_steps == 0:
                        try:
                            val_wer = lightning_eval(
                                self.model, self.eval_stream, self.collator,
                                self.tokenizer, self.processor, self.max_eval_samples
                            )
                            print(f"Step {global_step} | WER: {val_wer:.2f}%")

                            if val_wer < best_wer:
                                best_wer = val_wer
                                self.model.save_pretrained(f"{OUTPUT_DIR}/best_model")
                                print(f"New best WER: {val_wer:.2f}%")

                                if val_wer < 15.0:  # Early success
                                    print(f"TARGET ACHIEVED! WER < 15%")
                                    return best_wer

                        except Exception as e:
                            print(f"Eval failed: {e}")

                        self.model.train()
                        torch.cuda.empty_cache()

                    # Save checkpoint
                    if global_step % self.save_steps == 0:
                        self.model.save_pretrained(f"{OUTPUT_DIR}/checkpoint_{global_step}")
                        print(f"Saved checkpoint at step {global_step}")

                # Memory cleanup
                del batch, outputs, loss
                if batch_idx % 100 == 0:
                    torch.cuda.empty_cache()

            avg_loss = epoch_loss / max(1, num_updates)
            print(f"Epoch {epoch+1} done | Avg Loss: {avg_loss:.3f} | Best WER: {best_wer:.2f}%")

        print(f"Training finished! Best WER: {best_wer:.2f}%")
        return best_wer

In [None]:
# CELL 8: RUN LIGHTNING TRAINING
trainer = LightningTrainer(model, train_dataset, val_dataset, data_collator, tokenizer, processor)

# Clear memory and launch
gc.collect()
torch.cuda.empty_cache()

print("Starting lightning training...")
try:
    best_wer = trainer.train()
    print(f"Success! Best WER: {best_wer:.2f}%")
except Exception as e:
    print(f"Training failed: {e}")
    # Emergency fallback
    print("Trying emergency settings...")
    trainer.grad_accum = 64
    trainer.max_train_samples = 8000
    trainer.max_eval_samples = 25
    gc.collect()
    torch.cuda.empty_cache()
    best_wer = trainer.train()

# Final model save
try:
    model.save_pretrained("./final_lightning_model")
    print("Final model saved!")
except Exception as e:
    print(f"Save failed: {e}")

print("Lightning training completed!")

In [None]:
# FINAL CELL: GENERATE SUBMISSION CSV
print("Loading test dataset and generating predictions...")

import pandas as pd
from datasets import load_dataset, Audio
import torch
import numpy as np
import librosa

# Load test dataset
print("Loading test dataset...")
test_ds = load_dataset("sartifyllc/Sartify_ITU_Zindi_Testdataset", split="test")
test_ds = test_ds.cast_column("audio", Audio(sampling_rate=16000, decode=True))

# Load your best trained model for inference
print("Loading trained model for inference...")
try:
    # Load the best model from training
    model.load_adapter("./lightning_checkpoints/best_model")
    print("Loaded best model from training")
except:
    print("Using current model state")

# Prepare model for inference
model.eval()
device = next(model.parameters()).device

# Get forced decoder IDs for Swahili
forced_decoder_ids = processor.get_decoder_prompt_ids(language=LANG, task="transcribe")

def transcribe_audio(audio_array, sampling_rate=16000):
    """Transcribe a single audio sample"""
    try:
        # Ensure correct sampling rate
        if len(audio_array.shape) > 1:
            audio_array = audio_array.mean(axis=0)  # Convert to mono if stereo

        # Normalize audio
        if np.max(np.abs(audio_array)) > 0:
            audio_array = audio_array / np.max(np.abs(audio_array))

        # Extract features
        inputs = feature_extractor(
            audio_array,
            sampling_rate=sampling_rate,
            return_tensors="pt"
        )

        # Move to device
        input_features = inputs.input_features.to(device).half()

        # Generate transcription
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                generated_tokens = model.generate(
                    input_features=input_features,
                    forced_decoder_ids=forced_decoder_ids,
                    max_new_tokens=200,
                    num_beams=3,  # Beam search for better quality
                    do_sample=False,
                    temperature=1.0,
                    early_stopping=True,
                    repetition_penalty=1.1
                )

        # Decode the transcription
        transcription = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

        # Apply text normalization
        transcription = speed_normalize_text(transcription)

        return transcription

    except Exception as e:
        print(f"Error transcribing audio: {e}")
        return ""

# Process all test samples
print("Generating transcriptions...")
results = []

for i, sample in enumerate(tqdm(test_ds, desc="Transcribing")):
    filename = sample["filename"]
    audio_array = sample["audio"]["array"]
    sampling_rate = sample["audio"]["sampling_rate"]

    # Transcribe the audio
    transcription = transcribe_audio(audio_array, sampling_rate)

    results.append({
        "filename": filename,
        "text": transcription
    })

    # Progress update
    if (i + 1) % 10 == 0:
        print(f"Processed {i + 1} samples...")
        torch.cuda.empty_cache()  # Clean GPU memory periodically

# Create submission DataFrame
submission_df = pd.DataFrame(results)

# Save to CSV
submission_df.to_csv("submission.csv", index=False)

print(f"Submission file created successfully!")
print(f"Total samples: {len(results)}")
print(f"Sample predictions:")
print(submission_df.head())

# Verify the format
print(f"\nSubmission file format:")
print(f"Columns: {list(submission_df.columns)}")
print(f"Shape: {submission_df.shape}")

# Check for empty transcriptions
empty_transcriptions = submission_df[submission_df['text'] == ''].shape[0]
if empty_transcriptions > 0:
    print(f"Warning: {empty_transcriptions} empty transcriptions found")
else:
    print("All samples have transcriptions!")

print("Ready for submission: submission.csv")

Loading test dataset and generating predictions...
Loading test dataset...


README.md: 0.00B [00:00, ?B/s]

data/test-00000-of-00002.parquet:   0%|          | 0.00/387M [00:00<?, ?B/s]

data/test-00001-of-00002.parquet:   0%|          | 0.00/391M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4089 [00:00<?, ? examples/s]

Loading trained model for inference...
Using current model state
Generating transcriptions...


Transcribing:   0%|          | 0/4089 [00:00<?, ?it/s]

Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Processed 10 samples...
Processed 20 samples...
Processed 30 samples...
Processed 40 samples...
Processed 50 samples...
Processed 60 samples...
Processed 70 samples...
Processed 80 samples...
Processed 90 samples...
Processed 100 samples...
Processed 110 samples...
Processed 120 samples...
Processed 130 samples...
Processed 140 samples...
Processed 150 samples...
Processed 160 samples...
Processed 170 samples...
Processed 180 samples...
Processed 190 samples...
Processed 200 samples...
Processed 210 samples...
Processed 220 samples...
Processed 230 samples...
Processed 240 samples...
Processed 250 samples...
Processed 260 samples...
Processed 270 samples...
Processed 280 samples...
Processed 290 samples...
Processed 300 samples...
Processed 310 samples...
Processed 320 samples...
Processed 330 samples...
Processed 340 samples...
Processed 350 samples...
Processed 360 samples...
Processed 370 samples...
Processed 380 samples...
Processed 390 samples...
Processed 400 samples...
Processed