<a href="https://colab.research.google.com/github/acdc2019/algo-trading/blob/main/python/notebooks/strategies/SuperTrendStrategy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **SuperTrend Strategy**
## **Buy Signal**
* Step 1: A 15 min candle has to close above 31 EMA Close.
* Step 2: For same candle SuperTrend indicator should flip from **Above Close to Below Close**.
* This signal will continue till stop loss is hit or SuperTrend reversal

## **Sell Signal**
* Step 1: A 15 min candle has to close below 31 EMA Close.
* Step 2: For same candle SuperTrend indicator should flip from **Below Close to Above Close**.
* This signal will continue till stop loss is hit or SuperTrend reversal

## **Signal Reversal**
### **Buy -> Sell**
* Buy signal is running and condition for Sell signal is satisfied.
* Buy signal closes, PnL is calculated and new Sell signal is generated.

### **Sell -> Buy**
* Sell signal is running and condition for Buy signal is satisfied.
* Sell signal closes, PnL is calculated and new Buy signal is generated.

#### **Strategy Parameters**
* window_start, window_end: Dates between which to look for signal
* stop_loss = 2500

In [None]:
!pip install ta
!pip install pandas==1.3.5
!pip install pandas_ta
!pip install plotly
!pip install numpy

Collecting ta
  Downloading ta-0.9.0.tar.gz (25 kB)
Building wheels for collected packages: ta
  Building wheel for ta (setup.py) ... [?25l[?25hdone
  Created wheel for ta: filename=ta-0.9.0-py3-none-any.whl size=28908 sha256=ad51748f59523ca2359b81c08b62bfe2d0c6cd08cce4093369d8bf85bab0f5cd
  Stored in directory: /root/.cache/pip/wheels/72/78/64/cc1c01506a1010a9845e9bd7c69333730f7174661228ea4f98
Successfully built ta
Installing collected packages: ta
Successfully installed ta-0.9.0
Collecting pandas_ta
  Downloading pandas_ta-0.3.14b.tar.gz (115 kB)
[K     |████████████████████████████████| 115 kB 4.4 MB/s 
Building wheels for collected packages: pandas-ta
  Building wheel for pandas-ta (setup.py) ... [?25l[?25hdone
  Created wheel for pandas-ta: filename=pandas_ta-0.3.14b0-py3-none-any.whl size=218923 sha256=9f00561561bb225d2a1138dbb27f1de075dfa242b2b020c3d6745c61f86b66ab
  Stored in directory: /root/.cache/pip/wheels/0b/81/f0/cca85757840e4616a2c6b9fe12569d97d324c27cac60724c58
Suc

In [None]:
import pandas as pd
import pandas_ta as ta
from ta.momentum import RSIIndicator
from ta.trend import ADXIndicator
from plotly.subplots import make_subplots
from datetime import date
import numpy as np

In [None]:
class Signal():
    def __init__(self, strategy: str, sym: str, lot_size: int, ts: date, entry_price: float, stop_loss: float) -> None:
        self.strategy = strategy
        self.sym = sym
        self.lot_size = lot_size
        self.ts = ts
        self.exit_ts = ts
        self.entry_price = entry_price
        self.stop_loss = stop_loss
        self.pnl = 0
        self.comment = ""

    def __str__(self) -> str:
        return "Strategy: {}, Sym: {}, TS: {}, Entry: {}, StopLoss: {}, PnL: {}, Comment: {}".format(
            self.strategy, self.sym, self.ts, self.entry_price, self.stop_loss, self.pnl, self.comment)

class BackTestResult():
    def __init__(self, sym: str, signals: list) -> None:
        self.signals = signals
        self.sym = sym
        self.total_pnl = self._calc_total_pnl()
        self.gross_profit = self._calc_gross_profit()
        self.gross_loss = self._calc_gross_loss()
        self.profit_factor = self._calc_profit_factor()

    def _calc_profit_factor(self) -> float:
        gross_profit = self._calc_gross_profit()
        gross_loss = abs(self._calc_gross_loss())
        if (gross_loss == 0):
            gross_loss = 0.1

        profit_factor = gross_profit/gross_loss
        return round(profit_factor, 2)

    def _calc_gross_profit(self) -> float:
        gross_profit = 0
        for signal in self.signals:
            if (signal.pnl > 0):
                gross_profit = gross_profit + signal.pnl

        return round(gross_profit, 2)

    def _calc_gross_loss(self) -> float:
        gross_loss = 0
        for signal in self.signals:
            if (signal.pnl < 0):
                gross_loss = gross_loss - signal.pnl

        return round(gross_loss, 2)

    def _calc_strike_rate(self) -> float:
        if (len(self.signals) == 0):
            return 0

        profit = 0
        notional = 0
        for signal in self.signals:
            if (signal.pnl > 0):
                profit = profit+signal.pnl

            notional = notional + abs(signal.pnl)

        return round(profit/notional, 2)

    def _calc_total_pnl(self) -> float:
        total_pnl = 0
        for signal in self.signals:
            total_pnl = total_pnl + signal.pnl

        return round(total_pnl, 2)

    def __str__(self) -> str:
        return 'Sym,{},Total PnL,{},ProfitFactor,{},GrossProfit,{},GrossLoss,{},Total Signals,{}'.format(
            self.sym, self.total_pnl, self.profit_factor, self.gross_profit, self.gross_loss, len(self.signals))

In [None]:
def get_previous_candles(df: pd.DataFrame, index, n: int, include_index=False):
    '''
    Returns previous n candles from the given index in the DataFrame
    Parameters:
    df (DataFrame): DataFrame from which to return the previous candles
    index (DataFrame Index): DataFrame Index from which to return the previous candles
    n (int): Number of previous candles to return from index
    include_index (bool): If current index should be included in returned DataFrame
    Returns:
    DataFrame: Pandas dataframe with the previous n candles
    '''
    loc = df.index.get_loc(index)
    fromIdx = loc-n
    toIdx = loc+1 if include_index else loc
    return df.iloc[fromIdx:toIdx]


def get_next_candles(df: pd.DataFrame, index, n: int):
    '''
    Returns next n candles from the given index in the DataFrame
    Parameters:
    df (DataFrame): DataFrame from which to return the next candles
    index (DataFrame Index): DataFrame Index from which to return the next candles
    n (int): Number of next candles to return from index
    Returns:
    DataFrame: Pandas dataframe with the next n candles
    '''
    loc = df.index.get_loc(index)
    return df.iloc[loc+1:loc+1+n]

In [None]:
# Strategy params
file_15min = '/content/NIFTY22JANFUT-HIST-15M.csv'
strategy = 'STI'
sym = 'NIFTY'
window_start = '2022-01-01 00:00:00'
window_end = '2022-01-26 00:00:00'
expiry = '2022-01-25'
ema_interval = 31
supertrend_window = 10
supertrend_multiple = 2
rsi_15min = 70
lot_size = 50
stop_loss = 2500
back_candles = 5


In [None]:
df_15min = pd.read_csv(file_15min, parse_dates=['Date'], index_col=['Date'])

# Add RSI
rsi = RSIIndicator(df_15min['Close']).rsi()
df_15min = df_15min.assign(rsi=rsi.values)

# SuperTrend
sti = ta.supertrend(
  df_15min['High'], df_15min['Low'], df_15min['Close'], supertrend_window, supertrend_multiple)

df_15min['sti_trend'] = sti.iloc[:, 0].values
df_15min['sti_dir'] = sti.iloc[:, 1].values
df_15min['sti_long'] = sti.iloc[:, 2].values
df_15min['sti_short'] = sti.iloc[:, 3].values

# EMA
close_ema = df_15min['Close'].ewm(
            span=ema_interval, adjust=False).mean()
df_15min['ema_close'] = close_ema.values


df_15min.tail()

Unnamed: 0_level_0,Open,High,Low,Close,Volume,rsi,sti_trend,sti_dir,sti_long,sti_short,ema_close
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2022-01-25 15:15:00+05:30,17297.0,17309.0,17264.0,17264.0,791450,58.42193,17162.871432,1,17162.871432,,17194.055133
2022-01-27 09:15:00+05:30,16995.0,17046.4,16917.9,17001.9,1674200,37.065579,17162.635711,-1,,17162.635711,17182.045437
2022-01-27 09:30:00+05:30,17000.6,17002.1,16922.0,16949.9,1036100,34.380344,17140.50714,-1,,17140.50714,17167.536347
2022-01-27 09:45:00+05:30,16949.6,17037.65,16927.05,17023.35,795450,40.893885,17140.50714,-1,,17140.50714,17158.524701
2022-01-27 10:00:00+05:30,17020.7,17026.0,17003.35,17014.9,125800,40.397083,17140.50714,-1,,17140.50714,17149.548157


In [None]:
curr_window_df = df_15min[window_start:window_end]

df = pd.DataFrame(columns=df_15min.columns)
df.index.name = 'Date'

df = curr_window_df[df_15min.columns]
df.tail()

Unnamed: 0_level_0,Open,High,Low,Close,Volume,rsi,sti_trend,sti_dir,sti_long,sti_short,ema_close
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2022-01-25 14:15:00+05:30,17209.75,17238.65,17174.5,17206.55,476100,54.899332,17077.741052,1,17077.741052,,17173.09987
2022-01-25 14:30:00+05:30,17206.55,17259.95,17206.55,17235.75,397900,57.53644,17099.762252,1,17099.762252,,17177.015503
2022-01-25 14:45:00+05:30,17234.6,17297.0,17233.0,17259.3,560600,59.588737,17132.061027,1,17132.061027,,17182.158285
2022-01-25 15:00:00+05:30,17260.65,17297.9,17260.65,17297.9,422200,62.765267,17151.909924,1,17151.909924,,17189.392142
2022-01-25 15:15:00+05:30,17297.0,17309.0,17264.0,17264.0,791450,58.42193,17162.871432,1,17162.871432,,17194.055133


In [None]:
results = list()
signal = Signal('','',0,None,0,0)

temp_df = pd.DataFrame(columns=df.columns)
temp_df.index.name = 'Date'

for index, row in df.iterrows():
    prev_row = get_previous_candles(df_15min, index, 1).iloc[0]

    expiry_ts = expiry + ' 15:15:00+05:30'
    if(str(row.name) == expiry_ts):
      signal.exit_ts = row.name
      signal.exit_price = row['Close']
      signal.pnl = round(
          (signal.entry_price - row['Close'])*signal.lot_size, 2)
      if signal.strategy == 'ST_Buy':
          signal.pnl = -1 * signal.pnl
      signal.comment = 'Position squared off at expiry'

    # Stop Loss Checks
    if (signal.strategy == 'ST_Buy'):
      if(row['Low'] < signal.stop_loss):
        signal.exit_ts = row.name
        signal.exit_price = signal.stop_loss
        signal.pnl = round(-1 *
                            (signal.entry_price - signal.stop_loss)*signal.lot_size, 2)
        signal.comment = 'StopLoss breached'
        signal = Signal('','',0,None,0,0)
        continue

    if (signal.strategy == 'ST_Sell'):
      if(row['High'] > signal.stop_loss):
        signal.exit_ts = row.name
        signal.exit_price = signal.stop_loss
        signal.pnl = round(-1 *
                            (signal.stop_loss - signal.entry_price)*signal.lot_size, 2)
        signal.comment = 'StopLoss breached'
        signal = Signal('','',0,None,0,0)
        continue

    # Buy/Sell signal checks
    sti_buy_passed = row['sti_dir'] == 1 and prev_row['sti_dir'] == -1
    sti_sell_passed = row['sti_dir'] == -1 and prev_row['sti_dir'] == 1
    ema_close_buy_passed = row['Close'] > row['ema_close']
    ema_close_sell_passed = row['Close'] < row['ema_close']

    buy_passed = ema_close_buy_passed and sti_buy_passed and (
        signal.strategy == '' or signal.strategy == 'ST_Sell')
    sell_passed = ema_close_sell_passed and sti_sell_passed and (
        signal.strategy == '' or signal.strategy == 'ST_Buy')

    if buy_passed:
      if(signal.strategy == 'ST_Sell'):
          signal.exit_ts = row.name
          signal.exit_price = row['Close']
          signal.pnl = round((signal.entry_price -
                              row['Close'])*signal.lot_size, 2)
          signal.comment = 'STI Reversal'

      sl = round(row['Close'] - stop_loss / lot_size, 2)
      buy_signal = Signal('ST_Buy', sym, lot_size, row.name, row['Close'], sl)
      signal = buy_signal
      results.append(buy_signal)
      continue

    if sell_passed:
      if(signal.strategy == 'ST_Buy'):
          signal.exit_ts = row.name
          signal.exit_price = row['Close']
          signal.pnl = round((row['Close'] - signal.entry_price)*signal.lot_size, 2)
          signal.comment = 'STI Reversal'

      sl = round(row['Close'] + stop_loss / lot_size, 2)
      sell_signal = Signal('ST_Sell', sym, lot_size, row.name, row['Close'], sl)
      signal = sell_signal
      results.append(sell_signal)
      continue  

total_pnl = 0
for result in results:
  total_pnl = total_pnl + result.pnl
  print(result)

print('Total PnL: {}'.format(total_pnl))

Strategy: ST_Sell, Sym: NIFTY, TS: 2022-01-06 09:15:00+05:30, Entry: 17760.2, StopLoss: 17810.2, PnL: -2500.0, Comment: StopLoss breached
Strategy: ST_Sell, Sym: NIFTY, TS: 2022-01-07 11:45:00+05:30, Entry: 17833.1, StopLoss: 17883.1, PnL: -2500.0, Comment: StopLoss breached
Strategy: ST_Buy, Sym: NIFTY, TS: 2022-01-13 15:00:00+05:30, Entry: 18294.85, StopLoss: 18244.85, PnL: -2500.0, Comment: StopLoss breached
Strategy: ST_Buy, Sym: NIFTY, TS: 2022-01-14 13:45:00+05:30, Entry: 18268.0, StopLoss: 18218.0, PnL: 597.5, Comment: STI Reversal
Strategy: ST_Sell, Sym: NIFTY, TS: 2022-01-18 09:30:00+05:30, Entry: 18279.95, StopLoss: 18329.95, PnL: -767.5, Comment: STI Reversal
Strategy: ST_Buy, Sym: NIFTY, TS: 2022-01-18 12:45:00+05:30, Entry: 18295.3, StopLoss: 18245.3, PnL: -1617.5, Comment: STI Reversal
Strategy: ST_Sell, Sym: NIFTY, TS: 2022-01-18 14:15:00+05:30, Entry: 18262.95, StopLoss: 18312.95, PnL: 49947.5, Comment: Position squared off at expiry
Total PnL: 40660.0


In [None]:
import plotly.graph_objects as go
import plotly.offline as py

back=10
next=50
sig_param_col1 = ['Name','Entry Time','Exit Time','Entry INR','Exit INR','StopLoss INR','PnL','Comments']
for signal in results:
  prev_candles = get_previous_candles(df_15min, signal.ts, back, True)
  next_candles = get_next_candles(df_15min, signal.ts, next)
  
  candles = pd.concat([prev_candles, next_candles])
  candles['DateStr'] = candles.index.strftime('%d-%m %H:%M')

  # Get all strategy params for this signal
  sig_stop_loss = signal.stop_loss
  sig_entry = signal.entry_price

  fig = make_subplots(rows=1, cols=2, shared_xaxes=False,
               subplot_titles=('OHLC', ''), 
               vertical_spacing=0.1, 
               horizontal_spacing=0.01,
               # row_width=[0.25, 0.25, 0.5],
               column_widths=[0.8,0.2],
               specs=[[{"secondary_y": False, "type": "candlestick"},{"secondary_y": False, "type":"table"}]])
    
  fig.add_trace(go.Candlestick(x=candles['DateStr'],
                     open=candles['Open'],
                     high=candles['High'],
                     low=candles['Low'],
                     close=candles['Close'],
                     name='Signal Chart',
                     increasing_line_color='yellow',
                     increasing_fillcolor='yellow',
                     decreasing_line_color='red',
                     decreasing_fillcolor='red',),
                     row=1,col=1)
  
  # EMA Close
  fig.add_trace(go.Scatter(x=candles['DateStr'], y=candles['ema_close'], name='EMA Close',            
           marker_color='Blue'),
           row=1, col=1)
  
  # Close
  fig.add_trace(go.Scatter(x=candles['DateStr'], y=candles['Close'], name='Close',            
          marker_color='Yellow'),
          row=1, col=1)
  
  # SuperTrend
  fig.add_trace(go.Scatter(x=candles['DateStr'], y=candles['sti_trend'], name='SuperTrend',            
           marker_color='Cyan'),
           row=1, col=1)
  
  fig.add_annotation(x=back,y=sig_entry,
                     text='Signal')
  
  # Position Entry Point
  fig.add_shape(type='line', 
                x0=-1,x1=back+next+2,
                y0=sig_entry, y1=sig_entry, 
                line=dict(color='Green'),
                row=1,col=1)
  
  # Position Stop Loss
  fig.add_shape(type='line', 
                x0=-1,x1=back+next+2,
                y0=sig_stop_loss, y1=sig_stop_loss, 
                line=dict(color='Red'),
                row=1,col=1)

  # Signal Parameters Table
  fig.add_trace(go.Table(header=dict(values=['Param','Value'],
                                     line_color='white',
                                     fill_color='darkslategray',
                                     align='left'),
                         cells=dict(values=[sig_param_col1,
                                           [signal.strategy, signal.ts.strftime('%d-%m %H:%M'), signal.exit_ts.strftime('%d-%m %H:%M'),
                                            sig_entry, signal.exit_price,
                                            sig_stop_loss, signal.pnl, signal.comment]],
                                    line_color='white',
                                    fill_color='black',
                                    align='left')),
                row=1, col=2)

  fig.update_xaxes(type='category', rangeslider=dict(visible=False))
  fig.update_xaxes(showgrid=False, nticks=5)
  fig.update_yaxes(showgrid=False)
  fig.update_layout(
    title='Signal generated for RSI 15mins Strategy',
    title_x = 0.5,
    autosize=False,
    width=1450,
    height=650,
    plot_bgcolor='rgb(5,5,5)',
    paper_bgcolor='rgb(0,0,0)',
    font_color='white')

  py.iplot(fig)