In [None]:
import os
from functools import partial
from glob import glob
from typing import Optional, Tuple, Union

import torch
import torchaudio
import wandb
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from transformers import (
    SpeechT5ForSpeechToSpeech,
    SpeechT5HifiGan,
    SpeechT5PreTrainedModel,
    SpeechT5Processor,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    WhisperTokenizer,
)
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet

# Initialize wandb
wandb.login(key=os.environ.get("WANDB_API_KEY", ""))
wandb.init(project="clarion-ai-t5-speech-to-speech")

# # Set CUDA memory configuration
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"


def get_spectrogram_first_part(
    model: SpeechT5PreTrainedModel,
    input_values: torch.LongTensor,
    speaker_embeddings: Optional[torch.FloatTensor],
    attention_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.FloatTensor, int, int]:
    """First part of the spectrogram generation to reduce complexity."""
    input_values = input_values.to(model.device, dtype=torch.half)
    if speaker_embeddings is not None:
        speaker_embeddings = speaker_embeddings.to(model.device, dtype=torch.half)

    bsz = input_values.size(0)

    if attention_mask is None:
        encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int()
    else:
        encoder_attention_mask = attention_mask

    encoder_out = model.speecht5.encoder(
        input_values=input_values,
        attention_mask=encoder_attention_mask,
        return_dict=True,
    )

    encoder_last_hidden_state = encoder_out.last_hidden_state

    # downsample encoder attention mask
    if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet):
        encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask(
            encoder_out[0].shape[1], encoder_attention_mask
        )

    return encoder_last_hidden_state, encoder_attention_mask, speaker_embeddings, bsz


def get_spectrogram(
    model: SpeechT5PreTrainedModel,
    input_values: torch.LongTensor,
    speaker_embeddings: Optional[torch.FloatTensor],
    attention_mask: Optional[torch.LongTensor] = None,
    threshold: float = 0.5,
    minlenratio: float = 0.0,
    maxlenratio: float = 20.0,
    output_cross_attentions: bool = False,
    return_output_lengths: bool = False,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
    """Generate spectrogram from input values."""
    # Get initial setup and encoder outputs
    encoder_last_hidden_state, encoder_attention_mask, speaker_embeddings, bsz = get_spectrogram_first_part(
        model, input_values, speaker_embeddings, attention_mask
    )

    maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor)
    minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor)

    # Start the output sequence with a mel spectrum that is all zeros.
    output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins)

    spectrogram = []
    cross_attentions = []
    past_key_values = None
    idx = 0
    result_spectrogram = {}

    return process_spectrogram_loop(
        model,
        encoder_last_hidden_state,
        encoder_attention_mask,
        output_sequence,
        speaker_embeddings,
        bsz,
        spectrogram,
        cross_attentions,
        past_key_values,
        idx,
        result_spectrogram,
        minlen,
        maxlen,
        threshold,
        output_cross_attentions,
        return_output_lengths,
    )


def process_spectrogram_loop(
    model,
    encoder_last_hidden_state,
    encoder_attention_mask,
    output_sequence,
    speaker_embeddings,
    bsz,
    spectrogram,
    cross_attentions,
    past_key_values,
    idx,
    result_spectrogram,
    minlen,
    maxlen,
    threshold,
    output_cross_attentions,
    return_output_lengths,
):
    """Process the spectrogram generation loop to reduce complexity."""
    while True:
        idx += 1

        # Run the decoder prenet on the entire output sequence.
        decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
        # Run the decoder layers on the last element of the prenet output.
        decoder_out = model.speecht5.decoder.wrapped_decoder(
            hidden_states=decoder_hidden_states[:, -1:],
            attention_mask=None,
            encoder_hidden_states=encoder_last_hidden_state,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=output_cross_attentions,
            return_dict=True,
        )

        if output_cross_attentions:
            cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))

        last_decoder_output = decoder_out.last_hidden_state.squeeze(1)
        past_key_values = decoder_out.past_key_values

        # Predict the new mel spectrum for this step in the sequence.
        spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
        spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
        spectrogram.append(spectrum)

        # Extend the output sequence with the new mel spectrum.
        new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)
        output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1)
        # Predict the probability that this is the stop token.
        prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))

        if idx < minlen:
            continue

        # If the generation loop is less than maximum length time, check the ones in the batch that have met
        # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
        if idx < maxlen:
            meet_thresholds = torch.sum(prob, dim=-1) >= threshold
            meet_indexes = torch.where(meet_thresholds)[0].tolist()
        else:
            meet_indexes = range(len(prob))
        meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]

        if len(meet_indexes) > 0:
            spectrograms = torch.stack(spectrogram)
            spectrograms = spectrograms.transpose(0, 1).flatten(1, 2)
            spectrograms = model.speech_decoder_postnet.postnet(spectrograms)
            for meet_index in meet_indexes:
                result_spectrogram[meet_index] = spectrograms[meet_index]

        if len(result_spectrogram) >= bsz:
            break

    spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
    if not return_output_lengths:
        spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
    return spectrogram


def clip_gradients(model, max_norm=1.0):
    """Clips gradients to prevent exploding gradients."""
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)


class SpeechToSpeechTrainer:
    def __init__(self, model_path: str, vocoder_path: str, whisper_path: str, device: str):
        self.device = device
        self.processor = SpeechT5Processor.from_pretrained(model_path)
        self.model = SpeechT5ForSpeechToSpeech.from_pretrained(model_path).to(device)
        self.model.config.use_cache = False
        self.model.speecht5.decoder.forward = partial(self.model.speecht5.decoder.forward, use_cache=True)
        self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_path).to(device)
        self.whisper_processor = WhisperProcessor.from_pretrained(whisper_path)
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_path).to(device)
        self.whisper_tokenizer = WhisperTokenizer.from_pretrained(whisper_path)
        self.en_token_id = self.whisper_tokenizer.convert_tokens_to_ids("<|en|>")

    def compute_en_score(self, audio: torch.Tensor, sample_rate: int = 16000) -> float:
        with torch.no_grad():
            inputs = self.whisper_processor(audio.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt")
            input_features = inputs.input_features.to(self.device)
            decoder_input_ids = torch.full((input_features.shape[0], 1), 50258, dtype=torch.long, device=self.device)
            with autocast():  # Mixed precision inference
                outputs = self.whisper_model(input_features=input_features, decoder_input_ids=decoder_input_ids)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            en_score = probabilities[0, 0, self.en_token_id].item()
        return en_score

    def train(self, audio_files, optimizer, num_epochs=3, checkpoint_dir="./checkpoints"):
        self.model.speecht5.encoder.train()
        os.makedirs(checkpoint_dir, exist_ok=True)
        scaler = GradScaler()
        accumulation_steps = 2
        mean_steps = 10
        step_scores = []
        top_models = []  # Store top models with their mean en_score
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

        for epoch in range(num_epochs):
            torch.cuda.empty_cache()
            epoch_loss = 0.0
            current_scores = []  # Track scores for logging

            for step, file_path in enumerate(tqdm(audio_files, desc=f"Epoch {epoch + 1}/{num_epochs}")):
                waveform, sample_rate = torchaudio.load(file_path)
                waveform = waveform.to(self.device, dtype=torch.float)

                inputs = self.processor(
                    audio=waveform.squeeze().cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt"
                )
                input_values = inputs["input_values"].to(self.device, dtype=torch.float)
                speaker_embeddings = torch.zeros((1, 512), device=self.device, dtype=torch.float)

                with autocast():
                    spectrogram = get_spectrogram(self.model, input_values, speaker_embeddings)
                    generated_audio = self.vocoder(spectrogram).unsqueeze(0).to(self.device)

                    # Compute English score
                    en_score = self.compute_en_score(generated_audio)
                    step_scores.append(en_score)
                    current_scores.append(en_score)

                    # Define loss
                    target_waveform = waveform.to(self.device)
                    target_length = min(generated_audio.size(1), target_waveform.size(1))
                    mse_loss = torch.nn.functional.mse_loss(
                        generated_audio[:, :target_length], target_waveform[:, :target_length]
                    )
                    en_loss = 1 - en_score**2
                    loss = (mse_loss + 5 * en_loss) / accumulation_steps  # Balanced loss

                scaler.scale(loss).backward()

                # Clip gradients
                clip_gradients(self.model.speecht5.encoder)

                if (step + 1) % accumulation_steps == 0 or (step + 1) == len(audio_files):
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()

                epoch_loss += loss.item() * accumulation_steps

                wandb.log(
                    {
                        "epoch": epoch + 1,
                        "step": step + 1,
                        "loss": loss.item(),
                        "en_score": en_score,
                        "mse_loss": mse_loss.item(),
                    }
                )

                if (step + 1) % mean_steps == 0:
                    mean_score = sum(current_scores) / len(current_scores)
                    wandb.log({f"{mean_steps}_step_mean_en_score": mean_score, "step": step + 1})
                    checkpoint_path = os.path.join(
                        checkpoint_dir, f"model_epoch_{epoch + 1}_step_{step + 1}_mean_score_{mean_score:.4f}.pt"
                    )
                    torch.save(self.model.speecht5.encoder.state_dict(), checkpoint_path)
                    top_models.append((mean_score, checkpoint_path))
                    top_models = sorted(top_models, key=lambda x: x[0], reverse=True)[:3]
                    current_scores = []

            avg_loss = epoch_loss / len(audio_files)
            scheduler.step()
            print(f"Epoch {epoch + 1} completed. Avg Loss: {avg_loss:.4f}")
            wandb.log({"epoch_loss": avg_loss, "epoch": epoch + 1, "learning_rate": scheduler.get_last_lr()[0]})

        return self.model


# Main script
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_epochs = 1

    # Load audio files
    audio_files = glob("/input/speechocean762/train/*.wav")
    if not audio_files:
        raise ValueError("No audio files found in the specified path.")

    # Initialize components
    model_path = "microsoft/speecht5_vc"
    vocoder_path = "microsoft/speecht5_hifigan"
    whisper_path = "openai/whisper-tiny"
    trainer = SpeechToSpeechTrainer(
        model_path=model_path, vocoder_path=vocoder_path, whisper_path=whisper_path, device=device
    )

    # Optimizer
    optimizer = torch.optim.AdamW(trainer.model.speecht5.encoder.parameters(), lr=1e-4, weight_decay=0.01)

    # Train the model
    trained_model = trainer.train(audio_files, optimizer, num_epochs=num_epochs, checkpoint_dir="./checkpoints")

    # Save final model
    final_model_path = "./final_speech_to_speech_model.pt"
    torch.save(trained_model.state_dict(), final_model_path)
    print(f"Final model saved at {final_model_path}")

In [None]:
# from IPython.display import Audio, display

# display(Audio("/working/bef.wav", rate=16000))
# display(Audio("/working/gen.wav", rate=16000))

In [None]:
checkpoints = glob("/working/checkpoints/*")
checkpoints

In [None]:
import torch

torch.cuda.empty_cache()
# Load the state dictionaries
state_dict_1 = torch.load(checkpoints[0], weights_only=True)
state_dict_2 = torch.load(checkpoints[1], weights_only=True)

are_equal = all(torch.equal(state_dict_1[key], state_dict_2[key]) for key in state_dict_1.keys())

print(f"Are the state dictionaries equal? {are_equal}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device