In [6]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from datetime import datetime
import os
import wandb
from pathlib import Path

# 본 과제 제출자는 현재 우분투 도커 환경에서 작업중이므로 다음과 같이 경로 설정
BASE_PATH="/home/Deep-Learning-study"
import sys
sys.path.append(BASE_PATH)

CURRENT_FILE_PATH = os.getcwd()
CHECKPOINT_FILE_PATH = os.path.join(CURRENT_FILE_PATH, "checkpoints")

if not os.path.isdir(CHECKPOINT_FILE_PATH):
  os.makedirs(os.path.join(CURRENT_FILE_PATH, "checkpoints"))

In [7]:
from _01_code._08_fcn_best_practice.c_trainer import ClassificationTrainer
from _01_code._15_lstm_and_its_application.f_arg_parser import get_parser
from _01_code._15_lstm_and_its_application.g_crypto_currency_regression_train_lstm import get_btc_krw_data

In [8]:
def get_model():
  class MyModel(nn.Module):
    def __init__(self, n_input, n_output):
      super().__init__()

      self.lstm = nn.LSTM(input_size=n_input, hidden_size=256, num_layers=3, batch_first=True)
      self.fcn = nn.Linear(in_features=256, out_features=n_output)

    def forward(self, x):
      x, hidden = self.lstm(x)
      x = x[:, -1, :]  # x.shape: [32, 128]
      x = self.fcn(x)
      return x

  my_model = MyModel(n_input=5, n_output=2)

  return my_model

In [9]:
def main(args):
  run_time_str = datetime.now().astimezone().strftime('%Y-%m-%d_%H-%M-%S')

  config = {
    'epochs': args.epochs,
    'batch_size': args.batch_size,
    'validation_intervals': args.validation_intervals,
    'learning_rate': args.learning_rate,
    'early_stop_patience': args.early_stop_patience,
    'early_stop_delta': args.early_stop_delta,
    'weight_decay': args.weight_decay
  }

  project_name = "lstm_classification_btc_krw"
  wandb.init(
    mode="online" if args.wandb else "disabled",
    project=project_name,
    notes="btc_krw experiment with lstm",
    tags=["lstm", "classification", "btc_krw"],
    name=run_time_str,
    config=config
  )
  print(args)
  print(wandb.config)

  train_data_loader, validation_data_loader, _ = get_btc_krw_data(is_regression=False)
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  print(f"Training on device {device}.")

  model = get_model()
  model.to(device)

  optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate, weight_decay=wandb.config.weight_decay)

  classification_trainer = ClassificationTrainer(
    project_name, model, optimizer, train_data_loader, validation_data_loader, None,
    run_time_str, wandb, device, CHECKPOINT_FILE_PATH
  )
  classification_trainer.train_loop()

  wandb.finish()

In [12]:
if __name__ == "__main__":
    import sys
    if 'ipykernel' in sys.modules:  # Jupyter Notebook에서 실행 중인지 확인
        # Jupyter에서 실행할 때는 기본값 사용
        class Args:
            def __init__(self):
                self.wandb = True
                self.batch_size = 64
                self.epochs = 200
                self.learning_rate = 0.0001
                self.weight_decay = 0.00001
                self.validation_intervals = 1
                self.early_stop_patience = 10
                self.early_stop_delta = 0.001
        args = Args()
    else:
        # 일반 Python 스크립트로 실행할 때는 argparse 사용
        parser = get_parser()
        args = parser.parse_args()
    
    main(args)

<__main__.Args object at 0x7744bb0ed4d0>
{'epochs': 200, 'batch_size': 64, 'validation_intervals': 1, 'learning_rate': 0.0001, 'early_stop_patience': 10, 'early_stop_delta': 0.001, 'weight_decay': 1e-05}
Training on device cuda:0.
[Epoch   1] T_loss: 0.69205, T_accuracy: 51.8293 | V_loss: 0.69065, V_accuracy: 55.0000 | Early stopping is stated! | T_time: 00:00:00, T_speed: 0.000
[Epoch   2] T_loss: 0.69105, T_accuracy: 53.6863 | V_loss: 0.69643, V_accuracy: 44.0000 | Early stopping counter: 1 out of 10 | T_time: 00:00:00, T_speed: 0.000
[Epoch   3] T_loss: 0.69078, T_accuracy: 53.2982 | V_loss: 0.69525, V_accuracy: 44.0000 | Early stopping counter: 2 out of 10 | T_time: 00:00:00, T_speed: 0.000
[Epoch   4] T_loss: 0.69145, T_accuracy: 53.2151 | V_loss: 0.69714, V_accuracy: 44.0000 | Early stopping counter: 3 out of 10 | T_time: 00:00:00, T_speed: 0.000
[Epoch   5] T_loss: 0.69120, T_accuracy: 53.0488 | V_loss: 0.69689, V_accuracy: 44.0000 | Early stopping counter: 4 out of 10 | T_time:

0,1
Epoch,▁▂▂▃▄▅▅▆▇▇█
Training accuracy (%),▁█▇▆▆▇▄▆▇▆▆
Training loss,█▃▁▅▃▂▃▂▂▂▂
Training speed (epochs/sec.),▁▁▁▁▁▅▅▆▇▇█
Validation accuracy (%),█▁▁▁▁▁▁▁▁▁▁
Validation loss,▁▆▅▆▆█▅▇▆▅▅

0,1
Epoch,11.0
Training accuracy (%),53.0765
Training loss,0.69093
Training speed (epochs/sec.),11.0
Validation accuracy (%),44.0
Validation loss,0.69543
