In [1]:
!pip install torch transformers datasets soundfile librosa jiwer



In [2]:
from datasets import load_dataset, Audio

from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer
    )

import torch
import torchaudio

import torch.nn as nn

import os

https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-russian

In [3]:
os.environ["WANDB_DISABLED"] = "true"

In [22]:
dataset = load_dataset(
    "bond005/sberdevices_golos_10h_crowd",
    split="train+validation[:50]"
    )

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [23]:
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-russian"
processor = Wav2Vec2Processor.from_pretrained(model_name)

model = Wav2Vec2ForCTC.from_pretrained(
    model_name,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

In [24]:
def preprocess_text(batch):
    batch["transcription"] = batch["transcription"].lower().replace("ё", "е")
    return batch


dataset = dataset.map(preprocess_text)

In [25]:
def prepare_dataset(batch):
    input_values = processor(
        batch["audio"]["array"],
        sampling_rate=16000,
        padding="longest",
        return_tensors="pt"
    ).input_values.squeeze(0)

    labels = processor.tokenizer.encode(batch["transcription"])

    return {
        "input_values": input_values,
        "labels": torch.tensor(labels)
    }


dataset = dataset.map(
    prepare_dataset,
    remove_columns=["audio", "transcription"],
    #cache_file_name="./cache.golos"
    )

In [26]:
def collate_fn(batch):
    input_values = [torch.tensor(item["input_values"]) for item in batch]
    labels = [torch.tensor(item["labels"]) for item in batch]

    input_values = torch.nn.utils.rnn.pad_sequence(
        input_values,
        batch_first=True,
        padding_value=processor.feature_extractor.padding_value
    )

    labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=-100
    )

    return {
        "input_values": input_values,
        "labels": labels
        }

In [34]:
for param in model.wav2vec2.feature_extractor.parameters():
    param.requires_grad = False

In [38]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    warmup_ratio=0.1,
    fp16=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    logging_steps=100,
    save_steps=500,
    report_to="none",
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
)

In [44]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
)

In [45]:
torch.cuda.empty_cache()

In [46]:
print(f"memory: {torch.cuda.mem_get_info()[0]/1024**3:.2f}")  # gb

memory: 6.65


In [47]:
trainer.train()



Step,Training Loss
100,1.1458
200,0.8222
300,0.8488
400,1.1148
500,1.2428
600,1.1822
700,1.7285
800,1.6838
900,1.7388
1000,1.73




TrainOutput(global_step=2010, training_loss=1.2137901643022375, metrics={'train_runtime': 2978.4067, 'train_samples_per_second': 5.401, 'train_steps_per_second': 0.675, 'total_flos': 2.414740532868259e+18, 'train_loss': 1.2137901643022375, 'epoch': 1.9985082048731975})

In [48]:
model.save_pretrained("./stt_model")
processor.save_pretrained("./stt_model")

[]

In [49]:
import librosa

In [51]:
def transcribe_audio(file_path):
    model_path = "./stt_model"
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = Wav2Vec2ForCTC.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(torch.float32)  # Явно float32

    audio, sr = librosa.load(file_path, sr=16000)

    inputs = processor(
        audio,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True
    ).input_values.to(model.dtype)

    inputs = inputs.to(model.device)

    with torch.no_grad():
        logits = model(inputs).logits

    pred_ids = torch.argmax(logits, dim=-1)
    return processor.batch_decode(pred_ids)[0]


print(transcribe_audio("Sound_08129.wav"))

внимания говорить и показывает москва работают все центральные каналы телевидения смотрете и слушатеи москву
