In [None]:
# ctgan 객체와 input_data가 준비된 상태라고 가정

# 학습 루프
num_epochs = 50
generator_losses = []
discriminator_losses = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    ctgan.set_input(input_data)  # 데이터 입력
    ctgan.optimize_parameters()  # 학습 최적화

    # 손실값 저장
    generator_loss = ctgan.loss_G.item()
    discriminator_loss = ctgan.loss_D.item()
    generator_losses.append(generator_loss)
    discriminator_losses.append(discriminator_loss)

    # 현재 에폭의 손실 출력
    print(f"Generator Loss: {generator_loss:.4f}")
    print(f"Discriminator Loss: {discriminator_loss:.4f}")

# 손실함수 그래프 출력
plt.figure(figsize=(12, 6))
plt.plot(range(1, num_epochs + 1), generator_losses, label="Generator Loss", marker='o')
plt.plot(range(1, num_epochs + 1), discriminator_losses, label="Discriminator Loss", marker='x')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Generator and Discriminator Loss Over Epochs")
plt.legend()
plt.grid(True)
plt.show()
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR

# 학습 루프 설정
num_epochs = 80  # Epoch 추가
scheduler_step_epoch = 60  # Learning Rate 감소 시점 변경
learning_rate_decay_factor = 0.8  # 학습률 감소 비율
gan_loss_weight = 0.5  # GAN Loss 가중치 감소
reconstruction_loss_weight = 50.0  # Reconstruction Loss 가중치 증가

# Learning Rate Scheduler
scheduler_G = StepLR(ctgan.optimizer_G, step_size=scheduler_step_epoch, gamma=learning_rate_decay_factor)
scheduler_D = StepLR(ctgan.optimizer_D, step_size=scheduler_step_epoch, gamma=learning_rate_decay_factor)

# 손실값 저장
generator_losses = []
discriminator_losses = []

# 학습 루프
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    ctgan.set_input(input_data)
    ctgan.optimize_parameters()

    # 손실 계산
    generator_loss = (
        ctgan.loss_G_GAN.item() * gan_loss_weight +
        ctgan.loss_G_L1.item() * reconstruction_loss_weight
    )
    discriminator_loss = ctgan.loss_D.item()

    # 손실값 리스트에 추가
    generator_losses.append(generator_loss)
    discriminator_losses.append(discriminator_loss)

    print(f"Generator Loss: {generator_loss:.4f}")
    print(f"Discriminator Loss: {discriminator_loss:.4f}")

    # 학습률 업데이트
    if epoch >= scheduler_step_epoch:
        scheduler_G.step()
        scheduler_D.step()

# 손실 함수 그래프 출력
plt.figure(figsize=(12, 6))
plt.plot(range(1, num_epochs + 1), generator_losses, label="Generator Loss", marker='o')
plt.plot(range(1, num_epochs + 1), discriminator_losses, label="Discriminator Loss", marker='x')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Generator and Discriminator Loss Over Epochs")
plt.legend()
plt.grid(True)
plt.show()
