# Building the Pipeline on Local Spark

1. Setting up
2. Helpers
3. Pull stock info
4. Pull short interests
5. Pull stock prices
6. Combine datasets

## 1. Setting Up

We want to have a logging feature that works for both Jupyter notebook and Spark environments.

1. As it turned out, Spark has "WARN" but does not have "WARNING" level, while in current Python (3.6.x), "WARN" is deprecated, "WARNING" should be used instead.
2. Therefore, we create a custom "WARN" level as well as function `logger.warn` for Jupyter notebook.
3. As shown in [this StackOverflow post](https://stackoverflow.com/questions/35326814/change-level-logged-to-ipython-jupyter-notebook), this is not straightforward due to a Jupyter notebook bug. We need to workaround this by specifying an invalid value first, which we do in the code cell below.

In [2]:

# Run this, but don't copy into etl scripts
# workaround via specifying an invalid value first
%config Application.log_level='WORKAROUND'
import logging
logging.WARN = 21
logging.addLevelName(logging.WARN, 'WARN')

def warn(self, message, *args, **kws):
    if self.isEnabledFor(logging.WARN):
        # Yes, logger takes its '*args' as 'args'.
        self._log(logging.WARN, message, args, **kws) 
logging.Logger.warn = warn


logger = logging.getLogger()
logger.setLevel(logging.WARN)
logger.warn('hello')

# ------------------
from pyspark.sql import SparkSession

spark = spark = SparkSession \
        .builder \
        .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:2.7.0") \
        .getOrCreate()

import pandas as pd
import configparser
config = configparser.ConfigParser()
config.read('airflow/config.cfg')

ERROR:root:The 'log_level' trait of an IPKernelApp instance must be any of (0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL'), but a value of 'WORKAROUND' <class 'str'> was specified.
WARN:root:hello


['airflow/config.cfg']

In [3]:
# Run this cell, and also copy to all etl scripts, or simply include in common.py

import requests
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Row

from py4j.java_gateway import java_import

def delete_path(spark, host, path):
    sc = spark.sparkContext
    java_import(sc._gateway.jvm, "java.net.URI")
    uri = sc._gateway.jvm.java.net.URI
    fs = (sc._jvm.org
          .apache.hadoop
          .fs.FileSystem
          .get(uri(host), sc._jsc.hadoopConfiguration())
          )
    fs.delete(sc._jvm.org.apache.hadoop.fs.Path(host+path), True)

### Test

In [152]:
delete_path(spark, 's3a://short-interest-effect', '/data/raw/stock_info_nasdaq')

## 2. Helpers

Include this code as helpers in all next etl scripts.

### Code

In [6]:
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']

In [7]:
from py4j.protocol import Py4JJavaError

sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set("fs.s3a.access.key", AWS_ACCESS_KEY_ID)
sc._jsc.hadoopConfiguration().set("fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)

def spark_table_exists(host, table_path):
    URI           = sc._gateway.jvm.java.net.URI
    Path          = sc._gateway.jvm.org.apache.hadoop.fs.Path
    FileSystem    = sc._gateway.jvm.org.apache.hadoop.fs.FileSystem
    # Configuration = sc._gateway.jvm.org.apache.hadoop.conf.Configuration
    Configuration = sc._jsc.hadoopConfiguration


    fs = FileSystem.get(URI(host), Configuration())

    try:
        status = fs.listStatus(Path(table_path))

        return True
    except Py4JJavaError as e:
        if 'FileNotFoundException' in str(e):
            return False
        else:
            print(e)
    
    
def check_basic_quality(logger, host, table_path, table_type='parquet'):
    """ Checks quality of DAG.
    
    We do this by checking if the table exists and is not empty.
    
    Args:
        - table_type(str): 'parquet' or 'csv'
    """
    if not spark_table_exists(host, table_path):
        logger.warn("(FAIL) Table {} does not exist".format(host+table_path))
    else:
        if table_type == 'parquet':
            count = spark.read.parquet(host+table_path).count()
        elif table_type == 'csv':
            count = spark.read.csv(host+table_path, header=True).count()
            
        if count == 0:
            logger.warn("(FAIL) Table {} is empty.".format(host+table_path))
        else:
            logger.warn("(SUCCESS) Table {} has {} rows.".format(host+table_path, count))

### Test

In [91]:
print(spark_table_exists('s3a://short-interest-effect', 'data/test_table')) # Fails due to lack of '/' before the table path
print(spark_table_exists('s3a://short-interest-effect', '/data/test_table'))
print(spark_table_exists('', 'test_data/test_table'))

check_basic_quality(logger, 's3a://short-interest-effect', '/data/test_table')

False
True
True


WARN:root:(SUCCESS) Table s3a://short-interest-effect/data/test_table has 1706 rows.


## 3. Pull Stock Info

### Code

In [82]:
# Pass into `args` argument

URL_NASDAQ = 'https://old.nasdaq.com/screening/companies-by-name.aspx?letter=0&exchange=nasdaq&render=download'
URL_NYSE = 'https://old.nasdaq.com/screening/companies-by-name.aspx?letter=0&exchange=nyse&render=download'

DB_HOST = ''
# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'

In [135]:
def pull_stock_info(url, db_host, table_path):
    response = requests.get(url)
    if response.status_code == 200 or response.status_code == 201:
        content = response.content.decode('utf-8')
        content = content.replace('Summary Quote', 'SummaryQuote')
        delete_path(spark, db_host, table_path)
        df = spark.createDataFrame([[content]], ['info_csv'])
        df.rdd.map(lambda x: x['info_csv'].replace("[","").replace("]", "")).saveAsTextFile(db_host+table_path)
        logger.warn("Stored data from {} to {}".format(url, db_host+table_path))
    else:
        logger.warn("Failed to connect to {}. We will use existing stock info data if they have been created.".format(url))
        
    
pull_stock_info(URL_NASDAQ, DB_HOST, TABLE_STOCK_INFO_NASDAQ)
pull_stock_info(URL_NYSE, DB_HOST, TABLE_STOCK_INFO_NYSE)

WARN:root:Stored data from https://old.nasdaq.com/screening/companies-by-name.aspx?letter=0&exchange=nasdaq&render=download to test_data/raw/stock_info_nasdaq
WARN:root:Stored data from https://old.nasdaq.com/screening/companies-by-name.aspx?letter=0&exchange=nyse&render=download to test_data/raw/stock_info_nyse


### Test

In [138]:
df = spark.read.csv(DB_HOST+TABLE_STOCK_INFO_NASDAQ, header=True, inferSchema=True) \
               .drop('_c8').toPandas()
df.describe()

Unnamed: 0,Symbol,Name,LastSale,MarketCap,IPOyear,Sector,industry,SummaryQuote
count,3582,3582,3582.0,3582.0,3582.0,3582,3582.0,3582
unique,3582,3151,2845.0,2839.0,44.0,13,128.0,3582
top,FPXI,Barclays PLC,1.6,,,Health Care,,https://old.nasdaq.com/symbol/siga
freq,1,10,9.0,307.0,1889.0,785,522.0,1


In [137]:
df = spark.read.csv(DB_HOST+TABLE_STOCK_INFO_NYSE,
                    header=True, ignoreLeadingWhiteSpace=True, inferSchema=True) \
               .drop('_c8').toPandas()
df.describe()

Unnamed: 0,Symbol,Name,LastSale,MarketCap,IPOyear,Sector,industry,SummaryQuote
count,3092,3092,3092.0,3092.0,3092.0,3092.0,3092.0,3092
unique,3092,2438,2385.0,1924.0,36.0,13.0,133.0,3092
top,CS,Bank of America Corporation,,,,,,https://old.nasdaq.com/symbol/hpf
freq,1,14,106.0,700.0,1658.0,1010.0,1010.0,1


In [112]:
df.head(5)

Unnamed: 0,Symbol,Name,LastSale,MarketCap,IPOyear,Sector,industry,SummaryQuote
0,DDD,3D Systems Corporation,11.9,$1.41B,,Technology,Computer Software: Prepackaged Software,https://old.nasdaq.com/symbol/ddd
1,MMM,3M Company,179.78,$103.38B,,Health Care,Medical/Dental Instruments,https://old.nasdaq.com/symbol/mmm
2,WBAI,500.com Limited,8.0,$343.99M,2013.0,Consumer Services,Services-Misc. Amusement & Recreation,https://old.nasdaq.com/symbol/wbai
3,WUBA,58.com Inc.,69.18,$10.34B,2013.0,Technology,"Computer Software: Programming, Data Processing",https://old.nasdaq.com/symbol/wuba
4,EGHT,8x8 Inc,20.1,$2.01B,,Technology,EDP Services,https://old.nasdaq.com/symbol/eght


### Quality-check

In [None]:
check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NASDAQ, table_type='csv')
check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NYSE, table_type='csv')

## 4. Pull Short Interest

#### Parallelize based on stocks or parallelize based on returned data points?

At the time of writing (2020-01-15), we have 3582 stocks from NASDAQ and 3092 stocks from NYSE. The earliest date is 2013-04-01, which accounts for nearly 1700 data points (261 working days each year).

For each stock, we will need to connect to an external API (Quandl or QuoteMedia). This will take more of the processing time rather than data processing. Therefore, we parallelize based on the stocks rather than returned data points. This way, multiple Spark nodes can connect to different URLs and pull the data. The downside is, obviously, for each node we will have to iteratively process the data, but this is still faster as there are fewer data points than the stocks, at least until several years in the future (There might be a solution that allows each spark node to parallelize...).

### Code

In [47]:
# exchange_map = {
#     'nasdaq': 'FNSQ',
#     'nyse': 'FNYX'
# }

In [12]:
url = "https://www.quandl.com/api/v3/datasets/FINRA/FNYX_FB?api_key={}".format(config['Quandl']['API_KEY'])
result = requests.get(url).json()
print(result['dataset']['data'][0])
print(result['dataset']['column_names'])
col_names = [result['dataset']['column_names']] * len(result['dataset']['data'])
newdata = []
for i, cols in enumerate(col_names):
    newdata.append(dict(zip(cols, result['dataset']['data'][i])))
newdata[:2]

['2020-01-15', 763014.0, 15.0, 1201785.0]
['Date', 'ShortVolume', 'ShortExemptVolume', 'TotalVolume']


[{'Date': '2020-01-15',
  'ShortVolume': 763014.0,
  'ShortExemptVolume': 15.0,
  'TotalVolume': 1201785.0},
 {'Date': '2020-01-14',
  'ShortVolume': 918212.0,
  'ShortExemptVolume': 1.0,
  'TotalVolume': 1539251.0}]

#### Write to single tables

In [141]:
# Pass into `args` argument

START_DATE = config['App']['START_DATE']
QUANDL_API_KEY = config['Quandl']['API_KEY']
YESTERDAY_DATE = '2019-12-12'
LIMIT = 1
STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_SHORT_INTERESTS_NASDAQ = 'test_data/raw/short_interests_nasdaq' 
TABLE_SHORT_INTERESTS_NYSE = 'test_data/raw/short_interests_nyse'

In [140]:
%%timeit -n 1 -r 1

def convert_data(olddata, symbol, url):
    col_names = olddata['dataset']['column_names']
    col_names.append('Symbol')
    col_names.append('SourceURL')
    col_names_multiplied = [col_names] * len(olddata['dataset']['data'])
    newdata = []
    for i, cols in enumerate(col_names_multiplied):
        datum = olddata['dataset']['data'][i]
        datum.append(symbol)
        datum.append(url)
        newdata.append(dict(zip(cols, datum)))
    return newdata


def pull_short_interests(exchange, host, info_table_path, short_interests_table_path):

    create_table = not(spark_table_exists(host, short_interests_table_path))
        
    def pull_exchange_short_interests_by_symbol(symbol):
        """
        Return:
            list of dicts [{'colname': value, ...}, ...]
        """
        if create_table == True:
            # If table does not exist, pull all data.
            url = 'https://www.quandl.com/api/v3/datasets/FINRA/'+exchange+'_{}?start_date='+START_DATE+'&end_date='+YESTERDAY_DATE+'&api_key='+QUANDL_API_KEY
        else:
            # If table had existed, pull yesterday's data.
            url = 'https://www.quandl.com/api/v3/datasets/FINRA/'+exchange+'_{}?start_date='+YESTERDAY_DATE+'&end_date='+YESTERDAY_DATE+'&api_key='+QUANDL_API_KEY

        url = url.format(symbol)
        response = requests.get(url)
        newdata = []
        if response.status_code in [200, 201]:
            newdata = convert_data(response.json(), symbol, url)
        return newdata

    
    # [{'colname': value, ...}, ...]
    schema = T.ArrayType(
                T.MapType(
                    T.StringType(), T.StringType()
                )
             )
    udf_pull_exchange_short_interests = F.udf(pull_exchange_short_interests_by_symbol, schema)

    # Prepare list of stocks
    if STOCKS is not None and len(STOCKS) > 0:
        rdd1 = spark.sparkContext.parallelize(STOCKS)
        row_rdd = rdd1.map(lambda x: Row(x))
        df = spark.createDataFrame(row_rdd,['Symbol'])
    else:
        df = spark.read.parquet(host+info_table_path)
        if LIMIT is not None:
            df = df.limit(LIMIT)

    df = df.withColumn('short_interests', udf_pull_exchange_short_interests('Symbol'))

    # Convert [short_interests: [{col: val, ...}, ...]] to
    # [{col: val, ...}, ...]
    df = df.select(F.explode(df['short_interests']).alias('col')) \
         .rdd.map(lambda x: x['col'])

    df_schema = T.StructType([T.StructField('Date', T.StringType(), False),
                              T.StructField('ShortExemptVolume', T.StringType(), True),
                              T.StructField('ShortVolume', T.StringType(), True),
                              T.StructField('Symbol', T.StringType(), False),
                              T.StructField('TotalVolume', T.StringType(), True),
                              T.StructField('SourceURL', T.StringType(), True),
                             ])
    df = spark.createDataFrame(df, df_schema)
    df = df.withColumn('Date', df['Date'].cast(T.DateType())) \
         .withColumn('ShortExemptVolume', df['ShortExemptVolume'].cast(T.DoubleType())) \
         .withColumn('ShortVolume', df['ShortVolume'].cast(T.DoubleType())) \
         .withColumn('TotalVolume', df['TotalVolume'].cast(T.DoubleType()))

    if create_table:
        logger.warn("Creating table {}".format(host+short_interests_table_path))
        df.write.mode('overwrite').parquet(host+short_interests_table_path)
    else:
        logger.warn("Appending to table {}".format(host+short_interests_table_path))
        df.write.mode('append').parquet(host+short_interests_table_path)
        
        # Drop duplicates later when we combine the datasets:
        # 1. We do not want to waste S3 bandwidth.
        # 2. Raw data are meant to be dirty. We are going to use only the final dataset for analysis.
        # 3. If we really want to clean the datasets. Create another DAG for that.
        # code:
#         spark.read.parquet(host+short_interests_table_path).dropDuplicates(['Date']) \
#         .write.mode('append').parquet(host+short_interests_table_path)
        
    logger.warn("done!")

pull_short_interests('FNSQ', DB_HOST, TABLE_STOCK_INFO_NASDAQ, TABLE_SHORT_INTERESTS_NASDAQ)
pull_short_interests('FNYX', DB_HOST, TABLE_STOCK_INFO_NYSE, TABLE_SHORT_INTERESTS_NYSE)

WARN:root:Appending to table test_data/raw/short_interests_nasdaq
WARN:root:done!
WARN:root:Appending to table test_data/raw/short_interests_nyse
WARN:root:done!


9.34 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Test

In [143]:
sdf = spark.read.parquet(DB_HOST+TABLE_SHORT_INTERESTS_NASDAQ).dropDuplicates(['Date', 'Symbol'])
sdf = sdf.orderBy(sdf.Date.desc())
print(sdf.count())
df = sdf.toPandas()
print(df['SourceURL'][0])
df.head(5)

9080
https://www.quandl.com/api/v3/datasets/FINRA/FNSQ_GOOG?start_date=2013-04-01&api_key=zhiR5Rz7eFUy_XNcZb2f


Unnamed: 0,Date,ShortExemptVolume,ShortVolume,Symbol,TotalVolume,SourceURL
0,2020-01-15,119.0,135083.0,GOOG,234811.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
1,2020-01-15,8177.0,1137057.0,FB,2244132.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
2,2020-01-15,7813.0,643374.0,NFLX,1124373.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
3,2020-01-15,2312.0,191325.0,AMZN,609106.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
4,2020-01-15,0.0,38432.0,TRMT,66421.0,https://www.quandl.com/api/v3/datasets/FINRA/F...


In [144]:
df.sort_values(by=['Date', 'Symbol'], ascending=True).head(5)

Unnamed: 0,Date,ShortExemptVolume,ShortVolume,Symbol,TotalVolume,SourceURL
9079,2013-04-01,1200.0,321547.0,AMZN,627308.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
9076,2013-04-01,5500.0,3097667.0,FB,8349211.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
9077,2013-04-01,0.0,246063.0,GOOG,494596.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
9078,2013-04-01,100.0,409150.0,NFLX,1078168.0,https://www.quandl.com/api/v3/datasets/FINRA/F...
9075,2013-04-01,0.0,1780043.0,TSLA,5786214.0,https://www.quandl.com/api/v3/datasets/FINRA/F...


### Quality-check

In [145]:
# Pass into `args` argument

STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_SHORT_INTERESTS_NASDAQ = 'test_data/raw/short_interests_nasdaq' 
TABLE_SHORT_INTERESTS_NYSE = 'test_data/raw/short_interests_nyse'

In [146]:
if STOCKS is None or len(STOCKS) == 0:
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NASDAQ, table_type='csv')
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NYSE, table_type='csv')
check_basic_quality(logger, DB_HOST, TABLE_SHORT_INTERESTS_NASDAQ)
check_basic_quality(logger, DB_HOST, TABLE_SHORT_INTERESTS_NYSE)

WARN:root:(SUCCESS) Table test_data/raw/short_interests_nasdaq has 9085 rows.
WARN:root:(SUCCESS) Table test_data/raw/short_interests_nyse has 10544 rows.


## 5. Pull Stock Prices

### Code (1)

In [147]:
# Pass into `args` argument

START_DATE = config['App']['START_DATE']
QUANDL_API_KEY = config['Quandl']['API_KEY']
YESTERDAY_DATE = '2019-12-12'
LIMIT = 10
STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_STOCK_PRICES = 'test_data/raw/prices'

URL = """http://app.quotemedia.com/quotetools/getHistoryDownload.csv?&webmasterId=501&startDay={sd}&startMonth={sm}&startYear={sy}&endDay={ed}&endMonth={em}&endYear={ey}&isRanged=true&symbol={sym}"""

In [148]:
# Include on top of the ETL script

START_DAY = START_DATE.split('-')[2]
# In QuoteMedia, months start from 0, so we adjust this variable.
START_MONTH = int(START_DATE.split('-')[1]) - 1
START_YEAR = START_DATE.split('-')[0]

YST_DAY = YESTERDAY_DATE.split('-')[2]
# In QuoteMedia, months start from 0, so we adjust this variable.
YST_MONTH = int(YESTERDAY_DATE.split('-')[1]) - 1
YST_YEAR = YESTERDAY_DATE.split('-')[0]

#### 5.1. Test: Get data from one source

In [18]:
%%timeit -n 1 -r 1

import timeit
start_time = timeit.default_timer()

# %%timeit -n 1 -r 1
# 8.67 s ± 1.53 s per loop (mean ± std. dev. of 3 runs, 3 loops each)

import csv

# response = requests.get(URL.format(sd=YST_DAY, sm=YST_MONTH, sy=YST_YEAR,
response = requests.get(URL.format(sd=START_DATE, sm=START_MONTH, sy=START_YEAR,
                                   ed=YST_DAY, em=YST_MONTH, ey=YST_YEAR,
                                   sym='SPY'))
# content = response.content.decode('utf-8')
# data = [{k: v for k, v in row.items()}
#         for row in csv.DictReader(content.splitlines(), skipinitialspace=True)]
# print(len(data))
# print(data[0])

elapsed = timeit.default_timer() - start_time
print("elapsed time: {}s".format(elapsed))

elapsed time: 2.2370002679999743s
2.24 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [19]:
START_MONTH

3

#### 5.2. Test: Get data from all sources

#### 5.3. Test: Get data from one source and parallelize with Spark

In [21]:
%%timeit -n 1 -r 1
# response = requests.get(URL.format(sd=YST_DAY, sm=YST_MONTH, sy=YST_YEAR,
response = requests.get(URL.format(sd=START_DATE, sm=START_MONTH, sy=START_YEAR,
                                   ed=YST_DAY, em=YST_MONTH, ey=YST_YEAR,
                                   sym='SPY'))
content = response.content.decode('utf-8')
data = spark.sparkContext.parallelize(content.splitlines())
data = spark.read.csv(data, header=True) \
       .write.mode('overwrite').parquet('test/data/raw/test_prices')

3.04 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


#### 5.4. Get data from all source in parallel then store them (also in parallel)

```
## Writing 369 rows, 2 requests:
# 6.8 s ± 125 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
## Writing 2 rows, 2 requests:
# 5.86 s ± 133 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
## writing 6468 rows, 10 requests:
# 36.7 s ± 3.06 s per loop (mean ± std. dev. of 3 runs, 3 loops each)
## No partition, 10 requests (incorrect rows, 6477):
# 43.8 s ± 1.3 s per loop (mean ± std. dev. of 3 runs, 3 loops each)


## Using temp table, writing 367 rows, 2 requests:
# 7.81 s ± 386 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
## Using temp table, writing 2 rows, 2 requests:
# 7.05 s ± 289 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
## Using temp table, writing 9 rows, 10 requests:
# 20.3 s ± 85.8 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)
## Using temp table, writing 6495 rows, 10 requests:
# 23.5 s ± 3.82 s per loop (mean ± std. dev. of 3 runs, 3 loops each)
# Write as parquet
# 41.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
```

Conclusion: Better to use temp table written as CSV.

### Code (2)

In [80]:
%%timeit -n 1 -r 1

START_DAY = START_DATE.split('-')[2]
# In QuoteMedia, months start from 0, so we adjust this variable.
START_MONTH = int(START_DATE.split('-')[1]) - 1
START_YEAR = START_DATE.split('-')[0]

YST_DAY = YESTERDAY_DATE.split('-')[2]
# In QuoteMedia, months start from 0, so we adjust this variable.
YST_MONTH = int(YESTERDAY_DATE.split('-')[1]) - 1
YST_YEAR = YESTERDAY_DATE.split('-')[0]

create_table = not(spark_table_exists(DB_HOST, TABLE_STOCK_PRICES))

def pull_prices_by_symbol(symbol):
    """
    Return:
        list of dicts [{'colname': value, ...}, ...]
    """
    if create_table == True:
        # If table does not exist, pull all data.
        url = URL.format(sd=START_DAY, sm=START_MONTH, sy=START_YEAR,
                         ed=YST_DAY, em=YST_MONTH, ey=YST_YEAR,
                         sym=symbol)
    else:
        # If table had existed, pull yesterday's data.
        url = URL.format(sd=YST_DAY, sm=YST_MONTH, sy=YST_YEAR,
                         ed=YST_DAY, em=YST_MONTH, ey=YST_YEAR,
                         sym=symbol)
        
    # Code for always overwrite without temp table
#     url = URL.format(sd=START_DAY, sm=START_MONTH, sy=START_YEAR,
#                      ed=YST_DAY, em=YST_MONTH, ey=YST_YEAR,
#                      sym=symbol)

    response = requests.get(url)
    newdata = ""
    if response.status_code in [200, 201]:
        newdata = response.content.decode('utf-8')
        newdata = newdata.replace('\n', ','+symbol+'\n')
        newdata = newdata.replace('tradevol,'+symbol+'\n', 'tradevol,symbol\n')
    return newdata

schema = T.StringType()
udf_pull_prices = F.udf(pull_prices_by_symbol, schema)
    
# Prepare list of stocks
if STOCKS is not None and len(STOCKS) > 0:
    rdd1 = spark.sparkContext.parallelize(STOCKS)
    row_rdd = rdd1.map(lambda x: Row(x))
    df = spark.createDataFrame(row_rdd,['Symbol'])
else:
    df = spark.read.parquet(DB_HOST+TABLE_STOCK_INFO_NASDAQ,
                            DB_HOST+TABLE_STOCK_INFO_NYSE) \
         .select('Symbol').dropDuplicates()
    if LIMIT is not None:
        df = df.limit(LIMIT)

df = df.withColumn('prices_csv', udf_pull_prices('Symbol'))

df = df.select('prices_csv').where(df['prices_csv'] != '')

table_name = DB_HOST+TABLE_STOCK_PRICES
mode = 'overwrite'
if create_table:
    logger.warn("Creating table {}".format(table_name))    
else:
    logger.warn("Appending to table {}".format(table_name))
    mode = 'append'

# Repartition here is important so we may end up with multiple CSV-like files.
# Without repartition, the headers are going to be written multiple times
# in a single csv file.
tempdir = DB_HOST+TABLE_STOCK_PRICES+'-temp'
logger.warn("    Creating temporary table {}".format(tempdir))

numrows = df.count()
df \
    .repartition(numrows).write.mode('overwrite') \
    .csv(tempdir, header=False, quote=" ")


if create_table:
    logger.warn("    done! Now creating table {}".format(table_name))
else:
    logger.warn("    done! Now appending to table {}".format(table_name))

spark.read.csv(tempdir, header=True, ignoreLeadingWhiteSpace=True, inferSchema=True) \
.write.mode(mode).csv(table_name, header=True)
# .write.mode(mode).parquet(DB_HOST+TABLE_STOCK_PRICES)


logger.warn("done!")

WARN:root:Appending to table test_data/raw/prices
WARN:root:    Creating temporary table test_data/raw/prices-temp
WARN:root:    done! Now appending to table test_data/raw/prices
WARN:root:done!


10.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Test

In [46]:
pdf_temp = spark.read.csv(DB_HOST+TABLE_STOCK_PRICES+'-temp', header=True, ignoreLeadingWhiteSpace=True, inferSchema=True).toPandas()
print(pdf_temp.info())
pdf_temp.head(5)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10458 entries, 0 to 10457
Data columns (total 12 columns):
date        10455 non-null datetime64[ns]
open        10455 non-null object
high        10455 non-null object
low         10455 non-null object
close       10455 non-null float64
volume      10455 non-null float64
changed     10455 non-null float64
changep     10455 non-null object
adjclose    10455 non-null float64
tradeval    10455 non-null object
tradevol    10455 non-null float64
symbol      10455 non-null object
dtypes: datetime64[ns](1), float64(5), object(6)
memory usage: 980.5+ KB
None


Unnamed: 0,date,open,high,low,close,volume,changed,changep,adjclose,tradeval,tradevol,symbol
0,2019-12-13,196.4,196.8,193.17,194.11,18806020.0,-2.64,-1.34%,194.11,3657544992.09,192331.0,FB
1,2019-12-12,202.35,203.66,194.1,196.75,23766986.0,-5.51,-2.72%,196.75,4710540903.58,208246.0,FB
2,2019-12-11,200.28,202.63,200.28,202.26,8041827.0,1.39,0.69%,202.26,1622090474.6,77975.0,FB
3,2019-12-10,201.66,202.05,200.15,200.87,9485568.0,-0.47,-0.23%,200.87,1905136360.62,88428.0,FB
4,2019-12-09,200.65,203.1418,200.21,201.34,12013218.0,0.29,0.14%,201.34,2427805092.89,102640.0,FB


In [22]:
pdf = spark.read.csv(DB_HOST+TABLE_STOCK_PRICES, header=True, inferSchema=True) \
    .dropDuplicates(['date', 'symbol']).toPandas()
print(pdf.info())
pdf.head(5)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10456 entries, 0 to 10455
Data columns (total 12 columns):
date        10455 non-null datetime64[ns]
open        10455 non-null object
high        10455 non-null object
low         10455 non-null object
close       10455 non-null float64
volume      10455 non-null float64
changed     10455 non-null float64
changep     10455 non-null object
adjclose    10455 non-null float64
tradeval    10455 non-null object
tradevol    10455 non-null float64
symbol      10455 non-null object
dtypes: datetime64[ns](1), float64(5), object(6)
memory usage: 980.3+ KB
None


Unnamed: 0,date,open,high,low,close,volume,changed,changep,adjclose,tradeval,tradevol,symbol
0,2013-04-19,25.62,25.96,25.33,25.73,20353547.0,0.04,0.16%,25.73,523595607.18,66068.0,FB
1,2013-04-22,99.35,99.66,98.38,99.32,5609009.0,-0.493,-0.60%,81.59,555207178.14,32582.0,MCD
2,2013-04-26,100.83,100.99,100.4,100.89,3113532.0,-0.041,-0.05%,82.8798,313733616.64,16412.0,MCD
3,2013-05-13,26.6,27.325,26.531,26.82,29009648.0,0.14,0.52%,26.82,782579320.46,90803.0,FB
4,2013-07-10,291.41,293.34,289.4,292.33,1822877.0,0.8,0.27%,292.33,531832345.31,11812.0,AMZN


In [23]:
pdf.sort_values(by=['date', 'symbol'], ascending=True).head(5)

Unnamed: 0,date,open,high,low,close,volume,changed,changep,adjclose,tradeval,tradevol,symbol
5227,2013-04-02,262.4,265.89,260.55,263.322,2631038.0,1.712,0.65%,263.322,693325169.2,17760.0,AMZN
1986,2013-04-02,25.77,26.12,25.3,25.42,35124893.0,-0.11,-0.43%,25.42,904220577.12,107077.0,FB
2041,2013-04-02,99.4,100.42,99.025,100.26,5136501.0,0.994,1.22%,82.3622,513202156.02,24959.0,MCD
9771,2013-04-02,183.9,185.1799,176.1,176.69,4610979.0,-0.82,-3.15%,25.2414,828031668.71,27837.0,NFLX
5432,2013-04-02,43.6,45.5,43.5101,44.34,6621439.0,0.41,0.93%,44.34,294906438.13,28077.0,TSLA


### Quality-check

In [58]:
# Pass into `args` argument

STOCKS = ['FB', 'GOOG', 'AMZN', 'TRMT', 'TSLA', 'MCD', 'NFLX']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_INFO_NASDAQ = 'test_data/raw/stock_info_nasdaq'
TABLE_STOCK_INFO_NYSE = 'test_data/raw/stock_info_nyse'
TABLE_STOCK_PRICES = 'test_data/raw/prices'

In [62]:
if STOCKS is None or len(STOCKS) == 0:
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NASDAQ)
    check_basic_quality(logger, DB_HOST, TABLE_STOCK_INFO_NYSE)
check_basic_quality(logger, DB_HOST, TABLE_STOCK_PRICES, table_type='csv')

WARN:root:(SUCCESS) Table test_data/raw/prices has 10458 rows.


## 6. Combine Datasets

### Code

In [14]:
# Pass into `args` argument

YESTERDAY_DATE = '2019-12-12'
AWS_ACCESS_KEY_ID = config['AWS']['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = config['AWS']['AWS_SECRET_ACCESS_KEY']
DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_STOCK_PRICES = 'test_data/raw/prices'
TABLE_SHORT_INTERESTS_NASDAQ = 'test_data/raw/short_interests_nasdaq' 
TABLE_SHORT_INTERESTS_NYSE = 'test_data/raw/short_interests_nyse'
TABLE_SHORT_ANALYSIS = 'test_data/processed/short_analysis'

In [17]:
create_table = not(spark_table_exists(DB_HOST, TABLE_SHORT_ANALYSIS))

sdf_shorts = spark.read.parquet(DB_HOST+TABLE_SHORT_INTERESTS_NASDAQ, DB_HOST+TABLE_SHORT_INTERESTS_NYSE)
sdf_shorts = sdf_shorts.groupby(['Date', 'Symbol']) \
                 .agg(F.sum(sdf_shorts['ShortExemptVolume']).alias('short_exempt_volume'),
                      F.sum(sdf_shorts['ShortVolume']).alias('short_volume'),
                      F.sum(sdf_shorts['TotalVolume']).alias('total_volume'),
                      F.first(sdf_shorts['SourceURL']).alias('source_url')
                     )
sdf_prices = spark.read.csv(DB_HOST+TABLE_STOCK_PRICES, header=True, inferSchema=True) \
             .dropDuplicates(['date', 'symbol'])
sdf_prices = sdf_prices.withColumn('date', sdf_prices['date'].cast(T.DateType()))

if create_table == False:
    sdf_shorts = sdf_shorts.filter(sdf_shorts['Date'] >= F.to_date(F.lit(YESTERDAY_DATE)))
    sdf_prices = sdf_prices.filter(sdf_prices['date'] >= F.to_date(F.lit(YESTERDAY_DATE)))

sdf_short_analysis = sdf_shorts.join(sdf_prices, (sdf_shorts['Date'] == sdf_prices['date']) & \
                                     (sdf_shorts['Symbol'] == sdf_prices['symbol']), how='inner') \
                               .drop(sdf_shorts['Date']).drop(sdf_shorts['Symbol'])

mode = 'overwrite'
if create_table == False:
    logger.warn("Appending to table {}".format(DB_HOST+TABLE_SHORT_ANALYSIS))
    mode = 'append'
else:
    logger.warn("Creating table {}".format(DB_HOST+TABLE_SHORT_ANALYSIS))

sdf_short_analysis.write.mode(mode).parquet(DB_HOST+TABLE_SHORT_ANALYSIS)
logger.warn("done!")

WARN:root:Creating table test_data/processed/short_analysis
WARN:root:done!


### Test

In [18]:
df_short = sdf_shorts.toPandas()
df_prices = sdf_prices.toPandas()
df_short_analysis = sdf_short_analysis.toPandas()

print(df_short.info())
print(df_prices.info())
print(df_short_analysis.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10796 entries, 0 to 10795
Data columns (total 6 columns):
Date                   10796 non-null object
Symbol                 10796 non-null object
short_exempt_volume    10796 non-null float64
short_volume           10796 non-null float64
total_volume           10796 non-null float64
source_url             10796 non-null object
dtypes: float64(3), object(3)
memory usage: 506.1+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10456 entries, 0 to 10455
Data columns (total 12 columns):
date        10455 non-null object
open        10455 non-null object
high        10455 non-null object
low         10455 non-null object
close       10455 non-null float64
volume      10455 non-null float64
changed     10455 non-null float64
changep     10455 non-null object
adjclose    10455 non-null float64
tradeval    10455 non-null object
tradevol    10455 non-null float64
symbol      10455 non-null object
dtypes: float64(5), object(7)
memory u

In [19]:
spark.read.parquet(DB_HOST+TABLE_SHORT_ANALYSIS).toPandas().head(5)

Unnamed: 0,short_exempt_volume,short_volume,total_volume,source_url,date,open,high,low,close,volume,changed,changep,adjclose,tradeval,tradevol,symbol
0,0.0,418961.0,753385.0,https://www.quandl.com/api/v3/datasets/FINRA/F...,2013-05-01,215.92,217.389,211.65,212.91,2622654,-0.451,-1.46%,30.4157,561027949.42,14673,NFLX
1,37360.0,6060485.0,20502608.0,https://www.quandl.com/api/v3/datasets/FINRA/F...,2013-05-06,28.33,28.46,27.48,27.57,43862625,-0.741,-2.62%,27.57,1219158766.94,120016,FB
2,0.0,635427.0,1388246.0,https://www.quandl.com/api/v3/datasets/FINRA/F...,2013-05-06,209.63,212.45,204.02,210.69,4532918,-0.394,-1.29%,30.0985,949314738.05,26760,NFLX
3,300.0,451077.0,952404.0,https://www.quandl.com/api/v3/datasets/FINRA/F...,2013-05-09,258.73,263.55,256.88,260.16,2769255,1.48,0.57%,260.16,723085024.78,17822,AMZN
4,0.0,342728.0,1335921.0,https://www.quandl.com/api/v3/datasets/FINRA/F...,2013-06-28,276.19,279.83,276.19,277.69,3193262,0.14,0.05%,277.69,889020787.27,14332,AMZN


### Quality-check

In [63]:
# Pass into `args` argument

DB_HOST = ''

# Table names: update to add '/' in the final code.
TABLE_SHORT_ANALYSIS = 'test_data/processed/short_analysis'

In [64]:
check_basic_quality(logger, DB_HOST, TABLE_SHORT_ANALYSIS)

WARN:root:(SUCCESS) Table test_data/processed/short_analysis has 10394 rows.


## Final Dataset

In [9]:
sdf = spark.read.parquet(config['App']['DB_HOST']+config['App']['TABLE_SHORT_ANALYSIS'])

In [11]:
df = sdf.toPandas()

In [12]:
df.describe()

Unnamed: 0,short_exempt_volume,short_volume,total_volume,close,volume,changed,adjclose,tradevol
count,10548.0,10548.0,10548.0,10548.0,10548.0,10548.0,10548.0,10548.0
mean,11864.37,1403028.0,3208060.0,394.143509,9086290.0,0.389712,374.410873,66662.82
std,46190.0,2289725.0,5337051.0,422.500286,14486120.0,11.274337,429.671097,67526.94
min,0.0,1.0,1.0,3.94,53.0,-139.36,3.6124,0.0
25%,500.0,277236.8,695650.0,120.6275,2423723.0,-1.401,99.175,26285.25
50%,2183.0,632995.5,1455506.0,221.885,4423890.0,0.08,196.5202,42934.0
75%,8187.75,1543929.0,3499728.0,496.1725,9444361.0,2.08925,428.165,87239.25
max,1653923.0,48877400.0,120162500.0,2039.51,365380600.0,558.46,2039.51,1312878.0


In [13]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10548 entries, 0 to 10547
Data columns (total 14 columns):
short_exempt_volume    10548 non-null float64
short_volume           10548 non-null float64
total_volume           10548 non-null float64
source_url             10548 non-null object
open                   10548 non-null object
high                   10548 non-null object
low                    10548 non-null object
close                  10548 non-null float64
volume                 10548 non-null int32
changed                10548 non-null float64
changep                10548 non-null object
adjclose               10548 non-null float64
tradeval               10548 non-null object
tradevol               10548 non-null int32
dtypes: float64(6), int32(2), object(6)
memory usage: 1.0+ MB


## References

- Best practices: https://towardsdatascience.com/apache-airflow-tips-and-best-practices-ff64ce92ef8
- Write/store dataframe as textfile: https://stackoverflow.com/questions/44537889/write-store-dataframe-in-text-file
- Delete hdfs path: https://stackoverflow.com/a/55952480/278191
- Delete hdfs path in S3: http://bigdatatech.taleia.software/2015/12/28/deleting-a-amazon-s3-path-from-apache-spark/
- Import java class in python: https://stackoverflow.com/questions/33544105/running-custom-java-class-in-pyspark