In [24]:
from pathlib import Path
import pandas as pd
from itertools import count
from tqdm import tqdm
from collections import namedtuple
import random

HoldInfo = namedtuple('HoldInfo', ('hold', 'price'))

class ActionSpace():
    def __init__(self, n):
        self.n = n

    def sample(self):
        return random.sample(range(self.n), 1)[0]


class StockEnvironment():
    def __init__(self, date: str, window_sz=50):
        self.action_space = ActionSpace(3)

        self.date = date
        self.data_path = Path('.') / 'data'
        self.codes_path = list(self.data_path.iterdir())

        self.data = self.codes_path[0]
        self.data_index = 0

        dtype = {'체결시간': 'str'}
        df = pd.read_csv(self.data / (self.date + '.csv'), index_col=0, dtype=dtype).iloc[::-1].dropna()
        df = df[df['체결시간'].str[:8] == self.date].reset_index(drop=True)
        self.df = df
        self.df['env3'] = self.df['120이평'] * 0.97
        self.df['env5'] = self.df['120이평'] * 0.95
        self.window_sz = window_sz
        self.step_index = 0

        self.hold = HoldInfo(False, 0)

    def next_data(self, **kwargs):
        '''
        data_code = (str): (example) '001360'
        data_index = (int): data_code가 없으면 data_index가 제공돼야 함.
        next = (Bool): True(바로 다음 data code로 넘어감) default=False 위 옵션은 무시됨.

        위 3개 중 아무것도 kwarg가 없을 시 처음 요소로 데이터 초기화
        Return: self.data_index (데이터 인덱스)
        '''
        if 'next' in kwargs and kwargs['next']:
            self.data = self.codes_path[(self.data_index + 1) % len(self.codes_path)]
        elif ('data_code' in kwargs):
            self.data = self.data_path / kwargs['data_code']
        elif ('data_index' in kwargs):
            self.data = self.codes_path[kwargs['data_index']]
        else:
            self.data = self.codes_path[0]
        
        self.data_index = self.codes_path.index(self.data)

        dtype = {'체결시간': 'str'}
        df = pd.read_csv(self.data / (self.date + '.csv'), index_col=0, dtype=dtype).iloc[::-1].dropna()
        df = df[df['체결시간'].str[:8] == self.date].reset_index(drop=True)
        self.df = df
        self.df['env3'] = self.df['120이평'] * 0.97
        self.df['env5'] = self.df['120이평'] * 0.95

        self.step_index = 0
        self.hold = HoldInfo(False, 0)

        return self.data_index

    def reset(self):
        '''
        Return: observation, reward(None), terminated(None), truncated(None), info(None)
        '''
        self.hold = HoldInfo(False, 0)
        self.step_index = 0

        next_df = self.df.iloc[self.step_index:self.step_index+self.window_sz].copy()
        items = ['20이평', '60이평', '120이평', 'env3', 'env5']
        이격도 = list(map(lambda item: item + '이격도', items))
        for item in items:
            next_df[item + '이격도'] = (next_df[item] / next_df['현재가'] - 1) * 100

        return (next_df[이격도].values.reshape(-1), None, None, None, {})

    def step(self, action):
        '''
        action: (int) 0(Do nothing), 1(Buy or Hold), 2(Sell or Hold)

        Return: observation, reward, terminated, truncated, info
        '''

        # # dataframe 내에 null 데이터가 포함될 경우
        # if (df_part.isnull().any().any()):
        #     next = self.next_data(next=True)
        #     if (next == 0): return (self.df.iloc[self.step_index:self.step_index+self.window_sz], None, True, None, None)

        # data 반복이 끝날 경우
        if (self.step_index+self.window_sz > len(self.df)):
            next = self.next_data(next=True)
            if (next == 0): return (None, None, True, True, {})

        reward = 0
        current_price = self.df.iloc[self.step_index+self.window_sz-1]['현재가']
        if action == 0:
            if self.hold.hold:
                pass
            else:
                pass
        elif action == 1:
            if self.hold.hold:
                pass
            else:
                self.hold = HoldInfo(True, current_price)
        elif action == 2:
            if self.hold.hold:
                reward = (current_price / self.hold.price - 1) * 1000
                self.hold = HoldInfo(False, 0)
            else:
                pass

        self.step_index += 1

        next_df = self.df.iloc[self.step_index:self.step_index+self.window_sz].copy()
        items = ['20이평', '60이평', '120이평', 'env3', 'env5']
        이격도 = list(map(lambda item: item + '이격도', items))
        for item in items:
            next_df[item + '이격도'] = (next_df[item] / next_df['현재가'] - 1) * 100

        return (next_df[이격도].values.reshape(-1), reward, False, False, {})

env = StockEnvironment(date='20240223', window_sz=30)