# Building the Pipeline on Local Spark

1. Setting up
2. Helpers
3. Pull stock info
4. Pull short interests
5. Pull stock prices
6. Combine datasets

## 1. Setting Up

We want to have a logging feature that works for both Jupyter notebook and Spark environments.

1. As it turned out, Spark has "WARN" but does not have "WARNING" level, while in current Python (3.6.x), "WARN" is deprecated, "WARNING" should be used instead.
2. Therefore, we create a custom "WARN" level as well as function `logger.warn` for Jupyter notebook.
3. As shown in [this StackOverflow post](https://stackoverflow.com/questions/35326814/change-level-logged-to-ipython-jupyter-notebook), this is not straightforward due to a Jupyter notebook bug. We need to workaround this by specifying an invalid value first, which we do in the code cell below.

In [None]:

# Run this, but don't copy into etl scripts
# workaround via specifying an invalid value first
%config Application.log_level='WORKAROUND'
import logging
logging.WARN = 21
logging.addLevelName(logging.WARN, 'WARN')

def warn(self, message, *args, **kws):
    if self.isEnabledFor(logging.WARN):
        # Yes, logger takes its '*args' as 'args'.
        self._log(logging.WARN, message, args, **kws) 
logging.Logger.warn = warn


logger = logging.getLogger()
logger.setLevel(logging.WARN)
logger.warn('hello')

# ------------------
from pyspark.sql import SparkSession

spark = SparkSession \
        .builder \
        .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:2.7.0") \
        .getOrCreate()
#         .config("spark.eventLog.enabled", "true") \
#         .config("spark.eventLog.dir" "test_data/spark-logs") \

import pandas as pd
import configparser
config = configparser.ConfigParser()
config.read('airflow/config.cfg')
import tqdm

In [None]:
# Run this cell, and also copy to all etl scripts, or simply include in common.py

import requests
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Row

from py4j.java_gateway import java_import

## 2. Helpers

Include this code as helpers in all next etl scripts.

### Code

In [None]:
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']

In [None]:
from py4j.protocol import Py4JJavaError
from pyspark.sql.utils import AnalysisException


sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set("fs.s3a.access.key", AWS_ACCESS_KEY_ID)
sc._jsc.hadoopConfiguration().set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)


def delete_path(spark, host, path):
    sc = spark.sparkContext
    java_import(sc._gateway.jvm, "java.net.URI")
    uri = sc._gateway.jvm.java.net.URI
    fs = (sc._jvm.org
          .apache.hadoop
          .fs.FileSystem
          .get(uri(host), sc._jsc.hadoopConfiguration())
          )
    fs.delete(sc._jvm.org.apache.hadoop.fs.Path(host+path), True)


def copyMerge(spark, host, src_dir, dst_file, overwrite=False, deleteSource=False, debug=False):
    
    sc = spark.sparkContext
    
    hadoop = sc._jvm.org.apache.hadoop
#     conf = hadoop.conf.Configuration()
    conf = sc._jsc.hadoopConfiguration()
#     fs = hadoop.fs.FileSystem.get(conf)
    java_import(sc._gateway.jvm, "java.net.URI")
    uri = sc._gateway.jvm.java.net.URI
    fs = (sc._jvm.org
          .apache.hadoop
          .fs.FileSystem
          .get(uri(host), sc._jsc.hadoopConfiguration())
          )

    # check files that will be merged
    files = []
    for f in fs.listStatus(hadoop.fs.Path(src_dir)):
        if f.isFile():
            files.append(f.getPath())
    if not files:
        raise ValueError("Source directory {} is empty".format(src_dir))
    files.sort(key=lambda f: str(f))

    # dst_permission = hadoop.fs.permission.FsPermission.valueOf(permission)      # , permission='-rw-r-----'
    out_stream = fs.create(hadoop.fs.Path(dst_file), overwrite)

    try:
        # loop over files in alphabetical order and append them one by one to the target file
        for file in files:
            if debug: 
                print("Appending file {} into {}".format(file, dst_file))

            in_stream = fs.open(file)   # InputStream object
            try:
                hadoop.io.IOUtils.copyBytes(in_stream, out_stream, conf, False)     # False means don't close out_stream
            finally:
                in_stream.close()
    finally:
        out_stream.close()

    if deleteSource:
        fs.delete(hadoop.fs.Path(src_dir), True)    # True=recursive
        if debug:
            print("Source directory {} removed.".format(src_dir))


def spark_table_exists(host, table_path):
    URI           = sc._gateway.jvm.java.net.URI
    Path          = sc._gateway.jvm.org.apache.hadoop.fs.Path
    FileSystem    = sc._gateway.jvm.org.apache.hadoop.fs.FileSystem
    # Configuration = sc._gateway.jvm.org.apache.hadoop.conf.Configuration
    Configuration = sc._jsc.hadoopConfiguration


    fs = FileSystem.get(URI(host), Configuration())

    try:
        status = fs.listStatus(Path(table_path))
        spark.read.csv(host+table_path, header=True)

        return True
    except Py4JJavaError as e:
        if 'FileNotFoundException' in str(e):
            return False
        else:
            raise
    except AnalysisException as e:
        if 'Unable to infer schema' in str(e):
            return False
        else:
            raise


def check_basic_quality(logger, host, table_path, table_type='csv'):
    """ Checks quality of DAG.
    
    We do this by checking if the table exists and is not empty.
    
    Args:
        - table_type(str): 'parquet' or 'csv'
    """
    if not spark_table_exists(host, table_path):
        logger.warn("(FAIL) Table {} does not exist".format(host+table_path))
        return None
    else:
        if table_type == 'parquet':
            sdf = spark.read.parquet(host+table_path)
            count = sdf.rdd.countApprox(timeout=1000, confidence=0.9)
        elif table_type == 'csv':
            sdf = spark.read.csv(host+table_path, header=True)
            count = sdf.rdd.countApprox(timeout=1000, confidence=0.9)
            
        if count == 0:
            logger.warn("(FAIL) Table {} is empty.".format(host+table_path))
        else:
            logger.warn("(SUCCESS) Table {} has {} rows.".format(host+table_path, count))
        return sdf


### Test

In [None]:
print(spark_table_exists('s3a://short-interest-effect', 'data/test_table')) # Fails due to lack of '/' before the table path
print(spark_table_exists('s3a://short-interest-effect', '/data/test_table'))
print(spark_table_exists('', 'test_data/test_table'))

check_basic_quality(logger, 's3a://short-interest-effect', '/data/test_table')

## 3. Pull Stock Info

### Code

In [None]:
# Pass into `args` argument

URL_NASDAQ = 'https://old.nasdaq.com/screening/companies-by-name.aspx?letter=0&exchange=nasdaq&render=download'
URL_NYSE = 'https://old.nasdaq.com/screening/companies-by-name.aspx?letter=0&exchange=nyse&render=download'

DB_HOST = ''
# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'

In [None]:
def pull_stock_info(url, db_host, table_path):
    response = requests.get(url)
    if response.status_code == 200 or response.status_code == 201:
        content = response.content.decode('utf-8')
        content = content.replace('Summary Quote', 'SummaryQuote')
        delete_path(spark, db_host, table_path)
        df = spark.createDataFrame([[content]], ['info_csv'])
        df.rdd.map(lambda x: x['info_csv'].replace("[","").replace("]", "")).saveAsTextFile(db_host+table_path)
        logger.warn("Stored data from {} to {}".format(url, db_host+table_path))
    else:
        logger.warn("Failed to connect to {}. We will use existing stock info data if they have been created.".format(url))
        
    
pull_stock_info(URL_NASDAQ, DB_HOST, TABLE_STOCK_INFO_NASDAQ)
pull_stock_info(URL_NYSE, DB_HOST, TABLE_STOCK_INFO_NYSE)

### Test

In [None]:
# DB_HOST = 's3a://short-interest-effect'
# TABLE_STOCK_INFO_NASDAQ = '/data/raw/stock_info_nasdaq'

df = spark.read.csv(DB_HOST+TABLE_STOCK_INFO_NASDAQ, header=True) \
               .drop('_c8').toPandas()
df.describe()

In [None]:
df = spark.read.csv(DB_HOST+TABLE_STOCK_INFO_NYSE,
                    header=True, ignoreLeadingWhiteSpace=True, inferSchema=True) \
               .drop('_c8').toPandas()
df.describe()

In [None]:
df.head(5)

### Quality-check

In [None]:
check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NASDAQ, table_type='csv')
check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NYSE, table_type='csv')

## 4. Pull Short Interest

#### Parallelize based on stocks or parallelize based on returned data points?

At the time of writing (2020-01-15), we have 3582 stocks from NASDAQ and 3092 stocks from NYSE. The earliest date is 2013-04-01, which accounts for nearly 1700 data points (261 working days each year).

For each stock, we will need to connect to an external API (Quandl or QuoteMedia). This will take more of the processing time rather than data processing. Therefore, we parallelize based on the stocks rather than returned data points. This way, multiple Spark nodes can connect to different URLs and pull the data. The downside is, obviously, for each node we will have to iteratively process the data, but this is still faster as there are fewer data points than the stocks, at least until several years in the future (There might be a solution that allows each spark node to parallelize...).

### Code

In [None]:
# exchange_map = {
#     'nasdaq': 'FNSQ',
#     'nyse': 'FNYX'
# }

In [None]:
url = "https://www.quandl.com/api/v3/datasets/FINRA/FNYX_FB?api_key={}".format(config['Quandl']['API_KEY'])
result = requests.get(url).json()
print(result['dataset']['data'][0])
print(result['dataset']['column_names'])
col_names = [result['dataset']['column_names']] * len(result['dataset']['data'])
newdata = []
for i, cols in enumerate(col_names):
    newdata.append(dict(zip(cols, result['dataset']['data'][i])))
newdata[:2]

#### Write to single tables

In [None]:
# Pass into `args` argument

START_DATE = config['App']['START_DATE']
QUANDL_API_KEY = config['Quandl']['API_KEY']
YESTERDAY_DATE = '2019-12-16'
LIMIT = 100
# STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
STOCKS = []
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_SHORT_INTERESTS_NASDAQ = 'test_data/raw/short_interests_nasdaq' 
TABLE_SHORT_INTERESTS_NYSE = 'test_data/raw/short_interests_nyse'

In [None]:
# %%timeit -n 1 -r 1

from datetime import datetime

def a_before_b(a, b):
    date_format = "%Y-%m-%d"

    # create datetime objects from the strings
    da = datetime.strptime(a, date_format)
    db = datetime.strptime(b, date_format)

    if da < db:
        return True
    else:
        return False
    

def rowlist2dict(rowlist):
    obj = {}
    for row in rowlist:
        obj[row['Symbol']] = row['last_date']
    return obj


def convert_data(olddata, symbol, url):
    col_names = olddata['dataset']['column_names']
    col_names.append('Symbol')
    col_names.append('SourceURL')
    col_names_multiplied = [col_names] * len(olddata['dataset']['data'])
    newdata = []
    for i, cols in enumerate(col_names_multiplied):
        datum = olddata['dataset']['data'][i]
        datum.append(symbol)
        datum.append(url)
        newdata.append(Row(**dict(zip(cols, datum))))
    return newdata


def pull_short_interests(exchange, host, info_table_path, short_interests_table_path, log_every_n=100):
        
    def pull_exchange_short_interests_by_symbol(symbol, start_date, end_date):
        """
        Return:
            list of dicts [{'colname': value, ...}, ...]
        """
        url = 'https://www.quandl.com/api/v3/datasets/FINRA/'+exchange+'_{}?start_date='+start_date+'&end_date='+end_date+'&api_key='+QUANDL_API_KEY
        url = url.format(symbol)
        response = requests.get(url)
        newdata = []
        if response.status_code in [200, 201]:
            newdata = convert_data(response.json(), symbol, url)
        return newdata

    # Prepare list of stocks
    if STOCKS is not None and len(STOCKS) > 0:
        rdd1 = spark.sparkContext.parallelize(STOCKS)
        row_rdd = rdd1.map(lambda x: Row(x))
        df = spark.createDataFrame(row_rdd,['Symbol'])
    else:
        df = spark.read.csv(host+info_table_path, header=True)
        if LIMIT is not None:
            df = df.limit(LIMIT)
    symbols = df.select('Symbol').rdd.map(lambda r: r['Symbol']).collect()
    
    table_exists = spark_table_exists(host, short_interests_table_path)

    last_dates = None
    if table_exists:
        short_sdf = spark.read.csv(host+short_interests_table_path, header=True)
        last_dates = short_sdf.groupBy('Symbol').agg(F.max('Date').alias('last_date')).collect()
        last_dates = rowlist2dict(last_dates)
        
    total_rows = 0
    data_to_write = []
    for i, symbol in enumerate(symbols):
        data = []
        if table_exists:
            # Get the last date of a stock. If this last date >= PULL_DATE, don't do anything.
            if last_dates != None:
                if symbol in last_dates:
                    date = last_dates[symbol]
                    if a_before_b(date, PULL_DATE):
                        data = pull_exchange_short_interests_by_symbol(symbol, date, PULL_DATE)
                        if len(data)==0:
                            logger.warn("{}: last date ({}) is > pull date ({}) and data exist Keep the data for storing".format(symbol, date, PULL_DATE))
                        else:
                            logger.warn("{}: last date ({}) is > pull date ({}) but no data is available in Quandl".format(symbol, date, PULL_DATE))
                    else:
                        logger.warn("{}: last date ({}) is <= pull date ({}), so do nothing".format(symbol, date, PULL_DATE))
                else:
                    logger.warn("{}: pull data from all dates".format(symbol))
                    data = pull_exchange_short_interests_by_symbol(symbol, START_DATE, PULL_DATE)
            else:
                logger.warn("{}: pull data from all dates".format(symbol))
                data = pull_exchange_short_interests_by_symbol(symbol, START_DATE, PULL_DATE)
        else:
            data = pull_exchange_short_interests_by_symbol(symbol, START_DATE, PULL_DATE)
        
        if len(data) > 0:
            data_to_write += data

        total_rows += len(data)
        if (i%log_every_n == 0 or (i+1) == len(symbols)):
            logger.warn("storing data downloaded from exchange {} - {}/{} - total rows in this batch: {}".format(exchange, i+1, len(symbols), total_rows))
            if len(data_to_write) > 0:
                sdf_to_write = spark.createDataFrame(data_to_write)
                sdf_to_write.write.mode('append').format('csv').save(host+short_interests_table_path, header=True)
                logger.warn("Written {} rows to {}".format(len(data_to_write), host+short_interests_table_path))
                data_to_write = []


    logger.warn("done!")

pull_short_interests('FNSQ', DB_HOST, TABLE_STOCK_INFO_NASDAQ, TABLE_SHORT_INTERESTS_NASDAQ)
pull_short_interests('FNYX', DB_HOST, TABLE_STOCK_INFO_NYSE, TABLE_SHORT_INTERESTS_NYSE)

### Test

In [None]:
sdf = spark.read.csv(DB_HOST+TABLE_SHORT_INTERESTS_NASDAQ, header=True).limit(5).dropDuplicates(['Date', 'Symbol'])
sdf = sdf.orderBy(sdf.Date.desc())
print(sdf.count())
df = sdf.toPandas()
print(df['SourceURL'][0])
df.head(5)

In [None]:
df.sort_values(by=['Date', 'Symbol'], ascending=True).head(5)

### Quality-check

In [None]:
# Pass into `args` argument

STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_SHORT_INTERESTS_NASDAQ = 'test_data/raw/short_interests_nasdaq' 
TABLE_SHORT_INTERESTS_NYSE = 'test_data/raw/short_interests_nyse'

In [None]:
if STOCKS is None or len(STOCKS) == 0:
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NASDAQ, table_type='csv')
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NYSE, table_type='csv')
check_basic_quality(logger, DB_HOST, TABLE_SHORT_INTERESTS_NASDAQ)
check_basic_quality(logger, DB_HOST, TABLE_SHORT_INTERESTS_NYSE)

### Test

In [None]:
pdf_temp = spark.read.csv(DB_HOST+TABLE_STOCK_PRICES+'-temp', header=True, ignoreLeadingWhiteSpace=True, inferSchema=True).toPandas()
print(pdf_temp.info())
pdf_temp.head(5)

In [None]:
pdf = spark.read.csv(DB_HOST+TABLE_STOCK_PRICES, header=True, inferSchema=True) \
    .dropDuplicates(['date', 'symbol']).toPandas()
print(pdf.info())
pdf.head(5)

In [None]:
pdf.sort_values(by=['date', 'symbol'], ascending=True).head(5)

### Quality-check

In [None]:
# Pass into `args` argument

STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_STOCK_PRICES = 'test_data/raw/prices'

In [None]:
if STOCKS is None or len(STOCKS) == 0:
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NASDAQ)
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NYSE)
check_basic_quality(logger, DB_HOST, TABLE_STOCK_PRICES, table_type='csv')

## 6. Combine Datasets

### Code

In [None]:
# Pass into `args` argument

YESTERDAY_DATE = '2019-12-12'
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_PRICES = 'test_data/raw/prices'
TABLE_SHORT_INTERESTS_NASDAQ = 'test_data/raw/short_interests_nasdaq' 
TABLE_SHORT_INTERESTS_NYSE = 'test_data/raw/short_interests_nyse'
TABLE_SHORT_ANALYSIS = 'test_data/processed/short_analysis'

In [None]:
# ----------------
# Combine short interest tables
# ----------------

# Without a schema (even with inferSchema=True), the SUM aggregation fails.
schema = T.StructType([
    T.StructField("Date", T.StringType(), True),
    T.StructField("ShortExemptVolume", T.FloatType(), True),
    T.StructField("ShortVolume", T.FloatType(), True),
    T.StructField("SourceURL", T.StringType(), True),
    T.StructField("Symbol", T.StringType(), True),
    T.StructField("TotalVolume", T.FloatType(), True),
])

sdf_shorts = spark.read.format('csv') \
                  .option('header', True) \
                  .option('schema', schema) \
                  .option('mode', 'DROPMALFORMED') \
                  .load([DB_HOST+TABLE_SHORT_INTERESTS_NASDAQ, DB_HOST+TABLE_SHORT_INTERESTS_NYSE])

rows = sdf_shorts.where((F.col('Symbol') == 'SPY') & (F.col('Date') == '2020-02-21')).collect()
logger.warn("Rows: {}".format(rows))

sdf_shorts = sdf_shorts.groupBy('Date', 'Symbol') \
                 .agg(F.sum('ShortExemptVolume').alias('short_exempt_volume'),
                      F.sum('ShortVolume').alias('short_volume'),
                      F.sum('TotalVolume').alias('total_volume')
                     ) \
                 .withColumnRenamed('Date', 'date') \
                 .withColumnRenamed('Symbol', 'symbol')

# ----------------
# Prepare for Quantopian
# ----------------

# DataFrame[short_exempt_volume: string, short_volume: string, total_volume: string, date: string, open: string, 
# high: string, low: string, close: string, volume: string, changed: string, changep: string, adjclose: string, 
# tradeval: string, tradevol: string, symbol: string]

# Correct all Quantopian errors here.
# -----------
sdf = sdf_shorts.withColumn('symbol', F.when(F.col('symbol')=='GECCL', 'GECC_L').otherwise(F.col('symbol')))
# -----------

sdf.select(['date', 'symbol', 'short_exempt_volume', 'short_volume', 'total_volume']) \
   .coalesce(1).write.mode('overwrite').csv(DB_HOST+TABLE_SHORT_ANALYSIS, header=True)

delete_path(spark, DB_HOST, TABLE_SHORT_ANALYSIS+".csv")
copyMerge(spark, DB_HOST, DB_HOST+TABLE_SHORT_ANALYSIS, DB_HOST+TABLE_SHORT_ANALYSIS+".csv")
delete_path(spark, DB_HOST, TABLE_SHORT_ANALYSIS)

logger.warn("done!")

### Test

In [None]:
df_short = sdf_shorts.toPandas()
df_prices = sdf_prices.toPandas()
df_short_analysis = sdf_short_analysis.toPandas()

print(df_short.info())
print(df_prices.info())
print(df_short_analysis.info())

In [None]:
spark.read.parquet(DB_HOST+TABLE_SHORT_ANALYSIS).toPandas().head(5)

In [None]:
spark.read.parquet(DB_HOST+TABLE_SHORT_ANALYSIS).repartition(1).write.mode('overwrite').format('csv').save(DB_HOST+TABLE_SHORT_ANALYSIS+".csv", header=True)

### Quality-check

In [None]:
# Pass into `args` argument

DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_SHORT_ANALYSIS = 'test_data/processed/short_analysis'

In [None]:
check_basic_quality(logger, DB_HOST, TABLE_SHORT_ANALYSIS)

## Final Dataset

In [None]:
df = sdf.toPandas()

In [None]:
df.describe()

In [None]:
df.info()

## References

- Best practices: https://towardsdatascience.com/apache-airflow-tips-and-best-practices-ff64ce92ef8
- Write/store dataframe as textfile: https://stackoverflow.com/questions/44537889/write-store-dataframe-in-text-file
- Delete hdfs path: https://stackoverflow.com/a/55952480/278191
- Delete hdfs path in S3: http://bigdatatech.taleia.software/2015/12/28/deleting-a-amazon-s3-path-from-apache-spark/
- Import java class in python: https://stackoverflow.com/questions/33544105/running-custom-java-class-in-pyspark