In [41]:
import torch
import torch.nn as nn
import librosa
import numpy as np
import itertools
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from typing import List, Tuple

In [42]:
class ModelConfig:
    D_MODEL = 256
    N_HEADS = 4
    N_LAYERS = 4
    DIM_FEEDFORWARD = 1024
    DROPOUT = 0.1
    MAX_LEN = 5000

In [43]:
class AudioConfig:
    SAMPLE_RATE = 16000
    N_MELS = 80
    N_FFT = 1024
    HOP_LENGTH = 256

class AudioProcessor:
    @staticmethod
    def load_audio(path: str) -> torch.Tensor:
        y, _ = librosa.load(path, sr=AudioConfig.SAMPLE_RATE)
        return torch.FloatTensor(y).unsqueeze(0)  # (1, T)

    @staticmethod
    def extract_mel(waveform: torch.Tensor) -> torch.Tensor:
        y = waveform.squeeze(0).numpy()
        mel = librosa.feature.melspectrogram(
            y=y,
            sr=AudioConfig.SAMPLE_RATE,
            n_mels=AudioConfig.N_MELS,
            n_fft=AudioConfig.N_FFT,
            hop_length=AudioConfig.HOP_LENGTH
        )
        return torch.FloatTensor(np.log(mel + 1e-8)) 

In [44]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = ModelConfig.D_MODEL, max_len: int = ModelConfig.MAX_LEN):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:x.size(1)]
        return x

In [45]:
class EENDModel(nn.Module):
    def __init__(self, n_speakers: int = 2):
        super().__init__()
        self.n_speakers = n_speakers        
        
        self.mel_proj = nn.Linear(AudioConfig.N_MELS, ModelConfig.D_MODEL)        
        
        self.pos_encoder = PositionalEncoding()        
        
        encoder_layer = TransformerEncoderLayer(
            d_model=ModelConfig.D_MODEL,
            nhead=ModelConfig.N_HEADS,
            dim_feedforward=ModelConfig.DIM_FEEDFORWARD,
            dropout=ModelConfig.DROPOUT
        )
        self.transformer = TransformerEncoder(encoder_layer, ModelConfig.N_LAYERS)   
         
        self.head = nn.Linear(ModelConfig.D_MODEL, n_speakers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:        
        x = self.mel_proj(x)
        x = self.pos_encoder(x)
        x = x.permute(1, 0, 2)  
        x = self.transformer(x)
        x = x.permute(1, 0, 2) 
        return torch.sigmoid(self.head(x))

In [46]:
class PITLoss:
    @staticmethod
    def compute(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        n_speakers = pred.shape[-1]
        permutations = list(itertools.permutations(range(n_speakers)))
        losses = []
        
        for perm in permutations:
            permuted_pred = pred[:, :, list(perm)]
            loss = nn.functional.binary_cross_entropy(
                permuted_pred, target, reduction='none'
            ).mean(dim=(1, 2))
            losses.append(loss)
        
        losses = torch.stack(losses, dim=1)  
        return losses.min(dim=1)[0].mean()

class DiarizationDataset(torch.utils.data.Dataset):
    def __init__(self, audio_paths: List[str], labels: List[torch.Tensor]):
        self.audio_paths = audio_paths
        self.labels = labels

    def __len__(self) -> int:
        return len(self.audio_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        waveform = AudioProcessor.load_audio(self.audio_paths[idx])
        mel = AudioProcessor.extract_mel_spectrogram(waveform)
        mel = mel.squeeze(0).T   
        label = self.labels[idx]  
        return mel, label