In [1]:
%run Constant.ipynb
%run DataSource.ipynb
%run DataSourceManager.ipynb

import numpy as np
import logging as log

class Metric:
            
    # symbol=SZ#002335
    def __init__(self, symbol, context):
        self.context = context
        self.level = context['level']
        self.short = context['short']
        self.mid = context['mid']
        self.symbol = symbol
        
        self.dataSource = DataSourceManager().get_data_source(symbol,self.level)
        self.data = self.dataSource.get_data()

    def get_fixed_day(self,cur_date,fixed_day):
        cur_index = self.data[self.data['date']==cur_date].index.values[0]
        
        end_index = cur_index.astype(int)+fixed_day
        fixed = self.data[end_index:end_index+1]
        
        return fixed
        
    def yesterday(self,cur_date):
        cur_index = self.data[self.data['date']==cur_date].index.values[0]
        start_index = cur_index.astype(int)-1

        yesterday = self.data[start_index:cur_index]
        if(yesterday.shape[0]==0):
            return None
        else:
            return yesterday

    def get_price_previous_decrease_percent(self,cur_date,previous_days):
        previous_part = self.previous_series(cur_date,previous_days+1)
        previous_part = previous_part[previous_part['date']!=cur_date]
        
        if(previous_part.shape[0]<3):
            return 0
        
        max_close_price = max(previous_part['close'])
        
        cur_price = self.get_cur_price(cur_date)
        
        return 100.0*(max_close_price-cur_price)/max_close_price
        
        
    def get_previous_highest_price(self,cur_date, previous_days):
        log.debug('cur_date:' + cur_date +' does break highest')
        previous_part = self.previous_series(cur_date,previous_days+1)
        previous_part = previous_part[previous_part['date']!=cur_date]
        
        if(previous_part.shape[0]<3):
            return 0
        
        highest_price = max(previous_part['high'])
        return highest_price
        
    def atr(self,cur_date,N=14):
        part =self.previous_series(cur_date,N)
        array=[]
        
        for key,row in part.iterrows():
            tr_value = self.tr(row['date'])
            if(tr_value is not None):
                array.append(tr_value)
            
        return np.mean(array)
    
        
#   TR : MAX(MAX((HIGH-LOW),ABS(REF(CLOSE,1)-HIGH)),ABS(REF(CLOSE,1)-LOW));
    def tr(self,cur_date):
        cur = self.data[self.data['date']==cur_date]
        
        last_day = self.yesterday(cur_date)
        if(last_day is None):
            return None
        
        cur_tr = cur['high'].values[0]-cur['low'].values[0]
        yester_high_tr = abs(last_day['close'].values[0]-cur['high'].values[0])
        yester_low_tr = abs(last_day['close'].values[0]-cur['low'].values[0])
                  
        values = [cur_tr,yester_high_tr,yester_low_tr]
        tr = np.max(values)
        
        return tr
        
    
        
#     一个波段内，如果被loss_per幅度止损则killed，返回True
    def be_killed(self,start_date,end_date,loss_per):
        part = self.get_k_series_between(start_date,end_date)
        
        cur_price = self.get_cur_price(start_date)
        expect_lowest_price = cur_price*(1-1.0*loss_per/100.0)
        
        lowest_price = min(part['low'])
        if(lowest_price<expect_lowest_price):
            return True
        
        return False
    
    def does_break_lowest(self,cur_date):
        previous_part = self.previous_series(cur_date,self.short+1)
        previous_part = previous_part[previous_part['date']!=cur_date]
        if(previous_part.shape[0]<3):
            return False
            
        
        lowest_price = min(previous_part['low'])
        cur_price = self.get_cur_price(cur_date)
        
        if(cur_price<lowest_price):
            return True
        else:
            return False
        
    
#   当前K突破前N个周期的高点
    def does_break_highest(self,cur_date):
        log.debug('cur_date:' + cur_date +' does break highest')
        previous_part = self.previous_series(cur_date,self.mid+1)
        previous_part = previous_part[previous_part['date']!=cur_date]
        
        if(previous_part.shape[0]<3):
            return False
        
        highest_price = max(previous_part['high'])
        cur_price = self.get_cur_price(cur_date)
        
        if(cur_price>highest_price):
            return True
        else:
            return False
            
    def next_break_lowest_date(self,cur_date):
        next_part = self.next_series(cur_date,10000)
        if(next_part.shape[0]>0):
            for key,row in next_part.iterrows():
                date = row['date']
                if(self.does_break_lowest(date)):
                    return date
        
    
    def is_today_open(self,cur_date):
        part = self.data[self.data['date']==cur_date]
        
        if(part.shape[0]==0):
            return False
        else:
            return True
            
    def get_latest_price(self,cur_date):
        cur_datetime  = to_datetime(cur_date)
        
        part_df = self.data[self.data['datetime']<=cur_datetime]
        latest_price = part_df[part_df.shape[0]-1:]['close'].values[0]
        
        return latest_price
    
    
    def get_stock_start_date(self):
        return self.data['date'][:1].values[0]
        
    def get_stock_end_date(self):
        return self.data['date'][-1:].values[0]
    
    def list_stock_all_dates(self):
        return self.data['date'].values
    
    
    def next_series(self,cur_date, latest_n):
        cur_index = self.data[self.data['date']==cur_date].index.values[0]
        end_index = cur_index.astype(int)+latest_n

        part = self.data[cur_index:end_index]
        return part

        
    def previous_series(self,cur_date,latest_n):
        cur_index = self.data[self.data['date']==cur_date].index.values[0]
        start_index = cur_index.astype(int)-latest_n+1

        latest_part = self.data[start_index:cur_index+1]
            
        return latest_part
    
    def get_k_series_between(self, start_date, end_date):
        start_datetime = to_datetime(start_date)
        end_datetime = to_datetime(end_date)
        
        part_df = self.data[(self.data['datetime']>=start_datetime) & (self.data['datetime']<=end_datetime)]
        return part_df
        
    def ma(self,cur_date,periods):
        periods_part = self.previous_series(cur_date, periods)['close']

        return round(periods_part.mean(),2)
    
    def std(self,cur_date, periods):
        periods_part = self.previous_series(cur_date, periods)['close']
        return np.std(periods_part)


    # {参数 N: 2，250，20 }
    # BOLL:MA(CLOSE,M); UB:BOLL+2*STD(CLOSE,M); LB:BOLL-2*STD(CLOSE,M);
    def boll(self,cur_date):
        M=20
        price = self.data[self.data['date']==cur_date]['close'].values[0]

        boll = self.ma(cur_date, M)
        up_boll = boll + 2 * (self.std(cur_date,M))
        low_boll = boll - 2 * (self.std(cur_date,M))
        return boll
    
    
    def boll99(self,cur_date):
        M=99
        price = self.data[self.data['date']==cur_date]['close'].values[0]

        boll = self.ma(cur_date, M)
        up_boll = boll + 2 * (self.std(cur_date,M))
        low_boll = boll - 2 * (self.std(cur_date,M))
        return boll

    
    def is_down_boll(self, cur_date):
        boll_value = self.boll(cur_date)

        price = self.data[self.data['date']==cur_date]['close'].values[0]

        if(price<boll_value):
            return True
        else:
            return False
        
        
    def is_on_boll99(self,cur_date):
        boll_value = self.boll99(cur_date)

        price = self.data[self.data['date']==cur_date]['close'].values[0]

        if(price>boll_value):
            return True
        else:
            return False
        

    def is_on_boll(self,cur_date):
        boll_value = self.boll(cur_date)

        price = self.data[self.data['date']==cur_date]['close'].values[0]

        if(price>boll_value):
            return True
        else:
            return False
        
        
    def is_latest_n_keep_on_boll(self,cur_date,latest_n):
        pre_dates = self.previous_series(cur_date,latest_n)['date']
        
        for date in pre_dates:
            if(not(self.is_on_boll(date))):
                return False
            
        return True
        
        
    def is_down_cross(self,cur_date):
            cur_ma_short = self.ma(cur_date,self.short)
            cur_ma_mid = self.ma(cur_date,self.mid)

            previous_part = self.previous_series(cur_date,2)
            if(len(previous_part)==0):
                return False

            last_day = previous_part['date'].values[0]

            last_day_ma_short = self.ma(last_day,self.short)
            last_day_ma_mid = self.ma(last_day,self.mid)

            if((cur_ma_short<cur_ma_mid) and (last_day_ma_short>=last_day_ma_mid)):
                return True
            else:
                return False
        
    def is_up_cross(self,cur_date):
        cur_ma_short = self.ma(cur_date,self.short)
        cur_ma_mid = self.ma(cur_date,self.mid)
  
        previous_part = self.previous_series(cur_date,2)
        
        if(len(previous_part)==0):
            return False
        
        last_day = previous_part['date'].values[0]
        
        last_day_ma_short = self.ma(last_day,self.short)
        last_day_ma_mid = self.ma(last_day,self.mid)
        
        if((cur_ma_short>cur_ma_mid) and (last_day_ma_short<=last_day_ma_mid)):
            return True
        else:
            return False

    def k_on_ma(self,cur_date, latest_n_k=3):
        latest_n_k_close = self.previous_series(cur_date,latest_n_k)['close']
        flag=True

        for c in latest_n_k_close:
            if(c<self.ma(cur_date,short)):
                return False

        return flag


    def latest_n_period_up_crossed(self,cur_date, latest_n=3):
        latest_part = self.previous_series(cur_date, latest_n)['date'].values

        cross_array = []
        for i in range(0,latest_n):
            date = latest_part[i]
            up_cross = self.is_up_cross(date, short, mid)

            if(up_cross == True):
                return True

        return False


    def latest_n_k_break_ma(self,cur_date, latest_n_k=3):
        latest_n_k_close = self.previous_series(cur_date,latest_n_k)['close']

        flag=True

        for c in latest_n_k_close:
            if(c>self.ma(cur_date, self.mid)):
                return False

        return flag
    
    def deep_k_break_ma(self,cur_date, deep_percent=3):
        price = self.data[self.data['date']==cur_date]['close'].values[0]
        ma_price = self.ma(cur_date, self.mid)
        if((ma_price-price)*100/price>deep_percent):
            return True

        return False
    
    def get_cur_price(self, cur_date):
        cur_part = self.data[self.data['date']==cur_date]
        if(cur_part.shape[0]>0):
            price = cur_part['close'].values[0]
            return price
        
    def get_cur_low(self, cur_date):
        cur_part = self.data[self.data['date']==cur_date]
        if(cur_part.shape[0]>0):
            price = cur_part['low'].values[0]
            return price

    def get_vol(self,cur_date):
        vol = self.data[self.data['date']==cur_date]['volume'].values[0]
        return vol

    def ma_vol(self,cur_date, periods=60):
        periods_part = self.previous_series(cur_date, periods)['volume']
        return periods_part.mean()

    def is_vol_outburst_point(self,cur_date, threshold=2, periods=60):
        vol = self.get_vol(cur_date)
        ma_vol_num = self.ma_vol(cur_date, periods)
        vol_rate = vol/ ma_vol_num

        if(vol_rate>threshold):
            return True
        else:
            return False

    # LARGE_VOL = SUM(VOL,5)/5/MA(VOL,60)    
    def is_latest_n_vol_outburst(self,cur_date, latest_n=5, threshold=2, periods=60):
        latest_part = self.previous_series(cur_date,latest_n)

        sum_vol = sum(latest_part['volume'])

        vol_rate = sum_vol/latest_n/self.ma_vol(cur_date, periods)

        if(vol_rate>threshold):
            return True
        else:
            return False


    # LARGE_VOL:=SUM(VOL,5)/5/MA(VOL,60)
    # COUNT(LARGE_VOL>2,10)>5
    # 成交量异动
    def is_outburst(self,cur_date, latest_n = 10, periods=60):
        latest_part = self.previous_series(cur_date,latest_n)

        count = 0
        for index,row in latest_part.iterrows():
            if(self.is_vol_outburst_point(row['date'])):
                count += 1

        if(count>5):
            return True
        else:
            return False

    def is_up_cross_outburst(self,cur_date, latest_n=10):
        latest_part = self.previous_series(cur_date, latest_n)

        for index,row in latest_part.iterrows():
            date = row['date']

            if(self.is_outburst(date)):
                return True

        next_part = self.next_periods_series(cur_date, latest_n)

        for index,row in next_part.iterrows():
            date = row['date']

            if(self.is_outburst(date)):
                return True

        return False
    
    
#     最早的那根破ma的K线
    def next_down_ma(self, periods, next_sections):
        the_first=True
        for index,row in next_sections.iterrows():
            if(the_first==True):
                the_first = False
                continue
                
            date = row['date']
            if(self.latest_n_k_break_ma(date)):
                return date
            

        return 'None'
    
    def next_down_boll(self, next_sections):
        for index,row in next_sections.iterrows():
            date = row['date']
            if(self.is_down_boll(date)):
                return date

        return 'None'
        

    def next_down_cross(self, next_sections):
        for index,row in next_sections.iterrows():
            date = row['date']
            if(self.is_down_cross(date)):
                return date

        return 'None'

    def previous_down_cross(self,next_sections):
        reverse_section = next_sections.iloc[::-1]

        for index,row in reverse_section.iterrows():
            date = row['date']
            if(self.is_down_cross(date)):
                return date

        return 'None'
    
    
    def is_always_decrease(self, part):
        for i in range(0, len(part)):
            if(i==0):
                continue
            if(part[i]>part[i-1]):
                return False

        return True
        
    
    def is_always_increase(self, part):
        for i in range(0, len(part)):
            if(i==0):
                continue
            if(part[i]<part[i-1]):
                return False

        return True
            
    
    def ma_direction(self, cur_date, periods, latest_n=5):
        pre_dates = self.previous_series(cur_date, latest_n)['date'].values
        
        if(len(pre_dates)<3):
            return ma_unknow_direction
        
        previous_ma=[]
        for date in pre_dates:
            previous_ma.append(self.ma(date, periods))
            
        if(self.is_always_increase(previous_ma)):
            return ma_upturn_direction
        
        if(self.is_always_decrease(previous_ma)):
            return ma_downturn_direction
        
        return ma_unknow_direction
    
    
    def boll_direction(self, cur_date, latest_n=5):
        pre_dates = self.previous_series(cur_date, latest_n)['date'].values
        
        if(len(pre_dates)<3):
            return boll_unknow_direction
        
        previous_boll=[]
        for date in pre_dates:
            previous_boll.append(self.boll(date))
            
        if(self.is_always_increase(previous_boll)):
            return boll_upturn_direction
        
        if(self.is_always_decrease(previous_boll)):
            return boll_downturn_direction
        
        return boll_unknow_direction
        
            
    def get_max_price_date(self,start_date,end_date):
        part = self.get_k_series_between(start_date,end_date)
        
        return part.loc[part['high'].idxmax()]
    
    def get_min_price_date(self,start_date,end_date):
        part = self.get_k_series_between(start_date,end_date)
        
        return part.loc[part['low'].idxmin()]
     
        
    def clean_metric(self):
        self.data = None
        
        