In [15]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import numpy as np
from jamo import hangul_to_jamo
import librosa
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, spectral_norm
import torchaudio
import torch.optim as optim
from pathlib import Path
from tqdm import tqdm

In [16]:
class Trainer:
    def __init__(self, config):
        self.config = config

        self.device = torch.device("cpu")
        if torch.cuda.is_available():
            torch.cuda.set_device(config.device_num)
            self.device = torch.device("cuda")
        
        # 데이터셋 및 데이터로더 초기화
        self.train_dataset = KSSTTSDataset(split='train')
        self.valid_dataset = KSSTTSDataset(split='valid')
        
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            collate_fn=KSSTTSDataset.collate_fn,
            num_workers=4
        )
        
        self.valid_loader = DataLoader(
            self.valid_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            collate_fn=KSSTTSDataset.collate_fn,
            num_workers=4
        )
        
        # 모델 초기화
        self.tacotron2 = Tacotron2(
            n_mel_channels=80,
            vocab_size=len(self.train_dataset.vocab),
            embedding_dim=256,
            encoder_n_convolutions=3,
            encoder_kernel_size=5,
            attention_rnn_dim=512,
            attention_dim=128
        ).to(self.device)
        
        self.generator = Generator().to(self.device)
        self.msd = MultiScaleDiscriminator().to(self.device)
        self.mpd = MultiPeriodDiscriminator().to(self.device)
        self.mrf = MRFDiscriminator().to(self.device)
        
        # Loss 함수 초기화
        self.criterion = TTSLoss(device=self.device)
        
        # Optimizer 초기화
        self.optimizer_g = optim.AdamW([
            {'params': self.tacotron2.parameters()},
            {'params': self.generator.parameters()}
        ], lr=config.learning_rate, betas=(0.8, 0.99), weight_decay=0.01)
        
        self.optimizer_d = optim.AdamW([
            {'params': self.msd.parameters()},
            {'params': self.mpd.parameters()},
            {'params': self.mrf.parameters()}
        ], lr=config.learning_rate, betas=(0.8, 0.99), weight_decay=0.01)
        
        # 체크포인트 디렉토리 생성
        self.checkpoint_dir = Path(config.checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

    def save_checkpoint(self, epoch, loss):
        checkpoint = {
            'epoch': epoch,
            'tacotron2_state_dict': self.tacotron2.state_dict(),
            'generator_state_dict': self.generator.state_dict(),
            'msd_state_dict': self.msd.state_dict(),
            'mpd_state_dict': self.mpd.state_dict(),
            'mrf_state_dict': self.mrf.state_dict(),
            'optimizer_t2_state_dict': self.optimizer_t2.state_dict(),
            'optimizer_g_state_dict': self.optimizer_g.state_dict(),
            'optimizer_d_state_dict': self.optimizer_d.state_dict(),
            'loss': loss
        }
        
        path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
        torch.save(checkpoint, path)

    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        
        self.tacotron2.load_state_dict(checkpoint['tacotron2_state_dict'])
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.msd.load_state_dict(checkpoint['msd_state_dict'])
        self.mpd.load_state_dict(checkpoint['mpd_state_dict'])
        self.mrf.load_state_dict(checkpoint['mrf_state_dict'])
        
        self.optimizer_t2.load_state_dict(checkpoint['optimizer_t2_state_dict'])
        self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
        self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
        
        return checkpoint['epoch']

    def train_step(self, batch):
        try:
            # 배치 데이터 언패킹
            text_padded = batch['text_padded'].to(self.device)
            mel_padded = batch['mel_padded'].to(self.device)
            gate_padded = batch['gate_padded'].to(self.device)
            audio_padded = batch['audio_padded'].to(self.device)
            text_lengths = batch['text_lengths'].to(self.device)
            mel_lengths = batch['mel_lengths'].to(self.device)
            
            # 디버깅을 위한 shape 출력
            print("\nInput shapes:")
            print(f"text_padded: {text_padded.shape}")
            print(f"mel_padded: {mel_padded.shape}")
            print(f"gate_padded: {gate_padded.shape}")
            print(f"audio_padded: {audio_padded.shape}")
            
            # Tacotron2 forward
            mel_outputs_postnet, mel_outputs, gate_outputs, _ = self.tacotron2(
                text_padded, text_lengths, mel_padded, mel_lengths)
            
            # Generator forward
            mel_for_generator = mel_outputs_postnet.transpose(1, 2)
            fake_audio = self.generator(mel_for_generator)
            
            # Discriminator forward
            msd_real_outputs, msd_real_features = self.msd(audio_padded)
            msd_fake_outputs, msd_fake_features = self.msd(fake_audio.detach())
            
            mpd_real_outputs, mpd_real_features = self.mpd(audio_padded)
            mpd_fake_outputs, mpd_fake_features = self.mpd(fake_audio.detach())
            
            # Loss 계산
            # Tacotron2 loss
            tacotron2_losses = self.criterion.tacotron2_loss(
                mel_outputs, mel_outputs_postnet, gate_outputs,
                mel_padded, gate_padded, mel_lengths
            )
            
            # HiFi-GAN loss
            real_outputs = msd_real_outputs + mpd_real_outputs
            fake_outputs = msd_fake_outputs + mpd_fake_outputs
            real_features = msd_real_features + mpd_real_features
            fake_features = msd_fake_features + mpd_fake_features
            
            hifigan_losses = self.criterion.hifi_gan_loss(
                audio_padded, fake_audio,
                real_outputs, fake_outputs,
                real_features, fake_features
            )
            
            # Optimizer step
            # Generator update
            self.optimizer_g.zero_grad()
            g_loss = tacotron2_losses['total_loss'] + hifigan_losses['generator_loss']
            g_loss.backward(retain_graph=True)
            self.optimizer_g.step()
            
            # Discriminator update
            self.optimizer_d.zero_grad()
            d_loss = hifigan_losses['discriminator_loss']
            d_loss.backward()
            self.optimizer_d.step()
            
            return {
                'total_loss': g_loss + d_loss,
                **tacotron2_losses,
                **hifigan_losses
            }
            
        except Exception as e:
            print(f"\nError in train_step: {str(e)}")
            print("\nBatch information:")
            print(f"Batch size: {len(batch)}")
            print(f"Batch contents: {[type(x) for x in batch]}")
            raise e

    def train(self):
        start_epoch = 0
        if self.config.resume_checkpoint:
            start_epoch = self.load_checkpoint(self.config.resume_checkpoint)
        
        for epoch in range(start_epoch, self.config.num_epochs):
            self.tacotron2.train()
            self.generator.train()
            self.msd.train()
            self.mpd.train()
            self.mrf.train()
            
            total_loss = 0
            progress_bar = tqdm(self.train_loader)
            
            for batch in progress_bar:
                losses = self.train_step(batch)
                total_loss += losses['total_loss'].item()
                
                progress_bar.set_description(
                    f"Epoch {epoch+1}, Loss: {losses['total_loss'].item():.4f}"
                )
            
            avg_loss = total_loss / len(self.train_loader)
            
            if (epoch + 1) % self.config.save_interval == 0:
                self.save_checkpoint(epoch + 1, avg_loss)
            
            print(f'Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}')

In [17]:
class KSSTTSDataset(Dataset):
    def __init__(self, split='train', valid_size=0.1, seed=42, target_sr=22050):
        super().__init__()
        
        # 기본 설정
        self.target_sr = target_sr
        self.vocab = None
        
        # 데이터셋 로드
        dataset = load_dataset("Bingsu/KSS_Dataset")
        full_dataset = dataset['train']
        
        # train/valid 분할
        train_idx, valid_idx = train_test_split(
            range(len(full_dataset)),
            test_size=valid_size,
            random_state=seed
        )
        
        self.indices = train_idx if split == 'train' else valid_idx
        self.dataset = full_dataset
        
        # vocab 초기화
        self._initialize_vocab()
        
    def _initialize_vocab(self):
        all_texts = [self.dataset[idx]['original_script'] for idx in self.indices]
        unique_tokens = set()
        
        for text in all_texts:
            sequence = self._text_to_sequence(text)
            unique_tokens.update(sequence)
            
        self.vocab = {token: idx for idx, token in enumerate(sorted(unique_tokens))}
        self.vocab['<pad>'] = len(self.vocab)
        self.vocab['<unk>'] = len(self.vocab)
    
    def _text_to_sequence(self, text):
        sequence = []
        for char in text:
            if '가' <= char <= '힣':
                jamos = list(hangul_to_jamo(char))
                sequence.extend(jamos)
            else:
                sequence.append(char)
        return sequence
    
    def _wav_to_mel(self, wav):
        stft = librosa.stft(wav, n_fft=1024, hop_length=256, win_length=1024)
        mel_spec = librosa.feature.melspectrogram(
            S=np.abs(stft)**2,
            sr=self.target_sr,
            n_mels=80,
            fmin=0,
            fmax=8000
        )
        mel_spec = np.log(np.clip(mel_spec, a_min=1e-5, a_max=None))
        mel_spec = np.clip(mel_spec, a_min=-4, a_max=4)
        mel_spec = (mel_spec + 4) / 8
        return mel_spec

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        item = self.dataset[real_idx]
        
        # 텍스트 처리
        text = item['original_script']
        sequence = self._text_to_sequence(text)
        sequence = [self.vocab.get(char, self.vocab['<unk>']) for char in sequence]
        sequence = torch.LongTensor(sequence)
        
        # 오디오 처리
        audio = torch.from_numpy(item['audio']['array']).float()
        original_sr = item['audio']['sampling_rate']
        
        if original_sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(
                orig_freq=original_sr,
                new_freq=self.target_sr
            )
            audio = resampler(audio)
        
        # 멜 스펙트로그램 변환
        mel = self._wav_to_mel(audio.numpy())
        mel = torch.FloatTensor(mel).transpose(0, 1)
        
        return {
            'text': sequence,
            'text_length': sequence.size(0),
            'mel': mel,
            'mel_length': mel.size(0),
            'audio': audio
        }

    @staticmethod
    def collate_fn(batch):
        # 텍스트 패딩
        text_lengths = [x['text'].size(0) for x in batch]
        max_text_len = max(text_lengths)
        text_padded = torch.zeros(len(batch), max_text_len, dtype=torch.long)
        for i, x in enumerate(batch):
            text = x['text']
            text_padded[i, :len(text)] = text
        
        # 멜 스펙트로그램 패딩
        mel_lengths = [x['mel'].size(0) for x in batch]
        max_mel_len = max(mel_lengths)
        mel_padded = torch.zeros(len(batch), 80, max_mel_len)
        for i, x in enumerate(batch):
            mel = x['mel']
            mel_padded[i, :, :mel.size(0)] = mel.transpose(0, 1)
        
        # 오디오 패딩
        audio_lengths = [x['audio'].size(0) for x in batch]
        max_audio_len = max(audio_lengths)
        audio_padded = torch.zeros(len(batch), 1, max_audio_len)
        for i, x in enumerate(batch):
            audio = x['audio']
            audio_padded[i, 0, :audio.size(0)] = audio
        
        # gate 패딩 생성
        gate_padded = torch.zeros(len(batch), max_mel_len)
        for i, length in enumerate(mel_lengths):
            gate_padded[i, length-1:] = 1
        
        # 길이 정보도 함께 반환
        text_lengths = torch.LongTensor(text_lengths)
        mel_lengths = torch.LongTensor(mel_lengths)
        
        return {
            'text_padded': text_padded,
            'mel_padded': mel_padded,
            'gate_padded': gate_padded,
            'audio_padded': audio_padded,
            'text_lengths': text_lengths,
            'mel_lengths': mel_lengths
        }

In [18]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, negative_slope=0.01):
        super().__init__()
        # 패딩 계산을 동적으로 수행
        self.padding = (dilation * (kernel_size - 1)) // 2
        
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=self.padding, dilation=dilation
        )
        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size,
            stride=stride, padding=self.padding, dilation=dilation
        )
        self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)
        self.skip_connection = (
            nn.Conv1d(in_channels, out_channels, 1)
            if in_channels != out_channels else
            nn.Identity()
        )

    def forward(self, x):
        residual = self.skip_connection(x)
        x = self.leaky_relu(self.conv1(x))
        x = self.conv2(x)
        
        # 차원이 다른 경우 residual을 조정
        if x.size(-1) != residual.size(-1):
            # 더 작은 크기에 맞춤
            target_size = min(x.size(-1), residual.size(-1))
            x = x[..., :target_size]
            residual = residual[..., :target_size]
            
        return self.leaky_relu(x + residual)

In [19]:
class Generator(nn.Module):
    def __init__(self, input_size=80):
        super().__init__()
        # Initial conv
        self.conv_pre = nn.Conv1d(input_size, 512, 7, 1, 3)
        
        # Upsampling layers with kernel size and stride adjustments
        self.ups = nn.ModuleList([
            nn.ConvTranspose1d(512, 256, 16, 8, 4),
            nn.ConvTranspose1d(256, 128, 16, 8, 4),
            nn.ConvTranspose1d(128, 64, 4, 2, 1),
            nn.ConvTranspose1d(64, 32, 4, 2, 1),
        ])
        
        # Multi-Receptive Field Fusion
        self.mrf_blocks = nn.ModuleList([
            # First MRF block
            nn.ModuleList([
                ResBlock(256, 256, kernel_size=3, dilation=1),
                ResBlock(256, 256, kernel_size=3, dilation=3),
                ResBlock(256, 256, kernel_size=3, dilation=5)
            ]),
            # Second MRF block
            nn.ModuleList([
                ResBlock(128, 128, kernel_size=3, dilation=1),
                ResBlock(128, 128, kernel_size=3, dilation=3),
                ResBlock(128, 128, kernel_size=3, dilation=5)
            ]),
            # Third MRF block
            nn.ModuleList([
                ResBlock(64, 64, kernel_size=3, dilation=1),
                ResBlock(64, 64, kernel_size=3, dilation=3)
            ]),
            # Fourth MRF block
            nn.ModuleList([
                ResBlock(32, 32, kernel_size=3, dilation=1),
                ResBlock(32, 32, kernel_size=3, dilation=2)
            ])
        ])
        
        # Final conv
        self.conv_post = nn.Conv1d(32, 1, 7, 1, 3)
        self.leaky_relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        print(f"Generator input shape: {x.shape}")
        x = self.conv_pre(x)
        print(f"After conv_pre: {x.shape}")
        
        for i in range(len(self.ups)):
            x = self.leaky_relu(x)
            x = self.ups[i](x)
            print(f"After up {i}: {x.shape}")
            
            # Apply MRF blocks
            xs = None
            for j, resblock in enumerate(self.mrf_blocks[i]):
                if xs is None:
                    xs = resblock(x)
                else:
                    # 크기가 다른 경우 처리
                    res_out = resblock(x)
                    if res_out.size(-1) != xs.size(-1):
                        target_size = min(res_out.size(-1), xs.size(-1))
                        xs = xs[..., :target_size]
                        res_out = res_out[..., :target_size]
                    xs += res_out
            x = xs / len(self.mrf_blocks[i])
        
        print(f"Generator output shape: {x.shape}")
        x = self.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)
        
        return x

In [20]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorS(use_spectral_norm=True),
            DiscriminatorS(),
            DiscriminatorS(),
        ])
        self.meanpools = nn.ModuleList([
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(4, 2, padding=2)
        ])

    def forward(self, x):
        """
        x: 입력 오디오 (fake 또는 real)
        """
        y_d_rs = []  # discriminator outputs
        fmap_rs = []  # feature maps

        for i, d in enumerate(self.discriminators):
            if i != 0:
                x = self.meanpools[i-1](x)
            y_d_r, fmap_r = d(x)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)

        return y_d_rs, fmap_rs

In [21]:
class DiscriminatorS(nn.Module):
    def __init__(self, use_spectral_norm=False):
        super().__init__()
        norm_f = weight_norm if not use_spectral_norm else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
            norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
            norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
            norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
            norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
            norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
            norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
        ])
        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))

    def forward(self, x):
        fmap = []
        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, 0.1)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)
        return x, fmap

In [22]:
class MultiPeriodDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorP(2),
            DiscriminatorP(3),
            DiscriminatorP(5),
            DiscriminatorP(7),
            DiscriminatorP(11),
        ])

    def forward(self, x):
        """
        x: 입력 오디오 (fake 또는 real)
        """
        y_d_rs = []  # discriminator outputs
        fmap_rs = []  # feature maps

        for d in self.discriminators:
            y_d_r, fmap_r = d(x)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)

        return y_d_rs, fmap_rs

In [23]:
class DiscriminatorP(nn.Module):
    def __init__(self, period):
        super().__init__()
        self.period = period
        
        self.convs = nn.ModuleList([
            weight_norm(nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0))),
            weight_norm(nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0))),
            weight_norm(nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))),
            weight_norm(nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0))),
            weight_norm(nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))),
        ])
        self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
        self.leaky_relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        """
        x: [B, 1, T]
        """
        fmap = []
        
        # 1D -> 2D
        b, c, t = x.shape
        if t % self.period != 0:  # 패딩 추가
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for layer in self.convs:
            x = layer(x)
            x = self.leaky_relu(x)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap

In [24]:
class MRFDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 초기 채널 조정을 위한 Conv1d 레이어들
        self.msd_adjust = nn.Conv1d(1, 64, kernel_size=3, padding=1)
        
        # MPD feature를 위한 2D -> 1D 변환 레이어
        self.mpd_adjust = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=(1, 1)),
            nn.LeakyReLU(0.1),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.LeakyReLU(0.1)
        )
        
        # 공통 처리를 위한 레이어들
        self.shared_conv = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 512, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1)
        )
        
        # 최종 출력을 위한 레이어
        self.output_layer = nn.Conv1d(512, 1, kernel_size=3, padding=1)

    def forward(self, msd_features, mpd_features):
        """
        Args:
            msd_features: List[List[Tensor]] - MultiScaleDiscriminator의 feature maps
            mpd_features: List[List[Tensor]] - MultiPeriodDiscriminator의 feature maps
        """
        # 마지막 레이어의 feature map 사용
        msd_feat = msd_features[-1][-1]  # [B, 1, T]
        mpd_feat = mpd_features[-1][-1]  # [B, 1, H, W]
        
        # MPD feature 처리
        if mpd_feat.dim() == 4:
            B, C, H, W = mpd_feat.size()
            # 2D 처리
            mpd_processed = self.mpd_adjust(mpd_feat)  # [B, 64, H, W]
            # Flatten H, W 차원
            mpd_processed = mpd_processed.view(B, 64, -1)  # [B, 64, H*W]
        
        # MSD feature 처리
        msd_processed = self.msd_adjust(msd_feat)  # [B, 64, T]
        
        # 시간 차원 맞추기
        target_length = min(msd_processed.size(-1), mpd_processed.size(-1))
        msd_processed = F.interpolate(msd_processed, size=target_length, mode='linear', align_corners=False)
        mpd_processed = F.interpolate(mpd_processed, size=target_length, mode='linear', align_corners=False)
        
        # Feature map 결합
        combined = torch.cat([msd_processed, mpd_processed], dim=1)  # [B, 128, T]
        
        # 공통 처리
        x = self.shared_conv(combined)
        
        # 최종 출력
        output = self.output_layer(x)
        
        return output

In [25]:
class TTSLoss(nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()
        self.device = device
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        
        # Loss weights from the paper
        self.lambda_mel = 45.0
        self.lambda_fm = 2.0
        self.lambda_adv = 1.0
        
        # Mel-spectrogram transform configuration
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=22050,
            n_fft=1024,
            win_length=1024,
            hop_length=256,
            f_min=0,
            f_max=8000,
            n_mels=80,
            power=1.0,
            normalized=True
        ).to(device)

    def tacotron2_loss(self, mel_output, mel_output_postnet, gate_out, 
                      mel_target, gate_target, mel_lengths):
        """Tacotron2 Loss Calculation"""
        # mel_target: [B, n_mel_channels, T]
        # mel_output: [B, T, n_mel_channels]
        
        # 차원 맞추기
        mel_target = mel_target.transpose(1, 2)  # [B, T, n_mel_channels]
        
        # gate_out 차원 조정 [B, T, 1] -> [B, T]
        gate_out = gate_out.squeeze(-1)
        
        # 텐서 차원 확인
        B, T, C = mel_target.size()  # [batch_size, time_steps, n_mel_channels]
        
        # 마스크 생성 (B, T)
        mask = ~self.get_mask_from_lengths(mel_lengths)
        
        # 마스크를 멜 스펙트로그램 차원에 맞게 조정
        mel_mask = mask.unsqueeze(-1).expand(-1, -1, C)  # [B, T, C]
        
        # 마스킹 적용
        mel_target_masked = mel_target.masked_fill(mel_mask, 0)
        mel_output_masked = mel_output.masked_fill(mel_mask, 0)
        mel_output_postnet_masked = mel_output_postnet.masked_fill(mel_mask, 0)
        
        # gate에 대한 마스킹 적용
        gate_target = gate_target.masked_fill(mask, 0)
        gate_out = gate_out.masked_fill(mask, 0)
        
        # Loss 계산
        mel_loss = self.l1_loss(mel_output_masked, mel_target_masked) + \
                  self.l1_loss(mel_output_postnet_masked, mel_target_masked)
        gate_loss = self.bce_loss(gate_out, gate_target)
        
        return {
            'mel_loss': mel_loss,
            'gate_loss': gate_loss,
            'total_loss': mel_loss + gate_loss
        }

    def mel_spectrogram_loss(self, real_wave, fake_wave):
        """Mel-spectrogram L1 Loss"""
        # 더 작은 길이에 맞추기
        min_length = min(real_wave.size(-1), fake_wave.size(-1))
        real_wave = real_wave[..., :min_length]
        fake_wave = fake_wave[..., :min_length]
        
        # Mel spectrogram 생성
        real_mel = self.mel_transform(real_wave)
        fake_mel = self.mel_transform(fake_wave)
        
        # 디버깅을 위한 shape 출력
        print(f"Real wave shape: {real_wave.shape}")
        print(f"Fake wave shape: {fake_wave.shape}")
        print(f"Real mel shape: {real_mel.shape}")
        print(f"Fake mel shape: {fake_mel.shape}")
        
        return self.l1_loss(fake_mel, real_mel) * self.lambda_mel

    def feature_matching_loss(self, fmap_r, fmap_g):
        loss = 0
        for dr, dg in zip(fmap_r, fmap_g):
            for rl, gl in zip(dr, dg):
                if isinstance(rl, torch.Tensor) and isinstance(gl, torch.Tensor):
                    # 더 작은 크기에 맞추기
                    if rl.dim() == gl.dim():
                        min_length = min(rl.size(-1), gl.size(-1))
                        if rl.dim() == 3:  # MSD features
                            rl = rl[..., :min_length]
                            gl = gl[..., :min_length]
                        elif rl.dim() == 4:  # MPD features
                            rl = rl[..., :min_length, :]
                            gl = gl[..., :min_length, :]
                        
                        loss += torch.mean(torch.abs(rl - gl))
        
        return loss * self.lambda_fm

    def generator_loss(self, disc_outputs):
        loss = 0
        for dg in disc_outputs:
            loss += torch.mean((1-dg)**2)
        return loss

    def discriminator_loss(self, disc_real_outputs, disc_generated_outputs):
        loss = 0
        for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
            r_loss = torch.mean((1-dr)**2)
            g_loss = torch.mean(dg**2)
            loss += (r_loss + g_loss)
        return loss

    def hifi_gan_loss(self, real_wave, fake_wave, real_outputs, fake_outputs, real_feats, fake_feats):
        """HiFi-GAN Total Loss Calculation"""
        # 차원 확인 및 조정
        if real_wave.dim() == 2:
            real_wave = real_wave.unsqueeze(1)
        if fake_wave.dim() == 2:
            fake_wave = fake_wave.unsqueeze(1)
        
        # 길이 맞추기
        min_length = min(real_wave.size(-1), fake_wave.size(-1))
        real_wave = real_wave[..., :min_length]
        fake_wave = fake_wave[..., :min_length]
        
        # 디버깅을 위한 shape 출력
        print(f"\nFeature matching shapes:")
        print(f"Real features length: {[[(f.shape) for f in disc] for disc in real_feats]}")
        print(f"Fake features length: {[[(f.shape) for f in disc] for disc in fake_feats]}")
        
        # Loss 계산
        mel_loss = self.mel_spectrogram_loss(real_wave, fake_wave)
        fm_loss = self.feature_matching_loss(real_feats, fake_feats)
        gen_loss = self.generator_loss(fake_outputs)
        disc_loss = self.discriminator_loss(real_outputs, fake_outputs)
        
        # Total losses
        g_loss = mel_loss + fm_loss + gen_loss
        d_loss = disc_loss
        
        return {
            'generator_loss': g_loss,
            'discriminator_loss': d_loss,
            'mel_loss': mel_loss,
            'feature_matching_loss': fm_loss,
            'adversarial_loss_g': gen_loss,
            'adversarial_loss_d': disc_loss
        }

    @staticmethod
    def get_mask_from_lengths(lengths):
        max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len, device=lengths.device)
        mask = (ids < lengths.unsqueeze(1)).bool()
        return mask

In [26]:
class LocationLayer(nn.Module):
    def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
        super().__init__()
        padding = int((attention_kernel_size - 1) / 2)
        self.location_conv = nn.Conv1d(2, attention_n_filters, 
                                     kernel_size=attention_kernel_size,
                                     padding=padding, bias=False)
        self.location_dense = nn.Linear(attention_n_filters, attention_dim, bias=False)

    def forward(self, attention_weights_cat):
        processed_attention = self.location_conv(attention_weights_cat)
        processed_attention = processed_attention.transpose(1, 2)
        processed_attention = self.location_dense(processed_attention)
        return processed_attention

In [27]:
class Attention(nn.Module):
    def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
                 attention_location_n_filters, attention_location_kernel_size):
        super().__init__()
        self.query_layer = nn.Linear(attention_rnn_dim, attention_dim, bias=False)
        self.memory_layer = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.v = nn.Linear(attention_dim, 1, bias=False)
        self.location_layer = LocationLayer(attention_location_n_filters,
                                          attention_location_kernel_size,
                                          attention_dim)
        self.score_mask_value = -float("inf")

    def get_alignment_energies(self, query, processed_memory,
                             attention_weights_cat):
        processed_query = self.query_layer(query.unsqueeze(1))
        processed_attention_weights = self.location_layer(attention_weights_cat)
        energies = self.v(torch.tanh(
            processed_query + processed_memory + processed_attention_weights))
        return energies.squeeze(-1)

    def forward(self, attention_hidden_state, memory, processed_memory,
                attention_weights_cat, mask):
        alignment = self.get_alignment_energies(
            attention_hidden_state, processed_memory, attention_weights_cat)

        if mask is not None:
            alignment.data.masked_fill_(mask, self.score_mask_value)

        attention_weights = F.softmax(alignment, dim=1)
        attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
        attention_context = attention_context.squeeze(1)

        return attention_context, attention_weights


In [28]:
class Prenet(nn.Module):
    def __init__(self, in_dim, sizes):
        super().__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [nn.Linear(in_size, out_size)
             for (in_size, out_size) in zip(in_sizes, sizes)])

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
        return x

In [29]:
class Postnet(nn.Module):
    def __init__(self, n_mel_channels, postnet_embedding_dim,
                 postnet_kernel_size, postnet_n_convolutions):
        super().__init__()
        
        self.convolutions = nn.ModuleList()
        
        # 첫 번째 컨볼루션 레이어
        self.convolutions.append(
            nn.Sequential(
                nn.Conv1d(n_mel_channels, postnet_embedding_dim,
                         kernel_size=postnet_kernel_size, stride=1,
                         padding=(postnet_kernel_size - 1) // 2),
                nn.BatchNorm1d(postnet_embedding_dim),
                nn.Tanh(),
                nn.Dropout(0.5)
            )
        )

        # 중간 컨볼루션 레이어들
        for _ in range(1, postnet_n_convolutions - 1):
            self.convolutions.append(
                nn.Sequential(
                    nn.Conv1d(postnet_embedding_dim, postnet_embedding_dim,
                             kernel_size=postnet_kernel_size, stride=1,
                             padding=(postnet_kernel_size - 1) // 2),
                    nn.BatchNorm1d(postnet_embedding_dim),
                    nn.Tanh(),
                    nn.Dropout(0.5)
                )
            )

        # 마지막 컨볼루션 레이어
        self.convolutions.append(
            nn.Sequential(
                nn.Conv1d(postnet_embedding_dim, n_mel_channels,
                         kernel_size=postnet_kernel_size, stride=1,
                         padding=(postnet_kernel_size - 1) // 2),
                nn.BatchNorm1d(n_mel_channels),
                nn.Dropout(0.5)
            )
        )

    def forward(self, x):
        """
        Args:
            x: 멜 스펙트로그램 [batch_size, n_mel_channels, time]
        Returns:
            수정된 멜 스펙트로그램
        """
        for i in range(len(self.convolutions) - 1):
            x = self.convolutions[i](x)
        x = self.convolutions[-1](x)
        return x

In [30]:
class Encoder(nn.Module):
    def __init__(self, encoder_embedding_dim, encoder_n_convolutions,
                 encoder_kernel_size, encoder_lstm_dim):
        super().__init__()
        
        convolutions = []
        for _ in range(encoder_n_convolutions):
            conv_layer = nn.Sequential(
                nn.Conv1d(encoder_embedding_dim, encoder_embedding_dim,
                         encoder_kernel_size, stride=1,
                         padding=int((encoder_kernel_size - 1) / 2)),
                nn.BatchNorm1d(encoder_embedding_dim),
                nn.ReLU(),
                nn.Dropout(0.5)
            )
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)
        
        self.lstm = nn.LSTM(encoder_embedding_dim, encoder_lstm_dim,
                           num_layers=1, batch_first=True, bidirectional=True)

    def forward(self, x, input_lengths):
        """
        x: [B, embed_dim, T]
        input_lengths: [B]
        """
        # Conv layers
        for conv in self.convolutions:
            x = conv(x)
        
        # Prepare for LSTM
        x = x.transpose(1, 2)  # [B, T, embed_dim]
        
        # Pack sequence
        input_lengths = input_lengths.cpu()  # lengths를 CPU로 이동
        
        # Sort by length for packing
        input_lengths, sort_idx = torch.sort(input_lengths, descending=True)
        x = x[sort_idx]
        
        # Pack the sequence
        x_packed = nn.utils.rnn.pack_padded_sequence(
            x, input_lengths.cpu(), batch_first=True)
        
        # LSTM forward
        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x_packed)
        
        # Unpack the sequence
        outputs, _ = nn.utils.rnn.pad_packed_sequence(
            outputs, batch_first=True)
        
        # Restore original order
        _, unsort_idx = torch.sort(sort_idx)
        outputs = outputs[unsort_idx]
        
        return outputs

    def inference(self, x):
        for conv in self.convolutions:
            x = conv(x)

        x = x.transpose(1, 2)

        self.lstm.flatten_parameters()
        outputs, _ = self.lstm(x)

        return outputs

In [31]:
class Decoder(nn.Module):
    def __init__(self, n_mel_channels, encoder_embedding_dim,
                 attention_dim, attention_location_n_filters,
                 attention_location_kernel_size, attention_rnn_dim,
                 decoder_rnn_dim, prenet_dim, max_decoder_steps,
                 gate_threshold, p_attention_dropout, p_decoder_dropout):
        super().__init__()
        
        self.n_mel_channels = n_mel_channels
        self.encoder_embedding_dim = encoder_embedding_dim
        self.attention_rnn_dim = attention_rnn_dim
        self.decoder_rnn_dim = decoder_rnn_dim
        self.prenet_dim = prenet_dim
        self.max_decoder_steps = max_decoder_steps
        self.gate_threshold = gate_threshold
        self.p_attention_dropout = p_attention_dropout
        self.p_decoder_dropout = p_decoder_dropout

        # Prenet
        self.prenet = Prenet(n_mel_channels, [prenet_dim, prenet_dim])

        # Attention RNN
        self.attention_rnn = nn.LSTMCell(
            prenet_dim + encoder_embedding_dim,
            attention_rnn_dim)

        # Attention Layer
        self.attention_layer = Attention(
            attention_rnn_dim, encoder_embedding_dim,
            attention_dim, attention_location_n_filters,
            attention_location_kernel_size)

        # Decoder RNN
        self.decoder_rnn = nn.LSTMCell(
            attention_rnn_dim + encoder_embedding_dim,
            decoder_rnn_dim)

        # Linear Projection
        self.linear_projection = nn.Linear(
            decoder_rnn_dim + encoder_embedding_dim,
            n_mel_channels)  # 출력을 n_mel_channels로 수정

        # Gate Layer
        self.gate_layer = nn.Linear(
            decoder_rnn_dim + encoder_embedding_dim, 1,
            bias=True)

        # Attention 관련 레이어 추가
        self.attention_layer = Attention(
            attention_rnn_dim,
            encoder_embedding_dim,
            attention_dim,
            attention_location_n_filters,
            attention_location_kernel_size
        )
        
        # Memory layer 추가
        self.memory_layer = nn.Linear(
            encoder_embedding_dim,
            attention_dim,
            bias=False
        )
        
        # Attention context projection
        self.attention_projection = nn.Linear(
            encoder_embedding_dim,
            decoder_rnn_dim,
            bias=False
        )

    def parse_decoder_inputs(self, decoder_inputs):
        """ Prepares decoder inputs, i.e. mel outputs
        Args:
            decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
        """
        # (B, n_mel_channels, T) -> (B, T, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(1, 2)
        decoder_inputs = decoder_inputs.contiguous()
        
        # (B, T, n_mel_channels) -> (T, B, n_mel_channels)
        decoder_inputs = decoder_inputs.transpose(0, 1)
        return decoder_inputs

    def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
        """ Prepares decoder outputs for output
        Args:
            mel_outputs: mel outputs from the decoder
            gate_outputs: gate outputs from the decoder
            alignments: alignments from the decoder
        """
        # (T, B, n_mel_channels) -> (B, T, n_mel_channels)
        mel_outputs = mel_outputs.transpose(0, 1).contiguous()
        
        # (T, B) -> (B, T)
        gate_outputs = gate_outputs.transpose(0, 1).contiguous()
        
        # (T, B, n_text) -> (B, T, n_text)
        alignments = alignments.transpose(0, 1).contiguous()

        return mel_outputs, gate_outputs, alignments

    def forward(self, encoder_outputs, decoder_inputs, memory_lengths=None):
        """
        Args:
            encoder_outputs: [batch_size, max_time, encoder_embedding_dim]
            decoder_inputs: [batch_size, n_mel_channels, max_time]
            memory_lengths: [batch_size]
        """
        # decoder_inputs 형태 변환
        decoder_inputs = decoder_inputs.transpose(1, 2)  # [batch_size, max_time, n_mel_channels]
        
        # 초기 상태 초기화
        batch_size = encoder_outputs.size(0)
        max_time = decoder_inputs.size(1)
        
        # 초기 attention context
        attention_context = torch.zeros(
            batch_size,
            self.encoder_embedding_dim
        ).to(encoder_outputs.device)
        
        # 초기 attention hidden states
        attention_hidden = torch.zeros(
            batch_size,
            self.attention_rnn_dim
        ).to(encoder_outputs.device)
        
        attention_cell = torch.zeros(
            batch_size,
            self.attention_rnn_dim
        ).to(encoder_outputs.device)
        
        # 초기 decoder states
        decoder_hidden = torch.zeros(
            batch_size,
            self.decoder_rnn_dim
        ).to(encoder_outputs.device)
        
        decoder_cell = torch.zeros(
            batch_size,
            self.decoder_rnn_dim
        ).to(encoder_outputs.device)
        
        # 초기 attention weights
        attention_weights = torch.zeros(
            batch_size,
            encoder_outputs.size(1)
        ).to(encoder_outputs.device)
        
        # 출력을 저장할 리스트
        mel_outputs, gate_outputs, alignments = [], [], []
        
        # Memory를 미리 처리
        processed_memory = self.memory_layer(encoder_outputs)
        
        # 각 타임스텝에 대해 처리
        for i in range(max_time):
            current_input = decoder_inputs[:, i, :]  # [batch_size, n_mel_channels]
            current_input = self.prenet(current_input)  # [batch_size, prenet_dim]
            
            # Attention RNN
            cell_input = torch.cat((current_input, attention_context), -1)
            attention_hidden, attention_cell = self.attention_rnn(
                cell_input, (attention_hidden, attention_cell))
            attention_hidden = F.dropout(
                attention_hidden, self.p_attention_dropout, self.training)
            
            # Attention 계산
            attention_weights_cat = torch.cat(
                (attention_weights.unsqueeze(1),
                 attention_weights.unsqueeze(1)),
                dim=1)
            attention_context, attention_weights = self.attention_layer(
                attention_hidden, encoder_outputs,
                processed_memory, attention_weights_cat,
                mask=None if memory_lengths is None else ~get_mask_from_lengths(memory_lengths))
            
            # Decoder RNN
            decoder_input = torch.cat((attention_hidden, attention_context), -1)
            decoder_hidden, decoder_cell = self.decoder_rnn(
                decoder_input, (decoder_hidden, decoder_cell))
            decoder_hidden = F.dropout(
                decoder_hidden, self.p_decoder_dropout, self.training)
            
            # Linear projection
            decoder_hidden_attention = torch.cat(
                (decoder_hidden, attention_context), dim=1)
            decoder_output = self.linear_projection(decoder_hidden_attention)
            gate_prediction = self.gate_layer(decoder_hidden_attention)
            
            # 결과 저장
            mel_outputs.append(decoder_output)
            gate_outputs.append(gate_prediction)
            alignments.append(attention_weights)
        
        # 리스트를 텐서로 변환
        mel_outputs = torch.stack(mel_outputs)
        gate_outputs = torch.stack(gate_outputs)
        alignments = torch.stack(alignments)
        
        # 출력 형식 변환
        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)
        
        return mel_outputs, gate_outputs, alignments


In [32]:
class Tacotron2(nn.Module):
    attention_location_n_filters = 32
    attention_location_kernel_size = 31
    decoder_rnn_dim = 512
    prenet_dim = 256
    max_decoder_steps = 1000
    gate_threshold = 0.5
    p_attention_dropout = 0.1
    p_decoder_dropout = 0.1
    postnet_embedding_dim = 256
    postnet_kernel_size = 5

    def __init__(self, n_mel_channels, vocab_size, embedding_dim, 
                 encoder_n_convolutions, encoder_kernel_size,
                 attention_rnn_dim, attention_dim):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.encoder = Encoder(
            encoder_embedding_dim=embedding_dim,  # 256
            encoder_n_convolutions=encoder_n_convolutions,
            encoder_kernel_size=encoder_kernel_size,
            encoder_lstm_dim=int(embedding_dim/2)  # 128 (bi-directional = 256)
        )
        
        self.decoder = Decoder(
            n_mel_channels=n_mel_channels,
            encoder_embedding_dim=embedding_dim,  # 256
            attention_dim=attention_dim,  # 128
            attention_location_n_filters=self.attention_location_n_filters,
            attention_location_kernel_size=self.attention_location_kernel_size,
            attention_rnn_dim=attention_rnn_dim,  # 512
            decoder_rnn_dim=self.decoder_rnn_dim,  # 512
            prenet_dim=self.prenet_dim,  # 128
            max_decoder_steps=self.max_decoder_steps,
            gate_threshold=self.gate_threshold,
            p_attention_dropout=self.p_attention_dropout,
            p_decoder_dropout=self.p_decoder_dropout
        )
        
        self.postnet = Postnet(
            n_mel_channels=n_mel_channels,
            postnet_embedding_dim=self.postnet_embedding_dim,
            postnet_kernel_size=self.postnet_kernel_size,
            postnet_n_convolutions=5
        )

    def forward(self, text_inputs, text_lengths, mel_inputs, mel_lengths):
        """
        텍스트를 멜 스펙트로그램으로 변환
        """
        embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
        
        encoder_outputs = self.encoder(embedded_inputs, text_lengths)
        
        mel_outputs, gate_outputs, alignments = self.decoder(
            encoder_outputs, mel_inputs, text_lengths)
        
        # mel_outputs 차원 변환 (B, T, n_mel_channels) -> (B, n_mel_channels, T)
        mel_outputs_transpose = mel_outputs.transpose(1, 2)
        mel_outputs_postnet = self.postnet(mel_outputs_transpose)
        mel_outputs_postnet = mel_outputs_transpose + mel_outputs_postnet
        
        # 결과를 다시 원래 형태로 변환 (B, n_mel_channels, T) -> (B, T, n_mel_channels)
        mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
        
        return mel_outputs_postnet, mel_outputs, gate_outputs, alignments

    def inference(self, text_inputs):
        """
        추론 시 사용되는 forward 메서드
        """
        embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
        encoder_outputs = self.encoder.inference(embedded_inputs)
        mel_outputs, gate_outputs, alignments = self.decoder.inference(encoder_outputs)
        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
        
        return mel_outputs_postnet, mel_outputs, gate_outputs, alignments


In [33]:
def get_mask_from_lengths(lengths):
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, device=lengths.device)
    mask = (ids < lengths.unsqueeze(1)).bool()
    return mask 

In [34]:
from argparse import Namespace
    
config = Namespace(
    device_num=0,
    batch_size=4,
    learning_rate=0.0002,
    num_epochs=100,
    start_epoch=0,
    save_interval=10,
    checkpoint_dir='checkpoints',
    resume_checkpoint=None
)

In [35]:
trainer = Trainer(config)
trainer.train()     

README.md:   0%|          | 0.00/4.45k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


dataset_infos.json:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

train-00000-of-00007.parquet:   0%|          | 0.00/556M [00:00<?, ?B/s]

train-00001-of-00007.parquet:   0%|          | 0.00/600M [00:00<?, ?B/s]

train-00002-of-00007.parquet:   0%|          | 0.00/586M [00:00<?, ?B/s]

train-00003-of-00007.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

train-00004-of-00007.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

train-00005-of-00007.parquet:   0%|          | 0.00/530M [00:00<?, ?B/s]

train-00006-of-00007.parquet:   0%|          | 0.00/544M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12854 [00:00<?, ? examples/s]

DatasetGenerationError: An error occurred while generating the dataset