In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import librosa
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperTokenizer
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import traceback
import sys

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class WhisperConfig:
    def __init__(self,
                 n_mels=80,
                 n_ctx=1500,
                 n_heads=8,
                 n_audio_layers=4,
                 n_text_layers=4,
                 n_embed=512,
                 n_audio_ctx=1500,
                vocab_size=51865,
                sample_rate=16000,
                 dropout=0.1):
        self.n_mels = n_mels
        self.n_ctx = n_ctx
        self.n_heads = n_heads
        self.n_audio_layers = n_audio_layers
        self.n_text_layers = n_text_layers
        self.n_embed = n_embed
        self.n_audio_ctx = n_audio_ctx
        self.vocab_size = vocab_size
        self.sample_rate = sample_rate
        self.dropout = dropout

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.n_embed = config.n_embed
        assert self.n_embed % self.n_heads == 0

        self.head_dim = self.n_embed // self.n_heads
        self.wq = nn.Linear(config.n_embed, config.n_embed)
        self.wk = nn.Linear(config.n_embed, config.n_embed)
        self.wv = nn.Linear(config.n_embed, config.n_embed)
        self.wo = nn.Linear(config.n_embed, config.n_embed)

    def forward(self, x, mask=None, kv=None):
        b, t, c = x.size()
        q = self.wq(x).view(b, t, self.n_heads, self.head_dim).transpose(1, 2)

        if kv is None:
            k = self.wk(x).view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
            v = self.wv(x).view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        else:
            k = self.wk(kv).view(b, kv.size(1), self.n_heads, self.head_dim).transpose(1, 2)
            v = self.wv(kv).view(b, kv.size(1), self.n_heads, self.head_dim).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / np.sqrt(k.size(-1)))

        if mask is not None:
            att = att.masked_fill(mask == 0, float('-inf'))

        att = F.softmax(att, dim=-1)
        out = att @ v
        out = out.transpose(1, 2).contiguous().view(b, t, c)
        return self.wo(out)

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.w1 = nn.Linear(config.n_embed, 4 * config.n_embed)
        self.w2 = nn.Linear(4 * config.n_embed, config.n_embed)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.gelu(self.w1(x))))

class EncoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embed)
        self.attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embed)
        self.ffwd = FeedForward(config)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embed)
        self.mask_attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embed)
        self.cross_attn = MultiHeadAttention(config)
        self.ln3 = nn.LayerNorm(config.n_embed)
        self.ffwd = FeedForward(config)

    def forward(self, x, encoder_out):
        b, t, c = x.size()
        causal_mask = torch.tril(torch.ones(t, t)).view(1, 1, t, t).to(x.device)
        x = x + self.mask_attn(self.ln1(x), causal_mask)
        x = x + self.cross_attn(self.ln2(x), kv=encoder_out)
        x = x + self.ffwd(self.ln3(x))
        return x

class WhisperEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv1 = nn.Conv1d(config.n_mels, config.n_embed, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(config.n_embed, config.n_embed, kernel_size=3, padding=1)
        self.pos_embed = nn.Parameter(torch.zeros(1, config.n_audio_ctx, config.n_embed))
        self.blocks = nn.ModuleList([EncoderBlock(config) for _ in range(config.n_audio_layers)])
        self.ln_final = nn.LayerNorm(config.n_embed)

    def forward(self, x):
        if x.size(1) != self.conv1.in_channels:
            x = x.transpose(1, 2)
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.transpose(1, 2)
        seq_len = x.size(1)
        pos_emb = self.pos_embed[:, :seq_len, :]
        x = x + pos_emb
        for block in self.blocks:
            x = block(x)
        return self.ln_final(x)

class WhisperDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embed = nn.Embedding(config.vocab_size, config.n_embed)
        self.pos_embed = nn.Parameter(torch.zeros(1, config.n_ctx, config.n_embed))
        self.blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_text_layers)])
        self.ln_final = nn.LayerNorm(config.n_embed)
        self.head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
        self.token_embed.weight = self.head.weight

    def forward(self, x, encoder_out):
        b, t = x.size()
        token_emb = self.token_embed(x)
        pos_emb = self.pos_embed[:, :t, :]
        x = token_emb + pos_emb
        for block in self.blocks:
            x = block(x, encoder_out)
        x = self.ln_final(x)
        logits = self.head(x)
        return logits

In [None]:
class WhisperModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = WhisperEncoder(config)
        self.decoder = WhisperDecoder(config)

    def forward(self, audio_features, decoder_input_ids):
        encoder_out = self.encoder(audio_features)
        logits = self.decoder(decoder_input_ids, encoder_out)
        return logits

class AudioProcessor:
    def __init__(self, config):
        self.config = config
        self.sample_rate = config.sample_rate
        self.n_mels = config.n_mels

    def extract_fbank(self, audio_path):
        try:
            waveform, sr = torchaudio.load(audio_path)
            if waveform.size(0) > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            if sr != self.sample_rate:
                waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
            mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=400,
                hop_length=160,
                win_length=400,
                n_mels=self.n_mels
            )(waveform)
            log_mel = torch.log(mel_spec + 1e-9)
            mean = log_mel.mean()
            std = log_mel.std()
            log_mel = (log_mel - mean) / (std + 1e-9)
            return log_mel.squeeze(0)
        except Exception as e:
            print(f"Error processing audio file {audio_path}: {e}")
            return torch.zeros(self.n_mels, 400)

In [None]:
class CommonVoiceDataset(Dataset):
    def __init__(self, tsv_file, audio_dir, tokenizer, processor, max_audio_len=1500):
        self.df = pd.read_csv(tsv_file, sep='\t', on_bad_lines='skip')
        self.df = self.df[self.df['valid'] == 1].reset_index(drop=True)
        self.audio_dir = audio_dir
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_audio_len = max_audio_len
        print(f"Dataset loaded with {len(self.df)} valid entries")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            audio_path = os.path.join(self.audio_dir, row['path'])
            text = row['sentence']
            mel_features = self.processor.extract_fbank(audio_path)
            encoded_text = self.tokenizer(text, return_tensors="pt", padding="max_length",
                                        max_length=100, truncation=True)
            input_ids = encoded_text.input_ids.squeeze(0)
            decoder_input_ids = input_ids[:-1].clone()
            labels = input_ids[1:].clone()
            if mel_features.size(1) > self.max_audio_len:
                mel_features = mel_features[:, :self.max_audio_len]
            elif mel_features.size(1) < self.max_audio_len:
                pad_len = self.max_audio_len - mel_features.size(1)
                mel_features = F.pad(mel_features, (0, pad_len))
            mel_features = mel_features.unsqueeze(0)
            return {
                "audio_features": mel_features,
                "decoder_input_ids": decoder_input_ids,
                "labels": labels,
                "text": text,
                "audio_path": audio_path
            }
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            return {
                "audio_features": torch.zeros(1, self.processor.n_mels, self.max_audio_len),
                "decoder_input_ids": torch.zeros(99, dtype=torch.long),
                "labels": torch.zeros(99, dtype=torch.long),
                "text": "",
                "audio_path": ""
            }

def collate_batch(batch):
    try:
        audio_features = torch.cat([item["audio_features"] for item in batch], dim=0)
        decoder_input_ids = pad_sequence([item["decoder_input_ids"] for item in batch],
                                        batch_first=True, padding_value=0)
        labels = pad_sequence([item["labels"] for item in batch],
                            batch_first=True, padding_value=-100)
        texts = [item["text"] for item in batch]
        audio_paths = [item["audio_path"] for item in batch]
        return {
            "audio_features": audio_features,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels,
            "texts": texts,
            "audio_paths": audio_paths
        }
    except Exception as e:
        print(f"Error in collate_batch: {e}")
        traceback.print_exc()
        return {
            "audio_features": torch.zeros(len(batch), 80, 1500),
            "decoder_input_ids": torch.zeros(len(batch), 99, dtype=torch.long),
            "labels": torch.zeros(len(batch), 99, dtype=torch.long),
            "texts": [""] * len(batch),
            "audio_paths": [""] * len(batch)
        }

In [None]:
def train_epoch(model, dataloader, optimizer, device, scaler=None, gradient_accumulation_steps=4):
    model.train()
    total_loss = 0
    num_batches = len(dataloader)

    with tqdm(total=num_batches, desc="Training") as pbar:
        for i, batch in enumerate(dataloader):
            try:
                audio_features = batch["audio_features"].to(device)
                decoder_input_ids = batch["decoder_input_ids"].to(device)
                labels = batch["labels"].to(device)

                
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        logits = model(audio_features, decoder_input_ids)
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) / gradient_accumulation_steps

                    
                    scaler.scale(loss).backward()

                    if (i + 1) % gradient_accumulation_steps == 0 or i == num_batches - 1:
                        
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                        
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                else:
                    
                    logits = model(audio_features, decoder_input_ids)
                    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) / gradient_accumulation_steps
                    loss.backward()

                    if (i + 1) % gradient_accumulation_steps == 0 or i == num_batches - 1:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()
                        optimizer.zero_grad()

                total_loss += loss.item() * gradient_accumulation_steps

                
                del audio_features, decoder_input_ids, labels, logits, loss
                pbar.update(1)

            except torch.cuda.OutOfMemoryError:
                print(f"CUDA OOM error for batch {i}, skipping...")
                torch.cuda.empty_cache()
                optimizer.zero_grad()
                continue

    return total_loss / num_batches

In [None]:
!pip install jiwer

In [None]:
def evaluate(model, dataloader, device, tokenizer):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_references = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            try:
                audio_features = batch["audio_features"].to(device)
                decoder_input_ids = batch["decoder_input_ids"].to(device)
                labels = batch["labels"].to(device)
                texts = batch["texts"]

                logits = model(audio_features, decoder_input_ids)
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
                total_loss += loss.item()

                
                predictions = torch.argmax(logits, dim=-1)
                for pred, ref in zip(predictions, texts):
                    pred_text = tokenizer.decode(pred, skip_special_tokens=True)
                    all_predictions.append(pred_text)
                    all_references.append(ref)

                del audio_features, decoder_input_ids, labels, logits
            except Exception as e:
                print(f"Error in evaluation: {e}")
                continue

    try:
        from jiwer import wer
        error_rate = wer(all_references, all_predictions)
    except ImportError:
        print("jiwer not installed, skipping WER calculation")
        error_rate = 0

    avg_loss = total_loss / len(dataloader)
    return avg_loss, error_rate, all_predictions, all_references

In [None]:
def train(model, train_dataloader, val_dataloader, optimizer, device, tokenizer,
          num_epochs=30, checkpoint_dir="checkpoints", use_amp=True):
    os.makedirs(checkpoint_dir, exist_ok=True)
    scaler = torch.cuda.amp.GradScaler() if use_amp and torch.cuda.is_available() else None

    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    val_wers = []

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        
        train_loss = train_epoch(model, train_dataloader, optimizer, device, scaler)
        train_losses.append(train_loss)
        print(f"Train Loss: {train_loss:.4f}")

        
        val_loss, val_wer, predictions, references = evaluate(model, val_dataloader, device, tokenizer)
        val_losses.append(val_loss)
        val_wers.append(val_wer)
        print(f"Validation Loss: {val_loss:.4f}, WER: {val_wer:.4f}")

        
        for i in range(min(3, len(predictions))):
            print(f"\nExample {i+1}:")
            print(f"Reference: {references[i]}")
            print(f"Prediction: {predictions[i]}")

        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_wer': val_wer
        }

        torch.save(checkpoint, os.path.join(checkpoint_dir, f"whisper_model_epoch_{epoch+1}.pt"))

        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, os.path.join(checkpoint_dir, "whisper_model_best.pt"))
            print(f"New best model saved with validation loss: {val_loss:.4f}")

    
    try:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(val_wers, label='WER')
        plt.xlabel('Epoch')
        plt.ylabel('Word Error Rate')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(checkpoint_dir, 'training_history.png'))
        print(f"Training history saved to {os.path.join(checkpoint_dir, 'training_history.png')}")
    except ImportError:
        print("matplotlib not installed, skipping plot generation")

    return train_losses, val_losses, val_wers

In [None]:
!pip install torchcodec

In [None]:
import random
def augment_audio_features(mel_features, augment_prob=0.5):
    
    if random.random() < augment_prob:
        
        noise = torch.randn_like(mel_features) * 0.005
        mel_features = mel_features + noise

    if random.random() < 0.3:
        
        gain = random.uniform(0.8, 1.2)
        mel_features = mel_features * gain

    return mel_features



class AugmentedCommonVoiceDataset(Dataset):
    def __init__(self, tsv_file, audio_dir, tokenizer, processor, max_audio_len=1500, augment=False):
        self.df = pd.read_csv(tsv_file, sep='\\t', on_bad_lines='skip')
        self.df = self.df[self.df['valid'] == 1].reset_index(drop=True)
        self.audio_dir = audio_dir
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_audio_len = max_audio_len
        self.augment = augment
        print(f"Dataset loaded with {len(self.df)} valid entries (augment={augment})")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            audio_path = os.path.join(self.audio_dir, row['path'])
            text = row['sentence']

            mel_features = self.processor.extract_fbank(audio_path)

            
            if self.augment:
                mel_features = augment_audio_features(mel_features, augment_prob=0.5)

            encoded_text = self.tokenizer(text, return_tensors="pt", padding="max_length",
                                        max_length=100, truncation=True)
            input_ids = encoded_text.input_ids.squeeze(0)
            decoder_input_ids = input_ids[:-1].clone()
            labels = input_ids[1:].clone()

            if mel_features.size(1) > self.max_audio_len:
                mel_features = mel_features[:, :self.max_audio_len]
            elif mel_features.size(1) < self.max_audio_len:
                pad_len = self.max_audio_len - mel_features.size(1)
                mel_features = F.pad(mel_features, (0, pad_len))

            mel_features = mel_features.unsqueeze(0)

            return {
                "audio_features": mel_features,
                "decoder_input_ids": decoder_input_ids,
                "labels": labels,
                "text": text,
                "audio_path": audio_path
            }
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            return {
                "audio_features": torch.zeros(1, self.processor.n_mels, self.max_audio_len),
                "decoder_input_ids": torch.zeros(99, dtype=torch.long),
                "labels": torch.zeros(99, dtype=torch.long),
                "text": "",
                "audio_path": ""
            }


class LabelSmoothingLoss(nn.Module):
    
    def __init__(self, smoothing=0.1, ignore_index=-100):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        vocab_size = logits.size(-1)
        logits_flat = logits.view(-1, vocab_size)
        targets_flat = targets.view(-1)

        mask = (targets_flat != self.ignore_index)
        log_probs = F.log_softmax(logits_flat, dim=-1)

        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (vocab_size - 1))
            true_dist.scatter_(1, targets_flat.unsqueeze(1), 1.0 - self.smoothing)
            true_dist[~mask] = 0.0

        loss = (-true_dist * log_probs).sum(dim=-1)
        return loss[mask].mean()


def train_epoch_improved(model, dataloader, optimizer, scheduler, device,
                        scaler=None, gradient_accumulation_steps=4,
                        label_smoothing=0.1):
    
    model.train()
    total_loss = 0
    num_batches = len(dataloader)

    criterion = LabelSmoothingLoss(smoothing=label_smoothing)

    with tqdm(total=num_batches, desc="Training") as pbar:
        for i, batch in enumerate(dataloader):
            try:
                audio_features = batch["audio_features"].to(device)
                decoder_input_ids = batch["decoder_input_ids"].to(device)
                labels = batch["labels"].to(device)

                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        logits = model(audio_features, decoder_input_ids)
                        loss = criterion(logits, labels) / gradient_accumulation_steps

                    scaler.scale(loss).backward()

                    if (i + 1) % gradient_accumulation_steps == 0 or i == num_batches - 1:
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                else:
                    logits = model(audio_features, decoder_input_ids)
                    loss = criterion(logits, labels) / gradient_accumulation_steps
                    loss.backward()

                    if (i + 1) % gradient_accumulation_steps == 0 or i == num_batches - 1:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                        optimizer.step()
                        optimizer.zero_grad()

                total_loss += loss.item() * gradient_accumulation_steps

                
                current_lr = optimizer.param_groups[0]['lr']
                pbar.set_postfix({'loss': f'{loss.item()*gradient_accumulation_steps:.4f}',
                                'lr': f'{current_lr:.6f}'})
                pbar.update(1)

                del audio_features, decoder_input_ids, labels, logits, loss

            except torch.cuda.OutOfMemoryError:
                print(f"CUDA OOM error for batch {i}, skipping...")
                torch.cuda.empty_cache()
                optimizer.zero_grad()
                continue

    return total_loss / num_batches

In [None]:
def main():
    
    config = WhisperConfig(
        n_mels=80,
        n_ctx=1500,
        n_heads=8,
        n_audio_layers=4,
        n_text_layers=4,
        n_embed=512,
        n_audio_ctx=1500,
        vocab_size=51865,
        sample_rate=16000,
        dropout=0.2
    )

    tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")
    processor = AudioProcessor(config)

    use_amp = torch.cuda.is_available()
    scaler = torch.cuda.amp.GradScaler() if use_amp else None

    
    train_tsv = "/content/drive/MyDrive/whisper/splits/train.tsv"
    val_tsv = "/content/drive/MyDrive/whisper/splits/valid.tsv"
    audio_dir = "/content/drive/MyDrive/whisper/cv-corpus-21.0-2025-03-14/mn/clips"
    checkpoint_dir = "/content/drive/MyDrive/whisper/whisper_checkpoints/"

    
    print("Loading datasets...")
    train_dataset = AugmentedCommonVoiceDataset(
        train_tsv, audio_dir, tokenizer, processor, augment=True
    )
    val_dataset = AugmentedCommonVoiceDataset(
        val_tsv, audio_dir, tokenizer, processor, augment=False
    )

    train_dataloader = DataLoader(
        train_dataset, batch_size=4, shuffle=True,
        collate_fn=collate_batch, num_workers=2, pin_memory=True
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=4, shuffle=False,
        collate_fn=collate_batch, num_workers=2, pin_memory=True
    )

    
    model = WhisperModel(config).to(device)

    
    checkpoint_path = "/content/drive/MyDrive/whisper/whisper_checkpoints/whisper_model_epoch_24.pt"

    if not os.path.exists(checkpoint_path):
        print(f"ERROR: Checkpoint not found at {checkpoint_path}")
        print("Available checkpoints:")
        for f in os.listdir(checkpoint_dir):
            if f.endswith('.pt'):
                print(f"  - {f}")
        return

    print(f"Loading checkpoint from epoch 24: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])

    start_epoch = 24
    print(f"Successfully loaded checkpoint from epoch {start_epoch}")
    print(f"Previous WER: {checkpoint.get('val_wer', 'N/A')}")

    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=1e-5,
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )

    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        min_lr=1e-7
    )

    print(f"\\nModel Statistics:")
    print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  Initial learning rate: {optimizer.param_groups[0]['lr']:.6f}")

    
    num_epochs = 40
    patience = 8
    best_val_loss = checkpoint.get('val_loss', float('inf'))
    best_val_wer = checkpoint.get('val_wer', float('inf'))
    patience_counter = 0

    train_losses = []
    val_losses = []
    val_wers = []

    print(f"\\n{'='*60}")
    print(f"Starting IMPROVED training from epoch {start_epoch+1}")
    print(f"Improvements: Lower LR, Label Smoothing, Augmentation, LR Scheduler")
    print(f"{'='*60}")

    
    for epoch in range(start_epoch, num_epochs):
        print(f"\\n{'='*60}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*60}")

        
        train_loss = train_epoch_improved(
            model, train_dataloader, optimizer, scheduler, device,
            scaler=scaler, gradient_accumulation_steps=4,
            label_smoothing=0.1
        )
        train_losses.append(train_loss)
        print(f"Train Loss: {train_loss:.4f}")

        
        val_loss, val_wer, predictions, references = evaluate(
            model, val_dataloader, device, tokenizer
        )
        val_losses.append(val_loss)
        val_wers.append(val_wer)
        print(f"Validation Loss: {val_loss:.4f}, WER: {val_wer:.4f}")

        
        scheduler.step(val_wer)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning Rate: {current_lr:.6f}")

        
        print("\\nSample Predictions:")
        for i in range(min(3, len(predictions))):
            print(f"\\n  Example {i+1}:")
            print(f"    Reference:  {references[i]}")
            print(f"    Prediction: {predictions[i]}")

        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_wer': val_wer,
            'config': config
        }

        torch.save(checkpoint, os.path.join(checkpoint_dir, f"whisper_model_epoch_{epoch+1}.pt"))

        
        improved = False

        if val_wer < best_val_wer:
            improvement = ((best_val_wer - val_wer) / best_val_wer) * 100
            best_val_wer = val_wer
            torch.save(checkpoint, os.path.join(checkpoint_dir, "whisper_model_best_wer.pt"))
            print(f"\\n✓ NEW BEST WER! {val_wer:.4f} (improved by {improvement:.2f}%)")
            improved = True

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, os.path.join(checkpoint_dir, "whisper_model_best_loss.pt"))
            print(f"✓ New best loss! {val_loss:.4f}")
            improved = True

        if improved:
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"\\nNo improvement ({patience_counter}/{patience} patience)")

        
        if patience_counter >= patience:
            print(f"\\n⚠ Early stopping triggered after {patience} epochs without improvement")
            break

    
    try:
        import matplotlib.pyplot as plt

        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        
        epochs_range = range(start_epoch+1, start_epoch+1+len(train_losses))

        axes[0].plot(epochs_range, train_losses, 'b-', label='Train Loss', linewidth=2)
        axes[0].plot(epochs_range, val_losses, 'r-', label='Val Loss', linewidth=2)
        axes[0].axvline(x=start_epoch, color='g', linestyle='--', label='Resume Point')
        axes[0].set_xlabel('Epoch', fontsize=12)
        axes[0].set_ylabel('Loss', fontsize=12)
        axes[0].set_title('Training and Validation Loss (Improved)', fontsize=14)
        axes[0].legend(fontsize=10)
        axes[0].grid(True, alpha=0.3)

        axes[1].plot(epochs_range, val_wers, 'g-', linewidth=2)
        axes[1].axvline(x=start_epoch, color='g', linestyle='--', label='Resume Point')
        axes[1].set_xlabel('Epoch', fontsize=12)
        axes[1].set_ylabel('Word Error Rate', fontsize=12)
        axes[1].set_title('Validation WER (Improved)', fontsize=14)
        axes[1].legend(fontsize=10)
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plot_path = os.path.join(checkpoint_dir, 'training_history_improved.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"\\n✓ Training history saved to {plot_path}")
        plt.close()
    except ImportError:
        print("\\n⚠ matplotlib not installed, skipping plot")

    print("\\n" + "="*60)
    print("TRAINING COMPLETED!")
    print(f"Best Validation Loss: {best_val_loss:.4f}")
    print(f"Best WER: {best_val_wer:.4f}")
    print(f"Starting WER (epoch 16): {checkpoint.get('val_wer', 'N/A')}")
    print(f"Improvement: {((checkpoint.get('val_wer', best_val_wer) - best_val_wer) / checkpoint.get('val_wer', 1)) * 100:.2f}%")
    print("="*60)

if __name__ == "__main__":
    main()