In [None]:
# system/os/regex and basic math functions
import os
import re
import sys
import math
import json
import time
import string
import dateutil
from pathlib import Path
import datetime as dt
from itertools import chain

In [None]:
class adict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

In [None]:
# Set logging level
import logging
try:
    _delete_me = {'level':getattr(logging, LOG_LEVEL)}
except NameError:
    _delete_me = {'level':logging.WARNING}
    print('Set LOG_LEVEL="INFO" before running the import file to get moar output.')
try:
    _delete_me['format'] = LOG_FORMATs
except NameError:
    _delete_me['format'] = "%(levelname)s::%(message)s"
    print('Set LOG_FORMAT to change log format.')

logging.basicConfig(**_delete_me)
logger = logging.getLogger('notebook')
del _delete_me

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# IPython display convenience stuff
try:
    from IPython.display import HTML, display, display_html, display_javascript
    from IPython import __version__ as ipythonversion
    import ipywidgets
    print("IPython: {}".format(ipythonversion))
except ImportError:
    pass

In [None]:
try:
    # numpy for matrix algebra
    import numpy as np
    os.environ['NUMEXPR_MAX_THREADS'] = '20'
    print("Numpy (np): {}".format(np.version.full_version))
except ImportError:
    pass

In [None]:
try:
    # scipy for probability distributions and some statistical tests
    import scipy as sp
    import scipy.stats as stats
    print("Scipy (sp, stats): {}".format(sp.version.full_version))
except ImportError:
    pass

In [None]:
try:
    # pandas for data manipulation
    import pandas as pd
    print("Pandas (pd): {}".format(pd.__version__))

    def fmt_float(float_in, rstrip0s=re.compile(r'\.0+$')):
        try:
            return rstrip0s.sub('', '{0:,.{1}f}'.format(float_in, 3 - 3 * bool(abs(float_in) // 1000)))
        except Exception:
            return str(float_in)
    pd.set_option('float_format', fmt_float)
    pd.set_option('display.max_rows', 250)
    pd.set_option('display.max_columns', 250)
    pd.set_option('display.notebook_repr_html', True)
    try:
        # New pandas 2.0 feature
        pd.set_option('mode.nullable_dtypes', True)
    except Exception:
        pass

    def latex_format(num_in):
        """Format numbers for Latex tables"""
        try:
            num_in = float(num_in)
            num_dig = np.log10(abs(num_in)) + 1
            if num_in == 0:
                return "0"
            if num_dig >= 3:
                return f"{int(num_in):,d}"
            elif num_dig >= 1:
                return f"{num_in:2.1f}"
            return f"{num_in:1.3f}"
        except ValueError:
            return str(num_in)

    def S(df, cols=None, keep_dups=False):
        """S splits strings, and if called with a df input, interpolates variable names.

        Example::
            S('gvkey datadate') # --> ['gvkey', 'datadate']
            df.S('gvk* datad* num*') # --> ['gvkey', 'datadate', 'num_words', 'num_sentences']
        """
        if isinstance(df, str):
            cols = df
        if isinstance(cols, str):
            new_cols = []
            for col in cols.split():
                if '*' in col or '?' in col:
                    matcher = re.compile(r'\b'+col.replace('*', '.*').replace('?', '.')+r'\b', re.I)
                    new_cols.extend([c for c in df.columns if matcher.search(c)])
                else:
                    new_cols.append(col)
            cols = new_cols
        return cols if keep_dups else list(dict(zip(cols, cols)))

    # monkeypatch C into DataFrame
    pd.DataFrame.S = S

    def hugetable(df, soft_max=5000, hard_max=100_000):
        max_rows = pd.options.display.max_rows
        max_columns = pd.options.display.max_columns
        pd.options.display.max_rows = min(soft_max, 100_000)
        pd.options.display.max_columns = min(soft_max, 100_000)
        display_html(df)
        pd.options.display.max_rows = max_rows
        pd.options.display.max_columns = max_columns
except (ImportError, ModuleNotFoundError):
    pass

In [None]:
try:
    # matplotlib for plotting and pyplot for MATLAB-style API
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    from matplotlib.ticker import FuncFormatter
    print("MatPlotLib (mpl, plt): {}".format(mpl.__version__))
except ImportError:
    pass

In [None]:
try:
    # Seaborn for pretty plotting
    import seaborn as sns
    print("Seaborn (sns): {}".format(sns.__version__))
except ImportError:
    pass

In [None]:
try:
    # Scikit Learn for more regressions
    import sklearn as sk
    print("Scikit-Learn (sk): {}".format(sk.__version__))
except ImportError:
    pass

In [None]:
try:
    # statsmodels for econometrics
    import statsmodels.api as sm
    import statsmodels.formula.api as smf
    print("Statsmodels (sm): {}".format(sm.__version__))
except (ImportError, AttributeError):
    pass

In [None]:
try:
    # patsy for making formulas
    import patsy as pt
    print("Patsy (pt): {}".format(pt.__version__))
except ImportError:
    pass

In [None]:
try:
    # SQLAlchemy for relational db management
    import sqlalchemy as sa
    print("SQLAlchemy (sa): {}".format(sa.__version__))
except ImportError:
    pass

In [None]:
try:
    # Gensim for textual analysis
    import gensim
    print("Gensim: {}".format(gensim.__version__))
except ImportError:
    pass

In [None]:
try:
    # TQDM for progress bar outputs
    from tqdm.notebook import tqdm
except ImportError:
    def tqdm(thing, *args, **kwargs):
        return thing

In [None]:
try:
    from bs4 import BeautifulSoup
except ImportError:
    pass

In [None]:
try:
    from pyedgar.utilities import edgarweb
except (ImportError, ModuleNotFoundError):
    class _o_(object):
        def edgar_links(*args, **kwargs):
            return ''
    edgarweb = _o_()

In [None]:
MIN_DATE = dt.datetime(1900, 1, 1)
MAX_DATE = dt.datetime(2030, 1, 1)

TD_DAY = pd.Timedelta(days=1)
TD_YEAR = pd.Timedelta(days=1) * 365

In [None]:
# print("linkhead(df, n=5, title='', fields=None, cik='cik', accession='accession')")
def linkhead(df, n=5, title='', fields=None, cik='cik', accession='accession', return_df=False):
    """
    Displays top rows of a dataframe, and includes
    links to the HTML and FTP websites if CIK and Accession are found.
    """
    if len(df) == 0:
        if not return_df:
            display(df[fields or df.columns])
            return
        else:
            return df

    w = pd.get_option('display.max_colwidth')
    pd.set_option('display.max_colwidth', None)

    dfn = df.head(n)[fields or df.columns].copy()

    if cik in dfn.columns:
        linkstr, i = 'links', 0
        while linkstr in dfn.columns:
            linkstr = 'links%d' % i
            i += 1
        dfn[linkstr] = dfn.apply(lambda row: edgarweb.edgar_links(row[cik], row[accession]), axis=1)

    html = f"<h4>{title}</h4>" if title else ''
    html += dfn.to_html(escape=False, index=False, na_rep="")

    if not return_df: display_html(html, raw=True)
    pd.set_option('display.max_colwidth', w)

    if return_df:
        return dfn

In [None]:
# print("timehist(dtseries_or_df, time_variable='year', y_tic_number=4, x_tic_skip=0, *args, **kwargs)")
def timehist(dtseries_or_df, time_variable='year',
             y_tic_number=4, x_tic_skip=0,
             width=.9, ax=None, skip_retick=None,
             label=None,
             *args, **kwargs):
    """
    Historgam of observations per time period.
    First tries: dtseries_or_df.dt.time_variable
    Failing that, does dtseries_or_df.value_counts()
    Sends args and kwargs to figure.
    """
    if ax is not None and skip_retick is None:
            skip_retick = True
    skip_retick = skip_retick or False

    x_tic_skip += 1

    if not skip_retick:
        sns.set_style('darkgrid')
        sns.set_context('talk', rc={'patch.linewidth': 0, 'patch.edgecolor': 'k', 'patch.facecolor': 'k'})

    _d = dtseries_or_df
    try:
        _d = _d.dt.__getattribute__(time_variable)
    except:
        try:
            _d = _d[time_variable]
        except:
            pass
    _g = _d.value_counts().sort_index()
    if len(_g) > 1000:
        logger.error("ERROR: You are trying to plot something with too many levels. Don't do that.")
        return

    if ax is None:
        if 'figsize' not in kwargs:
            kwargs['figsize'] = (13,2)
        plt.figure(*args, **kwargs)
        ax = plt.gca()
        # If ax is none, assume kwargs are for figure generation.
        kwargs = {}

    ax.bar(_g.index, _g, width=width, label=label, **kwargs)

    if not skip_retick:
        # Format and label X axis
        ax.set_xlim(left=_g.index.min()-0.5, right=_g.index.max()+0.5)
        _t = _g.index[::x_tic_skip]
        ax.set_xticks(_t)
        ax.set_xticklabels(map(str, _t), rotation='vertical')

        # Label Y Axis
        tene = math.log10(_g.max())//1-1
        topnum = math.ceil(_g.max() / 10**tene)
        ax.set_yticks([(topnum * i // y_tic_number)*10**tene for i in range(y_tic_number, 0, -1)])

    return ax

In [None]:
# print("timeqtrhist(dtseries_or_df, dt_variable='datadate', y_tic_number=4, x_tic_skip=0, *args, **kwargs)")
def timeqtrhist(dtseries_or_df, dt_variable='datadate',
                y_tic_number=4, x_tic_skip=0,
                width=.23, ax=None, skip_retick=False,
                label=None,
                *args, **kwargs):
    """
    Historgam of observations per quarter.
    First tries: dtseries_or_df[dt_variable]
    Failing that, does dtseries_or_df.value_counts()
    Sends args and kwargs to figure.
    """
    x_tic_skip += 1
    sns.set_style('darkgrid')
    sns.set_context('talk', rc={'patch.linewidth': 0, 'patch.edgecolor': 'k', 'patch.facecolor': 'k'})
    _d = dtseries_or_df
    try:
        _d = _d[dt_variable].copy()
    except:
        pass
    try:
        _q = _d.dt.to_period('Q')
        _g = _q.value_counts().sort_index().reset_index()
        _g.columns = 'date nobs'.split()
        _g['xloc'] = _g.date.dt.year + _g.date.dt.quarter/4
        _g4 = _g.copy()
        _g4.loc[_g4.date.dt.quarter!=4, 'nobs'] = 0
    except:
        logger.error("ERROR: Need to pass in a date-time series.")
        return

    if ax is None:
        if 'figsize' not in kwargs:
            kwargs['figsize'] = (13,2)
        plt.figure(*args, **kwargs)
        ax = plt.gca()

    ax.bar(_g.xloc,
           _g.nobs,
           width=width, label=None)
    ax.bar(_g4.xloc,
           _g4.nobs,
           width=width, label=label)

    if not skip_retick:
        # Format and label X axis
        ax.set_xlim(left=_g.xloc.min()-0.25, right=_g.xloc.max()+0.25)
        _t = _g[_g.date.dt.quarter==4].date.dt.year[::x_tic_skip]
        # Make a plot with major ticks that are multiples of 20 and minor ticks that
        # are multiples of 5.  Label major ticks with '%d' formatting but don't label
        # minor ticks.
        ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(1))
        ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.0f'))
        plt.setp(ax.get_xticklabels(), rotation=30)

        # For the minor ticks, use no labels; default NullFormatter.
        ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(.25))

        #ax.set_xticks(_t)
        #ax.set_xticklabels(map(str, _t), rotation='vertical')

        # Label Y Axis
        tene = math.log10(_g.nobs.max())//1-1
        topnum = math.ceil(_g.nobs.max() / 10**tene)
        ax.set_yticks([(topnum * i // y_tic_number)*10**tene for i in range(y_tic_number, 0, -1)])

    return ax

In [None]:
def savefig(file_name, make_name_unique=False, root_dir=None, **kwargs):
    """
    Save figure to GLOB.FIGURE_DIR.

    If no extension is added, PNG is assumed (at default 300 DPI).

    If make_name_unique flag is True, Y-M-D_H-M-S is added to filename.

    Returns file path that was created.
    """
    _fname, _ext = os.path.splitext(file_name)
    if not _ext:
        _ext = '.png'
        file_name += _ext

    file_path = Path(root_dir or '.') / file_name

    if make_name_unique:
        file_path = file_path.with_extension(f"{dt.date.today():%Y-%m-%d_%H-%M-%S}{_ext}")

    default_kwargs = {
        'bbox_inches': 'tight',
        'pad_inches': 0.1,
        'transparent': True
    }
    if 'dpi' not in kwargs and _ext == '.png':
        default_kwargs['dpi'] = 300

    kwargs = {**default_kwargs, **kwargs}

    plt.savefig(file_path, **kwargs)

    return file_path

In [None]:
# 3rd party package imports
%matplotlib inline

In [None]:
def mbt_string_fmt(x, prefix='', suffix="", scale=1e6, decimals=0, fmt="{l_paren}{prefix}{x:,.{decimals}f}{mbt}{suffix}{r_paren}", zero_fmt="{prefix}0", **kwargs):
    kwargs["prefix"] = prefix
    kwargs["suffix"] = suffix
    kwargs["scale"] = scale
    kwargs["decimals"] = decimals

    if "l_paren" not in kwargs: kwargs["l_paren"] = "(" * bool(x <= 0)
    if "r_paren" not in kwargs: kwargs["r_paren"] = ")" * bool(x <= 0)
    x = abs(x) * scale

    if x == 0:
        return zero_fmt.format(**kwargs)

    for d,mbt in enumerate(['', 'K', 'M', 'B', 'T']):
        if x < 1000:
            break
        x /= 1000.0

    return fmt.format(x=x, mbt=mbt, **kwargs)

def mbt_ff(**kwargs):
    return FuncFormatter(lambda x,p,kwargs=kwargs: mbt_string_fmt(x, position=p, **kwargs))

In [None]:
# project imports
_code_dir = Path('.').absolute()
for i in range(10):
    if 'code' in os.listdir(_code_dir):
        _code_dir = _code_dir / 'code'
        if str(_code_dir) not in sys.path:
            sys.path.append(str(_code_dir))
        break
    _code_dir = _code_dir.parent
else:
    logger.error('Cannot find code directory, so likely will not import src modules!')

try:
    from src import ROOT_DIR, CODE_DIR, DATA_DIR

    FIG_DIR = Path('./figures').absolute()
    def save(*args, **kwargs):
        return savefig(*args, root_dir=FIG_DIR, **kwargs).relative_to(ROOT_DIR)

    for i in ['ROOT_DIR', 'CODE_DIR', 'DATA_DIR', 'FIG_DIR']:
        print(f"{i}: {globals()[i].relative_to(ROOT_DIR.parent)}")
except ImportError:
    logger.error('Cannot find src modle!')

In [None]:
SECTOR_DICT = {
    10: "Energy",
    15: "Materials",
    20: "Industrials",
    25: "Consumer Discretionary",
    30: "Consumer Staples",
    35: "Health Care",
    40: "Financials",
    45: "Information Technology",
    50: "Communication Services",
    55: "Utilities",
    60: "Real Estate",
}