In [None]:
import torch
import torch.nn as nn


# GAN 손실 함수
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label_tensor', torch.tensor(target_real_label))
        self.register_buffer('fake_label_tensor', torch.tensor(target_fake_label))
        self.loss = nn.MSELoss() if use_lsgan else nn.BCELoss()
        print(f"GAN loss initialized with {'LSGAN' if use_lsgan else 'Normal GAN'}")

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label_tensor
        else:
            target_tensor = self.fake_label_tensor

        if target_tensor.numel() != input.numel():
            target_tensor = target_tensor.expand_as(input)
        return target_tensor

    def forward(self, input, target_is_real):
        if isinstance(input[0], list):  # Multi-scale input
            loss = 0
            for input_i in input:
                pred = input_i[-1]
                target_tensor = self.get_target_tensor(pred, target_is_real).to(pred.device)
                loss += self.loss(pred, target_tensor)
            return loss
        else:  # Single-scale input
            target_tensor = self.get_target_tensor(input[-1], target_is_real).to(input.device)
            return self.loss(input[-1], target_tensor)


# WGAN 손실 함수
class WGANLoss(nn.Module):
    def __init__(self, grad_penalty=False, lambda_gp=10):
        super(WGANLoss, self).__init__()
        self.grad_penalty = grad_penalty
        self.lambda_gp = lambda_gp
        print(f"WGAN loss initialized with {'Gradient Penalty' if grad_penalty else 'Weight Clipping'}")

        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        d_interpolates = discriminator(interpolates)
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return self.lambda_gp * gradient_penalty

    def forward(self, input_fake, input_real=None, is_G=True, discriminator=None):
        if is_G:  # Generator loss
            return -torch.mean(input_fake[-1])
        else:  # Discriminator loss
            disc_loss = torch.mean(input_fake[-1]) - torch.mean(input_real[-1])
            if self.grad_penalty:
                gradient_penalty = self.compute_gradient_penalty(input_real[-1], input_fake[-1], discriminator)
                disc_loss += gradient_penalty
            return disc_loss


# 재구성 손실 함수
class RestructionLoss(nn.Module):
    def __init__(self, distance='l1', reduction='mean'):
        super(RestructionLoss, self).__init__()
        if distance == 'l1':
            self.loss = nn.L1Loss(reduction=reduction)
        elif distance == 'mse':
            self.loss = nn.MSELoss(reduction=reduction)
        else:
            raise ValueError(f"Unsupported distance type: {distance}. Use 'l1' or 'mse'.")

    def forward(self, gt, pred):
        if gt.shape != pred.shape:
            raise ValueError(f"Shape mismatch: ground truth shape {gt.shape} and prediction shape {pred.shape}")
        return self.loss(gt, pred)
