# Training Notebook

## Imports

In [None]:
!pip3 install -r requirements.txt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from common.model import VoiceAutoencoder
from common.mel_spectrogram import MelSpectrogram
from common.speaker_embed import DummySpeakerEmbedder
import torchaudio
import os
from torch.utils.data import DataLoader, Dataset
import random
from common.speaker_embed import ECAPASpeakerEmbedder

## Dataset Structure
This should make training faster.

In [None]:
class VoiceDataset(Dataset):
    def __init__(self, file_paths, speaker_ids):
        self.paths = file_paths
        self.ids = speaker_ids
        self.mel_transform = MelSpectrogram()

    def __getitem__(self, index):
        audio, sr = torchaudio.load(self.paths[index])
        audio = torchaudio.functional.resample(audio, sr, 22050)
        mel = self.mel_transform(audio).squeeze().transpose(0, 1)
        return mel, self.ids[index]

    def __len__(self):
        return len(self.paths)

## Initialize

In [None]:
model = VoiceAutoencoder().cuda()
embedder = ECAPASpeakerEmbedder(device='cuda')  # or 'cpu'
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Load data
# TODO Add dataset and dataset download
paths = ["data/speaker0_001.wav", "data/speaker1_002.wav"]  # example
ids = [0, 1]
dataset = VoiceDataset(paths, ids)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

## Training Loop

In [None]:
for epoch in range(20):
    model.train()
    for mel, ref_audio_path in loader:
        mel = mel.cuda()

        # Extract speaker embedding from reference audio
        speaker_embedding = embedder.extract_embedding(ref_audio_path)  # shape [dim]
        speaker_embedding = speaker_embedding.unsqueeze(0).expand(mel.size(0), -1)  # match batch size
        speaker_embedding = speaker_embedding.cuda()

        out = model(mel, speaker_embedding)
        loss = criterion(out, mel)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}: Loss = {loss.item()}")