In [None]:
import os


class CFG:
    project_name = "clarion-ai-003"

    compression_ratio = 0.9
    channels = 32

    batch_size = 8
    learning_rate = 0.001
    weight_decay = 1e-4
    epochs = 500
    wandb_api_key = os.environ.get("WANDB_API_KEY", "")
    dataset_path = "/input/speechocean762/train/*.wav"
    pretrained_model_path = "/input/clarion-ai-002-ds/speech_autoencoder.pth"
    model_save_path = "speech_autoencoder_en_score.pth"
    whisper_path = "openai/whisper-base"

In [None]:
%%time

import inspect
import os
from glob import glob

import torch
import torch.nn as nn
import torchaudio
import wandb
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader, Dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer


def collate_fn(batch):
    lengths = [item.size(1) for item in batch if item.numel() > 0]  # Skip empty tensors
    max_length = max(lengths)
    padded_batch = torch.zeros(len(batch), 1, max_length, dtype=batch[0].dtype)

    for i, item in enumerate(batch):
        if item.numel() > 0:
            padded_batch[i, :, : item.size(1)] = item

    return padded_batch


class SpeechDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        audio, sample_rate = torchaudio.load(self.file_paths[idx])
        if self.transform:
            audio = self.transform(audio)
        return audio


class Encoder(nn.Module):
    def __init__(self, compression_ratio=0.7, channels=16):
        super(Encoder, self).__init__()
        self.compression_ratio = compression_ratio

        # Three convolutional layers with group normalization and ReLU activation
        # Increase the number of feature maps to improve representational capacity
        self.encoder_layers = nn.Sequential(
            nn.Conv1d(1, channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(channels, 1, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, 1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Pass through the convolutional layers
        x = self.encoder_layers(x)
        # Compress the time dimension
        compressed_length = int(x.size(2) * self.compression_ratio)
        x = nn.functional.interpolate(x, size=compressed_length, mode="linear", align_corners=False)
        return x


class Decoder(nn.Module):
    def __init__(self, expansion_ratio=1 / 0.7, channels=16):
        super(Decoder, self).__init__()
        self.expansion_ratio = expansion_ratio

        # Three transposed convolutional layers with group normalization and ReLU activation
        # Mirror the encoder structure
        self.decoder_layers = nn.Sequential(
            nn.ConvTranspose1d(1, channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose1d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose1d(channels, 1, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, 1),
            # Note: Final activation can be omitted or chosen based on the output signal's nature
        )

    def forward(self, x):
        # Expand the time dimension
        original_length = int(x.size(2) * self.expansion_ratio)
        x = nn.functional.interpolate(x, size=original_length, mode="linear", align_corners=False)
        # Pass through the transposed convolutional layers
        x = self.decoder_layers(x)
        return x


class SpeechAutoencoder(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, compression_ratio=0.7, channels=16):
        super(SpeechAutoencoder, self).__init__()
        self.encoder = Encoder(compression_ratio=compression_ratio, channels=channels)
        self.decoder = Decoder(expansion_ratio=1 / compression_ratio, channels=channels)

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)

        # print(x.shape, x.numel())
        # print(latent.shape, latent.numel())
        # print(reconstructed.shape, reconstructed.numel())
        # print("-"*30)

        # if reconstructed.size(2) > x.size(2):
        #     reconstructed = reconstructed[:, :, :x.size(2)]
        # elif reconstructed.size(2) < x.size(2):
        #     pad_length = x.size(2) - reconstructed.size(2)
        #     reconstructed = nn.functional.pad(reconstructed, (0, pad_length))
        return reconstructed

    @classmethod
    def from_pretrained(cls, model_path, device="cpu", **kwargs):
        init_params = inspect.signature(cls).parameters
        init_kwargs = {
            key: param.default
            for key, param in init_params.items()
            if param.default is not inspect.Parameter.empty and key != "self"
        }
        init_kwargs.update(kwargs)
        model = cls(**init_kwargs).to(device)
        state_dict = torch.load(model_path, map_location=device, weights_only=True)
        model.load_state_dict(state_dict)
        return model

    @classmethod
    def _load_state_dict_into_model(cls, model: nn.Module, state_dict: dict):
        """
        Helper function to load the state dictionary into the model.
        """
        state_dict = state_dict.copy()  # Avoid modifying the original state_dict
        error_msgs = []

        def load(module: torch.nn.Module, prefix: str = ""):
            args = (state_dict, prefix, {}, True, [], [], error_msgs)
            module._load_from_state_dict(*args)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + ".")

        load(model)
        if len(error_msgs) > 0:
            raise RuntimeError(f"Error(s) in loading state_dict: {error_msgs}")
        return model


# EnScorePredictor Class
class EnScorePredictor:
    def __init__(self, whisper_path, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.whisper_processor = WhisperProcessor.from_pretrained(whisper_path)
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_path).to(self.device)
        self.whisper_tokenizer = WhisperTokenizer.from_pretrained(whisper_path)
        self.en_token_id = self.whisper_tokenizer.convert_tokens_to_ids("<|en|>")

    def compute_en_score(self, audio: torch.Tensor, sample_rate: int = 16000) -> float:
        with torch.no_grad():
            inputs = self.whisper_processor(audio.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt")
            input_features = inputs.input_features.to(self.device)
            decoder_input_ids = torch.full((input_features.shape[0], 1), 50258, dtype=torch.long, device=self.device)
            with autocast():
                outputs = self.whisper_model(input_features=input_features, decoder_input_ids=decoder_input_ids)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            en_score = probabilities[0, 0, self.en_token_id].item()
        return en_score


def train_autoencoder(autoencoder, dataloader, optimizer, device, en_score_predictor, epochs=CFG.epochs):
    autoencoder.train()

    for epoch in range(epochs):
        total_loss = 0
        total_en_score = 0
        for batch_idx, audio in enumerate(dataloader):
            audio = audio.to(device)
            optimizer.zero_grad()

            # Forward pass
            reconstructed = autoencoder(audio)

            # Adjust the shape to match for loss calculation
            min_length = min(reconstructed.size(2), audio.size(2))
            reconstructed = reconstructed[:, :, :min_length]
            audio = audio[:, :, :min_length]

            # For maximizing en_score: compute en_score on the reconstructed output
            rec_segment = reconstructed[0, :, :min_length].detach().cpu()
            en_score = en_score_predictor.compute_en_score(rec_segment, sample_rate=16000)
            total_en_score += en_score

            # Compute loss
            en_score_tensor = torch.tensor(en_score, requires_grad=True, device=device)
            loss = -en_score_tensor

            # Backward pass
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Log the batch loss and en_score to W&B
            wandb.log({"epoch": epoch + 1, "batch": batch_idx + 1, "batch_loss": loss.item(), "en_score": en_score})

        # Log average loss and en_score for the epoch
        avg_loss = total_loss / len(dataloader)
        avg_en_score = total_en_score / len(dataloader)
        wandb.log({"epoch": epoch + 1, "epoch_loss": avg_loss, "avg_en_score": avg_en_score})
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}, Avg En Score: {avg_en_score:.4f}")


if __name__ == "__main__":
    # Initialize W&B
    wandb.login(key=CFG.wandb_api_key)
    wandb.init(
        project=CFG.project_name,
        config={
            "compression_ratio": CFG.compression_ratio,
            "channels": CFG.channels,
            "batch_size": CFG.batch_size,
            "learning_rate": CFG.learning_rate,
            "epochs": CFG.epochs,
            "weight_decay": CFG.weight_decay,
            "dataset_path": CFG.dataset_path,
            "pretrained_model_path": CFG.pretrained_model_path,
        },
    )

    # Dataset and DataLoader
    file_paths = glob(CFG.dataset_path)
    # file_paths = glob(CFG.dataset_path)[:CFG.batch_size*4]
    transform = None
    dataset = SpeechDataset(file_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=True, collate_fn=collate_fn)

    # Model, Optimizer, EnScorePredictor
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    autoencoder = SpeechAutoencoder.from_pretrained(
        CFG.pretrained_model_path, compression_ratio=CFG.compression_ratio, channels=CFG.channels
    ).to(device)
    optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=CFG.learning_rate, weight_decay=CFG.weight_decay)
    en_score_predictor = EnScorePredictor(CFG.whisper_path, device=device)

    # Train
    train_autoencoder(autoencoder, dataloader, optimizer, device, en_score_predictor, epochs=CFG.epochs)

    # Save model
    torch.save(autoencoder.state_dict(), CFG.model_save_path)

    # Finish W&B run
    wandb.finish()

In [None]:
from glob import glob

import matplotlib.pyplot as plt
import torch
from IPython.display import Audio, display


# Function to load and process an audio file with the autoencoder
def process_audio(file_path, autoencoder_path, save_path="generated_speech.wav"):
    # Load the audio file
    waveform, sample_rate = torchaudio.load(file_path)

    # Display the original audio
    print("Original Audio:")
    display(Audio(file_path, rate=sample_rate))

    # Preprocessing: Convert to mono and pad for consistent input length
    if waveform.size(0) > 1:  # If stereo, convert to mono
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    waveform = waveform.unsqueeze(0)  # Add batch dimension (1, 1, time)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the pre-trained autoencoder model
    autoencoder = SpeechAutoencoder.from_pretrained(
        autoencoder_path, compression_ratio=CFG.compression_ratio, channels=CFG.channels
    ).to(device)
    autoencoder.eval()

    # Move input to the same device as the model
    waveform = waveform.to(device)

    # Inference: Pass waveform through the autoencoder
    with torch.no_grad():
        processed_waveform = autoencoder(waveform)

    # Save processed audio
    processed_waveform = processed_waveform.squeeze(0).cpu()
    torchaudio.save(save_path, processed_waveform, sample_rate=sample_rate)

    # Display the processed audio
    print("Processed Audio:")
    display(Audio(save_path, rate=sample_rate))

    return waveform.squeeze(0).cpu(), processed_waveform, sample_rate


# Plot waveforms and spectrograms
def plot_waveforms_and_spectrograms(waveform_before, waveform_after, sample_rate):
    # Compute Mel Spectrograms
    transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=64)
    mel_spectrogram_before = transform(waveform_before)
    mel_spectrogram_after = transform(waveform_after)

    # Plot Waveforms and Spectrograms Side-by-Side
    fig, axes = plt.subplots(2, 2, figsize=(14, 8))

    # Waveform Before
    axes[0, 0].plot(waveform_before.t().numpy())
    axes[0, 0].set_title("Waveform (Before Processing)")
    axes[0, 0].set_xlabel("Time (samples)")
    axes[0, 0].set_ylabel("Amplitude")
    axes[0, 0].grid()

    # Waveform After
    axes[0, 1].plot(waveform_after.t().numpy())
    axes[0, 1].set_title("Waveform (After Processing)")
    axes[0, 1].set_xlabel("Time (samples)")
    axes[0, 1].set_ylabel("Amplitude")
    axes[0, 1].grid()

    # Spectrogram Before
    img_before = axes[1, 0].imshow(
        mel_spectrogram_before.log2()[0, :, :].numpy(), cmap="viridis", origin="lower", aspect="auto"
    )
    axes[1, 0].set_title("Mel Spectrogram (Before Processing)")
    axes[1, 0].set_xlabel("Time (frames)")
    axes[1, 0].set_ylabel("Mel Frequency (bins)")
    fig.colorbar(img_before, ax=axes[1, 0], orientation="vertical", fraction=0.046, pad=0.04)

    # Spectrogram After
    img_after = axes[1, 1].imshow(
        mel_spectrogram_after.log2()[0, :, :].numpy(), cmap="viridis", origin="lower", aspect="auto"
    )
    axes[1, 1].set_title("Mel Spectrogram (After Processing)")
    axes[1, 1].set_xlabel("Time (frames)")
    axes[1, 1].set_ylabel("Mel Frequency (bins)")
    fig.colorbar(img_after, ax=axes[1, 1], orientation="vertical", fraction=0.046, pad=0.04)

    # Adjust layout
    plt.tight_layout()
    plt.show()


# Main Script
if __name__ == "__main__":
    audio_files = glob("/input/speechocean762/train/*.wav")
    audio_file = audio_files[0]
    waveform_before, waveform_after, sample_rate = process_audio(audio_file, CFG.model_save_path)

    # Plot waveforms and spectrograms
    plot_waveforms_and_spectrograms(waveform_before, waveform_after, sample_rate)

In [None]:
audio_files = glob("/input/speechocean762/test/*.wav")
audio_file = audio_files[0]
waveform_before, waveform_after, sample_rate = process_audio(audio_file, CFG.model_save_path)
plot_waveforms_and_spectrograms(waveform_before, waveform_after, sample_rate)