In [None]:
# !cp /content/drive/MyDrive/dls.tar.gz .
# !tar -xzf dls.tar.gz
# !pip install evaluate
# !pip install jiwer
# !pip install resampy

In [None]:
from functools import cached_property
from typing import Any, Dict, List, Optional, Union
import torch
import resampy
import numpy as np
import pandas as pd
import soundfile as sf
from transformers import (
    WhisperTokenizer,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    BatchFeature
)


class MultiligualTokenizer:

    def __init__(self, tokenizer: WhisperTokenizer):
        self._tokenizer = tokenizer
        self._lang2code = {language: f"{code}" for language, code in TO_LANGUAGE_CODE.items()}
        self.vocab = self._tokenizer.get_vocab()

    def language_token(self, language: str) -> str:
        language = language.lower()
        if language not in self._lang2code:
            raise KeyError(f"Language {language} not found in tokenizer.")
        return f"<|{self._lang2code[language]}|>"

    def language_id(self, language: str) -> int:
        return self.vocab[self.language_token(language)]

    @cached_property
    def sot(self) -> int:
        return self.vocab["<|startoftranscript|>"]

    @cached_property
    def eot(self) -> int:
        return self.vocab["<|endoftext|>"]

    @cached_property
    def no_timestamps(self) -> int:
        return self.vocab["<|notimestamps|>"]

    @cached_property
    def transcribe(self) -> int:
        return self.vocab["<|transcribe|>"]

    def tokenize(self, text: str, language: str) -> Dict[str, List[int]]:
        text_tokens = self._tokenizer.encode(" " + text.strip(), add_special_tokens=False)
        sot_sequence = [self.sot, self.language_id(language), self.transcribe, self.no_timestamps]
        return sot_sequence + text_tokens + [self.eot]



class WhisperDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        manifests_files: List[str],
        languages: List[str],
        processor: WhisperProcessor,
        group_weights: Optional[List[float]] = None,
        dataset_name: Optional[str] = None,
        sampling_rate: Optional[int] = None
    ):
        assert len(manifests_files) == len(languages)
        if sampling_rate is None:
            sampling_rate = 16000
        self.sampling_rate = sampling_rate
        self.dataset_name = dataset_name
        self.tokenizer = MultiligualTokenizer(tokenizer=processor.tokenizer)
        self.feature_extractor = processor.feature_extractor
        self.data = []
        for i, (lang, path) in enumerate(zip(languages, manifests_files)):
            df = pd.read_csv(path, delimiter="\t")
            for _, row in df.iterrows():
                self.data.append({
                    "dataset_id": i,
                    "path": row.path,
                    "transcription": row.transcription,
                    "lang": lang})

    def __getitem__(self, idx) -> Dict[str, Any]:
        item = self.data[idx]
        audio = self._read_audio(item["path"])
        return {
            "labels": self.tokenizer.tokenize(text=item["transcription"], language=item["lang"]),
            "input_features": self.feature_extractor(audio, sampling_rate=self.sampling_rate, padding="max_length").input_features[0],
            "language": item["lang"],
            "dataset_name": self.dataset_name,
        }

    def _read_audio(self, audio_file):
        audio, sr = sf.read(audio_file)
        if len(audio.shape) == 2:
            audio = np.mean(audio, axis=1)
        if sr != self.sampling_rate:
            audio = resampy.resample(audio, sr, self.sampling_rate)
        return audio

    def __len__(self):
        return len(self.data)


class WhisperDataCollator:
    def __call__(
        self, inputs: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> BatchFeature:
        # Extract input features and labels from the samples

        # Pad features
        input_features: List[np.ndarray] = [input["input_features"] for input in inputs]
        input_features_batch = np.stack(input_features, axis=0)
        input_features_batch = torch.FloatTensor(input_features_batch)

        # Pad labels
        labels: List[List[int]] = [input["labels"] for input in inputs]
        lengths = [len(label) for label in labels]
        max_length = max(lengths)
        labels_padded = [label + [-100] * (max_length - len(label)) for label in labels]
        labels_padded = torch.LongTensor(labels_padded)

        languages: List[str] = [input["language"] for input in inputs]
        dataset_names: List[str] = [input["dataset_name"] for input in inputs]
        return BatchFeature({
                "input_features": input_features_batch,
                "labels": labels_padded,
                "language": languages,
                "dataset_name": dataset_names,
            })

In [None]:
from tqdm.notebook import tqdm
from collections import Counter
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import WordErrorRate



class TrainingConfig:
    # Пути
    model_name = "openai/whisper-small"
    dataset_path = "./data"
    output_dir = "./whisper-finetuned"

    # Гиперпараметры
    batch_size = 8
    eval_batch_size = 4
    learning_rate = 1e-4
    num_epochs = 3
    warmup_steps = 500
    max_grad_norm = 1.0
    weight_decay = 0.01

    # Настройки данных
    max_length = 448
    max_target_length = 128
    sampling_rate = 16000


class Trainer:

    def __init__(
            self,
            model: WhisperForConditionalGeneration,
            processor: WhisperProcessor,
            train_dataloader: DataLoader,
            test_dataloaders: Dict[str, DataLoader],
            config: TrainingConfig,
        ):
        self.grads = []
        self.langs = []
        self.datasets = []

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.processor = processor
        self.config = config
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=len(train_dataloader) * self.config.num_epochs
        )
        self.train_dataloader = train_dataloader
        self.test_dataloaders = test_dataloaders
        self.wer_metric = WordErrorRate()

    def train_step(self):
        epoch_loss = 0.
        self.model.train()
        languages = []
        dataset_names = []
        grad_norms = []
        for batch in tqdm(self.train_dataloader):
            input_features = batch["input_features"].to(self.device)
            labels = batch["labels"].to(self.device)
            languages += batch["language"]
            dataset_names += batch["dataset_name"]
            outputs = self.model(
                    input_features=input_features,
                    labels=labels,
                    return_dict=True
            )
            loss = outputs.loss

            self.optimizer.zero_grad()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.max_grad_norm
            )
            self.optimizer.step()
            self.scheduler.step()
            epoch_loss += loss.item()
            grad_norms.append(grad_norm)
        print(sum(grad_norms) / len(grad_norms))
        self.grads.append(grad_norms)
        cnt_lang = Counter(languages)
        cnt_dataset = Counter(dataset_names)
        self.langs.append(cnt_lang)
        self.datasets.append(cnt_dataset)
        # assert cnt_lang["kyrgyz"] / len(languages) > 0.3, "Киргизского языка должно быть в обучении > 30%"
        # assert cnt_lang["kyrgyz"] / len(languages) < 0.5, "Киргизского языка должно быть в обучении > 50%"
        
        # assert cnt_dataset["fleurs_ky"] / cnt_dataset["common_voice_ky"] > 0.4, "Доля fleurs_ky в киргизской части должна быть около половины"
        # assert cnt_dataset["fleurs_ky"] / cnt_dataset["common_voice_ky"] < 0.6, "Доля fleurs_ky в киргизской части должна быть около половины"

        # assert cnt_dataset["fleurs_ru"] / cnt_dataset["common_voice_ru"] > 0.2, "Доля fleurs_ru в русской должна быть > 0.2"
        # assert cnt_dataset["fleurs_ru"] / cnt_dataset["common_voice_ru"] < 0.3, "Доля fleurs_ky в русской должна быть < 0.3"
        
        avg_loss = epoch_loss / len(self.train_dataloader)
        return avg_loss

    @torch.no_grad()
    def eval_step(self):
        self.model.eval()
        res = {}
        for name, test_dataloader in self.test_dataloaders.items():
            all_predictions = []
            all_references = []
            for batch in tqdm(test_dataloader):
                input_features = batch["input_features"].to(self.device)
                labels = batch["labels"]
                labels[labels == -100] = 50257

                # Генерация
                generated_ids = self.model.generate(
                    input_features=input_features,
                    max_length=self.config.max_target_length,
                    language=batch["language"],
                    num_beams=1
                )
                    # Декодирование
                predictions = self.processor.batch_decode(
                    generated_ids,
                    skip_special_tokens=True
                )
                references = self.processor.batch_decode(
                    labels,
                    skip_special_tokens=True
                )
                all_predictions.extend(predictions)
                all_references.extend(references)
            wer = self.wer_metric(all_predictions, all_references)
            res[name] = wer
        return res

    def train(self, epoch: int):
        train_losses = []
        eval_wers = []
        eval_wer = self.eval_step()
        eval_wers.append(eval_wer)
        for i in tqdm(range(epoch)):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            eval_wer = self.eval_step()
            eval_wers.append(eval_wer)
        return eval_wers


In [None]:
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES


def update_vocab(model: WhisperForConditionalGeneration, processor: WhisperProcessor):
    cnt_new_tokens = 0
    for i, (code, language) in enumerate(NEW_LANGUAGES.items()):
        token = f"<|{code}|>"
        cnt = processor.tokenizer.add_tokens(token, special_tokens=True)
        if cnt == 1:
            cnt_new_tokens += cnt
            model.generation_config.lang_to_id[token] =  processor.tokenizer.get_vocab()[token]
    model.resize_token_embeddings(len(processor.tokenizer))
    return cnt_new_tokens


NEW_LANGUAGES = {"ky": "kyrgyz"}
NEW_TO_LANGUAGE_CODE = {"kyrgyz": "ky"}

LANGUAGES.update(NEW_LANGUAGES)
TO_LANGUAGE_CODE.update(NEW_TO_LANGUAGE_CODE)

MODEL_NAME = "openai/whisper-small"

processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
cnt_new_tokens = update_vocab(model, processor)

dataset = WhisperDataset(
    manifests_files = [
     "dls/fleurs/ru/train/manifest.tsv",
     "dls/fleurs/ky/train/manifest.tsv",
     "dls/common_voice/ky/train/manifest.tsv",
     "dls/common_voice/ru/train/manifest.tsv",
     ],
    languages = [
        "russian",
        "kyrgyz",
        "kyrgyz",
        "russian"
        ],
    processor=processor,
)
eval_datasets = {
    "fleurs_ru": WhisperDataset(["dls/fleurs/ru/test/manifest.tsv"], ["russian"], processor),
    "fleurs_ky": WhisperDataset(["dls/fleurs/ky/test/manifest.tsv"], ["kyrgyz"], processor),
    "common_voice_ru": WhisperDataset(["dls/common_voice/ru/test/manifest.tsv"], ["russian"], processor),
    "common_voice_ky": WhisperDataset(["dls/common_voice/ky/test/manifest.tsv"], ["kyrgyz"], processor),
}
config = TrainingConfig()

train_dataloader = DataLoader(
            dataset,
            batch_size=config.batch_size,
            shuffle=True,
            collate_fn=WhisperDataCollator(),
            num_workers=1
)
test_dataloaders = {name: DataLoader(
            eval_dataset,
            batch_size=config.eval_batch_size,
            shuffle=True,
            collate_fn=WhisperDataCollator(),
            num_workers=1
) for (name, eval_dataset) in eval_datasets.items()}



trainer = Trainer(model, processor, train_dataloader, test_dataloaders, config)



In [None]:
res = trainer.train(config.num_epochs)

In [None]:
res

In [None]:
input_features = test_dataloaders["fleurs_ky"].dataset[212]["input_features"]
input_features = torch.from_numpy(input_features).unsqueeze(0).cuda()
generated_ids = model.generate(
                    input_features=input_features,
                    # max_length=config.max_target_length,
                    # language="ky",
                    # task="transcribe", 
                    num_beams=1
                )

In [None]:
generated_ids

In [None]:
batch = next(iter(test_dataloaders["fleurs_ky"]))

In [None]:
input_features = batch["input_features"].cuda()
labels = batch["labels"]
labels[labels == -100] = 50257

# Генерация
generated_ids = model.generate(
    input_features=input_features,
    max_length=config.max_target_length,
    language=batch["language"],
    num_beams=1)

In [None]:
processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [None]:
processor.tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)

In [None]:
batch["language"]

In [None]:
path = "dls/fleurs/ky/train/files/1.wav"

In [None]:
inputs = processor("dls/fleurs/ky/train/files/1.wav", sampling_rate=16000, return_tensors="pt", language="kyrgyz")