In [5]:
! pip install transformers accelerate datasets torch pydub torchinfo soundfile librosa



In [6]:
from transformers import VitsModel, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import librosa
from torchinfo import summary
import torch.nn.functional as F

# Load the model and tokenizer
model = VitsModel.from_pretrained("facebook/mms-tts-swe")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-swe")

# Load the dataset
dataset = load_dataset("json", data_files='./training_data/01_data.json')

# Split the dataset into train and validation sets
dataset = dataset["train"].train_test_split(test_size=0.1)

# Define the maximum length for tokenization
max_length = 512  # Adjust this value based on your needs

# Define a preprocessing function to tokenize the inputs
def preprocess_function(examples):
    # Tokenize the text inputs
    inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
    
    # Ensure inputs are properly formatted for batched processing
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    # Load and resample audio files to 16000 Hz
    audio_values = []
    for audio_path in examples["audio"]:
        audio, sr = librosa.load(audio_path, sr=None)
        resampled_audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
        audio_tensor = torch.tensor(resampled_audio, dtype=torch.float32)
        audio_values.append(audio_tensor)
        print(f"Processed audio file {audio_path} with shape {audio_tensor.shape}")
    
    # Ensure audio_values is correctly included
    if len(audio_values) != len(examples["audio"]):
        print("Mismatch between number of processed audio files and input examples")
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "audio_values": audio_values
    }

# Get a summary of the model
model_summary = summary(model, depth=3)
print(model_summary)

# Apply the preprocessing function to the training and validation datasets
def preprocess_batch(batch):
    return preprocess_function(batch)

tokenized_datasets = dataset.map(preprocess_batch, batched=True, remove_columns=["audio", "text"])

# Define custom data collator
def data_collator(features):
    # Debugging: Print keys of features
    for feature in features:
        print(feature.keys())

    input_ids = torch.stack([torch.tensor(f["input_ids"]) for f in features])
    attention_mask = torch.stack([torch.tensor(f["attention_mask"]) for f in features])
    audio_values = torch.stack([torch.tensor(f["audio_values"]) for f in features]) if "audio_values" in features[0] else None
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": audio_values
    }

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",          # output directory
    num_train_epochs=3,              # number of training epochs
    per_device_train_batch_size=4,   # batch size for training
    per_device_eval_batch_size=4,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir="./logs",            # directory for storing logs
    logging_steps=10,
    save_steps=500,                  # save model every 500 steps
    eval_strategy="steps",           # evaluation strategy to use
    eval_steps=500                   # evaluation step interval
)

# Custom Trainer class
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        labels = inputs["labels"]

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        spectrogram = outputs.spectrogram

        # Compute loss (use appropriate loss function, e.g., MSELoss for simplicity)
        loss = F.mse_loss(spectrogram, labels)

        return (loss, outputs) if return_outputs else loss

# Initialize Custom Trainer
trainer = CustomTrainer(
    model=model,                      # the instantiated 🤗 Transformers model to be trained
    args=training_args,               # training arguments, defined above
    train_dataset=tokenized_datasets['train'],  # training dataset
    eval_dataset=tokenized_datasets['test'],    # evaluation dataset
    tokenizer=tokenizer,              # tokenizer
    data_collator=data_collator       # custom data collator
)

# Train the model
trainer.train()

# Save the final model
trainer.save_model("./final_model")


Some weights of the model checkpoint at facebook/mms-tts-swe were not used when initializing VitsModel: ['flow.flows.0.wavenet.in_layers.0.weight_g', 'flow.flows.0.wavenet.in_layers.0.weight_v', 'flow.flows.0.wavenet.in_layers.1.weight_g', 'flow.flows.0.wavenet.in_layers.1.weight_v', 'flow.flows.0.wavenet.in_layers.2.weight_g', 'flow.flows.0.wavenet.in_layers.2.weight_v', 'flow.flows.0.wavenet.in_layers.3.weight_g', 'flow.flows.0.wavenet.in_layers.3.weight_v', 'flow.flows.0.wavenet.res_skip_layers.0.weight_g', 'flow.flows.0.wavenet.res_skip_layers.0.weight_v', 'flow.flows.0.wavenet.res_skip_layers.1.weight_g', 'flow.flows.0.wavenet.res_skip_layers.1.weight_v', 'flow.flows.0.wavenet.res_skip_layers.2.weight_g', 'flow.flows.0.wavenet.res_skip_layers.2.weight_v', 'flow.flows.0.wavenet.res_skip_layers.3.weight_g', 'flow.flows.0.wavenet.res_skip_layers.3.weight_v', 'flow.flows.1.wavenet.in_layers.0.weight_g', 'flow.flows.1.wavenet.in_layers.0.weight_v', 'flow.flows.1.wavenet.in_layers.1.wei

In [7]:
text = "Ahmad satt på bussen. Han var på väg till sin prao-plats. Skolan hade fixat det. Det var något trist företag han aldrig hade hört talats om."
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    output = model(**inputs).waveform

from IPython.display import Audio

Audio(output, rate=model.config.sampling_rate)