# Face DCGAN 튜토리얼


In [70]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchaudio

In [71]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, negative_slope=0.01):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.skip_connection = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        residual = self.skip_connection(x)
        x = self.leaky_relu(self.conv1(x))
        x = self.conv2(x)
        return self.leaky_relu(x + residual)

In [72]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv1d(80, 512, kernel_size=7, stride=1, padding=3)
        self.relu = nn.ReLU()

        # 첫 번째 업샘플링 레이어 (stride=4, kernel_size=16)
        self.deconv1 = nn.ConvTranspose1d(512, 512, kernel_size=10, stride=2, padding=4)

        # 두 번째 업샘플링 레이어 (stride=4, kernel_size=16)
        self.deconv2 = nn.ConvTranspose1d(512, 1, kernel_size=10, stride=2, padding=4)

        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.deconv1(x))
        x = self.deconv2(x)
        return self.tanh(x)

In [73]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv1d(4, 16, kernel_size=15, stride=1, padding=7),  # 첫 번째 레이어의 입력 채널을 4로 설정
            nn.Conv1d(16, 32, kernel_size=15, stride=2, padding=7),
            nn.Conv1d(32, 64, kernel_size=15, stride=2, padding=7),
            nn.Conv1d(64, 128, kernel_size=15, stride=2, padding=7),
        ])
        self.final_conv = nn.Conv1d(128, 1, 3, 1, 1)

        # real_audio 텐서를 4개의 채널로 변환하는 Conv1d 추가
        self.channel_conv = nn.Conv1d(4, 4, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # real_audio의 채널을 4로 변환
        x = self.channel_conv(x)
        features = []
        for layer in self.layers:
            x = layer(x)
            features.append(x)
        x = self.final_conv(x)
        return x, features

In [74]:
class MultiPeriodDiscriminator(nn.Module):
    def __init__(self):
        super(MultiPeriodDiscriminator, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv1d(4, 16, kernel_size=15, stride=1, padding=7),  # 첫 번째 레이어
            nn.Conv1d(16, 32, kernel_size=16, stride=2, padding=8),  # 주기적 특징을 잡기 위해 kernel_size, stride 조정
            nn.Conv1d(32, 64, kernel_size=32, stride=4, padding=16),  # 주기적인 정보 캡처
            nn.Conv1d(64, 128, kernel_size=64, stride=8, padding=32),  # 더 큰 주기 정보 캡처
        ])
        self.final_conv = nn.Conv1d(128, 1, 3, 1, 1)
        
        self.channel_conv = nn.Conv1d(4, 4, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.channel_conv(x)
        features = []
        for layer in self.layers:
            x = layer(x)
            features.append(x)
        x = self.final_conv(x)
        return x, features

In [75]:
def l1_loss(x, y):
    return torch.mean(torch.abs(x - y))

In [76]:
class HiFiGANLoss(nn.Module):
    def __init__(self, device='cuda'):
        super(HiFiGANLoss, self).__init__()
        self.device = device
        self.l1_loss = nn.L1Loss()
        
        # Mel-spectrogram 계산을 위한 설정
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=22050, n_fft=1024, hop_length=256, n_mels=80, normalized=True
        ).to(device)

    def mel_spectrogram_loss(self, real_wave, fake_wave):
        """ Mel-spectrogram 손실 (L1 Loss 또는 L2 Loss) """
        # Mel-spectrogram 계산
        real_mel = self.mel_transform(real_wave)
        fake_mel = self.mel_transform(fake_wave)
        print(real_mel.size())
        
        # Mel-spectrogram 간의 L1 손실 계산
        mel_loss = self.l1_loss(fake_mel, real_mel)
        return mel_loss

    def feature_loss(self, real_features, fake_features):
        """ 특징 손실 (L1 Loss) """
        return self.l1_loss(fake_features, real_features)

    def generator_loss(self, fake_output):
        """ 제너레이터의 손실 함수 (LSGAN 손실) """
        return torch.mean((fake_output - 1) ** 2)

    def discriminator_loss(self, real_output, fake_output):
        """ 디스크리미네이터의 손실 함수 (LSGAN 손실) """
        real_loss = torch.mean((real_output - 1) ** 2)
        fake_loss = torch.mean(fake_output ** 2)
        return (real_loss + fake_loss) * 0.5

    def forward(self, real_wave, fake_wave, real_features, fake_features):
        """ 전체 손실 계산 """
        # 1. Generator Loss (fake_output은 디스크리미네이터의 출력)
        fake_output = fake_wave.detach()  # Generator가 생성한 오디오
        generator_loss = self.generator_loss(fake_output)

        # 2. Discriminator Loss
        real_output = real_wave  # 실제 오디오
        discriminator_loss = self.discriminator_loss(real_output, fake_output)

        # 3. Mel-spectrogram Loss
        mel_loss = self.mel_spectrogram_loss(real_wave, fake_wave)

        # 4. Feature Loss
        feature_loss = self.feature_loss(real_features, fake_features)

        # 종합적인 손실 (가중치를 조정할 수 있음)
        total_loss = generator_loss + discriminator_loss + mel_loss + feature_loss
        return total_loss, generator_loss, discriminator_loss, mel_loss, feature_loss


In [77]:

class MRFDiscriminator(nn.Module):
    def __init__(self):
        super(MRFDiscriminator, self).__init__()
        # 채널 수를 맞추기 위한 Conv1d 레이어 추가
        self.channel_adjustment = nn.Conv1d(128, 512, kernel_size=3, padding=1)
        
        # ResBlock 추가
        self.resblock1 = ResBlock(512, 256)
        self.resblock2 = ResBlock(256, 128)
        
        # Fusion layer
        self.fusion_layer = nn.Conv1d(128, 64, kernel_size=3, padding=1)
        self.output_layer = nn.Conv1d(64, 1, kernel_size=3, padding=1)

    def forward(self, msd_features, mpd_features):
        # MSD와 MPD에서 나온 특징을 결합
        fused_features = torch.cat([msd_features[-1], mpd_features[-1]], dim=1)  # 채널 수가 맞는지 확인
        
        # 채널 수를 맞추기 위해 adjustment
        fused_features = self.channel_adjustment(fused_features)
        
        # ResBlock을 통한 특성 추출
        x = self.resblock1(fused_features)
        x = self.resblock2(x)
        
        # Fusion layer
        x = self.fusion_layer(x)
        x = self.output_layer(x)
        return x


In [78]:
class RandomMelDataset(Dataset):
    def __init__(self, num_samples, mel_spec_len=80, audio_len=800):
        self.num_samples = num_samples
        self.mel_spec_len = mel_spec_len
        self.audio_len = audio_len

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 랜덤 Mel-spectrogram과 오디오 (여기선 예시로 랜덤 데이터 사용)
        mel_spec = torch.randn(self.mel_spec_len, 200)  # Mel spectrogram (4, 80, 200)
        audio = torch.randn(self.audio_len)  # 오디오 시퀀스 (1, 800)
        return mel_spec, audio

In [79]:
# 데이터셋과 DataLoader
dataset = RandomMelDataset(num_samples=1000)  # 1000개의 샘플 사용
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)  # 배치 크기 4로 설정

In [80]:
generator = Generator().cuda()
msd = MultiScaleDiscriminator().cuda()
mpd = MultiPeriodDiscriminator().cuda()
mrf_discriminator = MRFDiscriminator().cuda()  # MRF discriminator 추가

optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.9))
optimizer_d = optim.Adam(list(msd.parameters()) + list(mpd.parameters()) + list(mrf_discriminator.parameters()), lr=0.0001, betas=(0.5, 0.9))

In [81]:
num_epochs = 100

In [82]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [83]:
generator = Generator().to(device)
msd = MultiScaleDiscriminator().to(device)
mpd = MultiPeriodDiscriminator().to(device)

opt_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.8, 0.99))
opt_d = optim.Adam(list(msd.parameters()) + list(mpd.parameters()), lr=0.0002, betas=(0.8, 0.99))

mel_spec = torch.randn(4, 80, 1000).to(device)
real_audio = torch.randn(4, 1, 16000).to(device)
hifi_gan_loss = HiFiGANLoss(device=device)

In [84]:
fake_audio = generator(mel_spec)
print(fake_audio.size())
print(real_audio.size())

torch.Size([4, 1, 4000])
torch.Size([4, 1, 16000])


In [85]:
for epoch in range(num_epochs):
    generator.train()
    msd.train()
    mpd.train()
    mrf_discriminator.train()

    total_g_loss = 0
    total_d_loss = 0
    total_mel_loss = 0
    total_feature_loss = 0

    for batch_idx, (mel_spec, real_audio) in enumerate(train_loader):
        mel_spec, real_audio = mel_spec.cuda(), real_audio.cuda()

        fake_audio = generator(mel_spec)
        fake_audio = torch.squeeze(fake_audio,dim = 1)

        real_scores_msd, real_features_msd = msd(real_audio)
        fake_scores_msd, fake_features_msd = msd(fake_audio.detach())
        print(len(real_features_msd))

        real_scores_mpd, real_features_mpd = mpd(real_audio)
        fake_scores_mpd, fake_features_mpd = mpd(fake_audio.detach())

        # MRF 결합된 판별자 사용
        fake_mrf_scores = mrf_discriminator(fake_features_msd, fake_features_mpd)
        real_mrf_scores = mrf_discriminator(real_features_msd, real_features_mpd)

        # Generator Loss
        g_loss_msd = l1_loss(fake_scores_msd, real_scores_msd)
        g_loss_mpd = l1_loss(fake_scores_mpd, real_scores_mpd)
        g_loss_mrf = l1_loss(fake_mrf_scores, torch.ones_like(fake_mrf_scores))
        generator_loss = g_loss_msd + g_loss_mpd + g_loss_mrf

        # Discriminator Loss
        d_loss_msd_real = l1_loss(real_scores_msd, torch.ones_like(real_scores_msd))
        d_loss_msd_fake = l1_loss(fake_scores_msd, torch.zeros_like(fake_scores_msd))
        d_loss_msd = (d_loss_msd_real + d_loss_msd_fake) * 0.5

        d_loss_mpd_real = l1_loss(real_scores_mpd, torch.ones_like(real_scores_mpd))
        d_loss_mpd_fake = l1_loss(fake_scores_mpd, torch.zeros_like(fake_scores_mpd))
        d_loss_mpd = (d_loss_mpd_real + d_loss_mpd_fake) * 0.5

        d_loss_mrf_real = l1_loss(real_mrf_scores, torch.ones_like(real_mrf_scores))
        d_loss_mrf_fake = l1_loss(fake_mrf_scores, torch.zeros_like(fake_mrf_scores))
        d_loss_mrf = (d_loss_mrf_real + d_loss_mrf_fake) * 0.5

        discriminator_loss = d_loss_msd + d_loss_mpd + d_loss_mrf

        # Mel-spectrogram Loss 추가
        mel_loss = hifi_gan_loss.mel_spectrogram_loss(real_audio, fake_audio)

        # Feature Loss 추가

        # 최적화
        optimizer_g.zero_grad()
        (generator_loss + mel_loss).backward(retain_graph=True)  # Mel-spectrogram loss, Feature loss 포함
        optimizer_g.step()

        optimizer_d.zero_grad()
        discriminator_loss.backward()
        optimizer_d.step()

        total_g_loss += generator_loss.item() + mel_loss.item()
        total_d_loss += discriminator_loss.item()
        total_mel_loss += mel_loss.item()

    # 에폭마다 출력
    print(f"Epoch [{epoch}/{num_epochs}], Generator Loss: {total_g_loss / len(train_loader):.4f}, Discriminator Loss: {total_d_loss / len(train_loader):.4f}, Mel Loss: {total_mel_loss / len(train_loader):.4f}, Feature Loss: {total_feature_loss / len(train_loader):.4f}")

4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])
4
torch.Size([4, 80, 4])


KeyboardInterrupt: 