<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/kpyopark/finance_session_hol/s01.aiml_finance_analysis_cgan_example/cgan_example.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/kpyopark/finance_session_hol/blob/main/s01.aiml_finance_analysis_cgan_example/cgan_example.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [None]:
#! pip install --upgrade pip setuptools wheel
! pip install torch yfinance

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import yfinance as yf
from sklearn.preprocessing import MinMaxScaler

# 삼성전자 주가 데이터 다운로드
samsung = yf.Ticker("005930.KS")
data = samsung.history(start="2023-10-01", end="2024-05-30")

# 종가 데이터 추출 및 정규화
close_prices = data['Close'].values.reshape(-1, 1)
scaler = MinMaxScaler(feature_range=(-1, 1))
normalized_data = scaler.fit_transform(close_prices)

# PyTorch 텐서로 변환
real_data = torch.FloatTensor(normalized_data)

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, sequence_length):
        super(Generator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.linear = nn.Linear(hidden_dim, 1)
        
    def forward(self, z, c):
        x = torch.cat([z, c], dim=-1)
        x, _ = self.lstm(x)
        return self.linear(x)

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.linear = nn.Linear(hidden_dim, 1)
        
    def forward(self, x, c):
        x = torch.cat([x, c], dim=-1)
        x, _ = self.lstm(x)
        return self.linear(x[:, -1, :])

In [None]:
def train_cgan(generator, discriminator, real_data, num_epochs, batch_size):
    criterion = nn.BCEWithLogitsLoss()
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

    for epoch in range(num_epochs):
        for i in range(0, real_data.size(0) - sequence_length, batch_size):
            batch = real_data[i:i+batch_size]
            
            # 조건 생성 (이전 5일 데이터를 조건으로 사용)
            c = batch[:, :5].unsqueeze(-1)
            
            # Discriminator 학습
            z = torch.randn(batch_size, sequence_length, noise_dim)
            fake_data = generator(z, c)
            
            d_optimizer.zero_grad()
            real_loss = criterion(discriminator(batch.unsqueeze(-1), c), torch.ones(batch_size, 1))
            fake_loss = criterion(discriminator(fake_data.detach(), c), torch.zeros(batch_size, 1))
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            d_optimizer.step()
            
            # Generator 학습
            g_optimizer.zero_grad()
            g_loss = criterion(discriminator(fake_data, c), torch.ones(batch_size, 1))
            g_loss.backward()
            g_optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')

    return generator, discriminator

In [None]:
sequence_length = 30
noise_dim = 10
hidden_dim = 64
batch_size = 32
num_epochs = 500

generator = Generator(noise_dim + 1, hidden_dim, sequence_length)
discriminator = Discriminator(2, hidden_dim)

trained_generator, _ = train_cgan(generator, discriminator, real_data, num_epochs, batch_size)

# 결과 시각화
z = torch.randn(1, sequence_length, noise_dim)
c = real_data[:5].unsqueeze(0).unsqueeze(-1)
generated_data = trained_generator(z, c).detach().numpy()

# 데이터 역정규화
generated_prices = scaler.inverse_transform(generated_data[0])
real_prices = scaler.inverse_transform(real_data.numpy())

plt.figure(figsize=(12, 6))
plt.plot(data.index[-sequence_length:], real_prices[-sequence_length:], label='Real Data')
plt.plot(data.index[-sequence_length:], generated_prices, label='Generated Data')
plt.legend()
plt.title('Samsung Electronics: Real vs Generated Stock Prices')
plt.xlabel('Date')
plt.ylabel('Price (KRW)')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()